def train(self): """The function for the pre-train phase.""" # Set the pretrain log trlog = {} trlog['args'] = vars(self.args) trlog['train_loss'] = [] trlog['val_loss'] = [] trlog['train_acc'] = [] trlog['val_acc'] = [] trlog['max_acc'] = 0.0 trlog['max_acc_epoch'] = 0 # Set the timer timer = Timer() # Set global count to zero global_count = 0 # Set tensorboardX writer = SummaryWriter(comment=self.args.save_path) # Start pretrain for epoch in range(1, self.args.pre_max_epoch + 1): # Set the model to train mode print('Epoch {}'.format(epoch)) self.model.train() self.model.mode = 'pre' # Set averager classes to record training losses and accuracies train_loss_averager = Averager() train_acc_averager = Averager() # Using tqdm to read samples from train loader tqdm_gen = tqdm.tqdm(self.train_loader) #for i, batch in enumerate(self.train_loader): for i, batch in enumerate(tqdm_gen, 1): # Update global count number global_count = global_count + 1 if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] label = batch[1] if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) logits = self.model(data) loss = F.cross_entropy(logits, label) # Calculate train accuracy acc = count_acc(logits, label) # Write the tensorboardX records writer.add_scalar('data/loss', float(loss), global_count) writer.add_scalar('data/acc', float(acc), global_count) # Print loss and accuracy for this step train_loss_averager.add(loss.item()) train_acc_averager.add(acc) # Loss backwards and optimizer updates self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update the averagers train_loss_averager = train_loss_averager.item() train_acc_averager = train_acc_averager.item() # start the original evaluation self.model.eval() self.model.mode = 'origval' _, valid_results = self.val_orig(self.valset.X_val, self.valset.y_val) print('validation accuracy ', valid_results[0]) # Start validation for this epoch, set model to eval mode self.model.eval() self.model.mode = 'preval' # Set averager classes to record validation losses and accuracies val_loss_averager = Averager() val_acc_averager = Averager() # Generate the labels for test label = torch.arange(self.args.way).repeat(self.args.val_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) label_shot = torch.arange(self.args.way).repeat(self.args.shot) if torch.cuda.is_available(): label_shot = label_shot.type(torch.cuda.LongTensor) else: label_shot = label_shot.type(torch.LongTensor) # Run meta-validation for i, batch in enumerate(self.val_loader, 1): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] #data=data.float() p = self.args.shot * self.args.way data_shot, data_query = data[:p], data[p:] logits = self.model((data_shot, label_shot, data_query)) loss = F.cross_entropy(logits, label) acc = count_acc(logits, label) val_loss_averager.add(loss.item()) val_acc_averager.add(acc) # Update validation averagers val_loss_averager = val_loss_averager.item() val_acc_averager = val_acc_averager.item() # Write the tensorboardX records writer.add_scalar('data/val_loss', float(val_loss_averager), epoch) writer.add_scalar('data/val_acc', float(val_acc_averager), epoch) # Update best saved model if val_acc_averager > trlog['max_acc']: trlog['max_acc'] = val_acc_averager trlog['max_acc_epoch'] = epoch self.save_model('max_acc') # Save model every 10 epochs if epoch % 10 == 0: self.save_model('epoch' + str(epoch)) # Update the logs trlog['train_loss'].append(train_loss_averager) trlog['train_acc'].append(train_acc_averager) trlog['val_loss'].append(val_loss_averager) trlog['val_acc'].append(val_acc_averager) # Save log torch.save(trlog, osp.join(self.args.save_path, 'trlog')) if epoch % 10 == 0: print('Running Time: {}, Estimated Time: {}'.format( timer.measure(), timer.measure(epoch / self.args.max_epoch))) writer.close()
def train(self): """The function for the meta-train phase.""" # Set the meta-train log trlog = {} trlog['args'] = vars(self.args) trlog['train_loss'] = [] trlog['val_loss'] = [] trlog['train_acc'] = [] trlog['val_acc'] = [] trlog['train_iou'] = [] trlog['val_iou'] = [] trlog['max_iou'] = 0.0 trlog['max_iou_epoch'] = 0 # Set the timer timer = Timer() # Set global count to zero global_count = 0 # Set tensorboardX writer = SummaryWriter(comment=self.args.save_path) # Start meta-train for epoch in range(1, self.args.max_epoch + 1): # Update learning rate self.lr_scheduler.step() # Set the model to train mode self.model.train() # Set averager classes to record training losses and accuracies train_loss_averager = Averager() train_acc_averager = Averager() train_iou_averager = Averager() # Using tqdm to read samples from train loader tqdm_gen = tqdm.tqdm(self.train_loader) self._reset_metrics() for i, batch in enumerate(tqdm_gen, 1): # Update global count number global_count = global_count + 1 if torch.cuda.is_available(): data, labels,_ = [_.cuda() for _ in batch] else: data = batch[0] labels=batch[1] p = self.args.way*self.args.shot data_shot, data_query = data[:p], data[p:] label_shot,label=labels[:p],labels[p:] # Output logits for model par=data_shot, label_shot, data_query logits = self.model(par) # Calculate meta-train loss #loss = self.FL(logits, label) + self.CD(logits,label) + self.LS(logits,label) loss = self.CD(logits,label) # Calculate meta-train accuracy self._reset_metrics() seg_metrics = eval_metrics(logits, label, self.args.way) self._update_seg_metrics(*seg_metrics) pixAcc, mIoU, _ = self._get_seg_metrics(self.args.way).values() # Add loss and accuracy for the averagers train_loss_averager.add(loss.item()) train_acc_averager.add(pixAcc) train_iou_averager.add(mIoU) # Print loss and accuracy till this step tqdm_gen.set_description('Epoch {}, Loss={:.4f} Acc={:.4f} IoU={:.4f}'.format(epoch, train_loss_averager.item(), train_acc_averager.item()*100.0,train_iou_averager.item())) # Loss backwards and optimizer updates self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update the averagers train_loss_averager = train_loss_averager.item() train_acc_averager = train_acc_averager.item() train_iou_averager = train_iou_averager.item() writer.add_scalar('data/train_loss (Meta)', float(train_loss_averager), epoch) writer.add_scalar('data/train_acc (Meta)', float(train_acc_averager)*100.0, epoch) writer.add_scalar('data/train_iou (Meta)', float(train_iou_averager), epoch) # Start validation for this epoch, set model to eval mode self.model.eval() # Set averager classes to record validation losses and accuracies val_loss_averager = Averager() val_acc_averager = Averager() val_iou_averager = Averager() # Print previous information if epoch % 1 == 0: print('Best Val Epoch {}, Best Val IoU={:.4f}'.format(trlog['max_iou_epoch'], trlog['max_iou'])) # Run meta for i, batch in enumerate(self.val_loader, 1): if torch.cuda.is_available(): data, labels,_ = [_.cuda() for _ in batch] else: data = batch[0] labels=batch[1] p = self.args.way* self.args.shot data_shot, data_query = data[:p], data[p:] label_shot,label=labels[:p],labels[p:] par=data_shot, label_shot, data_query logits = self.model(par) # Calculate meta val loss #loss = self.FL(logits, label) + self.CD(logits,label) + self.LS(logits,label) loss = self.CD(logits,label) # Calculate meta-val accuracy self._reset_metrics() seg_metrics = eval_metrics(logits, label, self.args.way) self._update_seg_metrics(*seg_metrics) pixAcc, mIoU, _ = self._get_seg_metrics(self.args.way).values() val_loss_averager.add(loss.item()) val_acc_averager.add(pixAcc) val_iou_averager.add(mIoU) # Update validation averagers val_loss_averager = val_loss_averager.item() val_acc_averager = val_acc_averager.item() val_iou_averager = val_iou_averager.item() # Write the tensorboardX records writer.add_scalar('data/val_loss (Meta)', float(val_loss_averager), epoch) writer.add_scalar('data/val_acc (Meta)', float(val_acc_averager)*100.0, epoch) writer.add_scalar('data/val_iou (Meta)', float(val_iou_averager), epoch) # Print loss and accuracy for this epoch print('Epoch {}, Val, Loss={:.4f} Acc={:.4f} IoU={:.4f}'.format(epoch, val_loss_averager, val_acc_averager*100.0,val_iou_averager)) # Update best saved model if val_iou_averager > trlog['max_iou']: trlog['max_iou'] = val_iou_averager trlog['max_iou_epoch'] = epoch self.save_model('max_iou') # Save model every 10 epochs if epoch % 10 == 0: self.save_model('epoch'+str(epoch)) # Update the logs trlog['train_loss'].append(train_loss_averager) trlog['train_acc'].append(train_acc_averager) trlog['val_loss'].append(val_loss_averager) trlog['val_acc'].append(val_acc_averager) trlog['train_iou'].append(train_iou_averager) trlog['val_iou'].append(val_iou_averager) # Save log torch.save(trlog, osp.join(self.args.save_path, 'trlog')) if epoch % 1 == 0: print('Running Time: {}, Estimated Time: {}'.format(timer.measure(), timer.measure(epoch / self.args.max_epoch))) writer.close()
def eval(self): """The function for the meta-evaluate (test) phase.""" # Load the logs trlog = torch.load(osp.join(self.args.save_path, 'trlog')) # Load meta-test set self.test_set = mDataset('test', self.args) self.sampler = CategoriesSampler(self.test_set.labeln, self.args.num_batch, self.args.way, self.args.teshot + self.args.test_query, self.args.teshot) self.loader = DataLoader(dataset=self.test_set, batch_sampler=self.sampler, num_workers=8, pin_memory=True) #self.loader = DataLoader(dataset=self.test_set,batch_size=10, shuffle=False, num_workers=8, pin_memory=True) # Set test accuracy recorder #test_acc_record = np.zeros((600,)) # Load model for meta-test phase if self.args.eval_weights is not None: self.model.load_state_dict(torch.load(self.args.eval_weights)['params']) else: self.model.load_state_dict(torch.load(osp.join(self.args.save_path, 'max_iou' + '.pth'))['params']) # Set model to eval mode self.model.eval() # Set accuracy averager ave_acc = Averager() # Start meta-test self._reset_metrics() count=1 for i, batch in enumerate(self.loader, 1): if torch.cuda.is_available(): data, labels,_ = [_.cuda() for _ in batch] else: data = batch[0] labels=batch[1] p = self.args.teshot*self.args.way data_shot, data_query = data[:p], data[p:] label_shot,label=labels[:p],labels[p:] logits = self.model((data_shot, label_shot, data_query)) seg_metrics = eval_metrics(logits, label, self.args.way) self._update_seg_metrics(*seg_metrics) pixAcc, mIoU, _ = self._get_seg_metrics(self.args.way).values() ave_acc.add(pixAcc) #test_acc_record[i-1] = acc #if i % 100 == 0: #print('batch {}: {Average Accuracy:.2f}({Pixel Accuracy:.2f} {IoU :.2f} )'.format(i, ave_acc.item() * 100.0, pixAcc * 100.0,mIoU)) #Saving Test Image, Ground Truth Image and Predicted Image for j in range(len(data_query)): x1 = data_query[j].detach().cpu() y1 = label[j].detach().cpu() z1 = logits[j].detach().cpu() x = transforms.ToPILImage()(x1).convert("RGB") y = transforms.ToPILImage()(y1 /(1.0*(self.args.way-1))).convert("LA") im = torch.tensor(np.argmax(np.array(z1),axis=0)/(1.0*(self.args.way-1))) im = im.type(torch.FloatTensor) z = transforms.ToPILImage()(im).convert("LA") px=self.args.save_image_dir+str(count)+'a.jpg' py=self.args.save_image_dir+str(count)+'b.png' pz=self.args.save_image_dir+str(count)+'c.png' x.save(px) y.save(py) z.save(pz) count=count+1
def eval(self, gradcam=False, rise=False, test_on_val=False): """The function for the meta-eval phase.""" # Load the logs if os.path.exists(osp.join(self.args.save_path, 'trlog')): trlog = torch.load(osp.join(self.args.save_path, 'trlog')) else: trlog = None torch.manual_seed(1) np.random.seed(1) # Load meta-test set test_set = Dataset('val' if test_on_val else 'test', self.args) sampler = CategoriesSampler(test_set.label, 600, self.args.way, self.args.shot + self.args.val_query) loader = DataLoader(test_set, batch_sampler=sampler, num_workers=8, pin_memory=True) # Set test accuracy recorder test_acc_record = np.zeros((600, )) # Load model for meta-test phase if self.args.eval_weights is not None: weights = self.addOrRemoveModule( self.model, torch.load(self.args.eval_weights)['params']) self.model.load_state_dict(weights) else: self.model.load_state_dict( torch.load(osp.join(self.args.save_path, 'max_acc' + '.pth'))['params']) # Set model to eval mode self.model.eval() # Set accuracy averager ave_acc = Averager() # Generate labels label = torch.arange(self.args.way).repeat(self.args.val_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) label_shot = torch.arange(self.args.way).repeat(self.args.shot) if torch.cuda.is_available(): label_shot = label_shot.type(torch.cuda.LongTensor) else: label_shot = label_shot.type(torch.LongTensor) if gradcam: self.model.layer3 = self.model.encoder.layer3 model_dict = dict(type="resnet", arch=self.model, layer_name='layer3') grad_cam = GradCAM(model_dict, True) grad_cam_pp = GradCAMpp(model_dict, True) self.model.features = self.model.encoder guided = GuidedBackprop(self.model) if rise: self.model.layer3 = self.model.encoder.layer3 score_mod = ScoreCam(self.model) # Start meta-test for i, batch in enumerate(loader, 1): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] k = self.args.way * self.args.shot data_shot, data_query = data[:k], data[k:] if i % 5 == 0: suff = "_val" if test_on_val else "" if self.args.rep_vec or self.args.cross_att: print('batch {}: {:.2f}({:.2f})'.format( i, ave_acc.item() * 100, acc * 100)) if self.args.cross_att: label_one_hot = self.one_hot(label).to(label.device) _, _, logits, simMapQuer, simMapShot, normQuer, normShot = self.model( (data_shot, label_shot, data_query), ytest=label_one_hot, retSimMap=True) else: logits, simMapQuer, simMapShot, normQuer, normShot, fast_weights = self.model( (data_shot, label_shot, data_query), retSimMap=True) torch.save( simMapQuer, "../results/{}/{}_simMapQuer{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) torch.save( simMapShot, "../results/{}/{}_simMapShot{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) torch.save( data_query, "../results/{}/{}_dataQuer{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) torch.save( data_shot, "../results/{}/{}_dataShot{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) torch.save( normQuer, "../results/{}/{}_normQuer{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) torch.save( normShot, "../results/{}/{}_normShot{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) else: logits, normQuer, normShot, fast_weights = self.model( (data_shot, label_shot, data_query), retFastW=True, retNorm=True) torch.save( normQuer, "../results/{}/{}_normQuer{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) torch.save( normShot, "../results/{}/{}_normShot{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) if gradcam: print("Saving gradmaps", i) allMasks, allMasks_pp, allMaps = [], [], [] for l in range(len(data_query)): allMasks.append( grad_cam(data_query[l:l + 1], fast_weights, None)) allMasks_pp.append( grad_cam_pp(data_query[l:l + 1], fast_weights, None)) allMaps.append( guided.generate_gradients(data_query[l:l + 1], fast_weights)) allMasks = torch.cat(allMasks, dim=0) allMasks_pp = torch.cat(allMasks_pp, dim=0) allMaps = torch.cat(allMaps, dim=0) torch.save( allMasks, "../results/{}/{}_gradcamQuer{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) torch.save( allMasks_pp, "../results/{}/{}_gradcamppQuer{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) torch.save( allMaps, "../results/{}/{}_guidedQuer{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) if rise: print("Saving risemaps", i) allScore = [] for l in range(len(data_query)): allScore.append( score_mod(data_query[l:l + 1], fast_weights)) else: if self.args.cross_att: label_one_hot = self.one_hot(label).to(label.device) _, _, logits = self.model( (data_shot, label_shot, data_query), ytest=label_one_hot) else: logits = self.model((data_shot, label_shot, data_query)) acc = count_acc(logits, label) ave_acc.add(acc) test_acc_record[i - 1] = acc # Calculate the confidence interval, update the logs m, pm = compute_confidence_interval(test_acc_record) if trlog is not None: print('Val Best Epoch {}, Acc {:.4f}, Test Acc {:.4f}'.format( trlog['max_acc_epoch'], trlog['max_acc'], ave_acc.item())) print('Test Acc {:.4f} + {:.4f}'.format(m, pm)) return m
def train(self, trial): """The function for the meta-train phase.""" # Set the meta-train log trlog = {} trlog['args'] = vars(self.args) trlog['train_loss'] = [] trlog['val_loss'] = [] trlog['train_acc'] = [] trlog['val_acc'] = [] trlog['max_acc'] = 0.0 trlog['max_acc_epoch'] = 0 # Set the timer timer = Timer() # Set global count to zero global_count = 0 # Set tensorboardX writer = SummaryWriter(comment=self.args.save_path) # Generate the labels for train set of the episodes label_shot = torch.arange(self.args.way).repeat(self.args.shot) if torch.cuda.is_available(): label_shot = label_shot.type(torch.cuda.LongTensor) else: label_shot = label_shot.type(torch.LongTensor) worstClasses = [] # Start meta-train for epoch in range(1, self.args.max_epoch + 1): # Update learning rate self.lr_scheduler.step() # Set the model to train mode self.model.train() # Set averager classes to record training losses and accuracies train_loss_averager = Averager() train_acc_averager = Averager() # Generate the labels for test set of the episodes during meta-train updates label = torch.arange(self.args.way).repeat(self.args.train_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) # Using tqdm to read samples from train loader tqdm_gen = tqdm.tqdm(self.train_loader) for i, batch in enumerate(tqdm_gen, 1): # Update global count number global_count = global_count + 1 if torch.cuda.is_available(): data, targ = [_.cuda() for _ in batch] else: data, targ = batch p = self.args.shot * self.args.way data_shot, data_query = data[:p], data[p:] # Output logits for model if self.args.cross_att: label_one_hot = self.one_hot(label).to(label.device) ytest, cls_scores, logits = self.model( (data_shot, label_shot, data_query), ytest=label_one_hot) pids = label_shot loss = self.crossAttLoss(ytest, cls_scores, label, pids) logits = logits[0] else: logits = self.model((data_shot, label_shot, data_query)) # Calculate meta-train loss loss = F.cross_entropy(logits, label) if self.args.distill_id: teachLogits = self.teacher( (data_shot, label_shot, data_query)) kl = F.kl_div(F.log_softmax(logits / self.args.kl_temp, dim=1), F.softmax(teachLogits / self.args.kl_temp, dim=1), reduction="batchmean") loss = (kl * self.args.kl_interp * self.args.kl_temp * self.args.kl_temp + loss * (1 - self.args.kl_interp)) acc = count_acc(logits, label) # Write the tensorboardX records writer.add_scalar('data/loss', float(loss), global_count) writer.add_scalar('data/acc', float(acc), global_count) # Print loss and accuracy for this step tqdm_gen.set_description( 'Epoch {}, Loss={:.4f} Acc={:.4f}'.format( epoch, loss.item(), acc)) # Add loss and accuracy for the averagers train_loss_averager.add(loss.item()) train_acc_averager.add(acc) # Loss backwards and optimizer updates self.optimizer.zero_grad() loss.backward() self.optimizer.step() if self.args.hard_tasks: if len(worstClasses) == self.args.way: inds = self.train_sampler.hardBatch(worstClasses) batch = [self.trainset[i][0] for i in inds] data_shot, data_query = data[:p], data[p:] logits = self.model( (data_shot, label_shot, data_query)) loss = F.cross_entropy(logits, label) self.optimizer.zero_grad() loss.backward() self.optimizer.step() worstClasses = [] else: error_mat = (logits.argmax(dim=1) == label).view( self.args.train_query, self.args.way) worst = error_mat.float().mean(dim=0).argmin() worst_trueInd = targ[worst] worstClasses.append(worst_trueInd) # Update the averagers train_loss_averager = train_loss_averager.item() train_acc_averager = train_acc_averager.item() # Start validation for this epoch, set model to eval mode self.model.eval() # Set averager classes to record validation losses and accuracies val_loss_averager = Averager() val_acc_averager = Averager() # Generate the labels for test set of the episodes during meta-val for this epoch label = torch.arange(self.args.way).repeat(self.args.val_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) # Print previous information if epoch % 10 == 0: print('Best Epoch {}, Best Val Acc={:.4f}'.format( trlog['max_acc_epoch'], trlog['max_acc'])) # Run meta-validation for i, batch in enumerate(self.val_loader, 1): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] p = self.args.shot * self.args.way data_shot, data_query = data[:p], data[p:] if self.args.cross_att: label_one_hot = self.one_hot(label).to(label.device) ytest, cls_scores, logits = self.model( (data_shot, label_shot, data_query), ytest=label_one_hot) pids = label_shot loss = self.crossAttLoss(ytest, cls_scores, label, pids) logits = logits[0] else: logits = self.model((data_shot, label_shot, data_query)) loss = F.cross_entropy(logits, label) acc = count_acc(logits, label) val_loss_averager.add(loss.item()) val_acc_averager.add(acc) # Update validation averagers val_loss_averager = val_loss_averager.item() val_acc_averager = val_acc_averager.item() # Write the tensorboardX records writer.add_scalar('data/val_loss', float(val_loss_averager), epoch) writer.add_scalar('data/val_acc', float(val_acc_averager), epoch) # Print loss and accuracy for this epoch print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format( epoch, val_loss_averager, val_acc_averager)) # Update best saved model if val_acc_averager > trlog['max_acc']: trlog['max_acc'] = val_acc_averager trlog['max_acc_epoch'] = epoch self.save_model('max_acc') # Update the logs trlog['train_loss'].append(train_loss_averager) trlog['train_acc'].append(train_acc_averager) trlog['val_loss'].append(val_loss_averager) trlog['val_acc'].append(val_acc_averager) # Save log torch.save(trlog, osp.join(self.args.save_path, 'trlog')) if epoch % 10 == 0: print('Running Time: {}, Estimated Time: {}'.format( timer.measure(), timer.measure(epoch / self.args.max_epoch))) trial.report(val_acc_averager, epoch) writer.close()
def eval(self): """The function for the meta-eval phase.""" # Load the logs trlog = torch.load(osp.join(self.args.save_path, 'trlog')) # Load meta-test set test_set = Dataset('test', self.args) sampler = CategoriesSampler(test_set.label, 600, self.args.way, self.args.shot + self.args.val_query) loader = DataLoader(test_set, batch_sampler=sampler, num_workers=8, pin_memory=True) # Set test accuracy recorder test_acc_record = np.zeros((600, )) # Load model for meta-test phase if self.args.eval_weights is not None: self.model.load_state_dict( torch.load(self.args.eval_weights)['params']) else: self.model.load_state_dict( torch.load(osp.join(self.args.save_path, 'max_acc' + '.pth'))['params']) # Set model to eval mode self.model.eval() # Set accuracy averager ave_acc = Averager() # Generate labels label = torch.arange(self.args.way).repeat(self.args.val_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) label_shot = torch.arange(self.args.way).repeat(self.args.shot) if torch.cuda.is_available(): label_shot = label_shot.type(torch.cuda.LongTensor) else: label_shot = label_shot.type(torch.LongTensor) # Start meta-test for i, batch in enumerate(loader, 1): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] k = self.args.way * self.args.shot data_shot, data_query = data[:k], data[k:] logits = self.model((data_shot, label_shot, data_query)) acc = count_acc(logits, label) ave_acc.add(acc) test_acc_record[i - 1] = acc if i % 100 == 0: print('batch {}: {:.2f}({:.2f})'.format( i, ave_acc.item() * 100, acc * 100)) # Calculate the confidence interval, update the logs m, pm = compute_confidence_interval(test_acc_record) print('Val Best Epoch {}, Acc {:.4f}, Test Acc {:.4f}'.format( trlog['max_acc_epoch'], trlog['max_acc'], ave_acc.item())) print('Test Acc {:.4f} + {:.4f}'.format(m, pm))
def train(self): """The function for the meta-train phase.""" # Set the meta-train log trlog = {} trlog['args'] = vars(self.args) trlog['train_loss'] = [] trlog['val_loss'] = [] trlog['train_acc'] = [] trlog['val_acc'] = [] trlog['max_acc'] = 0.0 trlog['max_acc_epoch'] = 0 # Set the timer timer = Timer() # Set global count to zero global_count = 0 # Set tensorboardX writer = SummaryWriter(comment=self.args.save_path) # Generate the labels for train set of the episodes label_shot = torch.arange(self.args.way).repeat(self.args.shot) if torch.cuda.is_available(): label_shot = label_shot.type(torch.cuda.LongTensor) else: label_shot = label_shot.type(torch.LongTensor) # Start meta-train for epoch in range(1, self.args.max_epoch + 1): # Set the model to train mode self.model.train() # Set averager classes to record training losses and accuracies train_loss_averager = Averager() train_acc_averager = Averager() # Generate the labels for test set of the episodes during meta-train updates label = torch.arange(self.args.way).repeat(self.args.train_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) # Using tqdm to read samples from train loader tqdm_gen = tqdm.tqdm(self.train_loader) for i, batch in enumerate(tqdm_gen, 1): # Update global count number global_count = global_count + 1 if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] p = self.args.shot * self.args.way data_shot, data_query = data[:p], data[p:] # Output logits for model logits = self.model((data_shot, label_shot, data_query)) # Calculate meta-train loss loss = F.cross_entropy(logits, label) # Calculate meta-train accuracy acc = count_acc(logits, label) # Write the tensorboardX records writer.add_scalar('data/loss', float(loss), global_count) writer.add_scalar('data/acc', float(acc), global_count) # Print loss and accuracy for this step tqdm_gen.set_description( 'Epoch {}, Loss={:.4f} Acc={:.4f}'.format( epoch, loss.item(), acc)) # Add loss and accuracy for the averagers train_loss_averager.add(loss.item()) train_acc_averager.add(acc) # Loss backwards and optimizer updates self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update learning rate self.lr_scheduler.step() # Update the averagers train_loss_averager = train_loss_averager.item() train_acc_averager = train_acc_averager.item() # Start validation for this epoch, set model to eval mode self.model.eval() # Set averager classes to record validation losses and accuracies val_loss_averager = Averager() val_acc_averager = Averager() # Generate the labels for test set of the episodes during meta-val for this epoch label = torch.arange(self.args.way).repeat(self.args.val_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) # Print previous information if epoch % 10 == 0: print('Best Epoch {}, Best Val Acc={:.4f}'.format( trlog['max_acc_epoch'], trlog['max_acc'])) # Run meta-validation for i, batch in enumerate(self.val_loader, 1): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] p = self.args.shot * self.args.way data_shot, data_query = data[:p], data[p:] logits = self.model((data_shot, label_shot, data_query)) loss = F.cross_entropy(logits, label) acc = count_acc(logits, label) val_loss_averager.add(loss.item()) val_acc_averager.add(acc) # Update validation averagers val_loss_averager = val_loss_averager.item() val_acc_averager = val_acc_averager.item() # Write the tensorboardX records writer.add_scalar('data/val_loss', float(val_loss_averager), epoch) writer.add_scalar('data/val_acc', float(val_acc_averager), epoch) # Print loss and accuracy for this epoch print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format( epoch, val_loss_averager, val_acc_averager)) # Update best saved model if val_acc_averager > trlog['max_acc']: trlog['max_acc'] = val_acc_averager trlog['max_acc_epoch'] = epoch self.save_model('max_acc') # Save model every 10 epochs if epoch % 10 == 0: self.save_model('epoch' + str(epoch)) # Update the logs trlog['train_loss'].append(train_loss_averager) trlog['train_acc'].append(train_acc_averager) trlog['val_loss'].append(val_loss_averager) trlog['val_acc'].append(val_acc_averager) # Save log torch.save(trlog, osp.join(self.args.save_path, 'trlog')) if epoch % 10 == 0: print('Running Time: {}, Estimated Time: {}'.format( timer.measure(), timer.measure(epoch / self.args.max_epoch))) writer.close()
def eval(self): """The function for the meta-eval phase.""" # Load the logs def multiclass_roc_auc_score(y_test, y_pred, average="macro"): lb = LabelBinarizer() lb.fit(y_test) y_test = lb.transform(y_test) y_pred = lb.transform(y_pred) return roc_auc_score(y_test, y_pred, average=average) trlog = torch.load(osp.join(self.args.save_path, 'trlog')) # Load meta-test set test_set = Dataset('test', self.args) sampler = CategoriesSampler(test_set.label, 20, self.args.way, self.args.shot + self.args.val_query) loader = DataLoader(test_set, batch_sampler=sampler, num_workers=8, pin_memory=True) # Set test accuracy recorder test_acc_record = np.zeros((20, )) test_f1_record = np.zeros((20, )) test_auc_record = np.zeros((20, )) # Load model for meta-test phase if self.args.eval_weights is not None: self.model.load_state_dict( torch.load(self.args.eval_weights)['params']) else: self.model.load_state_dict( torch.load(osp.join(self.args.save_path, 'max_acc' + '.pth'))['params']) # Set model to eval mode self.model.eval() # Set accuracy averager ave_acc = Averager() # Generate labels label = torch.arange(self.args.way).repeat(self.args.val_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) label_shot = torch.arange(self.args.way).repeat(self.args.shot) if torch.cuda.is_available(): label_shot = label_shot.type(torch.cuda.LongTensor) else: label_shot = label_shot.type(torch.LongTensor) Y = label.data.cpu().numpy() # Start meta-test for i, batch in enumerate(loader, 1): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] k = self.args.way * self.args.shot data_shot, data_query = data[:k], data[k:] logits = self.model((data_shot, label_shot, data_query)) acc = count_acc(logits, label) logits = logits.data.cpu().numpy() predicted = np.argmax(logits, axis=1) f1 = f1_score(Y, predicted, average='macro') auc = multiclass_roc_auc_score(Y, predicted) ave_acc.add(acc) test_acc_record[i - 1] = acc test_f1_record[i - 1] = f1 test_auc_record[i - 1] = auc if i % 100 == 0: print('batch {}: {:.2f}({:.2f})'.format( i, ave_acc.item() * 100, acc * 100)) # Calculate the confidence interval, update the logs m, pm = compute_confidence_interval(test_acc_record) f1_m, f1_pm = compute_confidence_interval(test_f1_record) auc_m, auc_pm = compute_confidence_interval(test_auc_record) print('Val Best Epoch {}, Acc {:.4f}, Test Acc {:.4f}'.format( trlog['max_acc_epoch'], trlog['max_acc'], ave_acc.item())) print('Test Acc {:.4f} + {:.4f}'.format(m, pm)) print('Test f1 {:.4f} + {:.4f}'.format(f1_m, f1_pm)) print('Test auc {:.4f} + {:.4f}'.format(auc_m, auc_pm))