-
Notifications
You must be signed in to change notification settings - Fork 55
/
Copy pathinference.py
57 lines (47 loc) · 1.85 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import argparse
import torch
import yaml
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
AutoTokenizer, GenerationConfig, LlamaForCausalLM,
LlamaTokenizer, BitsAndBytesConfig, pipeline)
"""
Ad-hoc sanity check to see if model outputs something coherent
Not a robust inference platform!
"""
def read_yaml_file(file_path):
with open(file_path, 'r') as file:
try:
data = yaml.safe_load(file)
return data
except yaml.YAMLError as e:
print(f"Error reading YAML file: {e}")
def get_prompt(human_prompt):
prompt_template=f"### HUMAN:\n{human_prompt}\n\n### RESPONSE:\n"
return prompt_template
def get_llm_response(prompt):
raw_output = pipe(get_prompt(prompt))
return raw_output
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("config_path", help="Path to the config YAML file")
args = parser.parse_args()
config = read_yaml_file(args.config_path)
q_config = BitsAndBytesConfig(load_in_8bit=True)
print("Load model")
model_path = f"{config['model_output_dir']}/{config['model_name']}"
if "model_family" in config and config["model_family"] == "llama":
tokenizer = LlamaTokenizer.from_pretrained(model_path)
model = LlamaForCausalLM.from_pretrained(model_path, device_map="auto", quantization_config=q_config)
else:
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", quantization_config=q_config)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_length=512,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.15
)
print(get_llm_response("What is your favorite movie?"))