def __init__(self, data_size, latent_size=128, depth=3): super().__init__() upmode = 'bilinear' c, h, w = data_size cs = [c] + [2**(d+4) for d in range(depth)] div = 2 ** depth cl = lambda x : int(math.ceil(x)) modules = [ nn.Linear(latent_size, cs[-1] * cl(h/div) * cl(w/div)), nn.ReLU(), util.Reshape( (cs[-1], cl(h/div), cl(w/div)) ) ] for d in range(depth, 0, -1): modules += [ nn.Upsample(scale_factor=2, mode=upmode), nn.ConvTranspose2d(cs[d], cs[d], 3, padding=1), nn.ReLU(), nn.ConvTranspose2d(cs[d], cs[d-1], 3, padding=1), nn.ReLU() ] modules += [ nn.ConvTranspose2d(c, c, (3, 3), padding=1), nn.Sigmoid(), util.Lambda(lambda x : x[:, :, :h, :w]) # crop out any extra pixels due to rounding errors ] self.decoder = nn.Sequential(*modules)
def __init__(self, data_size, latent_size=(5, 5, 128), depth=3, gadditional=2, radditional=4, region=0.2, method='clamp', sigma_scale=1.0, min_sigma=0.01): super().__init__() self.method, self.gadditional, self.radditional = method, gadditional, radditional self.sigma_scale, self.min_sigma = sigma_scale, min_sigma # latent space self.latent = nn.Parameter(torch.randn(size=latent_size)) self.region = [int(r * region) for r in latent_size[:-1]] ln = len(latent_size) emb_size = latent_size[-1] c, h, w = data_size cs = [c] + [2**(d + 4) for d in range(depth)] div = 2**depth modules = [] for d in range(depth): modules += [ nn.Conv2d(cs[d], cs[d + 1], 3, padding=1), nn.ReLU(), nn.Conv2d(cs[d + 1], cs[d + 1], 3, padding=1), nn.ReLU(), nn.MaxPool2d((2, 2)) ] modules += [ util.Flatten(), nn.Linear(cs[-1] * (h // div) * (w // div), 1024), nn.ReLU(), nn.Linear( 1024, len(latent_size) ) # encoder produces a cont. index tuple (ln -1 for the means, 1 for the sigma) ] self.encoder = nn.Sequential(*modules) upmode = 'bilinear' cl = lambda x: int(math.ceil(x)) modules = [ nn.Linear(emb_size, cs[-1] * cl(h / div) * cl(w / div)), nn.ReLU(), util.Reshape((cs[-1], cl(h / div), cl(w / div))) ] for d in range(depth, 0, -1): modules += [ nn.Upsample(scale_factor=2, mode=upmode), nn.ConvTranspose2d(cs[d], cs[d], 3, padding=1), nn.ReLU(), nn.ConvTranspose2d(cs[d], cs[d - 1], 3, padding=1), nn.ReLU() ] modules += [ nn.ConvTranspose2d(c, c, (3, 3), padding=1), nn.Sigmoid(), util.Lambda(lambda x: x[:, :, :h, :w] ) # crop out any extra pixels due to rounding errors ] self.decoder = nn.Sequential(*modules) self.smp = True
def go(arg): if arg.seed < 0: seed = random.randint(0, 1000000) print('random seed: ', seed) else: torch.manual_seed(arg.seed) tbw = SummaryWriter(log_dir=arg.tb_dir) normalize = transforms.Compose([transforms.ToTensor()]) if(arg.task=='mnist'): data = arg.data + os.sep + arg.task if arg.final: train = torchvision.datasets.MNIST(root=data, train=True, download=True, transform=normalize) trainloader = torch.utils.data.DataLoader(train, batch_size=arg.batch, shuffle=True, num_workers=2) test = torchvision.datasets.MNIST(root=data, train=False, download=True, transform=normalize) testloader = torch.utils.data.DataLoader(test, batch_size=arg.batch, shuffle=False, num_workers=2) else: NUM_TRAIN = 45000 NUM_VAL = 5000 total = NUM_TRAIN + NUM_VAL train = torchvision.datasets.MNIST(root=data, train=True, download=True, transform=normalize) trainloader = DataLoader(train, batch_size=arg.batch, sampler=util.ChunkSampler(0, NUM_TRAIN, total)) testloader = DataLoader(train, batch_size=arg.batch, sampler=util.ChunkSampler(NUM_TRAIN, NUM_VAL, total)) shape = (1, 28, 28) num_classes = 10 elif (arg.task == 'image-folder-bw'): tr = transforms.Compose([transforms.Grayscale(), transforms.ToTensor()]) if arg.final: train = torchvision.datasets.ImageFolder(root=arg.data + '/train/', transform=tr) test = torchvision.datasets.ImageFolder(root=arg.data + '/test/', transform=tr) trainloader = DataLoader(train, batch_size=arg.batch, shuffle=True) testloader = DataLoader(train, batch_size=arg.batch, shuffle=True) else: NUM_TRAIN = 45000 NUM_VAL = 5000 total = NUM_TRAIN + NUM_VAL train = torchvision.datasets.ImageFolder(root=arg.data + '/train/', transform=tr) trainloader = DataLoader(train, batch_size=arg.batch, sampler=util.ChunkSampler(0, NUM_TRAIN, total)) testloader = DataLoader(train, batch_size=arg.batch, sampler=util.ChunkSampler(NUM_TRAIN, NUM_VAL, total)) for im, labels in trainloader: shape = im[0].size() break num_classes = 10 else: raise Exception('Task name {} not recognized'.format(arg.task)) activation = nn.ReLU() hyperlayer = None if arg.modelname == 'conv': base = prep(*shape, pool=arg.pool) model = nn.Sequential(*( base + [activation, nn.Linear(HIDLIN, num_classes), nn.Softmax()]) ) reinforce = False elif arg.modelname == 'reinforce': hyperlayer = ReinforceLayer(in_shape=shape, glimpses=arg.num_glimpses, glimpse_size=(28, 28), num_classes=num_classes, pool=arg.pool) model = nn.Sequential( hyperlayer, R(util.Flatten()), R(nn.Linear(28 * 28 * shape[0] * arg.num_glimpses, arg.hidden)), R(activation), R(nn.Linear(arg.hidden, num_classes)), R(nn.Softmax()) ) reinforce = True elif arg.modelname == 'ash': hyperlayer = BoxAttentionLayer( glimpses=arg.num_glimpses, in_size=shape, k=arg.k, gadditional=arg.gadditional, radditional=arg.radditional, region=(arg.region, arg.region), min_sigma=arg.min_sigma, pool=arg.pool ) model = nn.Sequential( hyperlayer, util.Flatten(), nn.Linear(arg.k * arg.k * shape[0] * arg.num_glimpses, arg.hidden), activation, nn.Linear(arg.hidden, num_classes), nn.Softmax() ) reinforce = False elif arg.modelname == 'quad': """ Network with quadrangle attention (instead of bounding box). """ hyperlayer = QuadAttentionLayer( glimpses=arg.num_glimpses, in_size=shape, k=arg.k, gadditional=arg.gadditional, radditional=arg.radditional, region=(arg.region, arg.region), min_sigma=arg.min_sigma, pool=arg.pool ) model = nn.Sequential( hyperlayer, util.Flatten(), nn.Linear(arg.k * arg.k * shape[0] * arg.num_glimpses, arg.hidden), activation, nn.Linear(arg.hidden, num_classes), nn.Softmax() ) reinforce = False elif arg.modelname == 'aff': """ Network with affine tranformation (instead of bounding box). """ hyperlayer = AffineAttentionLayer( glimpses=arg.num_glimpses, in_size=shape, k=arg.k, gadditional=arg.gadditional, radditional=arg.radditional, region=(arg.region, arg.region), min_sigma=arg.min_sigma, scale=arg.stn_scale, pool=arg.pool ) model = nn.Sequential( hyperlayer, util.Flatten(), nn.Linear(arg.k * arg.k * shape[0] * arg.num_glimpses, arg.hidden), activation, nn.Linear(arg.hidden, num_classes), nn.Softmax() ) reinforce = False elif arg.modelname == 'aff-conv': """ Network with affine tranformation attention (instead of bounding box). """ hyperlayer = AffineAttentionLayer( glimpses=arg.num_glimpses, in_size=shape, k=arg.k, gadditional=arg.gadditional, radditional=arg.radditional, region=(arg.region, arg.region), min_sigma=arg.min_sigma, scale=arg.stn_scale, pool=arg.pool ) ch1, ch2, ch3 = 16, 32, 64 h = (arg.k // 8) ** 2 * 64 model = nn.Sequential( hyperlayer, util.Reshape((arg.num_glimpses * shape[0], arg.k, arg.k)), # Fold glimpses into channels nn.Conv2d(arg.num_glimpses * shape[0], ch1, kernel_size=3, padding=1), activation, nn.MaxPool2d(kernel_size=2), nn.Conv2d(ch1, ch2, kernel_size=3, padding=1), activation, nn.Conv2d(ch2, ch2, kernel_size=3, padding=1), activation, nn.MaxPool2d(kernel_size=2), nn.Conv2d(ch2, ch3, kernel_size=3, padding=1), activation, nn.Conv2d(ch3, ch3, kernel_size=3, padding=1), activation, nn.MaxPool2d(kernel_size=2), util.Flatten(), nn.Linear(h, 128), activation, nn.Linear(128, num_classes), nn.Softmax() ) reinforce = False elif arg.modelname == 'stn': """ Spatial transformer with an MLP head. """ hyperlayer = STNAttentionLayer(in_size=shape, k=arg.k, glimpses=arg.num_glimpses, scale=arg.stn_scale, pool=arg.pool) model = nn.Sequential( hyperlayer, util.Flatten(), nn.Linear(arg.k * arg.k * shape[0] * arg.num_glimpses, arg.hidden), activation, nn.Linear(arg.hidden, num_classes), nn.Softmax() ) reinforce = False elif arg.modelname == 'stn-conv': """ Spatial transformer with a convolutional head. """ hyperlayer = STNAttentionLayer(in_size=shape, k=arg.k, glimpses=arg.num_glimpses, scale=arg.stn_scale, pool=arg.pool) ch1, ch2, ch3 = 16, 32, 64 h = (arg.k // 8) ** 2 * 64 model = nn.Sequential( hyperlayer, util.Reshape((arg.num_glimpses * shape[0], arg.k, arg.k)), # Fold glimpses into channels nn.Conv2d(arg.num_glimpses * shape[0], ch1, kernel_size=3, padding=1), activation, nn.MaxPool2d(kernel_size=2), nn.Conv2d(ch1, ch2, kernel_size=3, padding=1), activation, nn.Conv2d(ch2, ch2, kernel_size=3, padding=1), activation, nn.MaxPool2d(kernel_size=2), nn.Conv2d(ch2, ch3, kernel_size=3, padding=1), activation, nn.Conv2d(ch3, ch3, kernel_size=3, padding=1), activation, nn.MaxPool2d(kernel_size=2), util.Flatten(), nn.Linear(h, 128), activation, nn.Linear(128, num_classes), nn.Softmax() ) reinforce = False elif arg.modelname == 'ash-conv': """ Model with a convolution head. More powerful classification, but more difficult to train on top of a hyperlayer. """ hyperlayer = BoxAttentionLayer( glimpses=arg.num_glimpses, in_size=shape, k=arg.k, gadditional=arg.gadditional, radditional=arg.radditional, region=(arg.region, arg.region), min_sigma=arg.min_sigma, pool=arg.pool ) ch1, ch2, ch3 = 16, 32, 64 h = (arg.k // 8) ** 2 * 64 model = nn.Sequential( hyperlayer, util.Reshape((arg.num_glimpses * shape[0], arg.k, arg.k)), # Fold glimpses into channels nn.Conv2d(arg.num_glimpses * shape[0], ch1, kernel_size=5, padding=2), activation, nn.MaxPool2d(kernel_size=2), nn.Conv2d(ch1, ch2, kernel_size=5, padding=2), activation, nn.Conv2d(ch2, ch2, kernel_size=5, padding=2), activation, nn.MaxPool2d(kernel_size=2), nn.Conv2d(ch2, ch3, kernel_size=5, padding=2), activation, nn.Conv2d(ch3, ch3, kernel_size=5, padding=2), activation, nn.MaxPool2d(kernel_size=2), util.Flatten(), nn.Linear(h, 128), activation, nn.Linear(128, num_classes), nn.Softmax() ) reinforce = False else: raise Exception('Model name {} not recognized'.format(arg.modelname)) if arg.cuda: model.cuda() optimizer = optim.Adam(model.parameters(), lr=arg.lr) xent = nn.CrossEntropyLoss() mse = nn.MSELoss() step = 0 sigs, vals = [], [] util.makedirs('./mnist/') for epoch in range(arg.epochs): model.train(True) for i, (inputs, labels) in tqdm(enumerate(trainloader, 0)): # if i> 2: # break if arg.cuda: inputs, labels = inputs.cuda(), labels.cuda() inputs, labels = Variable(inputs), Variable(labels) optimizer.zero_grad() if not reinforce: outputs = model(inputs) else: outputs, stoch_nodes, actions = model(inputs) mloss = F.cross_entropy(outputs, labels, reduce=False) if reinforce: rloss = stoch_nodes.log_prob(actions) * - mloss.detach()[:, None, None] loss = rloss.sum(dim=1) + mloss[:, None] tbw.add_scalar('mnist/train-loss', float(loss.mean().item()), step) tbw.add_scalar('mnist/model-loss', float(rloss.sum(dim=1).mean().item()), step) tbw.add_scalar('mnist/reinf-loss', float(mloss.mean().item()), step) else: loss = mloss tbw.add_scalar('mnist/train-loss', float(loss.data.sum().item()), step) loss = loss.sum() loss.backward() # compute the gradients optimizer.step() step += inputs.size(0) if epoch % arg.plot_every == 0 and i == 0 and hyperlayer is not None: hyperlayer.plot(inputs[:10, ...]) plt.savefig('mnist/attention.{:03}.pdf'.format(epoch)) total = 0.0 correct = 0.0 model.train(False) for i, (inputs, labels) in enumerate(testloader, 0): if arg.cuda: inputs, labels = inputs.cuda(), labels.cuda() # wrap them in Variables inputs, labels = Variable(inputs), Variable(labels) if not reinforce: outputs = model(inputs) else: outputs, _, _ = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = correct/total tbw.add_scalar('mnist1d/per-epoch-test-acc', accuracy, epoch) print('EPOCH {}: {} accuracy '.format(epoch, accuracy)) LOG.info('Finished Training.')
def __init__(self, in_shape, glimpses, glimpse_size, num_classes, rfboost=2.0, pool=4, coord='none'): super().__init__() self.rfboost = rfboost self.num_glimpses = glimpses self.glimpse_size = glimpse_size self.in_shape = in_shape activation = nn.ReLU() ci, hi, wi = in_shape modules = prep(ci, hi, wi, pool, coord) + [nn.ReLU(), nn.Linear(HIDLIN, glimpses * 3), util.Reshape((glimpses, 3))] self.preprocess = nn.Sequential(*modules)
def __init__(self, in_size, k, gadditional=0, radditional=0, region=None, sigma_scale=0.1, num_values=-1, min_sigma=0.0, glimpses=1, pool=4, coord='none'): assert(len(in_size) == 3) self.in_size = in_size self.k = k self.sigma_scale = sigma_scale self.num_values = num_values self.min_sigma = min_sigma ci, hi, wi = in_size co, ho, wo = ci , k, k out_size = glimpses, co, ho, wo self.out_size = out_size map = (glimpses, co, k, k) template = torch.LongTensor(list(np.ndindex( map))) assert template.size() == (prod(map), 4) template = torch.cat([template, template[:, 1:]], dim=1) assert template.size() == (prod(map), 7) self.lc = [5, 6] # learnable columns super().__init__( in_rank=3, out_size=(glimpses, co, ho, wo), temp_indices=template, learn_cols=self.lc, chunk_size=1, gadditional=gadditional, radditional=radditional, region=region, bias_type=util.Bias.NONE) self.num_glimpses = glimpses modules = prep(ci, hi, wi, pool, coord) + [nn.ReLU(), nn.Linear(HIDLIN, 8 * glimpses), util.Reshape((glimpses, 4, 2))] self.preprocess = nn.Sequential(*modules) self.register_buffer('grid', util.interpolation_grid((k, k))) self.register_buffer('quad_offset', torch.FloatTensor([[-1, 1], [1, 1], [1, -1], [-1, -1]])) # -- added to the quad, to make sure there's a training signal # from the initial weights (i.e. in case all outputs are close to zero) # One sigma per glimpse self.sigmas = Parameter(torch.randn( (glimpses, ) )) # All values 1, no bias. Glimpses extract only pixel information. self.register_buffer('one', torch.FloatTensor([1.0]))
def __init__(self, in_size, k, gadditional=0, radditional=0, region=None, sigma_scale=0.1, num_values=-1, min_sigma=0.0, glimpses=1, scale=0.001, pool=4, coord='none'): assert(len(in_size) == 3) self.in_size = in_size self.k = k self.sigma_scale = sigma_scale self.num_values = num_values self.min_sigma = min_sigma self.scale = scale ci, hi, wi = in_size co, ho, wo = ci , k, k out_size = glimpses, co, ho, wo self.out_size = out_size map = (glimpses, co, k, k) template = torch.LongTensor(list(np.ndindex( map))) assert template.size() == (prod(map), 4) template = torch.cat([template, template[:, 1:]], dim=1) assert template.size() == (prod(map), 7) self.lc = [5, 6] # learnable columns super().__init__( in_rank=3, out_size=(glimpses, co, ho, wo), temp_indices=template, learn_cols=self.lc, chunk_size=1, gadditional=gadditional, radditional=radditional, region=region, bias_type=util.Bias.NONE) self.num_glimpses = glimpses modules = prep(ci, hi, wi, pool, coord) + [nn.ReLU(), nn.Linear(HIDLIN, (2 * 3) * glimpses), util.Reshape((glimpses, 2, 3))] self.preprocess = nn.Sequential(*modules) self.register_buffer('grid', util.interpolation_grid((k, k))) self.register_buffer('identity', torch.FloatTensor([0.4, 0, 0, 0, 0.4, 0]).view(2, 3)) self.register_buffer('corners', torch.FloatTensor([-2, 2, 2, 2, 2, -2, -2, -2]).view(4, 2)) # One sigma per glimpse self.sigmas = Parameter(torch.randn( (glimpses, ) )) # All values 1, no bias. Glimpses extract only pixel information. self.register_buffer('one', torch.FloatTensor([1.0]))
def __init__(self, in_size, k, glimpses=1, scale=0.001, pool=4, coord='none'): super().__init__() self.in_size = in_size self.k = k self.num_glimpses = glimpses self.scale=scale ci, hi, wi = in_size co, ho, wo = ci , k, k modules = prep(ci, hi, wi, pool, coord) + [nn.ReLU(), nn.Linear(HIDLIN, 3 * 2 * glimpses), util.Reshape((glimpses, 2, 3))] self.preprocess = nn.Sequential(*modules) self.register_buffer('identity', torch.FloatTensor([0.4, 0, 0, 0, 0.4, 0]).view(2, 3))