forked from imgaojun/JunNMT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild_vocab.py
30 lines (25 loc) · 818 Bytes
/
build_vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import nmt.IO
import argparse
import nmt.utils.misc_utils as utils
import torch
import json
parser = argparse.ArgumentParser()
parser.add_argument('-train_src', type=str)
parser.add_argument('-train_tgt', type=str)
parser.add_argument('-save_data', type=str)
parser.add_argument('-config', type=str)
args = parser.parse_args()
opt = utils.load_hparams(args.config)
if opt.random_seed > 0:
torch.manual_seed(opt.random_seed)
fields = nmt.IO.get_fields()
print("Building Training...")
train = nmt.IO.NMTDataset(
src_path=args.train_src,
tgt_path=args.train_tgt,
fields=[('src', fields["src"]),
('tgt', fields["tgt"])])
print("Building Vocab...")
nmt.IO.build_vocab(train, opt)
print("Saving fields")
torch.save(nmt.IO.save_vocab(fields),open(args.save_data+'.vocab.pkl', 'wb'))