This repository was archived by the owner on Dec 27, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
50 lines (38 loc) · 1.71 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import argparse
import torch.optim as optim
import torch.nn as nn
from utils import Pairloader, SiameseNet, _tqdm as tqdm
from torch.utils.data import DataLoader
import os
parser = argparse.ArgumentParser(description='Train SiameseNet')
parser.add_argument('--save_location', '-sl', type=str, default='model/{}-epoch-{}.pth')
parser.add_argument('--epochs', '-e', type=int, default=50)
parser.add_argument('--save_every', '-se', type=int, default=5)
parser.add_argument('--device', '-d', type=str, default=None)
args = parser.parse_args()
if not args.device:
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SiameseNet(mode='train', device=args.device)
datagen = DataLoader(Pairloader(split='train'), shuffle=True)
bce_loss = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(args.epochs):
epoch_loss = 0.0
with tqdm(datagen) as t:
for i, batch in enumerate(t):
t.set_description('EPOCH: %i'%(epoch+1))
data1, data2, label = batch[0][0].to(device=args.device), batch[0][1].to(device=args.device), batch[1].to(device=args.device)
optimizer.zero_grad()
output = model(data1, data2)
loss = bce_loss(output, label)
loss.backward()
optimizer.step()
epoch_loss+=loss.item()
t.set_postfix(loss=epoch_loss/(i+1))
print('Loss-{}'.format(loss.item()/(i+1)))
if (epoch+1)%args.save_every == 0:
if not os.path.exists('model/'):
os.mkdir('model/')
torch.save(model.state_dict(),args.save_location.format('model', epoch+1))
#torch.save(optimizer.state_dict(), args.save_location.format('optimizer', epoch+1))