mirror of
https://github.com/datalab-to/chandra.git
synced 2026-05-13 23:54:16 +00:00
Add im end
This commit is contained in:
@@ -28,7 +28,18 @@ def generate_hf(
|
||||
)
|
||||
inputs = inputs.to(model.device)
|
||||
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=max_output_tokens)
|
||||
# Include both <|endoftext|> and <|im_end|> as stop tokens.
|
||||
# generation_config only has <|endoftext|>, but the model emits <|im_end|> at turn boundaries.
|
||||
eos_token_id = model.generation_config.eos_token_id
|
||||
im_end_id = model.processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
if im_end_id is not None and im_end_id not in eos_token_id:
|
||||
eos_token_id.append(im_end_id)
|
||||
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=max_output_tokens, eos_token_id=eos_token_id
|
||||
)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :]
|
||||
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
|
||||
@@ -15,4 +15,4 @@ def test_inference_image(simple_text_image):
|
||||
assert "Hello, World!" in output.markdown
|
||||
|
||||
chunks = output.chunks
|
||||
assert len(chunks) > 0
|
||||
assert len(chunks) == 1
|
||||
|
||||
Reference in New Issue
Block a user