You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
130 lines
3.9 KiB
130 lines
3.9 KiB
import sys
|
|
import vim
|
|
import os
|
|
|
|
try:
|
|
from openai import OpenAI
|
|
except ImportError:
|
|
print("Error: openai module not found. Please install with Pip and ensure equality of the versions given by :!python3 -V, and :python3 import sys; print(sys.version)")
|
|
raise
|
|
|
|
def safe_vim_eval(expression):
|
|
try:
|
|
return vim.eval(expression)
|
|
except vim.error:
|
|
return None
|
|
|
|
def create_client():
|
|
api_key = os.getenv('OPENAI_API_KEY') or safe_vim_eval('g:gpt_key') or safe_vim_eval('g:gpt_openai_api_key')
|
|
openai_base_url = os.getenv('OPENAI_PROXY') or os.getenv('OPENAI_API_BASE') or safe_vim_eval('g:gpt_openai_base_url')
|
|
client = OpenAI(
|
|
base_url=openai_base_url,
|
|
api_key=api_key,
|
|
)
|
|
return client
|
|
|
|
|
|
def chat_gpt(prompt, persist=0):
|
|
token_limits = {
|
|
"gpt-3.5-turbo": 4097,
|
|
"gpt-3.5-turbo-16k": 16385,
|
|
"gpt-3.5-turbo-1106": 16385,
|
|
"gpt-4": 8192,
|
|
"gpt-4-turbo": 128000,
|
|
"gpt-4-turbo-preview": 128000,
|
|
"gpt-4-32k": 32768,
|
|
"gpt-4o": 128000,
|
|
"gpt-4o-mini": 128000,
|
|
}
|
|
|
|
max_tokens = int(vim.eval('g:gpt_max_tokens'))
|
|
model = str(vim.eval('g:gpt_model'))
|
|
token_limit = int(vim.eval('g:gpt_models[g:gpt_model]'))
|
|
temperature = float(vim.eval('g:gpt_temperature'))
|
|
lang = str(vim.eval('g:gpt_lang'))
|
|
resp = f" And respond in {lang}." if lang != 'None' else ""
|
|
|
|
personas = dict(vim.eval('g:gpt_personas'))
|
|
persona = str(vim.eval('g:gpt_persona'))
|
|
|
|
systemCtx = {"role": "system", "content": f"{personas[persona]} {resp}"}
|
|
messages = []
|
|
session_id = 'gpt-persistent-session' if persist == 0 else None
|
|
|
|
# If session id exists and is in vim buffers
|
|
if session_id:
|
|
buffer = []
|
|
|
|
for b in vim.buffers:
|
|
# If the buffer name matches the session id
|
|
if session_id in b.name:
|
|
buffer = b[:]
|
|
break
|
|
|
|
# Read the lines from the buffer
|
|
history = "\n".join(buffer).split('\n\n>>> ')
|
|
history.reverse()
|
|
|
|
# Adding messages to history until token limit is reached
|
|
token_count = token_limit - max_tokens - len(prompt) - len(str(systemCtx))
|
|
|
|
for line in history:
|
|
if line.startswith("USER\n"):
|
|
role = "user"
|
|
message = line.replace("USER\n", "").strip()
|
|
elif line.startswith("ASSISTANT\n"):
|
|
role = "assistant"
|
|
message = line.replace("ASSISTANT\n", "").strip()
|
|
else:
|
|
continue
|
|
token_count -= len(message)
|
|
if token_count > 0:
|
|
messages.insert(0, {
|
|
"role": role.lower(),
|
|
"content": message
|
|
})
|
|
|
|
if session_id:
|
|
content = ''
|
|
if len(buffer) == 0:
|
|
content += '# GPT'
|
|
content += '\n\n>>> USER\n' + prompt + '\n\n>>> ASSISTANT\n'.replace("'", "''")
|
|
|
|
vim.command("call GptDisplay('{0}', '', '{1}')".format(content.replace("'", "''"), session_id))
|
|
vim.command("redraw")
|
|
|
|
messages.append({"role": "user", "content": prompt})
|
|
messages.insert(0, systemCtx)
|
|
|
|
try:
|
|
client = create_client()
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
stream=True
|
|
)
|
|
|
|
# Iterate through the response chunks
|
|
for chunk in response:
|
|
# newer Azure API responses contain empty chunks in the first streamed
|
|
# response
|
|
if not chunk.choices:
|
|
continue
|
|
|
|
chunk_session_id = session_id if session_id else chunk.id
|
|
choice = chunk.choices[0]
|
|
finish_reason = choice.finish_reason
|
|
|
|
if finish_reason:
|
|
vim.command("call GptDisplay('', '{0}', '{1}')".format(finish_reason.replace("'", "''"), chunk_session_id))
|
|
elif choice.delta:
|
|
content = choice.delta.content
|
|
vim.command("call GptDisplay('{0}', '', '{1}')".format(content.replace("'", "''"), chunk_session_id))
|
|
|
|
vim.command("redraw")
|
|
except Exception as e:
|
|
print("Error:", str(e))
|
|
|
|
chat_gpt(vim.eval('a:prompt'), int(vim.eval('a:persist')))
|
|
|