Basic chatbot
parent
665277ac61
commit
023256ee92
@ -0,0 +1,2 @@
|
||||
.pydevproject
|
||||
.project
|
||||
@ -0,0 +1,14 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd `dirname "$(realpath $0)"`
|
||||
cd ..
|
||||
|
||||
if test -d ./venv/bin; then
|
||||
echo "venv already installed, please remove for new install"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
python3 -m venv ./venv
|
||||
./venv/bin/python -m pip install transformers
|
||||
./venv/bin/python -m pip install torch
|
||||
./venv/bin/python -m pip install accelerate
|
||||
@ -0,0 +1,31 @@
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_name = "Qwen/Qwen2.5-7B-Instruct"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
def generate_response(prompt, max_length=50):
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
||||
input_ids = input_ids.to('cuda')
|
||||
|
||||
with torch.no_grad():
|
||||
output = model.generate(input_ids, max_length=max_length, num_return_sequences=1, pad_token_id=50256)
|
||||
|
||||
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
return response
|
||||
|
||||
print("Chatbot: Hi there! How can I help you?")
|
||||
while True:
|
||||
user_input = input("You: ")
|
||||
if user_input.lower() == "exit":
|
||||
print("Chatbot: Goodbye!")
|
||||
break
|
||||
|
||||
response = generate_response(user_input)
|
||||
print("Chatbot:", response)
|
||||
@ -0,0 +1,2 @@
|
||||
*
|
||||
!.gitignore
|
||||
Loading…
Reference in New Issue