batch_size=args.batch_size, **kwargs) args.num_class = train_loader.dataset.num_class args.num_channels = train_loader.dataset.num_channels else: test_dataset = get_MNIST_test_dataset(args.data_dir) test_loader = get_test_loader(test_dataset, args.batch_size, **kwargs) args.num_class = test_loader.dataset.num_class args.num_channels = test_loader.dataset.num_channels # build RAM model model = RecurrentAttention(args) if args.use_gpu: model.cuda() optimizer = torch.optim.SGD(model.parameters(), lr=args.init_lr, momentum=args.momentum) logger.info('Number of model parameters: {:,}'.format( sum([p.data.nelement() for p in model.parameters()]))) trainer = Trainer(model, optimizer, watch=['acc'], val_watch=['acc']) if args.is_train: logger.info("Train on {} samples, validate on {} samples".format( len(train_loader.dataset), len(val_loader.dataset))) start_epoch = 0 if args.resume: start_epoch = load_checkpoint(args.ckpt_dir, model, optimizer) trainer.train(train_loader,
class Trainer(object): """ Trainer encapsulates all the logic necessary for training the Recurrent Attention Model. All hyperparameters are provided by the user in the config file. """ def __init__(self, config, data_loader): """ Construct a new Trainer instance. Args ---- - config: object containing command line arguments. - data_loader: data iterator """ self.config = config # glimpse network params self.patch_size = config.patch_size self.glimpse_scale = config.glimpse_scale self.num_patches = config.num_patches self.loc_hidden = config.loc_hidden self.glimpse_hidden = config.glimpse_hidden # core network params self.num_glimpses = config.num_glimpses self.hidden_size = config.hidden_size # reinforce params self.std = config.std self.M = config.M # data params if config.is_train: self.train_loader = data_loader[0] self.valid_loader = data_loader[1] #self.num_train = len(self.train_loader.sampler.indices) #self.num_valid = len(self.valid_loader.sampler.indices) else: self.test_loader = data_loader self.num_test = len(self.test_loader.dataset) self.num_classes = 83 self.num_channels = 1 # training params self.epochs = config.epochs self.start_epoch = 0 self.momentum = config.momentum self.lr = config.init_lr # misc params self.use_gpu = config.use_gpu self.best = config.best self.ckpt_dir = config.ckpt_dir self.logs_dir = config.logs_dir self.best_valid_acc = 0. self.counter = 0 self.lr_patience = config.lr_patience self.train_patience = config.train_patience self.use_tensorboard = config.use_tensorboard self.trainSamplesSize = len(self.train_loader.trainSamples) self.resume = config.resume self.print_freq = config.print_freq self.plot_freq = config.plot_freq self.model_name = 'ram_{}_{}x{}_{}'.format(config.num_glimpses, config.patch_size, config.patch_size, config.glimpse_scale) self.plot_dir = './plots/' + self.model_name + '/' if not os.path.exists(self.plot_dir): os.makedirs(self.plot_dir) # configure tensorboard logging if self.use_tensorboard: tensorboard_dir = self.logs_dir + self.model_name print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir)) if not os.path.exists(tensorboard_dir): os.makedirs(tensorboard_dir) configure(tensorboard_dir) # build RAM model self.model = RecurrentAttention( self.patch_size, self.num_patches, self.glimpse_scale, self.num_channels, self.loc_hidden, self.glimpse_hidden, self.std, self.hidden_size, self.num_classes, ) if self.use_gpu: self.model.cuda() print('[*] Number of model parameters: {:,}'.format( sum([p.data.nelement() for p in self.model.parameters()]))) # # initialize optimizer and scheduler # self.optimizer = optim.SGD( # self.model.parameters(), lr=self.lr, momentum=self.momentum, # ) # self.scheduler = ReduceLROnPlateau( # self.optimizer, 'min', patience=self.lr_patience # ) self.optimizer = optim.Adam( self.model.parameters(), lr=3e-4, ) def reset(self): """ Initialize the hidden state of the core network and the location vector. This is called once every time a new minibatch `x` is introduced. """ dtype = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor) h_t = torch.zeros(self.batch_size, self.hidden_size) h_t[:, 1:2:self.hidden_size] = -1 h_t = Variable(h_t).type(dtype) l_t = torch.ones(self.batch_size, 2) l_t[:, 0] *= -1 l_t[:, 1] *= 0 #l_t = torch.stack([, torch.zeros(self.batch_size,1)], dim=1) #print(l_t, l_t.shape) #l_t = torch.Tensor(self.batch_size, 2)#.uniform_(-1, 1) l_t = Variable(l_t).type(dtype) return h_t, l_t def train(self): """ Train the model on the training set. A checkpoint of the model is saved after each epoch and if the validation accuracy is improved upon, a separate ckpt is created for use on the test set. """ # load the most recent checkpoint if self.resume: self.load_checkpoint(best=False) # print("\n[*] Train on {} samples, validate on {} samples".format( # self.num_train, self.num_valid) # ) #self.trainDataset(1) for epoch in range(self.start_epoch, self.epochs): print('\nEpoch: {}/{} - LR: {:.6f}'.format(epoch + 1, self.epochs, self.lr)) # train for 1 epoch train_loss, train_acc = self.trainDataset( epoch) #self.train_one_epoch(epoch) valid_loss, valid_acc = self.validateDataset( epoch) #self.train_one_epoch(epoch) # evaluate on validation set #valid_loss, valid_acc = 0,0 #valid_loss, valid_acc = self.validate(epoch) # # reduce lr if validation loss plateaus # self.scheduler.step(valid_loss) is_best = valid_acc > self.best_valid_acc msg1 = "train loss: {:.3f} - train acc: {:.3f} " msg2 = "val loss: {:.3f} - val acc: {:.3f}" if is_best: self.counter = 0 msg2 += " [*]" msg = msg1 + msg2 print(msg.format(train_loss, train_acc, valid_loss, valid_acc)) # check for improvement if not is_best: self.counter += 1 if self.counter > self.train_patience: print("[!] No improvement in a while, stopping training.") return self.best_valid_acc = max(valid_acc, self.best_valid_acc) self.save_checkpoint( { 'epoch': epoch + 1, 'model_state': self.model.state_dict(), 'optim_state': self.optimizer.state_dict(), 'best_valid_acc': self.best_valid_acc, }, is_best) def trainDataset(self, epoch): batch_time = AverageMeter() losses = AverageMeter() accs = AverageMeter() tic = time.time() with tqdm(total=self.trainSamplesSize) as pbar: self.train_loader.trainSet() #rew = torch.linspace(1,0.1,self.num_glimpses,dtype=float) i = 0 while self.train_loader.hasNext(): #if(i>2): break i += 1 iterInfo = self.train_loader.getIteratorInfo() batch = self.train_loader.getNext() x = batch.imgs y = batch.gtTexts x, y = torch.tensor(x), torch.tensor(y) x = x[:, None, :, :] x = x.type(torch.FloatTensor) self.batch_size = x.shape[0] bmax = 0 #print("y0",y) for ib in range(self.batch_size): #print((y[ib] != 82)) #print((y[ib] != 82).nonzero()) bmax = max(bmax, len((y[ib] != 82).nonzero())) #print("bmax",bmax) y = y[:, :bmax] #print("y1",y) x, y = Variable(x), Variable(y) if self.use_gpu: x, y = torch.tensor( x.clone().detach()).cuda(), torch.tensor( y.clone().detach()).cuda() #y=y-1 #adjusting to 0-25 #x=x.T #X = x.numpy() #X = np.transpose(X, [0, 2, 3, 1]) #plot_images(x, y) #print(x.shape,y) #print("\n",i,"*************************************") # #plot = False #if (epoch % self.plot_freq == 0) and (i == 0): # plot = True # initialize location vector and hidden state h_t, l_t = self.reset() #returns uniform(-1,1) x,y # save images imgs = [] imgs.append(x[0:4]) # extract the glimpses locs = [] locs.append(l_t[0:4]) log_pi = [] baselines = [] log_probas_list = [] predicted_list = [] R_list = [] baselines_list = [] #print("y0", y0) y0new = [] y0 = [] onecharglimpse = 4 Rdist = [] #print("no_glimpse", self.num_glimpses) #print(bmax*onecharglimpse) for t in range(bmax * onecharglimpse): #self.num_glimpses): #- 1): # forward pass through model # h_t, l_t, b_t, p = self.model(x, l_t, h_t) if t % (onecharglimpse) == 0: y0 = y[:, t // (onecharglimpse)] #for b in range(self.batch_size): #y0.append(y[b][t//(self.num_glimpses)])#first element for 8 glimpses in the batch #[:,t//sel...]Loop can be removed #y0 = torch.tensor(y0) y0new.append(y0) #will be 32X22 #y0 = torch.tensor(y0).cuda() l_t_Prev = l_t h_t, l_t, b_t, log_probas1, p = self.model(x, l_t, h_t, last=True) if (t + 1) % (onecharglimpse) == 0: log_probas_list.append(log_probas1) predicted_list.append(torch.max(log_probas1, 1)[1]) #22X32X83 predicted1 = torch.max(log_probas1, 1)[1] R1 = (predicted1.detach() == y0).float() R1 = R1.unsqueeze(1).repeat(1, onecharglimpse) R_list.append(R1) #22X32X8 locs.append(l_t[0:4]) baselines.append(b_t) # #Rdist.append(-1*(torch.dist(l_t_Prev,l_t,2))) Rdist.append(-1 * (torch.norm(l_t_Prev - l_t, p=2, dim=1))) log_pi.append(p) #print(len(Rdist)) #print(Rdist[0].shape) R1 = R_list[0] for R2 in R_list[1:]: R1 = torch.cat((R1, R2), 1) #print(R1.shape) Rdist = torch.stack(Rdist).transpose(1, 0) baselines = torch.stack(baselines).transpose(1, 0) #32X176 #print(baselines.shape) log_pi = torch.stack(log_pi).transpose(1, 0) loss_action = F.nll_loss(log_probas_list[0], y0new[0]) for l in range(1, len(y0new)): loss_action += F.nll_loss(log_probas_list[l], y0new[l]) loss_baseline = F.mse_loss(baselines, R1 + Rdist) #loss_baseline += F.mse_loss(baselines, Rdist) #print("predicted_list", predicted_list) # compute accuracy # compute reinforce loss # summed over timesteps and averaged across batch adjusted_reward = R1 + Rdist - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce predicted_list, y0new = torch.cat(predicted_list), torch.cat( y0new) correct = (predicted_list == y0new).float() acc = 100 * (correct.sum() / len(y0new)) #acc = 100 * ((correct[(y0new != 82).nonzero()]).sum() / len((y0new != 82).nonzero())) # store #losses.update(loss.data[0], x.size()[0]) #accs.update(acc.data[0], x.size()[0]) #print("loss", loss) losses.update(loss.data.item(), x.size()[0]) accs.update(acc.data.item(), x.size()[0]) # compute gradients and update SGD self.optimizer.zero_grad() loss.backward() self.optimizer.step() # measure elapsed time toc = time.time() batch_time.update(toc - tic) pbar.set_description( ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format( (toc - tic), loss.data.item(), acc.data.item()))) pbar.update(self.batch_size) # dump the glimpses and locs if (1): #plot: if self.use_gpu: imgs = [g.cpu().data.numpy().squeeze() for g in imgs] locs = [l.cpu().data.numpy() for l in locs] else: imgs = [g.data.numpy().squeeze() for g in imgs] locs = [l.data.numpy() for l in locs] pickle.dump( imgs, open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb")) pickle.dump( locs, open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb")) if self.use_tensorboard: iteration = epoch * len(self.train_loader) + i log_value('train_loss', losses.avg, iteration) log_value('train_acc', accs.avg, iteration) return losses.avg, accs.avg def validateDataset(self, epoch): """ Evaluate the model on the validation set. """ losses = AverageMeter() accs = AverageMeter() self.train_loader.validationSet() i = 0 while self.train_loader.hasNext(): #if(i>2): break i += 1 iterInfo = self.train_loader.getIteratorInfo() batch = self.train_loader.getNext() x = batch.imgs y = batch.gtTexts x, y = torch.tensor(x), torch.tensor(y) self.batch_size = x.shape[0] bmax = 0 #print("y0",y) for ib in range(self.batch_size): #print((y[ib] != 82)) #print((y[ib] != 82).nonzero()) bmax = max(bmax, len((y[ib] != 82).nonzero())) #print("bmax",bmax) y = y[:, :bmax] #x = x.type(torch.cuda.FloatTensor) x = x[:, None, :, :] x = x.type(torch.FloatTensor) x, y = Variable(x), Variable(y) if self.use_gpu: x, y = torch.tensor(x).cuda(), torch.tensor(y).cuda() # duplicate 10 times x = x.repeat(self.M, 1, 1, 1) self.batch_size = x.shape[0] #print(x.shape) # initialize location vector and hidden state h_t, l_t = self.reset() # extract the glimpses log_pi = [] locs = [] baselines = [] log_probas_list = [] predicted_list = [] R_list = [] y0new = [] y0 = [] onecharglimpse = 4 Rdist = [] #print("no_glimpse", self.num_glimpses) for t in range(bmax * onecharglimpse): # forward pass through model # h_t, l_t, b_t, p = self.model(x, l_t, h_t) if t % (onecharglimpse) == 0: y0 = y[:, t // (onecharglimpse)] '''for b in range(self.batch_size): y0.append(y[b][t//(self.num_glimpses)])'''#first element for 8 glimpses in the batch #[:,t//sel...]Loop can be removed y0new.append(y0) #will be 32X22 #y0 = torch.tensor(y0).cuda() l_t_Prev = l_t h_t, l_t, b_t, log_probas1, p = self.model(x, l_t, h_t, last=True) log_probas1 = log_probas1.view(self.M, -1, log_probas1.shape[-1]) log_probas1 = torch.mean(log_probas1, dim=0) if (t + 1) % (onecharglimpse) == 0: log_probas_list.append(log_probas1) #predicted_list.append(torch.max(log_probas1, 1)[1]) predicted1 = torch.max(log_probas1, 1)[1] R1 = (predicted1.detach() == y0).float() R1 = R1.unsqueeze(1).repeat(1, onecharglimpse) R_list.append(R1) #22X32X8 # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) # average l_t_Prev = l_t_Prev.view(self.M, -1, l_t_Prev.shape[-1]) l_t_Prev = torch.mean(l_t_Prev, dim=0) l_t1 = l_t.view(self.M, -1, l_t.shape[-1]) l_t1 = torch.mean(l_t1, dim=0) Rdist.append(-1 * (torch.norm(l_t_Prev - l_t1, p=2, dim=1))) R1 = R_list[0] for R2 in R_list[1:]: R1 = torch.cat((R1, R2), 1) #print(R1.shape) Rdist = torch.stack(Rdist).transpose(1, 0) baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # compute losses for differentiable modules#adding 21 for 21 chars or less #log_probas_list[0] = log_probas_list[0].view(self.M, -1, log_probas_list[0].shape[-1]) #log_probas_list[0] = torch.mean(log_probas_list[0], dim=0) loss_action = F.nll_loss(log_probas_list[0], y0new[0]) predicted_list.append(torch.max(log_probas_list[0], 1)[1]) for l in range(1, len(y0new)): #log_probas_list[l] = log_probas_list[l].view(self.M, -1, log_probas_list[0].shape[-1]) #log_probas_list[l] = torch.mean(log_probas_list[l], dim=0) loss_action += F.nll_loss(log_probas_list[l], y0new[l]) predicted_list.append(torch.max(log_probas_list[l], 1)[1]) baselines = baselines.contiguous().view(self.M, -1, baselines.shape[-1]) baselines = torch.mean(baselines, dim=0) log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1]) log_pi = torch.mean(log_pi, dim=0) loss_baseline = F.mse_loss(baselines, R1 + Rdist) predicted_list, y0new = torch.cat(predicted_list), torch.cat(y0new) # compute reinforce loss adjusted_reward = R1 + Rdist - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce # compute accuracy correct = (predicted_list == y0new).float() #acc = 100 * ((correct[(y0new != 82).nonzero()]).sum() / len((y0new != 82).nonzero())) acc = 100 * (correct.sum() / len(y0new)) #gb changes********************************************************************************************* # store #losses.update(loss.data[0], x.size()[0]) #accs.update(acc.data[0], x.size()[0]) losses.update(loss.item(), x.size()[0]) accs.update(acc.item(), x.size()[0]) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.valid_loader) + i log_value('valid_loss', losses.avg, iteration) log_value('valid_acc', accs.avg, iteration) return losses.avg, accs.avg def testData(self): """ Test the model on the held-out test data. This function should only be called at the very end once the model has finished training. """ correct = 0 # load the best checkpoint self.load_checkpoint(best=self.best) i = 0 while self.train_loader.hasNext(): #if(i>2): break i += 1 iterInfo = self.train_loader.getIteratorInfo() batch = self.train_loader.getNext() x = batch.imgs x = torch.tensor(x) #x = x.type(torch.cuda.FloatTensor) x = x[:, None, :, :] x = x.type(torch.FloatTensor) y = batch.gtTexts self.batch_size = x.shape[0] h_t, l_t = self.reset() log_pi = [] locs = [] baselines = [] log_probas_list = [] predicted_list = [] R_list = [] y0new = [] y0 = [] # extract the glimpses # extract the glimpses for t in range(self.num_glimpses): # forward pass through model if (t % 8 == 0): y0 = [] for b in range(self.batch_size): print(b, t // 8, t) y0.append(y[b][t // 8]) y0new += y0 y0 = torch.tensor(y0) h_t, l_t, b_t, log_probas1, p = self.model(x, l_t, h_t, last=True) if (t + 1) % 8 == 1: log_probas_list.append(log_probas1) predicted_list.append(torch.max(log_probas1, 1)[1]) locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) predicted1 = torch.max(log_probas1, 1)[1] R1 = (predicted1.detach() == y0).float() R_list.append(R1) # store #baselines.append(b_t) #log_pi.append(p) # # last iteration # h_t, l_t, b_t, log_probas, p = self.model( # x, l_t, h_t, last=True # ) # log_pi.append(p) # baselines.append(b_t) R = R_list R = torch.stack(R).transpose(1, 0) pred = log_probas_list.data.max(1, keepdim=True)[1] correct += pred.eq(y.data.view_as(pred)).cpu().sum() perc = (100. * correct) / (self.num_test) error = 100 - perc print('[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)'.format( correct, self.num_test, perc, error)) def test(self): """ Test the model on the held-out test data. This function should only be called at the very end once the model has finished training. """ correct = 0 # load the best checkpoint self.load_checkpoint(best=self.best) for i, (x, y) in enumerate(self.test_loader): if self.use_gpu: x, y = x.cuda(), y.cuda() x, y = Variable(x, volatile=True), Variable(y) y = y - 1 # duplicate 10 times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) pred = log_probas.data.max(1, keepdim=True)[1] correct += pred.eq(y.data.view_as(pred)).cpu().sum() perc = (100. * correct) / (self.num_test) error = 100 - perc print('[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)'.format( correct, self.num_test, perc, error)) def save_checkpoint(self, state, is_best): """ Save a copy of the model so that it can be loaded at a future date. This function is used when the model is being evaluated on the test data. If this model has reached the best validation accuracy thus far, a seperate file with the suffix `best` is created. """ # print("[*] Saving model to {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) torch.save(state, ckpt_path) if is_best: filename = self.model_name + '_model_best.pth.tar' shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename)) def load_checkpoint(self, best=False): """ Load the best copy of a model. This is useful for 2 cases: - Resuming training with the most recent model checkpoint. - Loading the best validation model to evaluate on the test data. Params ------ - best: if set to True, loads the best model. Use this if you want to evaluate your model on the test data. Else, set to False in which case the most recent version of the checkpoint is used. """ print("[*] Loading model from {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' if best: filename = self.model_name + '_model_best.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) ckpt = torch.load(ckpt_path) # load variables from checkpoint self.start_epoch = ckpt['epoch'] self.best_valid_acc = ckpt['best_valid_acc'] self.model.load_state_dict(ckpt['model_state']) self.optimizer.load_state_dict(ckpt['optim_state']) if best: print("[*] Loaded {} checkpoint @ epoch {} " "with best valid acc of {:.3f}".format( filename, ckpt['epoch'], ckpt['best_valid_acc'])) else: print("[*] Loaded {} checkpoint @ epoch {}".format( filename, ckpt['epoch']))
class Trainer: """A Recurrent Attention Model trainer. All hyperparameters are provided by the user in the config file. """ def __init__(self, config, data_loader): """ Construct a new Trainer instance. Args: config: object containing command line arguments. data_loader: A data iterator. """ self.config = config if config.use_gpu and torch.cuda.is_available(): self.device = torch.device("cuda") else: self.device = torch.device("cpu") # glimpse network params self.patch_size = config.patch_size self.glimpse_scale = config.glimpse_scale self.num_patches = config.num_patches self.loc_hidden = config.loc_hidden self.glimpse_hidden = config.glimpse_hidden # core network params self.num_glimpses = config.num_glimpses self.hidden_size = config.hidden_size # reinforce params self.std = config.std self.M = config.M # data params if config.is_train: self.train_loader = data_loader[0] self.valid_loader = data_loader[1] self.num_train = len(self.train_loader.sampler.indices) self.num_valid = len(self.valid_loader.sampler.indices) else: self.test_loader = data_loader self.num_test = len(self.test_loader.dataset) self.num_classes = 25 #10 self.num_channels = 1 # training params self.epochs = config.epochs self.start_epoch = 0 self.momentum = config.momentum self.lr = config.init_lr # misc params self.best = config.best self.ckpt_dir = config.ckpt_dir self.logs_dir = config.logs_dir self.best_valid_acc = 0.0 self.counter = 0 self.lr_patience = config.lr_patience self.train_patience = config.train_patience self.use_tensorboard = config.use_tensorboard self.resume = config.resume self.print_freq = config.print_freq self.plot_freq = config.plot_freq self.model_name = "ram_{}_{}x{}_{}".format( config.num_glimpses, config.patch_size, config.patch_size, config.glimpse_scale, ) self.plot_dir = "./plots/" + self.model_name + "/" if not os.path.exists(self.plot_dir): os.makedirs(self.plot_dir) # configure tensorboard logging if self.use_tensorboard: tensorboard_dir = self.logs_dir + self.model_name print("[*] Saving tensorboard logs to {}".format(tensorboard_dir)) if not os.path.exists(tensorboard_dir): os.makedirs(tensorboard_dir) configure(tensorboard_dir) # build RAM model self.model = RecurrentAttention( self.patch_size, self.num_patches, self.glimpse_scale, self.num_channels, self.loc_hidden, self.glimpse_hidden, self.std, self.hidden_size, self.num_classes, ) self.model.to(self.device) # initialize optimizer and scheduler self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.init_lr) self.scheduler = ReduceLROnPlateau(self.optimizer, "min", patience=self.lr_patience) def gmdataset(self): import pandas as pd #data = pd.read_csv("check.csv", index_col ="name") gW = pd.read_csv("goodware.csv", index_col="name") mW = pd.read_csv("malware.csv", index_col="name") out = mW.append(gW) data = out out.drop('(BAD)', axis=1, inplace=True) out.drop('STD', axis=1, inplace=True) out.drop('SHLD', axis=1, inplace=True) out.drop('SETLE', axis=1, inplace=True) out.drop('SETB', axis=1, inplace=True) out.drop('SBB', axis=1, inplace=True) out.drop('RDTSC', axis=1, inplace=True) out.drop('PUSHF', axis=1, inplace=True) out.drop('FSTCW', axis=1, inplace=True) out.drop('FDIVP', axis=1, inplace=True) out.drop('FILD', axis=1, inplace=True) out.drop('RETN', axis=1, inplace=True) out.drop('LEA', axis=1, inplace=True) out.drop('IMUL', axis=1, inplace=True) from sklearn.model_selection import train_test_split #print(data['labels']) M = data.values X = M[:, :-1] Y = M[:, -1] #print(Y) #X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.33, random_state=42) #print("HHHHHHH: ",X_train.shape, y_train.shape) import numpy as np # print(X_train.shape) #x_train = np.reshape(X_train, (X_train.shape[0],5, 5,1)) #padie=np.pad(X_train, ((0,0),(0,759)), 'constant', constant_values=0) padie = np.pad(X, ((0, 0), (0, 1)), 'constant', constant_values=0) print(padie.shape) x = np.reshape(padie, (padie.shape[0], 1, 4, 4)) return x, Y def alldatacsv(self): import numpy as np # linear algebra import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) from sklearn.feature_extraction.text import CountVectorizer from keras.preprocessing.text import Tokenizer from keras.preprocessing.sequence import pad_sequences from keras.models import Sequential from keras.layers import Dense, Embedding, LSTM, SpatialDropout1D, Dropout from sklearn.model_selection import train_test_split from keras.utils.np_utils import to_categorical import re # Input data files are available in the "../input/" directory. # For example, running this (by clicking run or pressing Shift+Enter) will list the files in the input directory """Only keeping the necessary columns.""" import pandas as pd data = pd.read_csv("AllData.csv") data = data.drop(['Unnamed: 0'], axis=1) data = data.rename(columns={'Text': 'text', 'Label': 'sentiment'}) data #data = pd.read_csv('../input/Sentiment.csv') # Keeping only the neccessary columns data = data[['text', 'sentiment']] pos = data[data['sentiment'] == 1] pos.shape[0] #data = data[data.sentiment != "Neutral"] data['text'] = data['text'].apply(lambda x: x.lower()) data['text'] = data['text'].apply( (lambda x: re.sub('[^a-zA-z0-9\s]', '', x))) #print(data[ data['sentiment'] == 1].size) #print(data[ data['sentiment'] == 0].size) for idx, row in data.iterrows(): row[0] = row[0].replace('rt', ' ') max_fatures = 2000 tokenizer = Tokenizer(num_words=max_fatures, split=' ') tokenizer.fit_on_texts(data['text'].values) X = tokenizer.texts_to_sequences(data['text'].values) X = pad_sequences(X) data['sentiment'] pd.DataFrame(data=X[1:, -20000:], index=X[1:, 0]) # 1st row as the column name #X=X[0:,-20000:] X.shape """# Train and Test Dataset Declaration""" Y = pd.get_dummies(data['sentiment']).values #X_train, X_test, Y_train, Y_test = train_test_split(X,Y, test_size = 0.33, random_state = 78) #conu = 0 # for x in Y_test: # if x.argmax()== 1: # conu = conu + 1 # conu import numpy as np import matplotlib.pyplot as plt #print(X_train.shape) #x_train = np.reshape(X_train, (X_train.shape[0],5, 5,1)) padie = np.pad(X, ((0, 0), (0, 251)), 'constant', constant_values=0) padie = padie[:, 459 * 459 - 50176:] x = np.reshape(padie, (padie.shape[0], 1, 224, 224)) #x_test = np.reshape(x_test, (x_test.shape[0],2,2, 1)) #print(Y) sk = pd.DataFrame(data=Y, columns=[0, 1]) inverted = sk.idxmax(1).values ss = np.rint(inverted) Y = ss #for i in range(0,Y.shape[0]): # if Y[i] == 1: # #print("YESSSSSSSS") # string = "imgs/" + str(i) + ".png" # plt.imsave(string,x[i][0,:,:]) # qq = i #x = x[:] #Y = Y[int(qq-qq/2):] #print(type(X)) return x, Y def batadal(self): import pandas as pd import numpy as np fD = pd.read_csv("newDatasets/BATADAL_dataset04.csv", header=None) #fD = pd.read_csv("/content/drive/My Drive/newDatasets/BATADAL_dataset02 (1).csv" , header=None) test = pd.read_csv("newDatasets/BATADAL_test_dataset.csv", header=None) test = test.drop(columns=0) test = test.drop([0], axis=0) Data = fD.drop(columns=0) Data = Data.drop([0], axis=0) nData = Data.values nData.shape testData = test.values xData = nData[:, :43] yData = nData[:, 43] xData.shape testData.shape xData = np.pad(xData, ((0, 0), (0, 6)), 'constant', constant_values=0) testData = np.pad(testData, ((0, 0), (0, 6)), 'constant', constant_values=0) test[:] xData.shape xData = xData.reshape(-1, 1, 7, 7) testData = testData.reshape(-1, 1, 7, 7) from sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split(xData, yData, test_size=0.1, random_state=42) #print(X_train, y_train) #print("JJ: ", type(xData)) return xData, yData def Malimg(self): import tensorflow as tf import keras import numpy as np from sklearn.preprocessing import StandardScaler from sklearn.model_selection import train_test_split dataset = np.load('malimg.npz', allow_pickle=True) BATCH_SIZE = 256 CELL_SIZE = 256 DROPOUT_RATE = 0.85 LEARNING_RATE = 1e-3 NODE_SIZE = [512, 256, 128] NUM_LAYERS = 5 features = dataset['arr'][:, 0] features = np.array([feature for feature in features]) features = np.reshape( features, (features.shape[0], features.shape[1] * features.shape[2])) r, c = features.shape print("Number of Samples", r) print("Number of Features", c) if 1 == 1: features = StandardScaler().fit_transform(features) labels = dataset['arr'][:, 1] labels = np.array([label for label in labels]) one_hot = np.zeros((labels.shape[0], labels.max() + 1)) one_hot[np.arange(labels.shape[0]), labels] = 1 labels = one_hot labels[labels == 0] = 0 num_features = features.shape[1] num_classes = labels.shape[1] Y = labels X = features print("Shape of Labels", Y.shape) print("Shape of Features", X.shape) train_features, test_features, train_labels, test_labels = train_test_split( features, labels, test_size=0.1, stratify=labels) #10% Test size train_size = int(train_features.shape[0]) train_features = train_features[:train_size - (train_size % BATCH_SIZE)] train_labels = train_labels[:train_size - (train_size % BATCH_SIZE)] test_size = int(test_features.shape[0]) test_features = test_features[:test_size - (test_size % BATCH_SIZE)] test_labels = test_labels[:test_size - (test_size % BATCH_SIZE)] fsize = int(features.shape[0]) features = features[:fsize - (fsize % BATCH_SIZE)] labels = labels[:fsize - (fsize % BATCH_SIZE)] r, c = train_features.shape print("Number of Training Samples", r) print("Number of Training Features", c) r, c = test_features.shape print("Number of Test Samples", r) print("Number of Test Features", c) #print(train_labels.shape) #print(tf.reshape(test_features[1], [32,32])) print(train_features.shape, test_features.shape, train_labels.shape, test_labels.shape) #print(train_labels) train_X = train_features.reshape(-1, 1, 32, 32) feat = features.reshape(-1, 1, 32, 32) test_X = test_features.reshape(-1, 32, 32, 1) Unchanined = X.reshape(-1, 32, 32, 1) y_test_non_category = [np.argmax(t) for t in labels] print("LABELS", np.asarray(y_test_non_category)) return feat, np.asarray(y_test_non_category) def SWAT(self): import pandas as pd import numpy as np fD = pd.read_excel("newDatasets/SWaT/SWaT_Dataset_Attack_v0.xlsx", header=None) Data = fD.drop(columns=0) Data = Data.drop([0, 1], axis=0) xData = Data.values[:, :51] yData = Data.values[:, 51] count = 0 for i in yData: if i == 'Normal': yData[count] = 0 else: yData[count] = 1 count = count + 1 xData.shape xData = np.pad(xData, ((0, 0), (0, 13)), 'constant', constant_values=0) xData = xData.reshape(xData.shape[0], 8, 8, 1) return xData[:200], yData[:200] def reset(self): h_t = torch.zeros( self.batch_size, self.hidden_size, dtype=torch.float, device=self.device, requires_grad=True, ) l_t = torch.FloatTensor(self.batch_size, 2).uniform_(-1, 1).to(self.device) l_t.requires_grad = True return h_t, l_t def train(self): """Train the model on the training set. A checkpoint of the model is saved after each epoch and if the validation accuracy is improved upon, a separate ckpt is created for use on the test set. """ # load the most recent checkpoint if self.resume: self.load_checkpoint(best=False) print("\n[*] Train on {} samples, validate on {} samples".format( self.num_train, self.num_valid)) for epoch in range(self.start_epoch, self.epochs): print("\nEpoch: {}/{} - LR: {:.6f}".format( epoch + 1, self.epochs, self.optimizer.param_groups[0]["lr"])) # train for 1 epoch train_loss, train_acc = self.train_one_epoch(epoch) # evaluate on validation set valid_loss, valid_acc = self.validate(epoch) # # reduce lr if validation loss plateaus self.scheduler.step(-valid_acc) is_best = valid_acc > self.best_valid_acc msg1 = "train loss: {:.3f} - train acc: {:.3f} " msg2 = "- val loss: {:.3f} - val acc: {:.3f} - val err: {:.3f}" if is_best: self.counter = 0 msg2 += " [*]" msg = msg1 + msg2 print( msg.format(train_loss, train_acc, valid_loss, valid_acc, 100 - valid_acc)) # check for improvement if not is_best: self.counter += 1 if self.counter > self.train_patience: print("[!] No improvement in a while, stopping training.") return self.best_valid_acc = max(valid_acc, self.best_valid_acc) self.save_checkpoint( { "epoch": epoch + 1, "model_state": self.model.state_dict(), "optim_state": self.optimizer.state_dict(), "best_valid_acc": self.best_valid_acc, }, is_best, ) def train_one_epoch(self, epoch): """ Train the model for 1 epoch of the training set. An epoch corresponds to one full pass through the entire training set in successive mini-batches. This is used by train() and should not be called manually. """ import pandas as pd import numpy as np self.model.train() batch_time = AverageMeter() losses = AverageMeter() accs = AverageMeter() tic = time.time() with tqdm(total=self.num_train) as pbar: for i, (x, y) in enumerate(self.train_loader): self.optimizer.zero_grad() x, y = x.to(self.device), y.to(self.device) x1, y1 = self.SWAT( ) #self.gmdataset()#self.batadal()#self.alldatacsv()#self.gmdataset() x1 = x1.astype(np.float32) y1 = y1.astype(np.float32) x, y = torch.from_numpy(x1).float(), torch.from_numpy( y1).long() #print("Here", y) plot = False if (epoch % self.plot_freq == 0) and (i == 0): plot = True # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # save images imgs = [] imgs.append(x[0:9]) # extract the glimpses locs = [] log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_pi.append(p) baselines.append(b_t) locs.append(l_t[0:9]) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # calculate reward predicted = torch.max(log_probas, 1)[1] R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action = F.nll_loss(log_probas, y) loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss # summed over timesteps and averaged across batch adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce * 0.01 # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) #print("Predicted: ", predicted, "\nTrue", y) # store losses.update(loss.item(), x.size()[0]) accs.update(acc.item(), x.size()[0]) # compute gradients and update SGD loss.backward() self.optimizer.step() # measure elapsed time toc = time.time() batch_time.update(toc - tic) pbar.set_description( ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format( (toc - tic), loss.item(), acc.item()))) pbar.update(self.batch_size) # dump the glimpses and locs if plot: imgs = [g.cpu().data.numpy().squeeze() for g in imgs] locs = [l.cpu().data.numpy() for l in locs] pickle.dump( imgs, open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb")) pickle.dump( locs, open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb")) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.train_loader) + i log_value("train_loss", losses.avg, iteration) log_value("train_acc", accs.avg, iteration) return losses.avg, accs.avg @torch.no_grad() def validate(self, epoch): """Evaluate the RAM model on the validation set. """ import torch import numpy as np losses = AverageMeter() accs = AverageMeter() for i, (x, y) in enumerate(self.valid_loader): x, y = x.to(self.device), y.to(self.device) x1, y1 = self.gmdataset( ) #self.batadal()#self.alldatacsv()#self.gmdataset() x1 = x1.astype(np.float32) y1 = y1.astype(np.float32) x, y = torch.from_numpy(x1).float(), torch.from_numpy(y1).long() # duplicate M times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # store baselines.append(b_t) log_pi.append(p) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_pi.append(p) baselines.append(b_t) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # average log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) baselines = baselines.contiguous().view(self.M, -1, baselines.shape[-1]) baselines = torch.mean(baselines, dim=0) log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1]) log_pi = torch.mean(log_pi, dim=0) # calculate reward predicted = torch.max(log_probas, 1)[1] R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action = F.nll_loss(log_probas, y) loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce * 0.01 # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) count = 0 countFP = 0 countFN = 0 countTN = 0 countTP = 0 for i in range(len(correct)): if (correct[i] == 0): count = count + 1 if (predicted[i] == 1 and y[i] == 0): #False Positive countFP = countFP + 1 if (predicted[i] == 0 and y[i] == 1): #False Negative countFN = countFN + 1 if (predicted[i] == 0 and y[i] == 0): #True Negative countTN = countTN + 1 if (predicted[i] == 1 and y[i] == 1): #True Positive countTP = countTP + 1 print("Total: ", len(correct), "Wrong: ", count) print("TP: ", countTP, "TN: ", countTN, "FN: ", countFN, "FP: ", countFP) # store losses.update(loss.item(), x.size()[0]) accs.update(acc.item(), x.size()[0]) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.valid_loader) + i log_value("valid_loss", losses.avg, iteration) log_value("valid_acc", accs.avg, iteration) return losses.avg, accs.avg @torch.no_grad() def test(self): """Test the RAM model. This function should only be called at the very end once the model has finished training. """ correct = 0 # load the best checkpoint self.load_checkpoint(best=self.best) for i, (x, y) in enumerate(self.test_loader): x, y = x.to(self.device), y.to(self.device) # duplicate M times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) pred = log_probas.data.max(1, keepdim=True)[1] correct += pred.eq(y.data.view_as(pred)).cpu().sum() perc = (100.0 * correct) / (self.num_test) error = 100 - perc print("[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)".format( correct, self.num_test, perc, error)) def save_checkpoint(self, state, is_best): """Saves a checkpoint of the model. If this model has reached the best validation accuracy thus far, a seperate file with the suffix `best` is created. """ filename = self.model_name + "_ckpt.pth.tar" ckpt_path = os.path.join(self.ckpt_dir, filename) torch.save(state, ckpt_path) if is_best: filename = self.model_name + "_model_best.pth.tar" shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename)) def load_checkpoint(self, best=False): """Load the best copy of a model. This is useful for 2 cases: - Resuming training with the most recent model checkpoint. - Loading the best validation model to evaluate on the test data. Args: best: if set to True, loads the best model. Use this if you want to evaluate your model on the test data. Else, set to False in which case the most recent version of the checkpoint is used. """ print("[*] Loading model from {}".format(self.ckpt_dir)) filename = self.model_name + "_ckpt.pth.tar" if best: filename = self.model_name + "_model_best.pth.tar" ckpt_path = os.path.join(self.ckpt_dir, filename) ckpt = torch.load(ckpt_path) # load variables from checkpoint self.start_epoch = ckpt["epoch"] self.best_valid_acc = ckpt["best_valid_acc"] self.model.load_state_dict(ckpt["model_state"]) self.optimizer.load_state_dict(ckpt["optim_state"]) if best: print("[*] Loaded {} checkpoint @ epoch {} " "with best valid acc of {:.3f}".format( filename, ckpt["epoch"], ckpt["best_valid_acc"])) else: print("[*] Loaded {} checkpoint @ epoch {}".format( filename, ckpt["epoch"]))
class Trainer(object): """ Trainer encapsulates all the logic necessary for training the Recurrent Attention Model. All hyperparameters are provided by the user in the config file. """ def __init__(self, config, data_loader): """ Construct a new Trainer instance. Args ---- - config: object containing command line arguments. - data_loader: data iterator """ self.config = config # glimpse network params self.patch_size = config.patch_size self.glimpse_scale = config.glimpse_scale self.num_patches = config.num_patches self.loc_hidden = config.loc_hidden self.glimpse_hidden = config.glimpse_hidden # core network params self.num_glimpses = config.num_glimpses self.hidden_size = config.hidden_size # reinforce params self.std = config.std self.M = config.M # data params if config.is_train: self.train_loader = data_loader[0] self.valid_loader = data_loader[1] self.num_train = len(self.train_loader.sampler.indices) self.num_valid = len(self.valid_loader.sampler.indices) else: self.test_loader = data_loader self.num_test = len(self.test_loader.dataset) self.num_classes = 10 self.num_channels = 1 # training params self.epochs = config.epochs self.start_epoch = 0 self.momentum = config.momentum self.lr = config.init_lr # misc params self.use_gpu = config.use_gpu self.best = config.best self.ckpt_dir = config.ckpt_dir self.logs_dir = config.logs_dir self.best_valid_acc = 0. self.counter = 0 self.patience = config.patience self.use_tensorboard = config.use_tensorboard self.resume = config.resume self.print_freq = config.print_freq self.plot_freq = config.plot_freq self.model_name = 'ram_{}_{}x{}_{}'.format(config.num_glimpses, config.patch_size, config.patch_size, config.glimpse_scale) self.plot_dir = './plots/' + self.model_name + '/' if not os.path.exists(self.plot_dir): os.makedirs(self.plot_dir) # configure tensorboard logging if self.use_tensorboard: tensorboard_dir = self.logs_dir + self.model_name print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir)) if not os.path.exists(tensorboard_dir): os.makedirs(tensorboard_dir) configure(tensorboard_dir) # build RAM model self.model = RecurrentAttention( self.patch_size, self.num_patches, self.glimpse_scale, self.num_channels, self.loc_hidden, self.glimpse_hidden, self.std, self.hidden_size, self.num_classes, ) if self.use_gpu: self.model.cuda() print('[*] Number of model parameters: {:,}'.format( sum([p.data.nelement() for p in self.model.parameters()]))) # initialize optimizer and scheduler self.optimizer = SGD( self.model.parameters(), lr=self.lr, momentum=self.momentum, ) self.scheduler = ReduceLROnPlateau(self.optimizer, 'min', patience=self.patience) def reset(self): """ Initialize the hidden state of the core network and the location vector. This is called once every time a new minibatch `x` is introduced. """ dtype = torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor h_t = torch.zeros(self.batch_size, self.hidden_size) h_t = Variable(h_t).type(dtype) l_t = torch.Tensor(self.batch_size, 2).uniform_(-1, 1) l_t = Variable(l_t).type(dtype) return h_t, l_t def train(self): """ Train the model on the training set. A checkpoint of the model is saved after each epoch and if the validation accuracy is improved upon, a separate ckpt is created for use on the test set. """ # load the most recent checkpoint if self.resume: self.load_checkpoint(best=False) print("\n[*] Train on {} samples, validate on {} samples".format( self.num_train, self.num_valid)) for epoch in range(self.start_epoch, self.epochs): print('\nEpoch: {}/{} - LR: {:.6f}'.format(epoch + 1, self.epochs, self.lr)) # train for 1 epoch train_loss, train_acc = self.train_one_epoch(epoch) # evaluate on validation set valid_loss, valid_acc = self.validate(epoch) # reduce lr if validation loss plateaus self.scheduler.step(valid_loss) is_best = valid_acc > self.best_valid_acc msg1 = "train loss: {:.3f} - train acc: {:.3f} " msg2 = "- val loss: {:.3f} - val acc: {:.3f}" if is_best: msg2 += " [*]" msg = msg1 + msg2 print(msg.format(train_loss, train_acc, valid_loss, valid_acc)) # check for improvement if not is_best: self.counter += 1 if self.counter > self.patience: print("[!] No improvement in a while, stopping training.") return self.best_valid_acc = max(valid_acc, self.best_valid_acc) self.save_checkpoint( { 'epoch': epoch + 1, 'model_state': self.model.state_dict(), 'optim_state': self.optimizer.state_dict(), 'best_valid_acc': self.best_valid_acc, }, is_best) def train_one_epoch(self, epoch): """ Train the model for 1 epoch of the training set. An epoch corresponds to one full pass through the entire training set in successive mini-batches. This is used by train() and should not be called manually. """ batch_time = AverageMeter() losses = AverageMeter() accs = AverageMeter() tic = time.time() with tqdm(total=self.num_train) as pbar: for i, (x, y) in enumerate(self.train_loader): if self.use_gpu: x, y = x.cuda(), y.cuda() x, y = Variable(x), Variable(y) plot = False if (epoch % self.plot_freq == 0) and (i == 0): plot = True # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # save images imgs = [] imgs.append(x[0:9]) # extract the glimpses locs = [] log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_pi.append(p) baselines.append(b_t) locs.append(l_t[0:9]) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # calculate reward predicted = torch.max(log_probas, 1)[1] R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action = F.nll_loss(log_probas, y) loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss adjusted_reward = R - baselines.detach() loss_reinforce = torch.mean(-log_pi * adjusted_reward) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store losses.update(loss.data[0], x.size()[0]) accs.update(acc.data[0], x.size()[0]) # compute gradients and update SGD self.optimizer.zero_grad() loss.backward() self.optimizer.step() # measure elapsed time toc = time.time() batch_time.update(toc - tic) pbar.set_description( ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format( (toc - tic), loss.data[0], acc.data[0]))) pbar.update(self.batch_size) # dump the glimpses and locs if plot: if self.use_gpu: imgs = [g.cpu().data.numpy().squeeze() for g in imgs] locs = [l.cpu().data.numpy() for l in locs] else: imgs = [g.data.numpy().squeeze() for g in imgs] locs = [l.data.numpy() for l in locs] pickle.dump( imgs, open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb")) pickle.dump( locs, open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb")) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.train_loader) + i log_value('train_loss', losses.avg, iteration) log_value('train_acc', accs.avg, iteration) return losses.avg, accs.avg def validate(self, epoch): """ Evaluate the model on the validation set. """ losses = AverageMeter() accs = AverageMeter() for i, (x, y) in enumerate(self.valid_loader): if self.use_gpu: x, y = x.cuda(), y.cuda() x, y = Variable(x), Variable(y) # duplicate 10 times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # store baselines.append(b_t) log_pi.append(p) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_pi.append(p) baselines.append(b_t) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # average log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) baselines = baselines.contiguous().view(self.M, -1, baselines.shape[-1]) baselines = torch.mean(baselines, dim=0) log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1]) log_pi = torch.mean(log_pi, dim=0) # calculate reward predicted = torch.max(log_probas, 1)[1] R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action = F.nll_loss(log_probas, y) loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss adjusted_reward = R - baselines.detach() loss_reinforce = torch.mean(-log_pi * adjusted_reward) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store losses.update(loss.data[0], x.size()[0]) accs.update(acc.data[0], x.size()[0]) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.valid_loader) + i log_value('valid_loss', losses.avg, iteration) log_value('valid_acc', accs.avg, iteration) return losses.avg, accs.avg def test(self): """ Test the model on the held-out test data. This function should only be called at the very end once the model has finished training. """ correct = 0 # load the best checkpoint self.load_checkpoint(best=self.best) for i, (x, y) in enumerate(self.test_loader): if self.use_gpu: x, y = x.cuda(), y.cuda() x, y = Variable(x, volatile=True), Variable(y) # duplicate 10 times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) pred = log_probas.data.max(1, keepdim=True)[1] correct += pred.eq(y.data.view_as(pred)).cpu().sum() perc = (100. * correct) / (self.num_test) print('[*] Test Acc: {}/{} ({:.2f}%)'.format(correct, self.num_test, perc)) def save_checkpoint(self, state, is_best): """ Save a copy of the model so that it can be loaded at a future date. This function is used when the model is being evaluated on the test data. If this model has reached the best validation accuracy thus far, a seperate file with the suffix `best` is created. """ # print("[*] Saving model to {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) torch.save(state, ckpt_path) if is_best: filename = self.model_name + '_model_best.pth.tar' shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename)) def load_checkpoint(self, best=False): """ Load the best copy of a model. This is useful for 2 cases: - Resuming training with the most recent model checkpoint. - Loading the best validation model to evaluate on the test data. Params ------ - best: if set to True, loads the best model. Use this if you want to evaluate your model on the test data. Else, set to False in which case the most recent version of the checkpoint is used. """ print("[*] Loading model from {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' if best: filename = self.model_name + '_model_best.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) ckpt = torch.load(ckpt_path) # load variables from checkpoint self.start_epoch = ckpt['epoch'] self.best_valid_acc = ckpt['best_valid_acc'] self.model.load_state_dict(ckpt['model_state']) self.optimizer.load_state_dict(ckpt['optim_state']) if best: print("[*] Loaded {} checkpoint @ epoch {} " "with best valid acc of {:.3f}".format( filename, ckpt['epoch'] + 1, ckpt['best_valid_acc'])) else: print("[*] Loaded {} checkpoint @ epoch {}".format( filename, ckpt['epoch'] + 1))
class Trainer(object): """ Trainer encapsulates all the logic necessary for training the Recurrent Attention Model. All hyperparameters are provided by the user in the config file. """ def __init__(self, config, data_loader): """ Construct a new Trainer instance. Args ---- - config: object containing command line arguments. - data_loader: data iterator """ self.config = config # glimpse network params self.patch_size = config.patch_size # core network params self.num_glimpses = config.num_glimpses self.hidden_size = config.hidden_size # reinforce params self.std = config.std self.M = config.M # data params if config.is_train: self.train_loader = data_loader[0] self.valid_loader = data_loader[1] image_tmp, _ = iter(self.train_loader).next() self.image_size = (image_tmp.shape[2], image_tmp.shape[3]) if 'MNIST' in config.dataset_name or config.dataset_name == 'CIFAR': self.num_train = len(self.train_loader.sampler.indices) self.num_valid = len(self.valid_loader.sampler.indices) elif config.dataset_name == 'ImageNet': # the ImageNet cannot be sampled, otherwise this part will be wrong. self.num_train = 100000 #len(train_dataset) in data_loader.py, wrong: len(self.train_loader) self.num_valid = 10000 #len(self.valid_loader) else: self.test_loader = data_loader self.num_test = len(self.test_loader.dataset) image_tmp, _ = iter(self.test_loader).next() self.image_size = (image_tmp.shape[2], image_tmp.shape[3]) # assign numer of channels and classes of images in this dataset, maybe there is more robust way if 'MNIST' in config.dataset_name: self.num_channels = 1 self.num_classes = 10 elif config.dataset_name == 'ImageNet': self.num_channels = 3 self.num_classes = 1000 elif config.dataset_name == 'CIFAR': self.num_channels = 3 self.num_classes = 10 # training params self.epochs = config.epochs self.start_epoch = 0 self.momentum = config.momentum self.lr = config.init_lr self.loss_fun_baseline = config.loss_fun_baseline self.loss_fun_action = config.loss_fun_action self.weight_decay = config.weight_decay # misc params self.use_gpu = config.use_gpu self.best = config.best self.ckpt_dir = config.ckpt_dir self.logs_dir = config.logs_dir self.best_valid_acc = 0. self.best_train_acc = 0. self.counter = 0 self.lr_patience = config.lr_patience self.train_patience = config.train_patience self.use_tensorboard = config.use_tensorboard self.resume = config.resume self.print_freq = config.print_freq self.plot_freq = config.plot_freq if config.use_gpu: self.model_name = 'ram_gpu_{0}_{1}_{2}x{3}_{4}_{5:1.2f}_{6}'.format( config.PBSarray_ID, config.num_glimpses, config.patch_size, config.patch_size, config.hidden_size, config.std, config.dropout) else: self.model_name = 'ram_{0}_{1}_{2}x{3}_{4}_{5:1.2f}_{6}'.format( config.PBSarray_ID, config.num_glimpses, config.patch_size, config.patch_size, config.hidden_size, config.std, config.dropout) self.plot_dir = './plots/' + self.model_name + '/' if not os.path.exists(self.plot_dir): os.makedirs(self.plot_dir, exist_ok=True) # configure tensorboard logging if self.use_tensorboard: print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir)) if not os.path.exists(tensorboard_dir): os.makedirs(tensorboard_dir) configure(tensorboard_dir) writer = SummaryWriter(logs_dir=self.logs_dir + self.model_name) # build DRAMBUTD model self.model = RecurrentAttention(self.patch_size, self.num_channels, self.image_size, self.std, self.hidden_size, self.num_classes, config) if self.use_gpu: self.model.cuda() print('[*] Number of model parameters: {:,}'.format( sum([p.data.nelement() for p in self.model.parameters()]))) # initialize optimizer and scheduler if config.optimizer == 'SGD': self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay) elif config.optimizer == 'ReduceLROnPlateau': self.scheduler = ReduceLROnPlateau(self.optimizer, 'min', patience=self.lr_patience, weight_decay=self.weight_decay) elif config.optimizer == 'Adadelta': self.optimizer = optim.Adadelta(self.model.parameters(), weight_decay=self.weight_decay) elif config.optimizer == 'Adam': self.optimizer = optim.Adam(self.model.parameters(), lr=3e-4, weight_decay=self.weight_decay) elif config.optimizer == 'AdaBound': self.optimizer = adabound.AdaBound(self.model.parameters(), lr=3e-4, final_lr=0.1, weight_decay=self.weight_decay) elif config.optimizer == 'Ranger': self.optimizer = Ranger(self.model.parameters(), weight_decay=self.weight_decay) def reset(self, x, SM): """ Initialize the hidden state of the core network and the location vector. This is called once every time a new minibatch `x` is introduced. """ dtype = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor) # h_t2, l_t, SM_local_smooth = self.model.initialize(x, SM) # initialize hidden state 1 as 0 vector to avoid the directly classification from context h_t1 = torch.zeros(self.batch_size, self.hidden_size).type(dtype) cell_state1 = torch.zeros(self.batch_size, self.hidden_size).type(dtype) cell_state2 = torch.zeros(self.batch_size, self.hidden_size).type(dtype) return h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth def train(self): """ Train the model on the training set. A checkpoint of the model is saved after each epoch and if the validation accuracy is improved upon, a separate ckpt is created for use on the test set. """ # load the most recent checkpoint if self.resume: self.load_checkpoint(best=False) print("\n[*] Train on {} samples, validate on {} samples".format( self.num_train, self.num_valid)) for epoch in range(self.start_epoch, self.epochs): print('\nEpoch: {}/{} - LR: {:.6f}'.format(epoch + 1, self.epochs, self.lr)) # train for 1 epoch train_loss, train_acc = self.train_one_epoch(epoch) # evaluate on validation set valid_loss, valid_acc = self.validate(epoch) # # reduce lr if validation loss plateaus # self.scheduler.step(valid_loss) is_best_valid = valid_acc > self.best_valid_acc is_best_train = train_acc > self.best_train_acc msg1 = "train loss: {:.3f} - train acc: {:.3f} " msg2 = "- val loss: {:.3f} - val acc: {:.3f}" if is_best_train: msg1 += " [*]" if is_best_valid: self.counter = 0 msg2 += " [*]" msg = msg1 + msg2 print(msg.format(train_loss, train_acc, valid_loss, valid_acc)) # check for improvement if not is_best_valid: self.counter += 1 if self.counter > self.train_patience: print("[!] No improvement in a while, stopping training.") return self.best_valid_acc = max(valid_acc, self.best_valid_acc) self.best_train_acc = max(train_acc, self.best_train_acc) self.save_checkpoint( { 'epoch': epoch + 1, 'model_state': self.model.state_dict(), 'optim_state': self.optimizer.state_dict(), 'best_valid_acc': self.best_valid_acc, 'best_train_acc': self.best_train_acc, }, is_best_valid) def train_one_epoch(self, epoch): """ Train the model for 1 epoch of the training set. An epoch corresponds to one full pass through the entire training set in successive mini-batches. This is used by train() and should not be called manually. """ batch_time = AverageMeter() losses = AverageMeter() accs = AverageMeter() tic = time.time() with tqdm(total=self.num_train) as pbar: for i, (x_raw, y) in enumerate(self.train_loader): # if self.use_gpu: x_raw, y = x_raw.cuda(), y.cuda() # detach images and their saliency maps x = x_raw[:, 0, ...].unsqueeze(1) SM = x_raw[:, 1, ...].unsqueeze(1) plot = False if (epoch % self.plot_freq == 0) and (i == 0): plot = True # initialize location vector and hidden state self.batch_size = x.shape[0] h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth = self.reset( x, SM) # save images imgs = [] imgs.append(x[0:9]) # extract the glimpses locs = [] log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t1, h_t2, l_t, b_t, p, cell_state1, cell_state2, SM_local_smooth = self.model( x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM, SM_local_smooth) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) # last iteration h_t1, h_t2, l_t, b_t, log_probas, p, cell_state1, cell_state2, SM_local_smooth = self.model( x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM, SM_local_smooth, last=True) log_pi.append(p) baselines.append(b_t) locs.append(l_t[0:9]) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # calculate reward predicted = torch.max(log_probas, 1)[1] if self.loss_fun_baseline == 'cross_entropy': # cross_entroy_loss need a long, batch x 1 tensor as target but R # also need to be subtracted by the baseline whose size is N x num_glimpse R = (predicted.detach() == y).long() # compute losses for differentiable modules loss_action, loss_baseline = self.choose_loss_fun( log_probas, y, baselines, R) R = R.float().unsqueeze(1).repeat(1, self.num_glimpses) else: R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action, loss_baseline = self.choose_loss_fun( log_probas, y, baselines, R) # loss_action = F.nll_loss(log_probas, y) # loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss # summed over timesteps and averaged across batch adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store #losses.update(loss.data[0], x.size()[0]) #accs.update(acc.data[0], x.size()[0]) losses.update(loss.data.item(), x.size()[0]) accs.update(acc.data.item(), x.size()[0]) # compute gradients and update SGD self.optimizer.zero_grad() loss.backward() self.optimizer.step() # measure elapsed time toc = time.time() batch_time.update(toc - tic) pbar.set_description( ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format( (toc - tic), loss.data.item(), acc.data.item()))) pbar.update(self.batch_size) # dump the glimpses and locs if plot: if self.use_gpu: imgs = [g.cpu().data.numpy().squeeze() for g in imgs] locs = [l.cpu().data.numpy() for l in locs] else: imgs = [g.data.numpy().squeeze() for g in imgs] locs = [l.data.numpy() for l in locs] pickle.dump( imgs, open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb")) pickle.dump( locs, open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb")) sio.savemat(self.plot_dir + "data_train_{}.mat".format(epoch + 1), mdict={ 'location': locs, 'patch': imgs }) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.train_loader) + i writer.add_scalar('Loss/train', losses, iteration) writer.add_scalar('Accuracy/train', accs, iteration) return losses.avg, accs.avg def validate(self, epoch): """ Evaluate the model on the validation set. """ losses = AverageMeter() accs = AverageMeter() for i, (x_raw, y) in enumerate(self.valid_loader): if self.use_gpu: x_raw, y = x_raw.cuda(), y.cuda() # detach images and their saliency maps x = x_raw[:, 0, ...].unsqueeze(1) SM = x_raw[:, 1, ...].unsqueeze(1) # duplicate M times x = x.repeat(self.M, 1, 1, 1) SM = SM.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth = self.reset( x, SM) # extract the glimpses log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t1, h_t2, l_t, b_t, p, cell_state1, cell_state2, SM_local_smooth = self.model( x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM, SM_local_smooth) # store baselines.append(b_t) log_pi.append(p) # last iteration h_t1, h_t2, l_t, b_t, log_probas, p, cell_state1, cell_state2, SM_local_smooth = self.model( x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM, SM_local_smooth, last=True) # store log_pi.append(p) baselines.append(b_t) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # average log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) baselines = baselines.contiguous().view(self.M, -1, baselines.shape[-1]) baselines = torch.mean(baselines, dim=0) log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1]) log_pi = torch.mean(log_pi, dim=0) # calculate reward predicted = torch.max(log_probas, 1)[1] if self.loss_fun_baseline == 'cross_entropy': # cross_entroy_loss need a long, batch x 1 tensor as target but R # also need to be subtracted by the baseline whose size is N x num_glimpse R = (predicted.detach() == y).long() # compute losses for differentiable modules loss_action, loss_baseline = self.choose_loss_fun( log_probas, y, baselines, R) R = R.float().unsqueeze(1).repeat(1, self.num_glimpses) else: R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action, loss_baseline = self.choose_loss_fun( log_probas, y, baselines, R) # compute losses for differentiable modules # loss_action = F.nll_loss(log_probas, y) # loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store losses.update(loss.data.item(), x.size()[0]) accs.update(acc.data.item(), x.size()[0]) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.valid_loader) + i writer.add_scalar('Accuracy/valid', accs, iteration) writer.add_scalar('Loss/valid', losses, iteration) return losses.avg, accs.avg def choose_loss_fun(self, log_probas, y, baselines, R): """ use disctionary to save function handle replacement of swith-case be careful of the argument data type and shape!!! """ loss_fun_pool = { 'mse': F.mse_loss, 'l1': F.l1_loss, 'nll': F.nll_loss, 'smooth_l1': F.smooth_l1_loss, 'kl_div': F.kl_div, 'cross_entropy': F.cross_entropy } return loss_fun_pool[self.loss_fun_action]( log_probas, y), loss_fun_pool[self.loss_fun_baseline](baselines, R) def test(self): """ Test the model on the held-out test data. This function should only be called at the very end once the model has finished training. """ correct = 0 # load the best checkpoint self.load_checkpoint(best=self.best) for i, (x, y) in enumerate(self.test_loader): if self.use_gpu: x, y = x.cuda(), y.cuda() x, y = Variable(x, volatile=True), Variable(y) # duplicate 10 times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth = self.reset( x, SM) # save images and glimpse location locs = [] imgs = [] imgs.append(x[0:9]) for t in range(self.num_glimpses - 1): # forward pass through model h_t1, h_t2, l_t, b_t, p, cell_state1, cell_state2, SM_local_smooth = self.model( x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM, SM_local_smooth) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) # last iteration h_t1, h_t2, l_t, b_t, log_probas, p, cell_state1, cell_state2, SM_local_smooth = self.model( x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM, SM_local_smooth, last=True) log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) pred = log_probas.data.max(1, keepdim=True)[1] correct += pred.eq(y.data.view_as(pred)).cpu().sum() # dump test data if self.use_gpu: imgs = [g.cpu().data.numpy().squeeze() for g in imgs] locs = [l.cpu().data.numpy() for l in locs] else: imgs = [g.data.numpy().squeeze() for g in imgs] locs = [l.data.numpy() for l in locs] pickle.dump(imgs, open(self.plot_dir + "g_test.p", "wb")) pickle.dump(locs, open(self.plot_dir + "l_test.p", "wb")) sio.savemat(self.plot_dir + "test_transient.mat", mdict={'location': locs}) perc = (100. * correct) / (self.num_test) error = 100 - perc print('[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)'.format( correct, self.num_test, perc, error)) def save_checkpoint(self, state, is_best): """ Save a copy of the model so that it can be loaded at a future date. This function is used when the model is being evaluated on the test data. If this model has reached the best validation accuracy thus far, a seperate file with the suffix `best` is created. """ # print("[*] Saving model to {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) torch.save(state, ckpt_path) if is_best: filename = self.model_name + '_model_best.pth.tar' shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename)) def load_checkpoint(self, best=False): """ Load the best copy of a model. This is useful for 2 cases: - Resuming training with the most recent model checkpoint. - Loading the best validation model to evaluate on the test data. Params ------ - best: if set to True, loads the best model. Use this if you want to evaluate your model on the test data. Else, set to False in which case the most recent version of the checkpoint is used. """ print("[*] Loading model from {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' if best: filename = self.model_name + '_model_best.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) ckpt = torch.load(ckpt_path) # load variables from checkpoint self.start_epoch = ckpt['epoch'] self.best_valid_acc = ckpt['best_valid_acc'] self.model.load_state_dict(ckpt['model_state']) self.optimizer.load_state_dict(ckpt['optim_state']) if best: print("[*] Loaded {} checkpoint @ epoch {} " "with best valid acc of {:.3f}".format( filename, ckpt['epoch'], ckpt['best_valid_acc'])) else: print("[*] Loaded {} checkpoint @ epoch {}".format( filename, ckpt['epoch']))
class Trainer(object): """ Trainer encapsulates all the logic necessary for training the Recurrent Attention Model. All hyperparameters are provided by the user in the config file. """ def __init__(self, config, data_loader): """ Construct a new Trainer instance. Args ---- - config: object containing command line arguments. - data_loader: data iterator """ self.config = config self.dis_R_thres = config.dis_R_thres # glimpse network params self.patch_size = config.patch_size self.glimpse_scale = config.glimpse_scale self.num_patches = config.num_patches self.loc_hidden = config.loc_hidden self.glimpse_hidden = config.glimpse_hidden # core network params self.num_glimpses = config.num_glimpses self.hidden_size = config.hidden_size # reinforce params self.std = config.std self.M = config.M self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") # data params if config.is_train: self.train_loader = data_loader[0] self.valid_loader = data_loader[1] self.num_train = len(self.train_loader) * config.batch_size self.num_valid = len(self.valid_loader) * config.batch_size else: self.test_loader = data_loader[1] self.num_test = len(self.test_loader.dataset) self.num_classes = 10 self.num_channels = 3 # training params self.epochs = config.epochs self.start_epoch = 0 self.momentum = config.momentum self.lr = config.init_lr # misc params self.use_gpu = config.use_gpu self.best = config.best self.ckpt_dir = config.ckpt_dir self.logs_dir = config.logs_dir self.best_valid_acc = 0. self.counter = 0 self.lr_patience = config.lr_patience self.train_patience = config.train_patience self.use_tensorboard = config.use_tensorboard self.resume = config.resume self.print_freq = config.print_freq self.plot_freq = config.plot_freq self.model_name = 'ram_{}_{}x{}_{}'.format(config.num_glimpses, config.patch_size, config.patch_size, config.glimpse_scale) self.plot_dir = './plots/' + self.model_name + '/' if not os.path.exists(self.plot_dir): os.makedirs(self.plot_dir) # configure tensorboard logging if self.use_tensorboard: tensorboard_dir = self.logs_dir + self.model_name print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir)) if not os.path.exists(tensorboard_dir): os.makedirs(tensorboard_dir) configure(tensorboard_dir) # build RAM model self.model = RecurrentAttention( self.patch_size, self.num_patches, self.glimpse_scale, self.num_channels, self.loc_hidden, self.glimpse_hidden, self.std, self.hidden_size, self.num_classes, ) if self.use_gpu: self.model.cuda() # train resnet or not # self.model.sensor.feature_extractor.eval() print('[*] Number of model parameters: {:,}'.format( sum([p.data.nelement() for p in self.model.parameters()]))) # # initialize optimizer and scheduler # self.optimizer = optim.SGD( # self.model.parameters(), lr=self.lr, momentum=self.momentum, # ) # self.scheduler = ReduceLROnPlateau( # self.optimizer, 'min', patience=self.lr_patience # ) self.optimizer = optim.Adam( self.model.parameters(), lr=3e-4, ) def reset(self): """ Initialize the hidden state of the core network and the location vector. This is called once every time a new minibatch `x` is introduced. """ dtype = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor) h_t = torch.zeros(self.batch_size, self.hidden_size) h_t = Variable(h_t).type(dtype) l_t = torch.Tensor(self.batch_size, 2).uniform_(-1, 1) l_t = Variable(l_t).type(dtype) return h_t, l_t def train(self): """ Train the model on the training set. A checkpoint of the model is saved after each epoch and if the validation accuracy is improved upon, a separate ckpt is created for use on the test set. """ # load the most recent checkpoint if self.resume: self.load_checkpoint(best=False) print("\n[*] Train on {} samples, validate on {} samples".format( self.num_train, self.num_valid)) for epoch in range(self.start_epoch, self.epochs): print('\nEpoch: {}/{} - LR: {:.6f}'.format(epoch + 1, self.epochs, self.lr)) # train for 1 epoch train_loss, train_acc = self.train_one_epoch(epoch) # evaluate on validation set valid_loss, valid_acc = self.validate(epoch) # train_loss, train_acc = self.train_one_epoch(epoch) # # reduce lr if validation loss plateaus # self.scheduler.step(valid_loss) # is_best = valid_acc > self.best_valid_acc is_best = 1 msg1 = "train loss: {:.3f} - train acc: {:.3f} " msg2 = "- val loss: {:.3f} - val acc: {:.3f}" if is_best: self.counter = 0 msg2 += " [*]" msg = msg1 + msg2 print(msg.format(train_loss, train_acc, valid_loss, valid_acc)) # check for improvement if not is_best: self.counter += 1 if self.counter > self.train_patience: print("[!] No improvement in a while, stopping training.") return self.best_valid_acc = max(valid_acc, self.best_valid_acc) is_best = 1 self.save_checkpoint( { 'epoch': epoch + 1, 'model_state': self.model.state_dict(), 'optim_state': self.optimizer.state_dict(), 'best_valid_acc': self.best_valid_acc, }, is_best) def train_one_epoch(self, epoch): """ Train the model for 1 epoch of the training set. An epoch corresponds to one full pass through the entire training set in successive mini-batches. This is used by train() and should not be called manually. """ batch_time = AverageMeter() losses = AverageMeter() accs = AverageMeter() losses_action = AverageMeter() countTotal = 1000 tic = time.time() count = 0 with tqdm(total=self.num_train) as pbar: for i, (x, fixation, y, speeds, courses, scale_gt, indexSeq, frameEnd) in enumerate(self.train_loader): if count > countTotal: return losses.avg, accs.avg count = count + 1 y = y.squeeze().float() if self.use_gpu: x, y, speeds, courses, scale_gt = x.cuda(), y.cuda( ), speeds.cuda(), courses.cuda(), scale_gt.cuda() x, y, speeds, courses, scale_gt = Variable(x), Variable( y), Variable(speeds), Variable(courses), Variable(scale_gt) plot = False if (epoch % self.plot_freq == 0) and (i == 0): plot = True # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # save images imgs = [] imgs.append(x[0:9]) # extract the glimpses locs = [] log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, speeds, courses, l_t, h_t, t) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) # last iteration h_t, l_t, b_t, l_t_final, p, scale = self.model( x, speeds, courses, l_t, h_t, self.num_glimpses - 1, last=True) log_pi.append(p) baselines.append(b_t) locs.append(l_t[0:9]) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # calculate reward # predicted = torch.max(log_probas, 1)[1] # R = (predicted.detach() == y).float() R = torch.zeros(y.shape[0]) for index in range(y.shape[0]): # get the distance of two locations distance = torch.sqrt( torch.pow(l_t_final[index, 0] - y[index, 0], 2) + torch.pow(l_t_final[index, 1] - y[index, 1], 2)).float() # R[index] = distance < self.dis_R_thres # temp= distance < self.dis_R_thres R[index] = 1 - distance # R = locs mean_R = torch.mean(R) R = R.unsqueeze(1).repeat(1, self.num_glimpses).to(self.device) # compute losses for differentiable modules # loss_action = F.nll_loss(log_probas, y) loss_action = F.mse_loss(l_t_final, y) loss_scale = F.mse_loss(scale, scale_gt) # if loss_action.data > 1: # print('loss_action > 1 and l_t_final = {} y = {}'.format(l_t_final.data, y.data)) loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss # summed over timesteps and averaged across batch adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action * 10 + loss_baseline * 10 + loss_reinforce + loss_scale # loss = loss_baseline + loss_reinforce # compute accuracy # correct = dis.float() # acc = 100 * (correct.sum() / len(y)) dist = distance # store losses.update(loss.data, x.size()[0]) accs.update(dist.data, x.size()[0]) losses_action.update(loss_action.data, x.size()[0]) # compute gradients and update SGD self.optimizer.zero_grad() loss.backward() self.optimizer.step() # measure elapsed time toc = time.time() batch_time.update(toc - tic) print( "Epoch: {} - {}/{} - {:.1f}s - loss: {:.3f} - dis: {:.3f} - loss_action : {:.3f} - loss_scale: {:.3f} " "- mean_R : {:.3f} - mean-baseline: {:.3f} - mean_adjusted_reward: {:.3f} " .format(epoch, count, countTotal, (toc - tic), loss.data, dist.data, loss_action.data, loss_scale.data, mean_R, torch.mean(baselines.data), torch.mean(adjusted_reward.data))) # pbar.set_description( # ( # "{:.1f}s - loss: {:.3f} - dis: {:.3f} - sum_R : {} - loss_action : {:.3f}".format( # (toc-tic), loss.data, dist.data, sum_R, loss_action.data # ) # ) # ) # pbar.update(self.batch_size) # dump the glimpses and locs if plot: if self.use_gpu: imgs = [g.cpu().data.numpy().squeeze() for g in imgs] locs = [l.cpu().data.numpy() for l in locs] else: imgs = [g.data.numpy().squeeze() for g in imgs] locs = [l.data.numpy() for l in locs] pickle.dump( imgs, open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb")) pickle.dump( locs, open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb")) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.train_loader) + i log_value('train_loss', losses.avg, iteration) log_value('train_acc', accs.avg, iteration) return losses.avg, accs.avg def validate(self, epoch): """ Evaluate the model on the validation set. """ losses = AverageMeter() accs = AverageMeter() countTotal = 50 count = 0 is_blend = 1 save_dir = os.path.join('logs', '{:02d}'.format(epoch)) if not os.path.exists(save_dir): os.mkdir(save_dir) for i, (x, fixs, y, speeds, courses, scale_gt, indexSeq, frameEnd) in enumerate(self.valid_loader): y = y.squeeze().float() if count > countTotal: return losses.avg, accs.avg count = count + 1 if self.use_gpu: x, y, speeds, courses = x.cuda(), y.cuda(), speeds.cuda( ), courses.cuda() x, y, speeds, courses = Variable(x), Variable(y), Variable( speeds), Variable(courses) # duplicate 10 times x = x.repeat(self.M, 1, 1, 1, 1) # speeds = speeds.repeat(self.M, 1,1,) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, speeds, courses, l_t, h_t, t) # store baselines.append(b_t) log_pi.append(p) # last iteration h_t, l_t, b_t, l_t_final, p, scale = self.model(x, speeds, courses, l_t, h_t, self.num_glimpses - 1, last=True) log_pi.append(p) baselines.append(b_t) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # average l_t_final = l_t_final.view(self.M, -1, l_t_final.shape[-1]) l_t_final = torch.mean(l_t_final, dim=0) if is_blend: for indexBlend in range(x.shape[0]): # img = x[indexBlend, :, -1, : , :].cpu().numpy() # img = np.transpose(img, (1,2,0))*255 # img = img[:, :, [2, 1, 0]] pathImg = os.path.join( dreyeve_dir, '{:02d}'.format(indexSeq[indexBlend]), 'frames', '{:06d}.jpg'.format(frameEnd[indexBlend])) img = read_image(pathImg, channels_first=False, color=True) # cv2.imwrite( 'temp.jpg', img) pathFix = os.path.join( dreyeve_dir, '{:02d}'.format(indexSeq[indexBlend]), 'saliency_fix', '{:06d}.png'.format(frameEnd[indexBlend])) map = read_image(pathFix, channels_first=False, color=False) # map = fixs[indexBlend, :,:].cpu().numpy() loc = l_t_final[indexBlend, :].cpu().detach().numpy() loc_gt = y[indexBlend].cpu().numpy() scale_blend = scale[indexBlend].cpu().detach().numpy() scale_gt_blend = scale_gt[indexBlend].cpu().detach().numpy( ) # blend = blend_map_with_focus_circle # loc= np.array([0,0]) # draw target blend = blend_map_with_focus_rectangle(img, map, loc, scale=scale_blend, color=(0, 0, 255)) #draw gt if not (np.isnan(loc_gt[0]) or np.isnan(loc_gt[1])): # loc_gt[0]=-0.9 # loc_gt[1]=0.2 blend = blend_map_with_focus_rectangle( blend, map, loc_gt, scale=scale_gt_blend, color=(0, 255, 0)) # blend = blend_map_with_focus_circle(img, map, loc_gt, color=(0, 255, 0)) print('scale is {:.3f} and scale_gt is {:.3f}'.format( float(scale_blend), float(scale_gt_blend))) cv2.imwrite( os.path.join(save_dir, '{:06d}.jpg'.format( frameEnd[indexBlend])), blend) baselines = baselines.contiguous().view(self.M, -1, baselines.shape[-1]) baselines = torch.mean(baselines, dim=0) log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1]) log_pi = torch.mean(log_pi, dim=0) # # calculate reward # predicted = torch.max(log_probas, 1)[1] # R = (predicted.detach() == y).float() # R = R.unsqueeze(1).repeat(1, self.num_glimpses) dis = 0 R = torch.zeros(y.shape[0]) for index in range(y.shape[0]): # get the distance of two locations distance = torch.sqrt( torch.pow(l_t_final[index, 0] - y[index, 0], 2) + torch.pow(l_t_final[index, 1] - y[index, 1], 2)) dis = dis + distance # R[index] = distance < self.dis_R_thres R[index] = distance < self.dis_R_thres # R = locs R = R.unsqueeze(1).repeat(1, self.num_glimpses).to(self.device) # compute losses for differentiable modules loss_action = F.mse_loss(l_t_final, y) loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action * 100 + loss_baseline + loss_reinforce # loss = loss_baseline + loss_reinforce # compute accuracy # compute accuracy correct = dis.float() # acc = 100 * (correct.sum() / len(y)) acc = dis / len(y) print('avg dist is {}'.format(acc)) # store losses.update(loss.data, x.size()[0]) accs.update(acc.data, x.size()[0]) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.valid_loader) + i log_value('valid_loss', losses.avg, iteration) log_value('valid_acc', accs.avg, iteration) return losses.avg, accs.avg def test(self): """ Test the model on the held-out test data. This function should only be called at the very end once the model has finished training. """ is_output = 1 if is_output: f = open('output.txt', 'w') # load the best checkpoint self.load_checkpoint(best=self.best) for i, (x, fixs, y, speeds, courses, scale_gt, indexSeq, frameEnd) in enumerate(self.test_loader): y = y.squeeze().float() if self.use_gpu: x, y, speeds, courses = x.cuda(), y.cuda(), speeds.cuda( ), courses.cuda() x, y, speeds, courses = Variable(x), Variable(y), Variable( speeds), Variable(courses) # duplicate 10 times x = x.repeat(self.M, 1, 1, 1, 1) # speeds = speeds.repeat(self.M, 1,1,) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, speeds, courses, l_t, h_t, t) # store baselines.append(b_t) log_pi.append(p) # last iteration h_t, l_t, b_t, l_t_final, p, scale = self.model(x, speeds, courses, l_t, h_t, self.num_glimpses - 1, last=True) log_pi.append(p) baselines.append(b_t) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # average l_t_final = l_t_final.view(self.M, -1, l_t_final.shape[-1]) l_t_final = torch.mean(l_t_final, dim=0) if is_output: for indexBlend in range(x.shape[0]): loc = l_t_final[indexBlend, :].cpu().detach().numpy() loc_gt = y[indexBlend].cpu().numpy() scale_blend = scale[indexBlend].cpu().detach().numpy() scale_gt_blend = scale_gt[indexBlend].cpu().detach().numpy( ) line = '{:02} {:04d} {:.3f} {:.3f} {:.3f} {:.3f} {:.3f} {:.3f}\n'\ .format(indexSeq[indexBlend], frameEnd[indexBlend], loc_gt[0], loc_gt[1], loc[0], loc[1], float(scale_gt_blend), float(scale_blend)) print('seq: {:02}- frame: {:04d} - loc_gt_h: {:.3f} - loc_gt_w: {:.3f} - loc_h: {:.3f} - loc_w: {:.3f} ' '- scale_gt: {:.3f}- scale: {:.3f}'\ .format(indexSeq[indexBlend], frameEnd[indexBlend], loc_gt[0], loc_gt[1], loc[0], loc[1], float(scale_gt_blend), float(scale_blend))) f.writelines(line) baselines = baselines.contiguous().view(self.M, -1, baselines.shape[-1]) baselines = torch.mean(baselines, dim=0) log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1]) log_pi = torch.mean(log_pi, dim=0) dis = 0 R = torch.zeros(y.shape[0]) for index in range(y.shape[0]): # get the distance of two locations distance = torch.sqrt( torch.pow(l_t_final[index, 0] - y[index, 0], 2) + torch.pow(l_t_final[index, 1] - y[index, 1], 2)) dis = dis + distance # R[index] = distance < self.dis_R_thres R[index] = distance < self.dis_R_thres # R = locs R = R.unsqueeze(1).repeat(1, self.num_glimpses).to(self.device) # compute losses for differentiable modules loss_action = F.mse_loss(l_t_final, y) loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action * 100 + loss_baseline + loss_reinforce # loss = loss_baseline + loss_reinforce # compute accuracy acc = dis / len(y) print('avg dist is {}'.format(acc)) def save_checkpoint(self, state, is_best): """ Save a copy of the model so that it can be loaded at a future date. This function is used when the model is being evaluated on the test data. If this model has reached the best validation accuracy thus far, a seperate file with the suffix `best` is created. """ # print("[*] Saving model to {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) torch.save(state, ckpt_path) if is_best: filename = self.model_name + '_model_best.pth.tar' shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename)) def load_checkpoint(self, best=False): """ Load the best copy of a model. This is useful for 2 cases: - Resuming training with the most recent model checkpoint. - Loading the best validation model to evaluate on the test data. Params ------ - best: if set to True, loads the best model. Use this if you want to evaluate your model on the test data. Else, set to False in which case the most recent version of the checkpoint is used. """ print("[*] Loading model from {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' if best: filename = self.model_name + '_model_best.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) ckpt = torch.load(ckpt_path) # load variables from checkpoint self.start_epoch = ckpt['epoch'] self.best_valid_acc = ckpt['best_valid_acc'] self.model.load_state_dict(ckpt['model_state']) self.optimizer.load_state_dict(ckpt['optim_state']) if best: print("[*] Loaded {} checkpoint @ epoch {} " "with best valid acc of {:.3f}".format( filename, ckpt['epoch'], ckpt['best_valid_acc'])) else: print("[*] Loaded {} checkpoint @ epoch {}".format( filename, ckpt['epoch']))
class Trainer(object): """ Trainer encapsulates all the logic necessary for training the Recurrent Attention Model. All hyperparameters are provided by the user in the config file. """ def __init__(self, config, data_loader): """ Construct a new Trainer instance. Args ---- - config: object containing command line arguments. - data_loader: data iterator """ self.config = config # glimpse network params self.patch_size = config.patch_size self.glimpse_scale = config.glimpse_scale self.num_patches = config.num_patches self.loc_hidden = config.loc_hidden self.glimpse_hidden = config.glimpse_hidden # core network params self.num_glimpses = config.num_glimpses self.hidden_size = config.hidden_size # reinforce params self.std = config.std self.M = config.M # data params if config.is_train: self.train_loader = data_loader[0] self.valid_loader = data_loader[1] self.num_train = len(self.train_loader.sampler.indices) self.num_valid = len(self.valid_loader.sampler.indices) else: self.test_loader = data_loader self.num_test = len(self.test_loader.dataset) self.num_classes = 10 self.num_channels = 1 # training params self.epochs = config.epochs self.start_epoch = 0 self.momentum = config.momentum self.lr = config.init_lr # misc params self.no_tqdm = config.no_tqdm self.use_gpu = config.use_gpu self.best = config.best self.ckpt_dir = config.ckpt_dir self.logs_dir = config.logs_dir self.best_valid_acc = 0. self.counter = 0 self.lr_patience = config.lr_patience self.train_patience = config.train_patience self.use_tensorboard = config.use_tensorboard self.resume = config.resume self.print_freq = config.print_freq self.plot_freq = config.plot_freq self.model_name = 'ram_{}_{}x{}_{}'.format(config.num_glimpses, config.patch_size, config.patch_size, config.glimpse_scale) if config.uncertainty == True: self.model_name += '_uncertainty_1' else: self.model_name += '_uncertainty_0' if config.intrinsic == True: self.model_name += '_intrinsic_1' else: self.model_name += '_intrinsic_0' self.plot_dir = './plots/' + self.model_name + '/' if not os.path.exists(self.plot_dir): os.makedirs(self.plot_dir) # configure tensorboard logging if self.use_tensorboard: tensorboard_dir = self.logs_dir + self.model_name print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir)) if not os.path.exists(tensorboard_dir): os.makedirs(tensorboard_dir) configure(tensorboard_dir) # build RAM model self.model = RecurrentAttention(self.patch_size, self.num_patches, self.glimpse_scale, self.num_channels, self.loc_hidden, self.glimpse_hidden, self.std, self.hidden_size, self.num_classes, self.config) if self.use_gpu: self.model.cuda() self.dtypeFloat = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor) self.dtypeLong = (torch.cuda.LongTensor if self.use_gpu else torch.LongTensor) print('[*] Number of model parameters: {:,}'.format( sum([p.data.nelement() for p in self.model.parameters()]))) # # initialize optimizer and scheduler self.optimizer = optim.Adam( self.model.parameters(), lr=self.config.init_lr, ) lambda_of_lr = lambda epoch: 0.95**epoch self.scheduler = LambdaLR(self.optimizer, lr_lambda=lambda_of_lr) # self.scheduler = StepLR(self.optimizer,step_size=20,gamma=0.1) # self.scheduler = ReduceLROnPlateau( # self.optimizer, 'min', patience=self.lr_patience # ) def reset(self): """ Initialize the hidden state of the core network and the location vector. This is called once every time a new minibatch `x` is introduced. """ dtype = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor) h_t = torch.zeros(self.batch_size, self.hidden_size) h_t = Variable(h_t).type(dtype) l_t = torch.Tensor(self.batch_size, 2).uniform_(-1, 1) l_t = Variable(l_t).type(dtype) return h_t, l_t def train(self): """ Train the model on the training set. A checkpoint of the model is saved after each epoch and if the validation accuracy is improved upon, a separate ckpt is created for use on the test set. """ # load the most recent checkpoint if self.resume: self.load_checkpoint(best=False) print( "\n[*] Train on {} samples, validate on {} samples, learn rate {}". format(self.num_train, self.num_valid, self.scheduler.get_lr())) for epoch in range(self.start_epoch, self.epochs): print('\nEpoch: {}/{} . lr: {:.4e} '.format( epoch + 1, self.epochs, self.scheduler.get_lr()[0])) # train for 1 epoch train_loss, train_acc = self.train_one_epoch(epoch) # evaluate on validation set valid_loss, valid_acc = self.validate(epoch) self.scheduler.step() is_best = valid_acc > self.best_valid_acc msg1 = "train loss: {:.3f} - train acc: {:.3f} " msg2 = "- val loss: {:.3f} - val acc: {:.3f}" if is_best: self.counter = 0 msg2 += " [*]" msg = msg1 + msg2 print(msg.format(train_loss, train_acc, valid_loss, valid_acc)) # check for improvement if not is_best: self.counter += 1 if self.counter > self.train_patience: print("[!] No improvement in a while, stopping training.") return self.best_valid_acc = max(valid_acc, self.best_valid_acc) self.save_checkpoint( { 'epoch': epoch + 1, 'model_state': self.model.state_dict(), 'optim_state': self.optimizer.state_dict(), 'best_valid_acc': self.best_valid_acc, }, is_best) def train_one_epoch(self, epoch): """ Train the model for 1 epoch of the training set. An epoch corresponds to one full pass through the entire training set in successive mini-batches. This is used by train() and should not be called manually. """ batch_time = AverageMeter() losses = AverageMeter() accs = AverageMeter() tic = time.time() with tqdm(total=self.num_train, disable=self.no_tqdm) as pbar: for i, (x, y) in enumerate(self.train_loader): if self.config.use_translate: x = translate_function(x, original_dataset=x) if self.use_gpu: x, y = x.cuda(), y.cuda() x, y = Variable(x), Variable(y) plot = False if (epoch % self.plot_freq == 0) and (i == 0): plot = True # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # save images imgs = [] imgs.append(x[0:9]) # extract the glimpses locs = [] log_pi = [] baselines = [] all_log_probas = [] # the prediction at each glimpse step uncertainities = [ ] # the self-uncertainty at each glimpse step uncertainities_baseline = [ ] # the self-uncertainty at each glimpse step, but this baseline is only used for the loss of training self-uncertainty, which only involves the error network. # by default it needs to run `self.num_glimpse` times num_glimpses_taken = [ self.num_glimpses - 1 for _ in range(self.batch_size) ] for t in range(self.num_glimpses): # forward pass through model h_t, l_t, b_t, log_probas, p, diff_uncertainty, diff_uncertainty_baseline = self.model( x, l_t, h_t, last=True) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) all_log_probas.append(log_probas) uncertainities.append(diff_uncertainty) uncertainities_baseline.append(diff_uncertainty_baseline) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # if self.config.uncertainty == True: if self.config.uncertainty == True: uncertainities = torch.stack(uncertainities).transpose( 1, 0) uncertainities_baseline = torch.stack( uncertainities_baseline).transpose(1, 0) all_log_probas = torch.stack(all_log_probas).transpose(1, 0) # calculate reward num_glimpses_taken_indices = torch.LongTensor( num_glimpses_taken).type(self.dtypeLong) log_probas = torch.cat([ torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(all_log_probas, num_glimpses_taken_indices) ]).squeeze() predicted = torch.max(log_probas, 1)[1] R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules num_glimpses_taken = Variable( torch.LongTensor(num_glimpses_taken), requires_grad=False).type(self.dtypeLong) # the mask is used to take only the result of the last glimpse mask = _sequence_mask(sequence_length=num_glimpses_taken, max_len=self.num_glimpses) loss_action = F.nll_loss(log_probas, y, reduction='none') loss_action = torch.mean(loss_action) loss_baseline = F.mse_loss(baselines, R, reduction='none') loss_baseline = torch.mean(loss_baseline * mask) # loss_baseline = torch.mean( loss_baseline ) # compute reinforce loss # summed over timesteps and averaged across batch adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward * mask, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce if self.config.uncertainty == True: y_real_value = F.one_hot( y, self.num_classes).float().detach() diff_ = Variable(torch.abs( y_real_value.unsqueeze(1).expand( -1, self.num_glimpses, -1).data - torch.exp(all_log_probas).data), requires_grad=False) # loss_self_uncertaintiy_baseline = F.mse_loss(uncertainities_baseline, diff_) loss_self_uncertaintiy_baseline = F.mse_loss( uncertainities_baseline, diff_, reduction='none').mean() loss_self_uncertaintiy_baseline = torch.mean( loss_self_uncertaintiy_baseline) loss += loss_self_uncertaintiy_baseline if self.config.intrinsic == True: # the intrinsic sparsity belief reg = self.config.lambda_intrinsic intrinsic_term = torch.sum(-(1.0 / self.num_classes) * log_probas) loss_intrinsic = reg * intrinsic_term loss += loss_intrinsic if self.config.uncertainty == True: # the second reinforce loss: minimizing the uncertainty reg = self.config.lambda_uncertainty loss_self_uncertaintiy_minimizing = reg * torch.sum( uncertainities) loss += loss_self_uncertaintiy_minimizing # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store losses.update(loss.data, list(x.size())[0]) accs.update(acc.data, list(x.size())[0]) # compute gradients and update SGD self.optimizer.zero_grad() loss.backward() self.optimizer.step() # measure elapsed time toc = time.time() batch_time.update(toc - tic) if self.no_tqdm is not True: pbar.set_description( ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format( (toc - tic), loss.data, acc.data))) pbar.update(self.batch_size) # dump the glimpses and locs if plot: if self.use_gpu: imgs = [g.cpu().data.numpy().squeeze() for g in imgs] locs = [l.cpu().data.numpy() for l in locs] else: imgs = [g.data.numpy().squeeze() for g in imgs] locs = [l.data.numpy() for l in locs] pickle.dump( imgs, open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb")) pickle.dump( locs, open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb")) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.train_loader) + i log_value('train_loss', losses.avg, iteration) log_value('train_acc', accs.avg, iteration) return losses.avg, accs.avg def validate(self, epoch, M=1): """ Evaluate the model on the validation set. """ losses = AverageMeter() accs = AverageMeter() for i, (x, y) in enumerate(self.valid_loader): if self.config.use_translate: x = translate_function(x, original_dataset=x) if self.use_gpu: x, y = x.cuda(), y.cuda() x, y = Variable(x), Variable(y) # duplicate M times x = x.repeat(M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses locs = [] log_pi = [] baselines = [] all_log_probas = [] uncertainities = [] uncertainities_baseline = [] # by default it needs to run `self.num_glimpse` times num_glimpses_taken = [ self.num_glimpses - 1 for _ in range(self.batch_size) ] for t in range(self.num_glimpses): # forward pass through model h_t, l_t, b_t, log_probas, p, diff_uncertainty, diff_uncertainty_baseline = self.model( x, l_t, h_t, last=True) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) all_log_probas.append(log_probas) uncertainities.append(diff_uncertainty) uncertainities_baseline.append(diff_uncertainty_baseline) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) if self.config.uncertainty == True: uncertainities = torch.stack(uncertainities).transpose(1, 0) uncertainities_baseline = torch.stack( uncertainities_baseline).transpose(1, 0) all_log_probas = torch.stack(all_log_probas).transpose(1, 0) # calculate reward num_glimpses_taken_indices = torch.LongTensor( num_glimpses_taken).type(self.dtypeLong) log_probas = torch.cat([ torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(all_log_probas, num_glimpses_taken_indices) ]).squeeze() # average the `self.M` times of prediction log_probas = log_probas.view(M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) predicted = torch.max(log_probas, 1)[1] R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(M, self.num_glimpses) # compute losses for differentiable modules num_glimpses_taken = Variable(torch.LongTensor(num_glimpses_taken), requires_grad=False).type( self.dtypeLong) mask = _sequence_mask(sequence_length=num_glimpses_taken, max_len=self.num_glimpses) loss_action = F.nll_loss(log_probas, y, reduction='none') loss_action = torch.mean(loss_action) loss_baseline = F.mse_loss(baselines, R, reduction='none') loss_baseline = torch.mean(loss_baseline * mask) adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward * mask, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce if self.config.uncertainty == True: y_real_value = F.one_hot(y, self.num_classes).float().detach() diff_ = Variable(torch.abs( y_real_value.unsqueeze(1).expand(-1, self.num_glimpses, -1).data - torch.exp(all_log_probas).data), requires_grad=False) loss_self_uncertaintiy_baseline = F.mse_loss( uncertainities_baseline, diff_, reduction='none').mean() loss_self_uncertaintiy_baseline = torch.mean( loss_self_uncertaintiy_baseline) loss += loss_self_uncertaintiy_baseline if self.config.intrinsic == True: # the intrinsic sparsity belief reg = self.config.lambda_intrinsic loss_intrinsic = reg * torch.sum( -(1.0 / self.num_classes) * log_probas) loss += loss_intrinsic if self.config.uncertainty == True: # the second reinforce loss: minimizing the uncertainty reg = self.config.lambda_uncertainty loss_self_uncertaintiy_minimizing = reg * torch.sum( uncertainities) loss += loss_self_uncertaintiy_minimizing # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store losses.update(loss.data, list(x.size())[0]) accs.update(acc.data, list(x.size())[0]) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.valid_loader) + i log_value('valid_loss', losses.avg, iteration) log_value('valid_acc', accs.avg, iteration) return losses.avg, accs.avg def test(self): """ Test the model on the held-out test data. This function should only be called at the very end once the model has finished training. """ correct = 0 # load the best checkpoint self.load_checkpoint(best=self.best) self.num_test = len(self.test_loader.sampler) all_num_glimpses_taken = [] for i, (x, y) in enumerate(self.test_loader): torch.manual_seed(self.config.random_seed) if self.use_gpu: x, y = x.cuda(), y.cuda() x, y = Variable(x), Variable(y) # duplicate 10 times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses locs = [] log_pi = [] baselines = [] all_log_probas = [] uncertainities = [] # by default it needs to run `self.num_glimpse` times num_glimpses_taken = [ self.config.num_glimpses - 1 for _ in range(self.batch_size) ] for t in range(self.config.num_glimpses): # forward pass through model h_t, l_t, b_t, log_probas, p, diff_uncertainty, diff_uncertainty_baseline = self.model( x, l_t, h_t, last=True) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) all_log_probas.append(log_probas) uncertainities.append(diff_uncertainty) if self.config.dynamic == True: # determine if it has achieve a threshold uncertainty probs_data = torch.exp(log_probas).data.tolist() diff_uncertainty_data = diff_uncertainty.data.tolist() for instance_idx, (prediction, uncertainty) in enumerate( zip(probs_data, diff_uncertainty_data)): a_star_idx = max(enumerate(prediction), key=lambda x: x[1])[0] a_prime_idx = max( [(idx, pred + self.config.exploration_rate * uncertainty[idx]) for idx, pred in enumerate(prediction) if idx != a_star_idx], key=lambda x: x[1])[0] a_star_lower_bound = prediction[ a_star_idx] - self.config.exploration_rate * uncertainty[ a_star_idx] a_prime_upper_bound = prediction[ a_prime_idx] - self.config.exploration_rate * uncertainty[ a_prime_idx] if a_star_lower_bound >= a_prime_upper_bound: num_glimpses_taken[instance_idx] = t if all([ num < self.config.num_glimpses - 1 for num in num_glimpses_taken ]): # print(num_glimpses_taken) break # print('strange! end now!:',t) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) if self.config.uncertainty == True or self.config.dynamic == True: uncertainities = torch.stack(uncertainities).transpose(1, 0) all_log_probas = torch.stack(all_log_probas).transpose(1, 0) all_num_glimpses_taken.extend(num_glimpses_taken) # calculate reward num_glimpses_taken_indices = torch.LongTensor( num_glimpses_taken).type(self.dtypeLong) log_probas = torch.cat([ torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(all_log_probas, num_glimpses_taken_indices) ]).squeeze() # average the `self.M` times of prediction log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) pred = log_probas.data.max(1, keepdim=True)[1] correct += pred.eq(y.data.view_as(pred)).cpu().sum() perc = (100. * correct) / (self.num_test) error = 100 - perc print('[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)'.format( correct, self.num_test, perc, error)) if self.config.dynamic == True: print('use dynamic') avg_num_glimpses_taken = sum(all_num_glimpses_taken) / len( all_num_glimpses_taken) + 1 return (avg_num_glimpses_taken, 1.0 * correct.tolist() / self.num_test) return 1.0 * correct.tolist() / self.num_test # return perc.tolist() def test_for_all( self, range_all=100, ): """ Test the model on the held-out test data. This is used to run the model under different number of glimpses """ correct = [] for _ in range(range_all): correct.append(0) # load the best checkpoint self.load_checkpoint(best=self.best) self.num_test = len(self.test_loader.sampler) all_num_glimpses_taken = [] for i, (x, y) in enumerate(tqdm(self.test_loader)): torch.manual_seed(self.config.random_seed) if self.use_gpu: x, y = x.cuda(), y.cuda() x, y = Variable(x), Variable(y) # duplicate 10 times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses locs = [] log_pi = [] baselines = [] all_log_probas = [] uncertainities = [] # by default it needs to run `self.num_glimpse` times num_glimpses_taken = [ range_all - 1 for _ in range(self.batch_size) ] for t in range(self.config.num_glimpses): # forward pass through model h_t, l_t, b_t, log_probas, p, diff_uncertainty, diff_uncertainty_baseline = self.model( x, l_t, h_t, last=True) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) all_log_probas.append(log_probas) uncertainities.append(diff_uncertainty) if self.config.dynamic == True: # determine if it has achieve a threshold uncertainty probs_data = torch.exp(log_probas).data.tolist() diff_uncertainty_data = diff_uncertainty.data.tolist() for instance_idx, (prediction, uncertainty) in enumerate( zip(probs_data, diff_uncertainty_data)): a_star_idx = max(enumerate(prediction), key=lambda x: x[1])[0] a_prime_idx = max( [(idx, pred + self.config.exploration_rate * uncertainty[idx]) for idx, pred in enumerate(prediction) if idx != a_star_idx], key=lambda x: x[1])[0] a_star_lower_bound = prediction[ a_star_idx] - self.config.exploration_rate * uncertainty[ a_star_idx] a_prime_upper_bound = prediction[ a_prime_idx] - self.config.exploration_rate * uncertainty[ a_prime_idx] if a_star_lower_bound >= a_prime_upper_bound: num_glimpses_taken[instance_idx] = t if all([ num < self.config.num_glimpses - 1 for num in num_glimpses_taken ]): # print(num_glimpses_taken) break # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) if self.config.uncertainty == True or self.config.dynamic == True: uncertainities = torch.stack(uncertainities).transpose(1, 0) all_log_probas = torch.stack(all_log_probas).transpose(1, 0) all_num_glimpses_taken.extend(num_glimpses_taken) # calculate reward for num in range(range_all): num_glimpses_taken = [num for _ in range(self.batch_size)] num_glimpses_taken_indices = torch.LongTensor( num_glimpses_taken).type(self.dtypeLong) # log_probas = torch.cat([ torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(all_log_probas, num_glimpses_taken_indices) ]).squeeze() log_probas = all_log_probas[:, num] # print(all_log_probas.size(),log_probas.size()) # average the `self.M` times of prediction log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) pred = log_probas.data.max(1, keepdim=True)[1] correct[num] += pred.eq(y.data.view_as(pred)).cpu().sum() return [1.0 * cor.tolist() / self.num_test for cor in correct] # return 1.0 * correct.tolist() / self.num_test def save_checkpoint(self, state, is_best): """ Save a copy of the model so that it can be loaded at a future date. This function is used when the model is being evaluated on the test data. If this model has reached the best validation accuracy thus far, a seperate file with the suffix `best` is created. """ # print("[*] Saving model to {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) torch.save(state, ckpt_path) if is_best: filename = self.model_name + '_model_best.pth.tar' shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename)) def load_checkpoint(self, best=False): """ Load the best copy of a model. This is useful for 2 cases: - Resuming training with the most recent model checkpoint. - Loading the best validation model to evaluate on the test data. Params ------ - best: if set to True, loads the best model. Use this if you want to evaluate your model on the test data. Else, set to False in which case the most recent version of the checkpoint is used. """ print("[*] Loading model from {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' if best: filename = self.model_name + '_model_best.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) ckpt = torch.load(ckpt_path) # load variables from checkpoint self.start_epoch = ckpt['epoch'] self.best_valid_acc = ckpt['best_valid_acc'] self.model.load_state_dict(ckpt['model_state']) self.optimizer.load_state_dict(ckpt['optim_state']) if best: print("[*] Loaded {} checkpoint @ epoch {} " "with best valid acc of {:.3f}".format( filename, ckpt['epoch'], ckpt['best_valid_acc'])) else: print("[*] Loaded {} checkpoint @ epoch {}".format( filename, ckpt['epoch']))
class Trainer(object): """ Trainer encapsulates all the logic necessary for training the Recurrent Attention Model. All hyperparameters are provided by the user in the config file. """ def __init__(self, config, data_loader): """ Construct a new Trainer instance. Args ---- - config: object containing command line arguments. - data_loader: data iterator """ self.config = config # glimpse network params self.patch_size = config.patch_size self.glimpse_scale = config.glimpse_scale self.num_patches = config.num_patches self.loc_hidden = config.loc_hidden self.glimpse_hidden = config.glimpse_hidden # core network params self.num_glimpses = config.num_glimpses self.hidden_size = config.hidden_size # reinforce params self.std = config.std self.M = config.M # data params if config.is_train: self.train_loader = data_loader[0] self.valid_loader = data_loader[1] self.num_train = len(self.train_loader.sampler.indices) self.num_valid = len(self.valid_loader.sampler.indices) else: self.test_loader = data_loader self.num_test = len(self.test_loader.dataset) self.num_classes = 10 #1000 #365 self.num_channels = 3 # training params self.epochs = config.epochs self.start_epoch = 0 self.momentum = config.momentum self.lr = config.init_lr # misc params self.use_gpu = config.use_gpu self.best = config.best self.ckpt_file = config.ckpt_file self.ckpt_dir = config.ckpt_dir self.logs_dir = config.logs_dir self.best_valid_acc = 0. self.counter = 0 self.lr_patience = config.lr_patience self.train_patience = config.train_patience self.use_tensorboard = config.use_tensorboard self.resume = config.resume self.print_freq = config.print_freq self.plot_freq = config.plot_freq self.model_name = 'ram_{}_{}x{}_{}_{}'.format( config.num_glimpses, config.patch_size, config.patch_size, config.glimpse_scale, datetime.date.today().strftime("%y-%m-%d")) self.plot_dir = './plots/' + self.model_name + '/' if not os.path.exists(self.plot_dir): os.makedirs(self.plot_dir) self.alternating_learning = False self.train_loc_flag = False # configure tensorboard logging if self.use_tensorboard: tensorboard_dir = self.logs_dir + self.model_name print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir)) if not os.path.exists(tensorboard_dir): os.makedirs(tensorboard_dir) configure(tensorboard_dir) # build RAM model self.model = RecurrentAttention( self.patch_size, self.num_patches, self.glimpse_scale, self.num_channels, self.loc_hidden, self.glimpse_hidden, self.std, self.hidden_size, self.num_classes, ) if self.use_gpu: self.model.cuda() print('[*] Number of model parameters: {:,}'.format( sum([p.data.nelement() for p in self.model.parameters()]))) # # initialize optimizer and scheduler # self.optimizer = optim.SGD( # self.model.parameters(), lr=self.lr, momentum=self.momentum, # ) # self.scheduler = ReduceLROnPlateau( # self.optimizer, 'min', patience=self.lr_patience # ) self.optimizer = optim.Adam( self.model.parameters(), lr=3e-4, ) def reset(self): """ Initialize the hidden state of the core network and the location vector. This is called once every time a new minibatch `x` is introduced. """ dtype = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor) h_t = torch.zeros(self.batch_size, self.hidden_size) h_t = Variable(h_t).type(dtype) l_t = torch.Tensor(self.batch_size, 2).uniform_(-1, 1) l_t = Variable(l_t).type(dtype) return h_t, l_t def train(self): """ Train the model on the training set. A checkpoint of the model is saved after each epoch and if the validation accuracy is improved upon, a separate ckpt is created for use on the test set. """ # load the most recent checkpoint if self.resume: self.load_checkpoint(best=False) print("\n[*] Train on {} samples, validate on {} samples".format( self.num_train, self.num_valid)) for epoch in range(self.start_epoch, self.epochs): print('\nEpoch: {}/{} - LR: {:.6f}'.format(epoch + 1, self.epochs, self.lr)) # train for 1 epoch train_loss, train_acc = self.train_one_epoch(epoch) # evaluate on validation set valid_loss, valid_acc = self.validate(epoch) # # reduce lr if validation loss plateaus # self.scheduler.step(valid_loss) is_best = valid_acc > self.best_valid_acc msg1 = "train loss: {:.3f} - train acc: {:.3f} " msg2 = "- val loss: {:.3f} - val acc: {:.3f}" if is_best: self.counter = 0 msg2 += " [*]" msg = msg1 + msg2 print(msg.format(train_loss, train_acc, valid_loss, valid_acc)) # check for improvement if not is_best: self.counter += 1 if self.alternating_learning: if self.counter >= 5: self.train_loc_flag = not self.train_loc_flag print( "[!] No improvement in a while. Switch loss. Now training:", ["ActionNet", "LocationNet"][self.train_loc_flag]) self.counter = 0 # if not self.train_loc_flag: # self.lr /= 5 # print("[!] No improvement in a while. Decrease learning rate:", self.lr) if self.counter > self.train_patience: print("[!] No improvement in a while, stopping training.") return self.best_valid_acc = max(valid_acc, self.best_valid_acc) self.save_checkpoint( { 'epoch': epoch + 1, 'model_state': self.model.state_dict(), 'optim_state': self.optimizer.state_dict(), 'best_valid_acc': self.best_valid_acc, }, is_best) def train_one_epoch(self, epoch): """ Train the model for 1 epoch of the training set. An epoch corresponds to one full pass through the entire training set in successive mini-batches. This is used by train() and should not be called manually. """ batch_time = AverageMeter() losses = AverageMeter() accs = AverageMeter() tic = time.time() with tqdm(total=self.num_train) as pbar: for i, (x, y) in enumerate(self.train_loader): # for i, (x, y), f in enumerate(self.train_loader): # uncomment when using dataset with fixation proposals if self.use_gpu: x, y = x.cuda(), y.cuda() x, y = Variable(x), Variable(y) plot = False if (epoch % self.plot_freq == 0) and (i == 0): plot = True # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # save images imgs = [] imgs.append(x[0:9]) # extract the glimpses locs = [] log_pi = [] log_probs = [] baselines = [] for t in range(self.num_glimpses): locs.append(l_t) h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t) #l_t = f[:,t].float() # uncomment when using dataset with fixation proposals # store baselines.append(b_t) log_pi.append(p) log_probs.append(log_probas[0:9]) # convert list to tensors and reshape baselines = torch.stack(baselines) baselines = baselines.transpose(1, 0) #log_pi = torch.stack(log_pi).transpose(1, 0) # only when using RL log_probs = torch.stack(log_probs).transpose(1, 0) # calculate reward predicted = torch.max(log_probas, 1)[1] R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) #R = R - self.get_loc_reward(locs) # compute losses for differentiable modules loss_action = F.nll_loss(log_probas, y) loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss # summed over timesteps and averaged across batch adjusted_reward = R - baselines.detach() #loss_reinforce = torch.sum(-log_pi*adjusted_reward, dim=1) # only when using RL #loss_reinforce = torch.mean(loss_reinforce, dim=0) # only when using RL loss = loss_action #loss = loss_action + loss_baseline + loss_reinforce * 0.01 # only when using RL # sum up into a hybrid loss # if self.alternating_learning: # if self.train_loc_flag: # loss = loss_reinforce # else: # loss = loss_action # else: # loss = loss_action + loss_baseline + loss_reinforce #loss = loss_action # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store losses.update(loss.data.item(), x.size()[0]) accs.update(acc.data.item(), x.size()[0]) # compute gradients and update SGD self.optimizer.zero_grad() loss.backward() self.optimizer.step() # measure elapsed time toc = time.time() batch_time.update(toc - tic) pbar.set_description( ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format( (toc - tic), loss.data.item(), acc.data.item()))) pbar.update(self.batch_size) # dump the glimpses and locs if plot: if self.use_gpu: imgs = [g.cpu().data.numpy().squeeze() for g in imgs] locs = [l.cpu().data.numpy()[:9] for l in locs] log_probs = [p.cpu().data.numpy() for p in log_probs] ys = [g.cpu().data.numpy() for g in y[:9]] else: imgs = [g.data.numpy().squeeze() for g in imgs] locs = [l.data.numpy()[:9] for l in locs] log_probs = [p.data.numpy() for p in log_probs] ys = [g.data.numpy() for g in y[:9]] pickle.dump( imgs, open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb")) pickle.dump( locs, open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb")) pickle.dump( log_probs, open(self.plot_dir + "p_{}.p".format(epoch + 1), "wb")) pickle.dump( ys, open(self.plot_dir + "y_{}.p".format(epoch + 1), "wb")) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.train_loader) + i log_value('train_loss', losses.avg, iteration) log_value('train_acc', accs.avg, iteration) return losses.avg, accs.avg def validate(self, epoch): """ Evaluate the model on the validation set. """ losses = AverageMeter() accs = AverageMeter() for i, (x, y) in enumerate(self.valid_loader): # for i, (x, y), f in enumerate(self.valid_loader): # uncomment when using dataset with fixation proposals if self.use_gpu: x, y = x.cuda(), y.cuda() x, y = Variable(x), Variable(y) # duplicate M times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses log_pi = [] baselines = [] for t in range(self.num_glimpses): # forward pass through model h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t) #l_t = f[:,t].float() # uncomment when using dataset with fixation proposals # store baselines.append(b_t) log_pi.append(p) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) #log_pi = torch.stack(log_pi).transpose(1, 0) # average log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) baselines = baselines.contiguous().view(self.M, -1, baselines.shape[-1]) baselines = torch.mean(baselines, dim=0) #log_pi = log_pi.contiguous().view( # self.M, -1, log_pi.shape[-1] #) #log_pi = torch.mean(log_pi, dim=0) # calculate reward predicted = torch.max(log_probas, 1)[1] R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action = F.nll_loss(log_probas, y) loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss adjusted_reward = R - baselines.detach() #loss_reinforce = torch.sum(-log_pi*adjusted_reward, dim=1) # only when using RL #loss_reinforce = torch.mean(loss_reinforce, dim=0) # only when using RL # sum up into a hybrid loss #loss = loss_action + loss_baseline + loss_reinforce * 0.01 # only when using RL loss = loss_action # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store losses.update(loss.data.item(), x.size()[0]) accs.update(acc.data.item(), x.size()[0]) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.valid_loader) + i log_value('valid_loss', losses.avg, iteration) log_value('valid_acc', accs.avg, iteration) return losses.avg, accs.avg def test(self): """ Test the model on the held-out test data. This function should only be called at the very end once the model has finished training. """ torch.manual_seed(15) import pandas as pd correct = 0 #corrects = np.zeros((self.num_glimpses)) offset_x = [] offset_y = [] l_ts1 = [] l_ts2 = [] l_ts3 = [] probas1 = [] probas2 = [] probas3 = [] ys = [] corrs = [] # load the best checkpoint self.load_checkpoint(best=self.best, ckpt_file=self.ckpt_file) # for i, (x, y, offset) in enumerate(self.test_loader): for i, (x, y) in enumerate(self.test_loader): if self.use_gpu: x, y = x.cuda(), y.cuda() #x, y = Variable(x, volatile=True), Variable(y) x, y = Variable(x, requires_grad=False), Variable(y) # duplicate 10 times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses l_ts_temp = [] probas_temp = [] for t in range(self.num_glimpses): l_ts_temp.append((l_t + 1) / 2.0) # forward pass through model h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t) # get acc after each glimpse log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) pred = log_probas.data.max(1, keepdim=True)[1] #corrects[t] += pred.eq(y.data.view_as(pred)).cpu().sum() probas_temp.append(log_probas) # offset_x.append(offset[0].cpu().detach()) # offset_y.append(offset[1].cpu().detach()) l_ts1.append(l_ts_temp[0].cpu().detach()) l_ts2.append(l_ts_temp[1].cpu().detach()) l_ts3.append(l_ts_temp[2].cpu().detach()) probas1.append(probas_temp[0].cpu().detach()) probas2.append(probas_temp[1].cpu().detach()) probas3.append(probas_temp[2].cpu().detach()) ys.append(y.cpu().detach()) corrs.append(pred.eq(y.data.view_as(pred)).cpu().detach()) # # last iteration # h_t, l_t, b_t, log_probas, p = self.model( # x, l_t, h_t, last=True # ) # # get acc after each glimpse # log_probas = log_probas.view( # self.M, -1, log_probas.shape[-1] # ) # log_probas = torch.mean(log_probas, dim=0) # pred = log_probas.data.max(1, keepdim=True)[1] # #corrects[t+1] += pred.eq(y.data.view_as(pred)).cpu().sum() correct += pred.eq(y.data.view_as(pred)).cpu().sum() #print(i+1,":",corrects/(i+1)) # log_probas = log_probas.view( # self.M, -1, log_probas.shape[-1] # ) # log_probas = torch.mean(log_probas, dim=0) # pred = log_probas.data.max(1, keepdim=True)[1] # correct += pred.eq(y.data.view_as(pred)).cpu().sum() # if i >= 3: # break # offset_x = torch.unsqueeze(torch.cat(offset_x), dim=1).float() # offset_y = torch.unsqueeze(torch.cat(offset_y), dim=1).float() l_ts1 = torch.cat(l_ts1) l_ts2 = torch.cat(l_ts2) l_ts3 = torch.cat(l_ts3) probas1 = torch.cat(probas1) probas2 = torch.cat(probas2) probas3 = torch.cat(probas3) ys = torch.unsqueeze(torch.cat(ys), dim=1).float() corrs = torch.cat(corrs).float() offset_x = torch.zeros(ys.shape) offset_y = torch.zeros(ys.shape) data = torch.cat([ offset_x, offset_y, l_ts1, l_ts2, l_ts3, probas1, probas2, probas3, ys, corrs ], dim=1) data = pd.DataFrame(data.numpy()) data.to_csv('temp.csv') # perc = (100. * corrects[t+1]) / (self.num_test) perc = (100.0 * correct) / (self.num_test) error = 100 - perc print('[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)'.format( correct, self.num_test, perc, error)) #print((100. * correct) / (self.num_test)) def save_checkpoint(self, state, is_best): """ Save a copy of the model so that it can be loaded at a future date. This function is used when the model is being evaluated on the test data. If this model has reached the best validation accuracy thus far, a seperate file with the suffix `best` is created. """ print("[*] Saving model to {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) torch.save(state, ckpt_path) if is_best: filename = self.model_name + '_model_best.pth.tar' shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename)) def load_checkpoint(self, best=False, ckpt_file=None): """ Load the best copy of a model. This is useful for 2 cases: - Resuming training with the most recent model checkpoint. - Loading the best validation model to evaluate on the test data. Params ------ - best: if set to True, loads the best model. Use this if you want to evaluate your model on the test data. Else, set to False in which case the most recent version of the checkpoint is used. """ if ckpt_file == None: if best: filename = self.model_name + '_model_best.pth.tar' else: filename = self.model_name + '_ckpt.pth.tar' else: filename = ckpt_file ckpt_path = os.path.join(self.ckpt_dir, filename) ckpt = torch.load(ckpt_path) # load variables from checkpoint self.start_epoch = ckpt['epoch'] self.best_valid_acc = ckpt['best_valid_acc'] self.model.load_state_dict(ckpt['model_state']) self.optimizer.load_state_dict(ckpt['optim_state']) if best: print("[*] Loaded {} checkpoint @ epoch {} " "with best valid acc of {:.3f}".format( filename, ckpt['epoch'], ckpt['best_valid_acc'])) else: print("[*] Loaded {} checkpoint @ epoch {}".format( filename, ckpt['epoch'])) def get_loc_reward(self, locs): """ Calculates a negative reward if subsequent glimpses are very close to a previuos glimpse Args: ---- locs: List of locations Returns: ---- reward: A negativ reward signal """ pdist = torch.nn.PairwiseDistance(p=2, ) min_dists = [] # calc distance from glimpse to all glimpses before for i, l_t in enumerate(locs): dists = [] # use max possible distance for first glimpse as no glimpses before if i == 0: if self.use_gpu: dists.append( torch.ones(l_t.shape[0]).unsqueeze(1).cuda() * float("inf")) else: dists.append( torch.ones(l_t.shape[0]).unsqueeze(1) * float("inf")) # get distance to all previous glimpses for l in locs[:i]: dists.append(pdist(l, l_t).unsqueeze(1)) dists = torch.cat(dists, dim=1) dists = torch.min(dists, dim=1).values.unsqueeze(1) min_dists.append(dists) min_dists = torch.cat(min_dists, dim=1) reward = torch.exp(-10 * min_dists) return reward
class Trainer(object): """ Trainer encapsulates all the logic necessary for training the Recurrent Attention Model. All hyperparameters are provided by the user in the config file. """ def __init__(self, config, data_loader): """ Construct a new Trainer instance. Args ---- - config: object containing command line arguments. - data_loader: data iterator """ self.config = config # glimpse network params self.patch_size = config.patch_size self.glimpse_scale = config.glimpse_scale self.num_patches = config.num_patches self.loc_hidden = config.loc_hidden self.glimpse_hidden = config.glimpse_hidden # core network params self.num_glimpses = config.num_glimpses self.hidden_size = config.hidden_size # reinforce params self.std = config.std self.M = config.M # data params if config.is_train: self.train_loader = data_loader[0] self.valid_loader = data_loader[1] self.num_train = len(self.train_loader.sampler.indices) self.num_valid = len(self.valid_loader.sampler.indices) else: self.test_loader = data_loader self.num_test = len(self.test_loader.sampler) self.num_classes = config.num_classes self.num_channels = 1 # self.num_channels = 1 if config.dataset == 'mnist' else 3 # training params self.epochs = config.epochs self.start_epoch = 0 self.momentum = config.momentum self.lr = config.init_lr # misc params self.use_gpu = config.use_gpu self.best = config.best self.ckpt_dir = config.ckpt_dir self.logs_dir = config.logs_dir self.best_valid_acc = 0. self.counter = 0 self.lr_patience = config.lr_patience self.train_patience = config.train_patience self.use_tensorboard = config.use_tensorboard self.resume = config.resume self.print_freq = config.print_freq self.plot_freq = config.plot_freq self.image_size = config.image_size # import pdb; pdb.set_trace() self.model_name = '{}-{}_gnum:{}_gsize:{}x{}_imgsize:{}x{}'.format( config.dataset, config.selected_attrs[0], config.num_glimpses, config.patch_size, config.patch_size, config.image_size, config.image_size) self.model_checkpoints = self.ckpt_dir + '/' + self.model_name + '/' if not os.path.exists(self.model_checkpoints): os.makedirs(self.model_checkpoints) self.plot_dir = './plots/' + self.model_name + '/' if not os.path.exists(self.plot_dir): os.makedirs(self.plot_dir) # configure tensorboard logging if self.use_tensorboard: tensorboard_dir = self.logs_dir + self.model_name print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir)) if not os.path.exists(tensorboard_dir): os.makedirs(tensorboard_dir) configure(tensorboard_dir) # build RAM model self.model = RecurrentAttention( self.patch_size, self.num_patches, self.glimpse_scale, self.num_channels, self.loc_hidden, self.glimpse_hidden, self.std, self.hidden_size, self.num_classes, ) if self.use_gpu: self.model.cuda() print('[*] Number of model parameters: {:,}'.format( sum([p.data.nelement() for p in self.model.parameters()]))) # # initialize optimizer and scheduler # self.optimizer = optim.SGD( # self.model.parameters(), lr=self.lr, momentum=self.momentum, # ) # self.scheduler = ReduceLROnPlateau( # self.optimizer, 'min', patience=self.lr_patience # ) self.optimizer = optim.Adam( self.model.parameters(), lr=3e-4, ) def reset(self): """ Initialize the hidden state of the core network and the location vector. This is called once every time a new minibatch `x` is introduced. """ dtype = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor) h_t = torch.zeros(self.batch_size, self.hidden_size) h_t = Variable(h_t).type(dtype) l_t = torch.Tensor(self.batch_size, 2).uniform_(-1, 1) l_t = Variable(l_t).type(dtype) return h_t, l_t def train(self): """ Train the model on the training set. A checkpoint of the model is saved after each epoch and if the validation accuracy is improved upon, a separate ckpt is created for use on the test set. """ # load the most recent checkpoint if self.resume: # TODO !!!!!!! self.load_checkpoint(best=False) print("\n[*] Train on {} samples, validate on {} samples".format( self.num_train, self.num_valid)) for epoch in range(self.start_epoch, self.epochs): print('\nEpoch: {}/{} - LR: {:.6f}'.format(epoch + 1, self.epochs, self.lr)) # train for 1 epoch train_loss, train_acc = self.train_one_epoch(epoch) # evaluate on validation set valid_loss, valid_acc = self.validate(epoch) # # reduce lr if validation loss plateaus # self.scheduler.step(valid_loss) is_best = valid_acc > self.best_valid_acc msg1 = "train loss: {:.3f} - train acc: {:.3f} " msg2 = "- val loss: {:.3f} - val acc: {:.3f}" if is_best: self.counter = 0 msg2 += " [*]" msg = msg1 + msg2 print(msg.format(train_loss, train_acc, valid_loss, valid_acc)) # check for improvement if not is_best: self.counter += 1 # if self.counter > self.train_patience: # print("[!] No improvement in a while, stopping training.") # return self.best_valid_acc = max(valid_acc, self.best_valid_acc) self.save_checkpoint( { 'epoch': epoch + 1, 'model_state': self.model.state_dict(), 'optim_state': self.optimizer.state_dict(), 'best_valid_acc': self.best_valid_acc, }, is_best) def train_one_epoch(self, epoch): """ Train the model for 1 epoch of the training set. An epoch corresponds to one full pass through the entire training set in successive mini-batches. This is used by train() and should not be called manually. """ batch_time = AverageMeter() losses = AverageMeter() accs = AverageMeter() tic = time.time() with tqdm(total=self.num_train) as pbar: for i, (x, y) in enumerate(self.train_loader): if self.use_gpu: x, y = x.cuda(), y.cuda() try: x, y = Variable(x), Variable(y.squeeze(1)) except: x, y = Variable(x), Variable(y) plot = False if (epoch % self.plot_freq == 0) and (i == 0): plot = True # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # save images imgs = [] imgs.append(x[0:9]) # extract the glimpses locs = [] log_pi = [] log_p_targets = [] kl_divs = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_pi.append(p) baselines.append(b_t) locs.append(l_t[0:9]) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # calculate reward predicted = torch.max(log_probas, 1)[1] R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action = F.nll_loss(log_probas, y) loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss # summed over timesteps and averaged across batch adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store losses.update(loss.item(), x.size()[0]) accs.update(acc.item(), x.size()[0]) # compute gradients and update SGD self.optimizer.zero_grad() loss.backward() self.optimizer.step() # measure elapsed time toc = time.time() batch_time.update(toc - tic) pbar.set_description( ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format( (toc - tic), loss.item(), acc.item()))) pbar.update(self.batch_size) # dump the glimpses and locs if plot: if self.use_gpu: imgs = [g.cpu().data.numpy().squeeze() for g in imgs] locs = [l.cpu().data.numpy() for l in locs] else: imgs = [g.data.numpy().squeeze() for g in imgs] locs = [l.data.numpy() for l in locs] pickle.dump( imgs, open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb")) pickle.dump( locs, open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb")) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.train_loader) + i log_value('train_loss', losses.avg, iteration) log_value('train_acc', accs.avg, iteration) return losses.avg, accs.avg def validate(self, epoch): """ Evaluate the model on the validation set. """ losses = AverageMeter() accs = AverageMeter() for i, (x, y) in enumerate(self.valid_loader): if self.use_gpu: x, y = x.cuda(), y.cuda() try: x, y = Variable(x), Variable(y.squeeze(1)) except: x, y = Variable(x), Variable(y) # duplicate 10 times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # store baselines.append(b_t) log_pi.append(p) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_pi.append(p) baselines.append(b_t) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # average log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) baselines = baselines.contiguous().view(self.M, -1, baselines.shape[-1]) baselines = torch.mean(baselines, dim=0) log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1]) log_pi = torch.mean(log_pi, dim=0) # calculate reward predicted = torch.max(log_probas, 1)[1] R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action = F.nll_loss(log_probas, y) loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store losses.update(loss.item(), x.size()[0]) accs.update(acc.item(), x.size()[0]) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.valid_loader) + i log_value('valid_loss', losses.avg, iteration) log_value('valid_acc', accs.avg, iteration) return losses.avg, accs.avg def test(self): """ Test the model on the held-out test data. This function should only be called at the very end once the model has finished training. """ # load the best checkpoint epoch = 1 f1s = [] accs = [] print("Testing trained model with ", len(self.test_loader), " examples") while (True): try: self.load_checkpoint(epoch=epoch) except: break correct = 0 f1_correct = 0 f1_reported = 0 f1_relevant = 0 for i, (x, y) in enumerate(self.test_loader): with torch.no_grad(): if self.use_gpu: x, y = x.cuda(), y.cuda() try: x, y = Variable(x), Variable(y.squeeze(1)) except: x, y = Variable(x), Variable(y) # duplicate 10 times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) pred = log_probas.data.max(1, keepdim=True)[1] correct += pred.eq(y.data.view_as(pred)).cpu().sum() preds = pred.flatten() total_reported = pred.sum() total_relevant = y.sum() preds[preds == 0] = 2 total_correct = preds.eq(y.cpu()).sum() f1_correct += total_correct f1_reported += total_reported f1_relevant += total_relevant perc = (100. * correct) / (self.num_test) error = 100 - perc precision = float(f1_correct) / float(f1_reported) recall = float(f1_correct) / float(f1_relevant) f1_score = 2 * (precision * recall / (precision + recall)) accuracy = float(correct) / float(self.num_test) print('[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%) : F1 Score - {} \n'. format(correct, self.num_test, perc, error, f1_score)) epoch += 1 f1s.append(f1_score) accs.append(accuracy) fig, ax = plt.subplots() ax.plot(np.arange(len(f1s)), f1s) ax.plot(np.arange(len(accs)), accs) plt.show() def kde(self): epoch = 5 print("plotting kde of trained model with ", len(self.test_loader), " examples") self.load_checkpoint(epoch=epoch) fig, ax = plt.subplots() # for key, value in model_preds[model].items(): # fly_kde = value[fly_idx, :, :2] # t_5_x.append(fly_kde[timestep, 0]) # t_5_y.append(fly_kde[timestep, 1]) img_min = 0 img_max = self.image_size # m1 = np.array(t_5_x) # m2 = np.array(t_5_y) X, Y = np.mgrid[img_min:img_max:100j, img_min:img_max:100j] positions = np.vstack([X.ravel(), Y.ravel()]) all_locations = torch.Tensor([]) for i, (x, y) in enumerate(self.test_loader): with torch.no_grad(): if self.use_gpu: x, y = x.cuda(), y.cuda() try: x, y = Variable(x), Variable(y.squeeze(1)) except: x, y = Variable(x), Variable(y) # duplicate 10 times # x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) all_locations = torch.cat((all_locations, l_t)) coords = denormalize(self.image_size, all_locations) coords = coords + (self.patch_size / 2) values = torch.stack((coords[:, 0], (self.image_size - coords[:, 1]))) kernel = stats.gaussian_kde(values) Z = np.reshape(kernel(positions).T, X.shape) im = ax.imshow(np.rot90(Z), cmap=plt.cm.gist_earth_r, extent=[0, 256, 0, 256]) plt.show() def save_checkpoint(self, state, is_best): """ Save a copy of the model so that it can be loaded at a future date. This function is used when the model is being evaluated on the test data. If this model has reached the best validation accuracy thus far, a seperate file with the suffix `best` is created. """ # print("[*] Saving model to {}".format(self.ckpt_dir)) filename = self.model_name + '_' + str( state['epoch']) + '_ckpt.pth.tar' ckpt_path = os.path.join(self.model_checkpoints, filename) torch.save(state, ckpt_path) if is_best: filename = self.model_name + '_model_best.pth.tar' shutil.copyfile(ckpt_path, os.path.join(self.model_checkpoints, filename)) def load_checkpoint(self, epoch=1): """ Load the best copy of a model. This is useful for 2 cases: - Resuming training with the most recent model checkpoint. - Loading the best validation model to evaluate on the test data. Params ------ - best: if set to True, loads the best model. Use this if you want to evaluate your model on the test data. Else, set to False in which case the most recent version of the checkpoint is used. """ print("[*] Loading model from {}".format(self.model_checkpoints)) filename = self.model_name + '_' + str(epoch) + '_ckpt.pth.tar' # if best: # filename = self.model_name + '_model_best.pth.tar' ckpt_path = os.path.join(self.model_checkpoints, filename) ckpt = torch.load(ckpt_path) # load variables from checkpoint self.start_epoch = ckpt['epoch'] self.best_valid_acc = ckpt['best_valid_acc'] self.model.load_state_dict(ckpt['model_state']) self.optimizer.load_state_dict(ckpt['optim_state']) # if best: # print( # "[*] Loaded {} checkpoint @ epoch {} " # "with best valid acc of {:.3f}".format( # filename, ckpt['epoch'], ckpt['best_valid_acc']) # ) # else: print("[*] Loaded {} checkpoint @ epoch {}".format( filename, ckpt['epoch']))
class Trainer(object): def __init__(self, config, data_loader): self.config = config self.train_loader = data_loader[0] self.valid_loader = data_loader[1] self.num_train = len(self.train_loader.sampler.indices) self.num_valid = len(self.valid_loader.sampler.indices) if self.config.binary: self.num_classes = 2 self.loss = F.binary_cross_entropy_with_logits namestr2 = 'binary' namestr3 = str(config.cat) else: self.num_classes = 8 self.loss = F.cross_entropy namestr2 = 'all' namestr3 = 'nocat' # model params if self.config.semi: namestr1 = 'semi' self.input_dim = self.num_classes + 3 else: namestr1 = 'fully' self.input_dim = 3 self.output_dim = self.num_classes self.mask_rate = config.mask_rate self.pc_size = config.pc_size # training params self.epochs = config.epochs self.start_epoch = 0 self.lr = config.init_lr # misc params self.use_gpu = config.use_gpu self.ckpt_dir = config.ckpt_dir self.best = config.best self.best_mIoU = -10 self.best_acc = 0 self.counter = 0 self.lr_patience = config.lr_patience self.train_patience = config.train_patience self.resume = config.resume self.model_name = 'dsseg_{}_{}_{}_{}_{}'.format( config.init_lr, namestr1, namestr2, namestr3, config.pc_size) # attention parameters # glimpse params # glimpse network params self.num_points_per_pc = config.pc_size self.num_points_per_sample = config.num_points_per_sample self.box_size = config.box_size self.glimpse_scale = config.glimpse_scale self.num_samples = config.num_samples self.loc_hidden = config.loc_hidden self.glimpse_hidden = config.glimpse_hidden # core network params self.num_glimpses = config.num_glimpses self.hidden_size = config.hidden_size # reinforce params self.std = config.std self.M = config.M # # build DS model # self.model = DTanh( # self.input_dim, self.output_dim # ) # build RAM model self.model = RecurrentAttention( self.num_points_per_pc, self.num_points_per_sample, self.num_samples, self.box_size, self.glimpse_scale, self.num_channels, self.loc_hidden, self.glimpse_hidden, self.std, self.hidden_size, self.num_classes, self.use_gpu) if self.use_gpu: self.model.cuda() print('[*] Number of model parameters: {:,}'.format( sum([p.data.nelement() for p in self.model.parameters()]))) self.optimizer = optim.Adam( self.model.parameters(), lr=self.lr, ) def mask_tensor(self, x, rate): """ Masks a percentage of the entries in tensor x randomly """ tensor_len = x.shape[1] if (rate == 0.): return x, np.arange(tensor_len) num_index = int(rate * tensor_len) permute_indices = np.random.RandomState( seed=42).permutation(tensor_len)[:num_index] zero_mask = torch.zeros(x.shape[-1] - 3, dtype=torch.float32) if self.use_gpu: zero_mask = zero_mask.cuda() x[:, permute_indices, 3:] = zero_mask return x, permute_indices def train(self): """ Train the model on the training set. A checkpoint of the model is saved after each epoch and if the validation accuracy is improved upon, a separate ckpt is created for use on the test set. """ # load the most recent checkpoint if self.resume: self.load_checkpoint(best=False) print("\n[*] Train on {} samples, validate on {} samples".format( self.num_train, self.num_valid)) if self.config.binary: for epoch in range(self.start_epoch, self.epochs): print('\nEpoch: {}/{} - LR: {:.6f}'.format( epoch + 1, self.epochs, self.lr)) # train for 1 epoch # train_loss, train_acc = self.train_one_epoch(epoch) train_loss, train_acc, zeros_acc, ones_acc = self.train_one_epoch( epoch) # evaluate on validation set # valid_loss, valid_acc = self.validate(epoch) valid_loss, valid_acc, val_zeros_acc, val_ones_acc = self.validate( epoch) # mIoU = (np.mean(valid_IoUs)) # is_best = mIoU > self.best_mIoU is_best = valid_acc > self.best_acc msg1 = "train loss: {:.3f} - train acc: {:.3f} - zeros acc: {:.3f} - ones acc: {:.3f}\n" # msg2 = "- val loss: {:.3f} - val acc: {:.3f} - val_mIoU: {:.3f}" msg2 = "- val loss: {:.3f} - val acc: {:.3f} - val_zeros acc: {:.3f} - val_ones acc: {:.3f}" if is_best: self.counter = 0 msg2 += " [*]" msg = msg1 + msg2 # print(msg.format(train_loss, train_acc, valid_loss, valid_acc, mIoU)) print( msg.format(train_loss, train_acc, zeros_acc, ones_acc, valid_loss, valid_acc, val_zeros_acc, val_ones_acc)) # check for improvement if not is_best: self.counter += 1 if self.counter > self.train_patience: print("[!] No improvement in a while, stopping training.") return # self.best_mIoU = max(mIoU, self.best_mIoU) self.best_acc = max(valid_acc, self.best_acc) self.save_checkpoint( { 'epoch': epoch + 1, 'model_state': self.model.state_dict(), 'optim_state': self.optimizer.state_dict(), # 'best_valid_mIoU': self.best_mIoU, 'best_acc': self.best_acc }, is_best) else: for epoch in range(self.start_epoch, self.epochs): print('\nEpoch: {}/{} - LR: {:.6f}'.format( epoch + 1, self.epochs, self.lr)) # train for 1 epoch # train_loss, train_acc = self.train_one_epoch(epoch) train_loss, train_acc, rand_acc, maj_acc = self.train_one_epoch( epoch) # evaluate on validation set # valid_loss, valid_acc = self.validate(epoch) valid_loss, valid_acc, val_rand_acc, val_maj_acc = self.validate( epoch) # mIoU = (np.mean(valid_IoUs)) # is_best = mIoU > self.best_mIoU is_best = valid_acc > self.best_acc msg1 = "train loss: {:.3f} - train acc: {:.3f} - rand acc: {:.3f} - maj acc: {:.3f}\n" # msg2 = "- val loss: {:.3f} - val acc: {:.3f} - val_mIoU: {:.3f}" msg2 = "- val loss: {:.3f} - val acc: {:.3f} - val rand acc: {:.3f} - val maj acc: {:.3f}" if is_best: self.counter = 0 msg2 += " [*]" msg = msg1 + msg2 # print(msg.format(train_loss, train_acc, valid_loss, valid_acc, mIoU)) print( msg.format(train_loss, train_acc, rand_acc, maj_acc, valid_loss, valid_acc, val_rand_acc, val_maj_acc)) # check for improvement if not is_best: self.counter += 1 if self.counter > self.train_patience: print("[!] No improvement in a while, stopping training.") return # self.best_mIoU = max(mIoU, self.best_mIoU) self.best_acc = max(valid_acc, self.best_acc) self.save_checkpoint( { 'epoch': epoch + 1, 'model_state': self.model.state_dict(), 'optim_state': self.optimizer.state_dict(), # 'best_valid_mIoU': self.best_mIoU, 'best_acc': self.best_acc }, is_best) def train_one_epoch(self, epoch): """ Train the model for 1 epoch of the training set. An epoch corresponds to one full pass through the entire training set in successive mini-batches. This is used by train() and should not be called manually. """ batch_time = AverageMeter() losses = AverageMeter() accs = AverageMeter() all_zeros = AverageMeter() all_ones = AverageMeter() all_rand = AverageMeter() all_majority = AverageMeter() tic = time.time() with tqdm(total=self.num_train) as pbar: for i, (x, y) in enumerate(self.train_loader): if self.config.binary: x, y = Variable(x).float(), Variable(y).float() else: x, y = Variable(x).float(), Variable(y).long() if self.use_gpu: x, y = x.cuda(), y.cuda() self.batch_size = x.shape[0] x = x.view(self.batch_size, self.pc_size, self.input_dim) # Do the masking of indices to create the semi-supervised learning problem if self.config.semi: x, mask_indices = self.mask_tensor(x, self.mask_rate) out = self.model(x) # TODO: Instead of squeeze, change view to handle tensors of batch_size != 1 # To calculate loss, we retrieve everything from 4th column to end if self.config.semi: pred = out.squeeze()[mask_indices] labels = y.squeeze()[mask_indices] else: pred = out.squeeze() labels = y.squeeze() if self.config.binary: loss = self.loss(pred, labels) # compute accuracy predicted = torch.max(pred, 1)[1] true = torch.max(labels, 1)[1] correct = (predicted == true).float() acc = 100 * (correct.sum() / labels.shape[0]) predicted = torch.zeros(labels.shape[0], dtype=torch.long) if self.use_gpu: predicted = predicted.cuda() correct = (predicted == true).float() acc_zeros = 100 * (correct.sum() / labels.shape[0]) predicted = torch.ones(labels.shape[0], dtype=torch.long) if self.use_gpu: predicted = predicted.cuda() correct = (predicted == true).float() acc_ones = 100 * (correct.sum() / labels.shape[0]) all_zeros.update(acc_zeros.item(), labels.size()[0]) all_ones.update(acc_ones.item(), labels.size()[0]) else: labels = torch.max(labels, 1)[1] loss = self.loss(pred, labels) # compute accuracy predicted = torch.max(pred, 1)[1] true = labels correct = (predicted == true).float() acc = 100 * (correct.sum() / labels.shape[0]) # For the 1-of-8 problem, we use a random tensor as a baseline # as well as a majority class tensor predicted = torch.zeros(labels.shape[0], dtype=torch.long).random_(0, 8) if self.use_gpu: predicted = predicted.cuda() correct = (predicted == true).float() acc_rand = 100 * (correct.sum() / labels.shape[0]) predicted = torch.zeros(labels.shape[0], dtype=torch.long) if self.use_gpu: predicted = predicted.cuda() correct = (predicted == true).float() acc_maj = 100 * (correct.sum() / labels.shape[0]) all_rand.update(acc_rand.item(), labels.size()[0]) all_majority.update(acc_maj.item(), labels.size()[0]) # store losses.update(loss.item(), labels.size()[0]) accs.update(acc.item(), labels.size()[0]) # compute gradients and update SGD self.optimizer.zero_grad() loss.backward() self.optimizer.step() # measure elapsed time toc = time.time() batch_time.update(toc - tic) pbar.set_description( ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format( (toc - tic), loss.item(), acc.item()))) pbar.update(self.batch_size) if self.config.binary: return losses.avg, accs.avg, all_zeros.avg, all_ones.avg else: return losses.avg, accs.avg, all_rand.avg, all_majority.avg def validate(self, epoch): """ Evaluate the model on the validation set. """ losses = AverageMeter() accs = AverageMeter() all_zeros = AverageMeter() all_ones = AverageMeter() all_rand = AverageMeter() all_majority = AverageMeter() for i, (x, y) in enumerate(self.valid_loader): if self.config.binary: x, y = Variable(x).float(), Variable(y).float() else: x, y = Variable(x).float(), Variable(y).long() if self.use_gpu: x, y = x.cuda(), y.cuda() self.batch_size = x.shape[0] x = x.view(self.batch_size, self.pc_size, self.input_dim) if self.config.semi: # Do the masking of indices to create the semi-supervised learning problem x, mask_indices = self.mask_tensor(x, self.mask_rate) out = self.model(x) # TODO: Instead of squeeze, change view to handle tensors of batch_size != 1 # To calculate loss, we retrieve everything from 4th column to end if self.config.semi: pred = out.squeeze()[mask_indices] labels = y.squeeze()[mask_indices] else: pred = out.squeeze() labels = y.squeeze() if self.config.binary: loss = self.loss(pred, labels) # compute accuracy predicted = torch.max(pred, 1)[1] true = torch.max(labels, 1)[1] correct = (predicted == true).float() acc = 100 * (correct.sum() / labels.shape[0]) predicted = torch.zeros(labels.shape[0], dtype=torch.long) if self.use_gpu: predicted = predicted.cuda() correct = (predicted == true).float() acc_zeros = 100 * (correct.sum() / labels.shape[0]) predicted = torch.ones(labels.shape[0], dtype=torch.long) if self.use_gpu: predicted = predicted.cuda() correct = (predicted == true).float() acc_ones = 100 * (correct.sum() / labels.shape[0]) all_zeros.update(acc_zeros.item(), labels.size()[0]) all_ones.update(acc_ones.item(), labels.size()[0]) else: labels = torch.max(labels, 1)[1] loss = self.loss(pred, labels) # compute accuracy predicted = torch.max(pred, 1)[1] true = labels correct = (predicted == true).float() acc = 100 * (correct.sum() / labels.shape[0]) # For the 1-of-8 problem, we use a random tensor as a baseline # as well as a majority class tensor predicted = torch.zeros(labels.shape[0], dtype=torch.long).random_(0, 8) if self.use_gpu: predicted = predicted.cuda() correct = (predicted == true).float() acc_rand = 100 * (correct.sum() / labels.shape[0]) predicted = torch.zeros(labels.shape[0], dtype=torch.long) if self.use_gpu: predicted = predicted.cuda() correct = (predicted == true).float() acc_maj = 100 * (correct.sum() / labels.shape[0]) all_rand.update(acc_rand.item(), labels.size()[0]) all_majority.update(acc_maj.item(), labels.size()[0]) # store losses.update(loss.item(), labels.size()[0]) accs.update(acc.item(), labels.size()[0]) if self.config.binary: return losses.avg, accs.avg, all_zeros.avg, all_ones.avg else: return losses.avg, accs.avg, all_rand.avg, all_majority.avg def test(self): total_acc = 0 total_zeros = 0 total_ones = 0 total_rand = 0 total_majority = 0 total_num_points = 0 # load the best checkpoint self.load_checkpoint(best=self.best) for i, (x, y) in enumerate(self.valid_loader): if self.config.binary: x, y = Variable(x).float(), Variable(y).float() else: x, y = Variable(x).float(), Variable(y).long() if self.use_gpu: x, y = x.cuda(), y.cuda() # initialize location vector and hidden state self.batch_size = x.shape[0] x = x.view(self.batch_size, self.pc_size, self.input_dim) # Do the masking of indices to create the semi-supervised learning problem if self.config.semi: x, mask_indices = self.mask_tensor(x, self.mask_rate) out = self.model(x) # TODO: Instead of squeeze, change view to handle tensors of batch_size != 1 # To calculate loss, we retrieve everything from 4th column to end if self.config.semi: pred = out.squeeze()[mask_indices] labels = y.squeeze()[mask_indices] total_num_points += labels.shape[0] else: pred = out.squeeze() labels = y.squeeze() total_num_points += labels.shape[0] # compute accuracy predicted = torch.max(pred, 1)[1] true = torch.max(labels, 1)[1] correct = (predicted == true).float() total_acc += correct.sum() if self.config.binary: predicted = torch.zeros(labels.shape[0], dtype=torch.long) if self.use_gpu: predicted = predicted.cuda() correct = (predicted == true).float() total_zeros += correct.sum() predicted = torch.ones(labels.shape[0], dtype=torch.long) if self.use_gpu: predicted = predicted.cuda() correct = (predicted == true).float() total_ones += correct.sum() else: # For the 1-of-8 problem, we use a random tensor as a baseline # as well as a majority class baseline predicted = torch.zeros(labels.shape[0], dtype=torch.long).random_(0, 8) if self.use_gpu: predicted = predicted.cuda() correct = (predicted == true).float() total_rand += correct.sum() predicted = torch.zeros(labels.shape[0], dtype=torch.long) if self.use_gpu: predicted = predicted.cuda() correct = (predicted == true).float() total_majority += correct.sum() print("Done with %.3f%%" % ((i + 1) / self.num_valid * 100.)) print() if self.config.binary: msg = "Final Accuracy: {:.3f} - Background Baseline: {:.3f} - Foreground Baseline: {:.3f}\n" print( msg.format(total_acc / total_num_points, total_zeros / total_num_points, total_ones / total_num_points)) else: msg = "Final Accuracy: {:.3f} - Random Baseline: {:.3f} - Majority Class Baseline: {:.3f}\n" print( msg.format(total_acc / total_num_points, total_rand / total_num_points, total_majority / total_num_points)) def save_checkpoint(self, state, is_best): """ Save a copy of the model so that it can be loaded at a future date. This function is used when the model is being evaluated on the test data. If this model has reached the best validation accuracy thus far, a seperate file with the suffix `best` is created. """ # print("[*] Saving model to {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) torch.save(state, ckpt_path) if is_best: filename = self.model_name + '_model_best.pth.tar' shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename)) def load_checkpoint(self, best=False): """ Load the best copy of a model. This is useful for 2 cases: - Resuming training with the most recent model checkpoint. - Loading the best validation model to evaluate on the test data. Params ------ - best: if set to True, loads the best model. Use this if you want to evaluate your model on the test data. Else, set to False in which case the most recent version of the checkpoint is used. """ print("[*] Loading model from {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' if best: filename = self.model_name + '_model_best.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) ckpt = torch.load(ckpt_path) # load variables from checkpoint self.start_epoch = ckpt['epoch'] # self.best_mIoU = ckpt['best_valid_mIoU'] self.model.load_state_dict(ckpt['model_state']) self.optimizer.load_state_dict(ckpt['optim_state']) print("Successfully loaded model...") if best: print("[*] Loaded {} checkpoint @ epoch {} " "with best valid acc of {:.3f}".format( filename, ckpt['epoch'], ckpt['best_acc'])) else: print("[*] Loaded {} checkpoint @ epoch {}".format( filename, ckpt['epoch']))
class Trainer: """A Recurrent Attention Model trainer. All hyperparameters are provided by the user in the config file. """ def __init__(self, config, data_loader): """ Construct a new Trainer instance. Args: config: object containing command line arguments. data_loader: A data iterator. """ self.config = config if config.use_gpu and torch.cuda.is_available(): self.device = torch.device("cuda") else: self.device = torch.device("cpu") # glimpse network params self.patch_size = config.patch_size self.glimpse_scale = config.glimpse_scale self.num_patches = config.num_patches self.loc_hidden = config.loc_hidden self.glimpse_hidden = config.glimpse_hidden # core network params self.num_glimpses = config.num_glimpses self.hidden_size = config.hidden_size # reinforce params self.std = config.std self.M = config.M # data params if config.is_train: self.train_loader = data_loader[0] self.valid_loader = data_loader[1] self.num_train = len(self.train_loader.sampler.indices) self.num_valid = len(self.valid_loader.sampler.indices) else: self.test_loader = data_loader self.num_test = len(self.test_loader.dataset) self.num_classes = 10 self.num_channels = 1 # training params self.epochs = config.epochs self.start_epoch = 0 self.momentum = config.momentum self.lr = config.init_lr # misc params self.best = config.best self.ckpt_dir = config.ckpt_dir self.logs_dir = config.logs_dir self.best_valid_acc = 0.0 self.counter = 0 self.lr_patience = config.lr_patience self.train_patience = config.train_patience self.use_tensorboard = config.use_tensorboard self.resume = config.resume self.print_freq = config.print_freq self.plot_freq = config.plot_freq self.model_name = "ram_{}_{}x{}_{}".format( config.num_glimpses, config.patch_size, config.patch_size, config.glimpse_scale, ) self.plot_dir = "./plots/" + self.model_name + "/" if not os.path.exists(self.plot_dir): os.makedirs(self.plot_dir) # configure tensorboard logging if self.use_tensorboard: tensorboard_dir = self.logs_dir + self.model_name print("[*] Saving tensorboard logs to {}".format(tensorboard_dir)) if not os.path.exists(tensorboard_dir): os.makedirs(tensorboard_dir) configure(tensorboard_dir) # build RAM model self.model = RecurrentAttention( self.patch_size, self.num_patches, self.glimpse_scale, self.num_channels, self.loc_hidden, self.glimpse_hidden, self.std, self.hidden_size, self.num_classes, ) self.model.to(self.device) # initialize optimizer and scheduler self.optimizer = torch.optim.Adam( self.model.parameters(), lr=self.config.init_lr ) self.scheduler = ReduceLROnPlateau( self.optimizer, "min", patience=self.lr_patience ) def reset(self): h_t = torch.zeros( self.batch_size, self.hidden_size, dtype=torch.float, device=self.device, requires_grad=True, ) l_t = torch.FloatTensor(self.batch_size, 2).uniform_(-1, 1).to(self.device) l_t.requires_grad = True return h_t, l_t def train(self): """Train the model on the training set. A checkpoint of the model is saved after each epoch and if the validation accuracy is improved upon, a separate ckpt is created for use on the test set. """ # load the most recent checkpoint if self.resume: self.load_checkpoint(best=False) print( "\n[*] Train on {} samples, validate on {} samples".format( self.num_train, self.num_valid ) ) for epoch in range(self.start_epoch, self.epochs): print( "\nEpoch: {}/{} - LR: {:.6f}".format( epoch + 1, self.epochs, self.optimizer.param_groups[0]["lr"] ) ) # train for 1 epoch train_loss, train_acc = self.train_one_epoch(epoch) # evaluate on validation set valid_loss, valid_acc = self.validate(epoch) # # reduce lr if validation loss plateaus self.scheduler.step(-valid_acc) is_best = valid_acc > self.best_valid_acc msg1 = "train loss: {:.3f} - train acc: {:.3f} " msg2 = "- val loss: {:.3f} - val acc: {:.3f} - val err: {:.3f}" if is_best: self.counter = 0 msg2 += " [*]" msg = msg1 + msg2 print( msg.format( train_loss, train_acc, valid_loss, valid_acc, 100 - valid_acc ) ) # check for improvement if not is_best: self.counter += 1 if self.counter > self.train_patience: print("[!] No improvement in a while, stopping training.") return self.best_valid_acc = max(valid_acc, self.best_valid_acc) self.save_checkpoint( { "epoch": epoch + 1, "model_state": self.model.state_dict(), "optim_state": self.optimizer.state_dict(), "best_valid_acc": self.best_valid_acc, }, is_best, ) def train_one_epoch(self, epoch): """ Train the model for 1 epoch of the training set. An epoch corresponds to one full pass through the entire training set in successive mini-batches. This is used by train() and should not be called manually. """ self.model.train() batch_time = AverageMeter() losses = AverageMeter() accs = AverageMeter() tic = time.time() with tqdm(total=self.num_train) as pbar: for i, (x, y) in enumerate(self.train_loader): self.optimizer.zero_grad() x, y = x.to(self.device), y.to(self.device) plot = False if (epoch % self.plot_freq == 0) and (i == 0): plot = True # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # save images imgs = [] imgs.append(x[0:9]) # extract the glimpses locs = [] log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_pi.append(p) baselines.append(b_t) locs.append(l_t[0:9]) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # calculate reward predicted = torch.max(log_probas, 1)[1] R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action = F.nll_loss(log_probas, y) loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss # summed over timesteps and averaged across batch adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce * 0.01 # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store losses.update(loss.item(), x.size()[0]) accs.update(acc.item(), x.size()[0]) # compute gradients and update SGD loss.backward() self.optimizer.step() # measure elapsed time toc = time.time() batch_time.update(toc - tic) pbar.set_description( ( "{:.1f}s - loss: {:.3f} - acc: {:.3f}".format( (toc - tic), loss.item(), acc.item() ) ) ) pbar.update(self.batch_size) # dump the glimpses and locs if plot: imgs = [g.cpu().data.numpy().squeeze() for g in imgs] locs = [l.cpu().data.numpy() for l in locs] pickle.dump( imgs, open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb") ) pickle.dump( locs, open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb") ) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.train_loader) + i log_value("train_loss", losses.avg, iteration) log_value("train_acc", accs.avg, iteration) return losses.avg, accs.avg @torch.no_grad() def validate(self, epoch): """Evaluate the RAM model on the validation set. """ losses = AverageMeter() accs = AverageMeter() for i, (x, y) in enumerate(self.valid_loader): x, y = x.to(self.device), y.to(self.device) # duplicate M times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # store baselines.append(b_t) log_pi.append(p) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_pi.append(p) baselines.append(b_t) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # average log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) baselines = baselines.contiguous().view(self.M, -1, baselines.shape[-1]) baselines = torch.mean(baselines, dim=0) log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1]) log_pi = torch.mean(log_pi, dim=0) # calculate reward predicted = torch.max(log_probas, 1)[1] R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action = F.nll_loss(log_probas, y) loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce * 0.01 # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store losses.update(loss.item(), x.size()[0]) accs.update(acc.item(), x.size()[0]) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.valid_loader) + i log_value("valid_loss", losses.avg, iteration) log_value("valid_acc", accs.avg, iteration) return losses.avg, accs.avg @torch.no_grad() def test(self): """Test the RAM model. This function should only be called at the very end once the model has finished training. """ correct = 0 # load the best checkpoint self.load_checkpoint(best=self.best) for i, (x, y) in enumerate(self.test_loader): x, y = x.to(self.device), y.to(self.device) # duplicate M times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) pred = log_probas.data.max(1, keepdim=True)[1] correct += pred.eq(y.data.view_as(pred)).cpu().sum() perc = (100.0 * correct) / (self.num_test) error = 100 - perc print( "[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)".format( correct, self.num_test, perc, error ) ) def save_checkpoint(self, state, is_best): """Saves a checkpoint of the model. If this model has reached the best validation accuracy thus far, a seperate file with the suffix `best` is created. """ filename = self.model_name + "_ckpt.pth.tar" ckpt_path = os.path.join(self.ckpt_dir, filename) torch.save(state, ckpt_path) if is_best: filename = self.model_name + "_model_best.pth.tar" shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename)) def load_checkpoint(self, best=False): """Load the best copy of a model. This is useful for 2 cases: - Resuming training with the most recent model checkpoint. - Loading the best validation model to evaluate on the test data. Args: best: if set to True, loads the best model. Use this if you want to evaluate your model on the test data. Else, set to False in which case the most recent version of the checkpoint is used. """ print("[*] Loading model from {}".format(self.ckpt_dir)) filename = self.model_name + "_ckpt.pth.tar" if best: filename = self.model_name + "_model_best.pth.tar" ckpt_path = os.path.join(self.ckpt_dir, filename) ckpt = torch.load(ckpt_path) # load variables from checkpoint self.start_epoch = ckpt["epoch"] self.best_valid_acc = ckpt["best_valid_acc"] self.model.load_state_dict(ckpt["model_state"]) self.optimizer.load_state_dict(ckpt["optim_state"]) if best: print( "[*] Loaded {} checkpoint @ epoch {} " "with best valid acc of {:.3f}".format( filename, ckpt["epoch"], ckpt["best_valid_acc"] ) ) else: print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt["epoch"]))