drizzlezyk commited on
Commit
c4baf91
·
verified ·
1 Parent(s): 1a6fb8a

Upload inference/generate.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference/generate.py +56 -0
inference/generate.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ import types
4
+ import torch
5
+ try:
6
+ import torch_npu
7
+ except ImportError as e:
8
+ pass
9
+ from transformers import AutoTokenizer
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+ from generation_utils import diffusion_generate
12
+
13
+ model_local_path = "path_to_openPangu-7B-Diffusion-Base"
14
+
15
+ # load the tokenizer and the model
16
+ tokenizer = AutoTokenizer.from_pretrained(
17
+ model_local_path,
18
+ use_fast=False,
19
+ trust_remote_code=True,
20
+ local_files_only=True
21
+ )
22
+
23
+ model = AutoModelForCausalLM.from_pretrained(
24
+ model_local_path,
25
+ trust_remote_code=True,
26
+ torch_dtype="auto",
27
+ device_map="npu",
28
+ local_files_only=True
29
+ )
30
+
31
+ model.diffusion_generate = types.MethodType(diffusion_generate, model)
32
+
33
+ mask_token_id = 45830
34
+ eos_token_id = tokenizer.eos_token_id
35
+
36
+ prompts = ["introduce the china", "hello",
37
+ "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. "
38
+ "How many clips did Natalia sell altogether in April and May?"]
39
+ input_ids = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left").input_ids.to(model.device)
40
+ # Create attention mask: Mark positions with non-padding tokens as True(attended), and padding tokens as False(ignored).
41
+ attention_mask = input_ids.ne(tokenizer.pad_token_id)
42
+
43
+ output = model.diffusion_generate(
44
+ input_ids,
45
+ block_length=32,
46
+ attention_mask=attention_mask,
47
+ temperature=0.0,
48
+ max_new_tokens=128,
49
+ alg="entropy",
50
+ mask_token_id=mask_token_id,
51
+ eos_token_id=eos_token_id,
52
+ num_small_blocks=4
53
+ )
54
+ generation = tokenizer.batch_decode(output[:, input_ids.shape[1]:].tolist())
55
+ generation = [x.split(tokenizer.eos_token)[0].strip() for x in generation]
56
+ print(generation)