def main(args): train_loss_meter = Meter() val_loss_meter = Meter() val_accuracy_meter = Meter() log = JsonLogger(args.log_path, rand_folder=True) log.update(args.__dict__) state = args.__dict__ state['exp_dir'] = os.path.dirname(log.path) print(state) if args.train_distractors == 0: dataset = datasets.MNIST(args.mnist_path, train=True, transform=transforms.Compose([ transforms.RandomCrop(40, padding=14), transforms.ToTensor(), transforms.Normalize([0.1307], [0.3081]) ])) train_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers) dataset = datasets.MNIST(args.mnist_path, train=False, transform=transforms.Compose([ transforms.RandomCrop(40, padding=14), transforms.ToTensor(), transforms.Normalize([0.1307], [0.3081]) ])) val_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.num_workers) elif args.train_distractors > 0: train_dataset = NpyDataset( os.path.join(args.mnist_path, 'train_%d.npy' % args.test_distractors)) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers) val_dataset = NpyDataset( os.path.join(args.mnist_path, 'valid_%d.npy' % args.train_distractors)) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.num_workers) test_dataset = NpyDataset( os.path.join(args.mnist_path, 'test_%d.npy' % args.test_distractors)) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.num_workers) else: raise ValueError("#train distractors not >= 0") if args.model == "baseline": model = Baseline().cuda() elif args.model == 'stn': stn = STN().cuda() baseline = Baseline().cuda() model = torch.nn.Sequential(stn, baseline) else: model = AttentionNet(args.att_depth, args.nheads, args.has_gates, args.reg_w).cuda() if args.load != "": model.load_state_dict(torch.load(args.load), strict=False) model = model.cuda() if args.model != "stn": optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, weight_decay=1e-5, momentum=0.9) else: optimizer = optim.SGD([{ 'params': baseline.parameters() }, { 'params': stn.parameters(), 'lr': 0.001 }], lr=args.learning_rate, weight_decay=1e-5, momentum=0.9) def train(): """ """ model.train() for data, label in train_loader: data, label = torch.autograd.Variable(data, requires_grad=False).cuda(), \ torch.autograd.Variable(label, requires_grad=False).cuda() optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, label) if args.reg_w > 0: loss += model.reg_loss() loss.backward() optimizer.step() train_loss_meter.update(loss.data[0], data.size(0)) state['train_loss'] = train_loss_meter.mean() def val(): """ """ model.eval() for data, label in val_loader: data, label = torch.autograd.Variable(data, requires_grad=False).cuda(), \ torch.autograd.Variable(label, requires_grad=False).cuda() output = model(data) loss = F.nll_loss(output, label) val_loss_meter.update(loss.data[0], data.size(0)) preds = output.max(1)[1] val_accuracy_meter.update((preds == label).float().sum().data[0], data.size(0)) state['val_loss'] = val_loss_meter.mean() state['val_accuracy'] = val_accuracy_meter.mean() def test(): """ """ model.eval() for data, label in test_loader: data, label = torch.autograd.Variable(data, requires_grad=False).cuda(), \ torch.autograd.Variable(label, requires_grad=False).cuda() output = model(data) loss = F.cross_entropy(output, label) val_loss_meter.update(loss.data[0], data.size(0)) preds = output.max(1)[1] val_accuracy_meter.update((preds == label).float().sum().data[0], data.size(0)) state['test_loss'] = val_loss_meter.mean() state['test_accuracy'] = val_accuracy_meter.mean() if args.load != "": test() print(state) log.update(state) else: for epoch in range(args.epochs): train() val() if epoch == args.epochs - 1: test() state['epoch'] = epoch + 1 log.update(state) print(state) if (epoch + 1) in args.schedule: for param_group in optimizer.param_groups: param_group['lr'] *= 0.1 if args.save: torch.save(model.state_dict(), os.path.join(state["exp_dir"], "model.pytorch"))
channels = 4 depth = 64 inputImages = torch.zeros(nframes, channels, height, width, depth) x = np.repeat(np.expand_dims( np.repeat(np.expand_dims( np.repeat(np.expand_dims( np.repeat(np.expand_dims(np.arange(-1, 1, 2.0 / height), 0), repeats=width, axis=0).T, 0), repeats=depth, axis=0), 0), repeats=3, axis=0), 0), repeats=nframes, axis=0) grids = torch.from_numpy(x.astype(np.float32)) input1, input2 = Variable(inputImages, requires_grad=True), Variable(grids, requires_grad=True) input1.data.uniform_() input2.data.uniform_(-1, 1) s2 = STN(layout='BCHW') start = time.time() out = s2(input1, input2) print('forward:', out.size(), 'time:', time.time() - start) start = time.time() out.backward(input1.data) print('backward', input1.grad.size(), 'time:', time.time() - start)
input1.data.uniform_() input2.data.uniform_(-1,1) input = Variable(torch.from_numpy(np.array([[[0.8, 0.3, 1], [0.5, 0, 0]]], dtype=np.float32)), requires_grad = True) print(input) g = AffineGridGen(64, 128, aux_loss = True) out, aux = g(input) print((out.size())) out.backward(out.data) print(input.grad.size()) #print input2.data s = STN() start = time.time() out = s(input1, input2) print(out.size(), 'time:', time.time() - start) start = time.time() out.backward(input1.data) print(input1.grad.size(), 'time:', time.time() - start) input1 = input1.cuda() input2 = input2.cuda() start = time.time() out = s(input1, input2) print(out.size(), 'time:', time.time() - start) start = time.time()
def __init__(self, dcnn_inputsize): super(stn_module, self).__init__() self.loc_net = localization_net() self.sampler = STN() self.grid_generator = AffineGridGen(dcnn_inputsize, dcnn_inputsize, lr=0.001, aux_loss=True)