diff --git a/chandra/model/hf.py b/chandra/model/hf.py index 3f4ebf0..c98e58f 100644 --- a/chandra/model/hf.py +++ b/chandra/model/hf.py @@ -28,7 +28,18 @@ def generate_hf( ) inputs = inputs.to(model.device) - generated_ids = model.generate(**inputs, max_new_tokens=max_output_tokens) + # Include both <|endoftext|> and <|im_end|> as stop tokens. + # generation_config only has <|endoftext|>, but the model emits <|im_end|> at turn boundaries. + eos_token_id = model.generation_config.eos_token_id + im_end_id = model.processor.tokenizer.convert_tokens_to_ids("<|im_end|>") + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + if im_end_id is not None and im_end_id not in eos_token_id: + eos_token_id.append(im_end_id) + + generated_ids = model.generate( + **inputs, max_new_tokens=max_output_tokens, eos_token_id=eos_token_id + ) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) diff --git a/tests/integration/test_image_inference.py b/tests/integration/test_image_inference.py index bb066d7..46e0c1f 100644 --- a/tests/integration/test_image_inference.py +++ b/tests/integration/test_image_inference.py @@ -15,4 +15,4 @@ def test_inference_image(simple_text_image): assert "Hello, World!" in output.markdown chunks = output.chunks - assert len(chunks) > 0 + assert len(chunks) == 1