From 023256ee92d06741fa5e6c3f4e1710bff78e9393 Mon Sep 17 00:00:00 2001 From: ia Date: Wed, 13 Nov 2024 13:51:29 +0100 Subject: [PATCH] Basic chatbot --- .gitignore | 2 ++ scripts/install_venv.sh | 14 ++++++++++++++ src/main.py | 31 +++++++++++++++++++++++++++++++ venv/.gitignore | 2 ++ 4 files changed, 49 insertions(+) create mode 100644 .gitignore create mode 100755 scripts/install_venv.sh create mode 100644 src/main.py create mode 100644 venv/.gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2251a19 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.pydevproject +.project diff --git a/scripts/install_venv.sh b/scripts/install_venv.sh new file mode 100755 index 0000000..3ad5367 --- /dev/null +++ b/scripts/install_venv.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +cd `dirname "$(realpath $0)"` +cd .. + +if test -d ./venv/bin; then + echo "venv already installed, please remove for new install" + exit 1 +fi + +python3 -m venv ./venv +./venv/bin/python -m pip install transformers +./venv/bin/python -m pip install torch +./venv/bin/python -m pip install accelerate \ No newline at end of file diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..b29d0f4 --- /dev/null +++ b/src/main.py @@ -0,0 +1,31 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_name = "Qwen/Qwen2.5-7B-Instruct" + +model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype="auto", + device_map="auto" +) +tokenizer = AutoTokenizer.from_pretrained(model_name) + +def generate_response(prompt, max_length=50): + input_ids = tokenizer.encode(prompt, return_tensors="pt") + input_ids = input_ids.to('cuda') + + with torch.no_grad(): + output = model.generate(input_ids, max_length=max_length, num_return_sequences=1, pad_token_id=50256) + + response = tokenizer.decode(output[0], skip_special_tokens=True) + return response + +print("Chatbot: Hi there! How can I help you?") +while True: + user_input = input("You: ") + if user_input.lower() == "exit": + print("Chatbot: Goodbye!") + break + + response = generate_response(user_input) + print("Chatbot:", response) \ No newline at end of file diff --git a/venv/.gitignore b/venv/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/venv/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore