def main(args):
    train_loss_meter = Meter()
    val_loss_meter = Meter()
    val_accuracy_meter = Meter()
    log = JsonLogger(args.log_path, rand_folder=True)
    log.update(args.__dict__)
    state = args.__dict__
    state['exp_dir'] = os.path.dirname(log.path)
    print(state)

    if args.train_distractors == 0:
        dataset = datasets.MNIST(args.mnist_path,
                                 train=True,
                                 transform=transforms.Compose([
                                     transforms.RandomCrop(40, padding=14),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.1307], [0.3081])
                                 ]))
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=args.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=args.num_workers)
        dataset = datasets.MNIST(args.mnist_path,
                                 train=False,
                                 transform=transforms.Compose([
                                     transforms.RandomCrop(40, padding=14),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.1307], [0.3081])
                                 ]))
        val_loader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=args.num_workers)
    elif args.train_distractors > 0:
        train_dataset = NpyDataset(
            os.path.join(args.mnist_path,
                         'train_%d.npy' % args.test_distractors))
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=args.num_workers)
        val_dataset = NpyDataset(
            os.path.join(args.mnist_path,
                         'valid_%d.npy' % args.train_distractors))
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=args.num_workers)
        test_dataset = NpyDataset(
            os.path.join(args.mnist_path,
                         'test_%d.npy' % args.test_distractors))
        test_loader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  pin_memory=True,
                                                  num_workers=args.num_workers)
    else:
        raise ValueError("#train distractors not >= 0")

    if args.model == "baseline":
        model = Baseline().cuda()
    elif args.model == 'stn':
        stn = STN().cuda()
        baseline = Baseline().cuda()
        model = torch.nn.Sequential(stn, baseline)
    else:
        model = AttentionNet(args.att_depth, args.nheads, args.has_gates,
                             args.reg_w).cuda()

    if args.load != "":
        model.load_state_dict(torch.load(args.load), strict=False)
        model = model.cuda()

    if args.model != "stn":
        optimizer = optim.SGD(model.parameters(),
                              lr=args.learning_rate,
                              weight_decay=1e-5,
                              momentum=0.9)
    else:
        optimizer = optim.SGD([{
            'params': baseline.parameters()
        }, {
            'params': stn.parameters(),
            'lr': 0.001
        }],
                              lr=args.learning_rate,
                              weight_decay=1e-5,
                              momentum=0.9)

    def train():
        """

        """
        model.train()
        for data, label in train_loader:
            data, label = torch.autograd.Variable(data, requires_grad=False).cuda(), \
                          torch.autograd.Variable(label, requires_grad=False).cuda()
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, label)
            if args.reg_w > 0:
                loss += model.reg_loss()
            loss.backward()
            optimizer.step()
            train_loss_meter.update(loss.data[0], data.size(0))
        state['train_loss'] = train_loss_meter.mean()

    def val():
        """

        """
        model.eval()
        for data, label in val_loader:
            data, label = torch.autograd.Variable(data, requires_grad=False).cuda(), \
                          torch.autograd.Variable(label, requires_grad=False).cuda()
            output = model(data)
            loss = F.nll_loss(output, label)
            val_loss_meter.update(loss.data[0], data.size(0))
            preds = output.max(1)[1]
            val_accuracy_meter.update((preds == label).float().sum().data[0],
                                      data.size(0))
        state['val_loss'] = val_loss_meter.mean()
        state['val_accuracy'] = val_accuracy_meter.mean()

    def test():
        """

        """
        model.eval()
        for data, label in test_loader:
            data, label = torch.autograd.Variable(data, requires_grad=False).cuda(), \
                          torch.autograd.Variable(label, requires_grad=False).cuda()
            output = model(data)
            loss = F.cross_entropy(output, label)
            val_loss_meter.update(loss.data[0], data.size(0))
            preds = output.max(1)[1]
            val_accuracy_meter.update((preds == label).float().sum().data[0],
                                      data.size(0))
        state['test_loss'] = val_loss_meter.mean()
        state['test_accuracy'] = val_accuracy_meter.mean()

    if args.load != "":
        test()
        print(state)
        log.update(state)
    else:
        for epoch in range(args.epochs):
            train()
            val()
            if epoch == args.epochs - 1:
                test()
            state['epoch'] = epoch + 1
            log.update(state)
            print(state)
            if (epoch + 1) in args.schedule:
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= 0.1
        if args.save:
            torch.save(model.state_dict(),
                       os.path.join(state["exp_dir"], "model.pytorch"))
Ejemplo n.º 2
0
channels = 4
depth = 64

inputImages = torch.zeros(nframes, channels, height, width, depth)
x = np.repeat(np.expand_dims(
    np.repeat(np.expand_dims(
        np.repeat(np.expand_dims(
            np.repeat(np.expand_dims(np.arange(-1, 1, 2.0 / height), 0),
                      repeats=width,
                      axis=0).T, 0),
                  repeats=depth,
                  axis=0), 0),
              repeats=3,
              axis=0), 0),
              repeats=nframes,
              axis=0)
grids = torch.from_numpy(x.astype(np.float32))
input1, input2 = Variable(inputImages,
                          requires_grad=True), Variable(grids,
                                                        requires_grad=True)
input1.data.uniform_()
input2.data.uniform_(-1, 1)

s2 = STN(layout='BCHW')
start = time.time()
out = s2(input1, input2)
print('forward:', out.size(), 'time:', time.time() - start)
start = time.time()
out.backward(input1.data)
print('backward', input1.grad.size(), 'time:', time.time() - start)
Ejemplo n.º 3
0
input1.data.uniform_()
input2.data.uniform_(-1,1)

input = Variable(torch.from_numpy(np.array([[[0.8, 0.3, 1], [0.5, 0, 0]]], dtype=np.float32)), requires_grad = True)
print(input)

g = AffineGridGen(64, 128, aux_loss = True)
out, aux = g(input)
print((out.size()))
out.backward(out.data)
print(input.grad.size())



#print input2.data
s = STN()
start = time.time()
out = s(input1, input2)
print(out.size(), 'time:', time.time() - start)

start = time.time()
out.backward(input1.data)
print(input1.grad.size(), 'time:', time.time() - start)

input1 = input1.cuda()
input2 = input2.cuda()

start = time.time()
out = s(input1, input2)
print(out.size(), 'time:', time.time() - start)
start = time.time()
Ejemplo n.º 4
0
 def __init__(self, dcnn_inputsize):
   super(stn_module, self).__init__()
   self.loc_net = localization_net()
   self.sampler = STN()
   self.grid_generator = AffineGridGen(dcnn_inputsize, dcnn_inputsize, lr=0.001, aux_loss=True)