From 56800685df2eb1337110a5edb3abf2e7140d70c7 Mon Sep 17 00:00:00 2001 From: ia Date: Wed, 13 Nov 2024 14:02:09 +0100 Subject: [PATCH] Change max_new_token --- src/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main.py b/src/main.py index b29d0f4..4d8c63c 100644 --- a/src/main.py +++ b/src/main.py @@ -10,12 +10,12 @@ model = AutoModelForCausalLM.from_pretrained( ) tokenizer = AutoTokenizer.from_pretrained(model_name) -def generate_response(prompt, max_length=50): +def generate_response(prompt, max_new_tokens=256): input_ids = tokenizer.encode(prompt, return_tensors="pt") input_ids = input_ids.to('cuda') with torch.no_grad(): - output = model.generate(input_ids, max_length=max_length, num_return_sequences=1, pad_token_id=50256) + output = model.generate(input_ids, max_new_tokens=max_new_tokens, num_return_sequences=1, pad_token_id=50256) response = tokenizer.decode(output[0], skip_special_tokens=True) return response