diff --git a/src/main.py b/src/main.py index b29d0f4..4d8c63c 100644 --- a/src/main.py +++ b/src/main.py @@ -10,12 +10,12 @@ model = AutoModelForCausalLM.from_pretrained( ) tokenizer = AutoTokenizer.from_pretrained(model_name) -def generate_response(prompt, max_length=50): +def generate_response(prompt, max_new_tokens=256): input_ids = tokenizer.encode(prompt, return_tensors="pt") input_ids = input_ids.to('cuda') with torch.no_grad(): - output = model.generate(input_ids, max_length=max_length, num_return_sequences=1, pad_token_id=50256) + output = model.generate(input_ids, max_new_tokens=max_new_tokens, num_return_sequences=1, pad_token_id=50256) response = tokenizer.decode(output[0], skip_special_tokens=True) return response