-
Notifications
You must be signed in to change notification settings - Fork 43
/
Copy pathpredict_arithmetic.py
129 lines (114 loc) · 4.96 KB
/
predict_arithmetic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import tempfile
import sys
sys.path.append('CLIP')
from pathlib import Path
import cog
import argparse
import torch
import clip
from model.ZeroCLIP import CLIPTextGenerator
def perplexity_score(text, lm_model, lm_tokenizer, device):
encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt')
input_ids = encodings.input_ids.to(device)
target_ids = input_ids.clone()
outputs = lm_model(input_ids, labels=target_ids)
log_likelihood = outputs[0]
ll = log_likelihood.item()
return ll
class Predictor(cog.Predictor):
def setup(self):
self.args = get_args()
self.args.reset_context_delta = True
self.text_generator = CLIPTextGenerator(**vars(self.args))
@cog.input(
"image1",
type=Path,
help="Final result will be: image1 + (image2 - image3)"
)
@cog.input(
"image2",
type=Path,
help="Final result will be: image1 + (image2 - image3)"
)
@cog.input(
"image3",
type=Path,
help="Final result will be: image1 + (image2 - image3)"
)
@cog.input(
"cond_text",
type=str,
default='Image of a',
help="conditional text",
)
@cog.input(
"beam_size",
type=int,
default=3, min=1, max=10,
help="Number of beams to use",
)
@cog.input(
"end_factors",
type=float,
default=1.06, min=1.0, max=1.10,
help="Higher value for shorter captions",
)
@cog.input(
"max_seq_lengths",
type=int,
default=3, min=1, max=20,
help="Maximum number of tokens to generate",
)
@cog.input(
"ce_loss_scale",
type=float,
default=0.2, min=0.0, max=0.6,
help="Scale of cross-entropy loss with un-shifted language model",
)
def predict(self, image1, image2, image3, cond_text, beam_size, end_factors, max_seq_lengths, ce_loss_scale):
self.args.cond_text = cond_text
self.text_generator.end_factor = end_factors
self.text_generator.target_seq_length = max_seq_lengths
self.text_generator.ce_scale = ce_loss_scale
self.text_generator.fusion_factor = 0.95
self.text_generator.grad_norm_factor = 0.95
image_features = self.text_generator.get_combined_feature([str(image1), str(image2), str(image3)], [], [1, 1, -1], None)
captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size)
# CLIP SCORE
encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device))
for c in captions]
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions]
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item()
# Perplexity SCORE
ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions]
best_ppl_index = torch.tensor(ppl_scores).argmin().item()
best_clip_caption = self.args.cond_text + captions[best_clip_idx]
best_mixed = self.args.cond_text + captions[0]
best_PPL = self.args.cond_text + captions[best_ppl_index]
final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}'
return final
# return self.args.cond_text + captions[best_clip_idx]
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo")
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP")
parser.add_argument("--target_seq_length", type=int, default=15)
parser.add_argument("--cond_text", type=str, default="Image of a")
parser.add_argument("--reset_context_delta", action="store_true",
help="Should we reset the context at each token gen")
parser.add_argument("--num_iterations", type=int, default=5)
parser.add_argument("--clip_loss_temperature", type=float, default=0.01)
parser.add_argument("--clip_scale", type=float, default=1)
parser.add_argument("--ce_scale", type=float, default=0.2)
parser.add_argument("--stepsize", type=float, default=0.3)
parser.add_argument("--grad_norm_factor", type=float, default=0.95)
parser.add_argument("--fusion_factor", type=float, default=0.95)
parser.add_argument("--repetition_penalty", type=float, default=1)
parser.add_argument("--end_token", type=str, default=".", help="Token to end text")
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token")
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens")
parser.add_argument("--beam_size", type=int, default=5)
args = parser.parse_args('')
return args