/
experiment.py
60 lines (49 loc) · 2.18 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
Written by Jason Taylor <jasonrbtaylor@gmail.com> 2017-2018
"""
import argparse
import json
import os
import torch
import numpy as np
import data
import model
import train
def run(batch_size=128,n_features=64,n_layers=6,n_scales=1,n_bins=16,
exp_name='pixelCNN',exp_dir='/home/jason/experiments/pytorch_pixelcnn/',
optimizer='adam',learnrate=1e-4,dropout=0.5,cuda=True,resume=False):
exp_name += '_%ifeat_%iscales_%ilayers_%ibins'%(
n_features,n_scales,n_layers,n_bins)
exp_dir = os.path.join(exp_dir,exp_name)
if not os.path.isdir(exp_dir):
os.makedirs(exp_dir)
if not resume:
# Store experiment params in params.json
params = {'batch_size':batch_size, 'n_features':n_features,
'n_layers':n_layers, 'n_scales':n_scales,
'n_bins':n_bins, 'optimizer': optimizer,
'learnrate':learnrate, 'dropout':dropout, 'cuda':cuda}
with open(os.path.join(exp_dir,'params.json'),'w') as f:
json.dump(params,f)
# Model
net = model.PixelCNN(1,n_features,n_layers,n_scales,n_bins,dropout)
else:
# if resuming, need to have params, stats and checkpoint files
if not (os.path.isfile(os.path.join(exp_dir,'params.json'))
and os.path.isfile(os.path.join(exp_dir,'stats.json'))
and os.path.isfile(os.path.join(exp_dir,'last_checkpoint'))):
raise Exception('Missing param, stats or checkpoint file on resume')
net = torch.load(os.path.join(exp_dir,'last_checkpoint'))
# Data loaders
train_loader,val_loader = data.mnist(batch_size)
# Up-weight 1s (~8x rarer) to balance loss, interpolate intermediate values
weight = torch.from_numpy(np.linspace(1,8,n_bins,dtype='float32'))
if cuda:
weight = weight.cuda()
# Define loss fcn, incl. label formatting from input
def input2label(x):
return torch.squeeze(torch.round((n_bins-1)*x).type(torch.LongTensor),1)
loss_fcn = torch.nn.NLLLoss2d(torch.autograd.Variable(weight))
# Train
train.fit(train_loader,val_loader,net,exp_dir,input2label,loss_fcn,
optimizer,learnrate=learnrate,cuda=cuda,resume=resume)