class MetaTrainer(object): """The class that contains the code for the meta-train phase and meta-eval phase.""" def __init__(self, args): # Set the folder to save the records and checkpoints log_base_dir = './logs/' if not osp.exists(log_base_dir): os.mkdir(log_base_dir) meta_base_dir = osp.join(log_base_dir, 'meta') if not osp.exists(meta_base_dir): os.mkdir(meta_base_dir) save_path1 = '_'.join([args.dataset, args.model_type, 'MTL']) save_path2 = 'shot' + str(args.shot) + '_way' + str(args.way) + '_query' + str(args.train_query) + \ '_step' + str(args.step_size) + '_gamma' + str(args.gamma) + '_lr1' + str(args.meta_lr1) + '_lr2' + str(args.meta_lr2) + \ '_batch' + str(args.num_batch) + '_maxepoch' + str(args.max_epoch) + \ '_baselr' + str(args.base_lr) + '_updatestep' + str(args.update_step) + \ '_stepsize' + str(args.step_size) + '_' + args.meta_label args.save_path = meta_base_dir + '/' + save_path1 + '_' + save_path2 ensure_path(args.save_path) # Set args to be shareable in the class self.args = args # Load meta-train set self.trainset = Dataset('train', self.args) self.train_sampler = CategoriesSampler( self.trainset.label, self.args.num_batch, self.args.way, self.args.shot + self.args.train_query) self.train_loader = DataLoader(dataset=self.trainset, batch_sampler=self.train_sampler, num_workers=8, pin_memory=True) # Load meta-val set self.valset = Dataset('val', self.args) self.val_sampler = CategoriesSampler( self.valset.label, 600, self.args.way, self.args.shot + self.args.val_query) self.val_loader = DataLoader(dataset=self.valset, batch_sampler=self.val_sampler, num_workers=8, pin_memory=True) # Build meta-transfer learning model self.model = MtlLearner(self.args) # Set optimizer self.optimizer = torch.optim.Adam([{'params': filter(lambda p: p.requires_grad, self.model.encoder.parameters())}, \ {'params': self.model.base_learner.parameters(), 'lr': self.args.meta_lr2}], lr=self.args.meta_lr1) # Set learning rate scheduler self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, step_size=self.args.step_size, gamma=self.args.gamma) # load pretrained model without FC classifier self.model_dict = self.model.state_dict() if self.args.init_weights is not None: pretrained_dict = torch.load(self.args.init_weights)['params'] else: pre_base_dir = osp.join(log_base_dir, 'pre') pre_save_path1 = '_'.join([args.dataset, args.model_type]) pre_save_path2 = 'batchsize' + str(args.pre_batch_size) + '_lr' + str(args.pre_lr) + '_gamma' + str(args.pre_gamma) + '_step' + \ str(args.pre_step_size) + '_maxepoch' + str(args.pre_max_epoch) pre_save_path = pre_base_dir + '/' + pre_save_path1 + '_' + pre_save_path2 pretrained_dict = torch.load(osp.join(pre_save_path, 'max_acc.pth'))['params'] pretrained_dict = { 'encoder.' + k: v for k, v in pretrained_dict.items() } pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in self.model_dict } print(pretrained_dict.keys()) self.model_dict.update(pretrained_dict) self.model.load_state_dict(self.model_dict) # Set model to GPU if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True self.model = self.model.cuda() def save_model(self, name): """The function to save checkpoints. Args: name: the name for saved checkpoint """ torch.save(dict(params=self.model.state_dict()), osp.join(self.args.save_path, name + '.pth')) def train(self): """The function for the meta-train phase.""" # Set the meta-train log trlog = {} trlog['args'] = vars(self.args) trlog['train_loss'] = [] trlog['val_loss'] = [] trlog['train_acc'] = [] trlog['val_acc'] = [] trlog['max_acc'] = 0.0 trlog['max_acc_epoch'] = 0 # Set the timer timer = Timer() # Set global count to zero global_count = 0 # Set tensorboardX writer = SummaryWriter(comment=self.args.save_path) # Generate the labels for train set of the episodes label_shot = torch.arange(self.args.way).repeat(self.args.shot) if torch.cuda.is_available(): label_shot = label_shot.type(torch.cuda.LongTensor) else: label_shot = label_shot.type(torch.LongTensor) # Start meta-train for epoch in range(1, self.args.max_epoch + 1): # Update learning rate self.lr_scheduler.step() # Set the model to train mode self.model.train() # Set averager classes to record training losses and accuracies train_loss_averager = Averager() train_acc_averager = Averager() # Generate the labels for test set of the episodes during meta-train updates label = torch.arange(self.args.way).repeat(self.args.train_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) # Using tqdm to read samples from train loader tqdm_gen = tqdm.tqdm(self.train_loader) for i, batch in enumerate(tqdm_gen, 1): # Update global count number global_count = global_count + 1 if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] p = self.args.shot * self.args.way data_shot, data_query = data[:p], data[p:] # Output logits for model logits = self.model((data_shot, label_shot, data_query)) # Calculate meta-train loss loss = F.cross_entropy(logits, label) # Calculate meta-train accuracy acc = count_acc(logits, label) # Write the tensorboardX records writer.add_scalar('data/loss', float(loss), global_count) writer.add_scalar('data/acc', float(acc), global_count) # Print loss and accuracy for this step tqdm_gen.set_description( 'Epoch {}, Loss={:.4f} Acc={:.4f}'.format( epoch, loss.item(), acc)) # Add loss and accuracy for the averagers train_loss_averager.add(loss.item()) train_acc_averager.add(acc) # Loss backwards and optimizer updates self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update the averagers train_loss_averager = train_loss_averager.item() train_acc_averager = train_acc_averager.item() # Start validation for this epoch, set model to eval mode self.model.eval() # Set averager classes to record validation losses and accuracies val_loss_averager = Averager() val_acc_averager = Averager() # Generate the labels for test set of the episodes during meta-val for this epoch label = torch.arange(self.args.way).repeat(self.args.val_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) # Print previous information if epoch % 10 == 0: print('Best Epoch {}, Best Val Acc={:.4f}'.format( trlog['max_acc_epoch'], trlog['max_acc'])) # Run meta-validation for i, batch in enumerate(self.val_loader, 1): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] p = self.args.shot * self.args.way data_shot, data_query = data[:p], data[p:] logits = self.model((data_shot, label_shot, data_query)) loss = F.cross_entropy(logits, label) acc = count_acc(logits, label) val_loss_averager.add(loss.item()) val_acc_averager.add(acc) # Update validation averagers val_loss_averager = val_loss_averager.item() val_acc_averager = val_acc_averager.item() # Write the tensorboardX records writer.add_scalar('data/val_loss', float(val_loss_averager), epoch) writer.add_scalar('data/val_acc', float(val_acc_averager), epoch) # Print loss and accuracy for this epoch print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format( epoch, val_loss_averager, val_acc_averager)) # Update best saved model if val_acc_averager > trlog['max_acc']: trlog['max_acc'] = val_acc_averager trlog['max_acc_epoch'] = epoch self.save_model('max_acc') # Save model every 10 epochs if epoch % 10 == 0: self.save_model('epoch' + str(epoch)) # Update the logs trlog['train_loss'].append(train_loss_averager) trlog['train_acc'].append(train_acc_averager) trlog['val_loss'].append(val_loss_averager) trlog['val_acc'].append(val_acc_averager) # Save log torch.save(trlog, osp.join(self.args.save_path, 'trlog')) if epoch % 10 == 0: print('Running Time: {}, Estimated Time: {}'.format( timer.measure(), timer.measure(epoch / self.args.max_epoch))) writer.close() def eval(self): """The function for the meta-eval phase.""" # Load the logs trlog = torch.load(osp.join(self.args.save_path, 'trlog')) # Load meta-test set test_set = Dataset('test', self.args) sampler = CategoriesSampler(test_set.label, 600, self.args.way, self.args.shot + self.args.val_query) loader = DataLoader(test_set, batch_sampler=sampler, num_workers=8, pin_memory=True) # Set test accuracy recorder test_acc_record = np.zeros((600, )) # Load model for meta-test phase if self.args.eval_weights is not None: self.model.load_state_dict( torch.load(self.args.eval_weights)['params']) else: self.model.load_state_dict( torch.load(osp.join(self.args.save_path, 'max_acc' + '.pth'))['params']) # Set model to eval mode self.model.eval() # Set accuracy averager ave_acc = Averager() # Generate labels label = torch.arange(self.args.way).repeat(self.args.val_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) label_shot = torch.arange(self.args.way).repeat(self.args.shot) if torch.cuda.is_available(): label_shot = label_shot.type(torch.cuda.LongTensor) else: label_shot = label_shot.type(torch.LongTensor) # Start meta-test for i, batch in enumerate(loader, 1): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] k = self.args.way * self.args.shot data_shot, data_query = data[:k], data[k:] logits = self.model((data_shot, label_shot, data_query)) acc = count_acc(logits, label) ave_acc.add(acc) test_acc_record[i - 1] = acc if i % 100 == 0: print('batch {}: {:.2f}({:.2f})'.format( i, ave_acc.item() * 100, acc * 100)) # Calculate the confidence interval, update the logs m, pm = compute_confidence_interval(test_acc_record) print('Val Best Epoch {}, Acc {:.4f}, Test Acc {:.4f}'.format( trlog['max_acc_epoch'], trlog['max_acc'], ave_acc.item())) print('Test Acc {:.4f} + {:.4f}'.format(m, pm))
class PreTrainer(object): """The class that contains the code for the pretrain phase.""" def __init__(self, args): # Set the folder to save the records and checkpoints log_base_dir = './logs/' if not osp.exists(log_base_dir): os.mkdir(log_base_dir) pre_base_dir = osp.join(log_base_dir, 'pre') if not osp.exists(pre_base_dir): os.mkdir(pre_base_dir) save_path1 = '_'.join([args.dataset, args.model_type]) save_path2 = 'batchsize' + str(args.pre_batch_size) + '_lr' + str(args.pre_lr) + '_gamma' + str(args.pre_gamma) + '_step' + \ str(args.pre_step_size) + '_maxepoch' + str(args.pre_max_epoch) args.save_path = pre_base_dir + '/' + save_path1 + '_' + save_path2 ensure_path(args.save_path) # Set args to be shareable in the class self.args = args # Load pretrain set self.trainset = Dataset('train', self.args, train_aug=True) self.train_loader = DataLoader(dataset=self.trainset, batch_size=args.pre_batch_size, shuffle=True, num_workers=8, pin_memory=True) # Load meta-val set self.valset = Dataset('val', self.args) self.val_sampler = CategoriesSampler( self.valset.label, 600, self.args.way, self.args.shot + self.args.val_query) self.val_loader = DataLoader(dataset=self.valset, batch_sampler=self.val_sampler, num_workers=8, pin_memory=True) # Set pretrain class number num_class_pretrain = self.trainset.num_class # Build pretrain model self.model = MtlLearner(self.args, mode='pre', num_cls=num_class_pretrain) # Set optimizer self.optimizer = torch.optim.SGD([{'params': self.model.encoder.parameters(), 'lr': self.args.pre_lr}, \ {'params': self.model.pre_fc.parameters(), 'lr': self.args.pre_lr}], \ momentum=self.args.pre_custom_momentum, nesterov=True, weight_decay=self.args.pre_custom_weight_decay) # Set learning rate scheduler self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.args.pre_step_size, \ gamma=self.args.pre_gamma) # Set model to GPU if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True self.model = self.model.cuda() def save_model(self, name): """The function to save checkpoints. Args: name: the name for saved checkpoint """ torch.save(dict(params=self.model.encoder.state_dict()), osp.join(self.args.save_path, name + '.pth')) def train(self): """The function for the pre-train phase.""" # Set the pretrain log trlog = {} trlog['args'] = vars(self.args) trlog['train_loss'] = [] trlog['val_loss'] = [] trlog['train_acc'] = [] trlog['val_acc'] = [] trlog['max_acc'] = 0.0 trlog['max_acc_epoch'] = 0 # Set the timer timer = Timer() # Set global count to zero global_count = 0 # Set tensorboardX writer = SummaryWriter(comment=self.args.save_path) # Start pretrain for epoch in range(1, self.args.pre_max_epoch + 1): # Update learning rate self.lr_scheduler.step() # Set the model to train mode self.model.train() self.model.mode = 'pre' # Set averager classes to record training losses and accuracies train_loss_averager = Averager() train_acc_averager = Averager() # Using tqdm to read samples from train loader tqdm_gen = tqdm.tqdm(self.train_loader) for i, batch in enumerate(tqdm_gen, 1): # Update global count number global_count = global_count + 1 if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] label = batch[1] if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) # Output logits for model logits = self.model(data) # Calculate train loss loss = F.cross_entropy(logits, label) # Calculate train accuracy acc = count_acc(logits, label) # Write the tensorboardX records writer.add_scalar('data/loss', float(loss), global_count) writer.add_scalar('data/acc', float(acc), global_count) # Print loss and accuracy for this step tqdm_gen.set_description( 'Epoch {}, Loss={:.4f} Acc={:.4f}'.format( epoch, loss.item(), acc)) # Add loss and accuracy for the averagers train_loss_averager.add(loss.item()) train_acc_averager.add(acc) # Loss backwards and optimizer updates self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update the averagers train_loss_averager = train_loss_averager.item() train_acc_averager = train_acc_averager.item() # Start validation for this epoch, set model to eval mode self.model.eval() self.model.mode = 'preval' # Set averager classes to record validation losses and accuracies val_loss_averager = Averager() val_acc_averager = Averager() # Generate the labels for test label = torch.arange(self.args.way).repeat(self.args.val_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) label_shot = torch.arange(self.args.way).repeat(self.args.shot) if torch.cuda.is_available(): label_shot = label_shot.type(torch.cuda.LongTensor) else: label_shot = label_shot.type(torch.LongTensor) # Print previous information if epoch % 10 == 0: print('Best Epoch {}, Best Val acc={:.4f}'.format( trlog['max_acc_epoch'], trlog['max_acc'])) # Run meta-validation for i, batch in enumerate(self.val_loader, 1): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] p = self.args.shot * self.args.way data_shot, data_query = data[:p], data[p:] logits = self.model((data_shot, label_shot, data_query)) loss = F.cross_entropy(logits, label) acc = count_acc(logits, label) val_loss_averager.add(loss.item()) val_acc_averager.add(acc) # Update validation averagers val_loss_averager = val_loss_averager.item() val_acc_averager = val_acc_averager.item() # Write the tensorboardX records writer.add_scalar('data/val_loss', float(val_loss_averager), epoch) writer.add_scalar('data/val_acc', float(val_acc_averager), epoch) # Print loss and accuracy for this epoch print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format( epoch, val_loss_averager, val_acc_averager)) # Update best saved model if val_acc_averager > trlog['max_acc']: trlog['max_acc'] = val_acc_averager trlog['max_acc_epoch'] = epoch self.save_model('max_acc') # Save model every 10 epochs if epoch % 10 == 0: self.save_model('epoch' + str(epoch)) # Update the logs trlog['train_loss'].append(train_loss_averager) trlog['train_acc'].append(train_acc_averager) trlog['val_loss'].append(val_loss_averager) trlog['val_acc'].append(val_acc_averager) # Save log torch.save(trlog, osp.join(self.args.save_path, 'trlog')) if epoch % 10 == 0: print('Running Time: {}, Estimated Time: {}'.format( timer.measure(), timer.measure(epoch / self.args.max_epoch))) writer.close()
class PreTrainer(object): def __init__(self, args): log_base_dir = './logs/' if not osp.exists(log_base_dir): os.mkdir(log_base_dir) pre_base_dir = osp.join(log_base_dir, 'pre') if not osp.exists(pre_base_dir): os.mkdir(pre_base_dir) save_path1 = '_'.join([args.dataset, args.model_type]) save_path2 = 'batchsize' + str(args.pre_batch_size) + '_lr' + str(args.pre_lr) + '_gamma' + str(args.pre_gamma) + '_step' + \ str(args.pre_step_size) + '_maxepoch' + str(args.pre_max_epoch) args.save_path = pre_base_dir + '/' + save_path1 + '_' + save_path2 ensure_path(args.save_path) self.args = args if self.args.dataset == 'MiniImageNet': from dataloader.mini_imagenet import MiniImageNet as Dataset elif self.args.dataset == 'TieredImageNet': from dataloader.tiered_imagenet import TieredImageNet as Dataset elif self.args.dataset == 'FC100': from dataloader.fewshotcifar import FewshotCifar as Dataset else: raise ValueError('Please set correct dataset.') self.trainset = Dataset('train', self.args, train_aug=True) self.train_loader = DataLoader(dataset=self.trainset, batch_size=args.pre_batch_size, shuffle=True, num_workers=8, pin_memory=True) self.valset = Dataset('test', self.args) self.val_sampler = CategoriesSampler( self.valset.label, 600, self.args.way, self.args.shot + self.args.val_query) self.val_loader = DataLoader(dataset=self.valset, batch_sampler=self.val_sampler, num_workers=8, pin_memory=True) num_class_pretrain = self.trainset.num_class self.model = MtlLearner(self.args, mode='pre', num_cls=num_class_pretrain) self.optimizer = torch.optim.SGD([{'params': self.model.encoder.parameters(), 'lr': self.args.pre_lr}, \ {'params': self.model.pre_fc.parameters(), 'lr': self.args.pre_lr}], \ momentum=self.args.pre_custom_momentum, nesterov=True, weight_decay=self.args.pre_custom_weight_decay) self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.args.pre_step_size, \ gamma=self.args.pre_gamma) if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True self.model = self.model.cuda() def save_model(self, name): torch.save(dict(params=self.model.encoder.state_dict()), osp.join(self.args.save_path, name + '.pth')) def train(self): trlog = {} trlog['args'] = vars(self.args) trlog['train_loss'] = [] trlog['val_loss'] = [] trlog['train_acc'] = [] trlog['val_acc'] = [] trlog['max_acc'] = 0.0 trlog['max_acc_epoch'] = 0 timer = Timer() global_count = 0 writer = SummaryWriter(comment=self.args.save_path) for epoch in range(1, self.args.pre_max_epoch + 1): self.lr_scheduler.step() self.model.train() self.model.mode = 'pre' tl = Averager() ta = Averager() tqdm_gen = tqdm.tqdm(self.train_loader) for i, batch in enumerate(tqdm_gen, 1): global_count = global_count + 1 if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] label = batch[1] if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) logits = self.model(data) loss = F.cross_entropy(logits, label) acc = count_acc(logits, label) writer.add_scalar('data/loss', float(loss), global_count) writer.add_scalar('data/acc', float(acc), global_count) tqdm_gen.set_description( 'Epoch {}, Loss={:.4f} Acc={:.4f}'.format( epoch, loss.item(), acc)) tl.add(loss.item()) ta.add(acc) self.optimizer.zero_grad() loss.backward() self.optimizer.step() tl = tl.item() ta = ta.item() self.model.eval() self.model.mode = 'preval' vl = Averager() va = Averager() label = torch.arange(self.args.way).repeat(self.args.val_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) label_shot = torch.arange(self.args.way).repeat(self.args.shot) if torch.cuda.is_available(): label_shot = label_shot.type(torch.cuda.LongTensor) else: label_shot = label_shot.type(torch.LongTensor) print('Best Epoch {}, Best Val acc={:.4f}'.format( trlog['max_acc_epoch'], trlog['max_acc'])) for i, batch in enumerate(self.val_loader, 1): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] p = self.args.shot * self.args.way data_shot, data_query = data[:p], data[p:] logits = self.model((data_shot, label_shot, data_query)) loss = F.cross_entropy(logits, label) acc = count_acc(logits, label) vl.add(loss.item()) va.add(acc) vl = vl.item() va = va.item() writer.add_scalar('data/val_loss', float(vl), epoch) writer.add_scalar('data/val_acc', float(va), epoch) print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format( epoch, vl, va)) if va > trlog['max_acc']: trlog['max_acc'] = va trlog['max_acc_epoch'] = epoch self.save_model('max_acc') if epoch % 20 == 0: self.save_model('epoch' + str(epoch)) trlog['train_loss'].append(tl) trlog['train_acc'].append(ta) trlog['val_loss'].append(vl) trlog['val_acc'].append(va) torch.save(trlog, osp.join(self.args.save_path, 'trlog')) if epoch > self.args.pre_max_epoch - 2: self.save_model('epoch-last') torch.save( self.optimizer.state_dict(), osp.join(self.args.save_path, 'optimizer_latest.pth')) print('Running Time: {}, Estimated Time: {}'.format( timer.measure(), timer.measure(epoch / self.args.max_epoch))) writer.close()
class MetaTrainer(object): """The class that contains the code for the meta-train phase and meta-eval phase.""" def __init__(self, args): # Set the folder to save the records and checkpoints save_image_dir = '../results1/' if not osp.exists(save_image_dir): os.mkdir(save_image_dir) log_base_dir = '../logs1/' if not osp.exists(log_base_dir): os.mkdir(log_base_dir) meta_base_dir = osp.join(log_base_dir, 'meta') if not osp.exists(meta_base_dir): os.mkdir(meta_base_dir) save_path1 = '_'.join([args.dataset, args.model_type, 'MTL']) save_path2 = '_mtype' + str(args.mtype) + '_shot' + str(args.train_query) + '_way' + str(args.way) + '_query' + str(args.train_query) + \ '_step' + str(args.step_size) + '_gamma' + str(args.gamma) + '_lr' + str(args.meta_lr) + \ '_batch' + str(args.num_batch) + '_maxepoch' + str(args.max_epoch) + \ '_baselr' + str(args.base_lr) + '_updatestep' + str(args.update_step) + \ '_stepsize' + str(args.step_size) + '_' + args.meta_label args.save_path = meta_base_dir + '/' + save_path1 + '_' + save_path2 args.save_image_dir = save_image_dir ensure_path(args.save_path) # Set args to be shareable in the class self.args = args # Load meta-train set self.trainset = Dataset('train', self.args) self.train_sampler = CategoriesSampler(self.trainset.labeln, self.args.num_batch, self.args.way + 1, self.args.train_query, self.args.test_query) self.train_loader = DataLoader(dataset=self.trainset, batch_sampler=self.train_sampler, num_workers=8, pin_memory=True) # Load meta-val set if (self.args.valdata == 'Yes'): self.valset = Dataset('val', self.args) self.val_sampler = CategoriesSampler(self.valset.labeln, self.args.num_batch, self.args.way + 1, self.args.train_query, self.args.test_query) self.val_loader = DataLoader(dataset=self.valset, batch_sampler=self.val_sampler, num_workers=8, pin_memory=True) # Build meta-transfer learning model self.model = MtlLearner(self.args) self.CD = CE_DiceLoss() self.FL = FocalLoss() self.LS = LovaszSoftmax() # Set model to GPU if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True self.model = self.model.cuda() # Set optimizer self.optimizer = torch.optim.Adam([{ 'params': filter(lambda p: p.requires_grad, self.model.encoder.parameters()) }], lr=self.args.meta_lr) # Set learning rate scheduler self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, step_size=self.args.step_size, gamma=self.args.gamma) # load pretrained model # Path should nbe changed accordingly self.model.load_state_dict( torch.load(osp.join(self.args.save_path, 'epoch24' + '.pth'))['params']) self.optimizer.load_state_dict( torch.load(osp.join(self.args.save_path, 'epoch24' + '_o.pth'))['params_o']) self.lr_scheduler.load_state_dict( torch.load(osp.join(self.args.save_path, 'epoch24' + '_s.pth'))['params_s']) self.model_dict = self.model.state_dict() self.optimizer_dict = self.optimizer.state_dict() self.lr_scheduler_dict = self.lr_scheduler.state_dict() #Total Model Parameters pytorch_total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) print("Total Trainable Parameters in the Model: " + str(pytorch_total_params)) def _reset_metrics(self): self.total_inter, self.total_union = 0, 0 self.total_correct, self.total_label = 0, 0 def _update_seg_metrics(self, correct, labeled, inter, union): self.total_correct += correct self.total_label += labeled self.total_inter += inter self.total_union += union def _get_seg_metrics(self, n_class): self.n_class = n_class pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label) IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) mIoU = IoU.mean() return { "Pixel_Accuracy": np.round(pixAcc, 3), "Mean_IoU": np.round(mIoU, 3), "Class_IoU": dict(zip(range(self.n_class), np.round(IoU, 3))) } def save_model(self, name): """The function to save checkpoints. Args: name: the name for saved checkpoint """ torch.save(dict(params=self.model.state_dict()), osp.join(self.args.save_path, name + '.pth')) torch.save(dict(params_o=self.optimizer.state_dict()), osp.join(self.args.save_path, name + '_o.pth')) torch.save(dict(params_s=self.lr_scheduler.state_dict()), osp.join(self.args.save_path, name + '_s.pth')) def train(self): """The function for the meta-train phase.""" # Set the meta-train log #Change when resuming training initial_epoch = 25 trlog = {} trlog['args'] = vars(self.args) trlog['train_loss'] = [] trlog['train_acc'] = [] trlog['train_iou'] = [] # Set the meta-val log trlog['val_loss'] = [] trlog['val_acc'] = [] trlog['val_iou'] = [] trlog['max_iou'] = 0.2856 trlog['max_iou_epoch'] = 4 # Set the timer timer = Timer() # Set global count to zero global_count = 0 # Set tensorboardX writer = SummaryWriter(comment=self.args.save_path) K = self.args.way + 1 #included Background as class N = self.args.train_query Q = self.args.test_query # Start meta-train for epoch in range(initial_epoch, self.args.max_epoch + 1): print( '----------------------------------------------------------------------------------------------------------------------------------------------------------' ) # Update learning rate self.lr_scheduler.step() # Set the model to train mode self.model.train() # Set averager classes to record training losses and accuracies train_loss_averager = Averager() train_acc_averager = Averager() train_iou_averager = Averager() # Using tqdm to read samples from train loader tqdm_gen = tqdm.tqdm(self.train_loader) for i, batch in enumerate(tqdm_gen, 1): # Update global count number global_count = global_count + 1 if torch.cuda.is_available(): data, labels, _ = [_.cuda() for _ in batch] else: data = batch[0] labels = batch[1] #print(data.shape) #print(labels.shape) p = K * N im_train, im_test = data[:p], data[p:] #Adjusting labels for each meta task labels = downlabel(labels, K) out_train, out_test = labels[:p], labels[p:] ''' print(im_train.shape) print(im_test.shape) print(out_train.shape) print(out_test.shape) ''' if (torch.cuda.is_available()): im_train = im_train.cuda() im_test = im_test.cuda() out_train = out_train.cuda() out_test = out_test.cuda() #Reshaping train set ouput Ytr = out_train.reshape(-1) Ytr = onehot(Ytr, K) #One hot encoding for loss Yte = out_test.reshape(out_test.shape[0], -1) if (torch.cuda.is_available()): Ytr = Ytr.cuda() Yte = Yte.cuda() # Output logits for model Gte = self.model(im_train, Ytr, im_test, Yte) GteT = torch.transpose(Gte, 1, 2) # Calculate meta-train loss #loss = self.CD(GteT,Yte) loss = self.FL(GteT, Yte) #loss = self.LS(GteT,Yte) self._reset_metrics() # Calculate meta-train accuracy seg_metrics = eval_metrics(GteT, Yte, K) self._update_seg_metrics(*seg_metrics) pixAcc, mIoU, _ = self._get_seg_metrics(K).values() # Print loss and accuracy for this step tqdm_gen.set_description( 'Epoch {}, Loss={:.4f} Acc={:.4f} IoU={:.4f}'.format( epoch, loss.item(), pixAcc * 100.0, mIoU)) # Add loss and accuracy for the averagers # Calculate the running averages train_loss_averager.add(loss.item()) train_acc_averager.add(pixAcc) train_iou_averager.add(mIoU) # Loss backwards and optimizer updates self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update the averagers train_loss_averager = train_loss_averager.item() train_acc_averager = train_acc_averager.item() train_iou_averager = train_iou_averager.item() #Adding to Tensorboard writer.add_scalar('data/train_loss (Meta)', float(train_loss_averager), epoch) writer.add_scalar('data/train_acc (Meta)', float(train_acc_averager) * 100.0, epoch) writer.add_scalar('data/train_iou (Meta)', float(train_iou_averager), epoch) # Update best saved model if validation set is not present and save it if (self.args.valdata == 'No'): if train_iou_averager > trlog['max_iou']: print("New Best!") trlog['max_iou'] = train_iou_averager trlog['max_iou_epoch'] = epoch self.save_model('max_iou') # Save model every 2 epochs if epoch % 2 == 0: self.save_model('epoch' + str(epoch)) # Update the logs trlog['train_loss'].append(train_loss_averager) trlog['train_acc'].append(train_acc_averager) trlog['train_iou'].append(train_iou_averager) if epoch % 1 == 0: print('Running Time: {}, Estimated Time: {}'.format( timer.measure(), timer.measure(epoch / self.args.max_epoch))) print('Epoch:{}, Average Loss: {:.4f}, Average mIoU: {:.4f}'. format(epoch, train_loss_averager, train_iou_averager)) """The function for the meta-val phase.""" if (self.args.valdata == 'Yes'): # Start meta-val # Set the model to val mode self.model.eval() # Set averager classes to record training losses and accuracies val_loss_averager = Averager() val_acc_averager = Averager() val_iou_averager = Averager() # Using tqdm to read samples from train loader tqdm_gen = tqdm.tqdm(self.val_loader) for i, batch in enumerate(tqdm_gen, 1): # Update global count number global_count = global_count + 1 if torch.cuda.is_available(): data, labels, _ = [_.cuda() for _ in batch] else: data = batch[0] labels = batch[1] #print(data.shape) #print(labels.shape) p = K * N im_train, im_test = data[:p], data[p:] #Adjusting labels for each meta task labels = downlabel(labels, K) out_train, out_test = labels[:p], labels[p:] ''' print(im_train.shape) print(im_test.shape) print(out_train.shape) print(out_test.shape) ''' if (torch.cuda.is_available()): im_train = im_train.cuda() im_test = im_test.cuda() out_train = out_train.cuda() out_test = out_test.cuda() #Reshaping val set ouput Ytr = out_train.reshape(-1) Ytr = onehot(Ytr, K) #One hot encoding for loss Yte = out_test.reshape(out_test.shape[0], -1) if (torch.cuda.is_available()): Ytr = Ytr.cuda() Yte = Yte.cuda() # Output logits for model Gte = self.model(im_train, Ytr, im_test, Yte) GteT = torch.transpose(Gte, 1, 2) self._reset_metrics() # Calculate meta-train accuracy seg_metrics = eval_metrics(GteT, Yte, K) self._update_seg_metrics(*seg_metrics) pixAcc, mIoU, _ = self._get_seg_metrics(K).values() # Print loss and accuracy for this step tqdm_gen.set_description( 'Epoch {}, Val Loss={:.4f} Val Acc={:.4f} Val IoU={:.4f}' .format(epoch, loss.item(), pixAcc * 100.0, mIoU)) # Add loss and accuracy for the averagers # Calculate the running averages val_loss_averager.add(loss.item()) val_acc_averager.add(pixAcc) val_iou_averager.add(mIoU) # Update the averagers val_loss_averager = val_loss_averager.item() val_acc_averager = val_acc_averager.item() val_iou_averager = val_iou_averager.item() #Adding to Tensorboard writer.add_scalar('data/val_loss (Meta)', float(val_loss_averager), epoch) writer.add_scalar('data/val_acc (Meta)', float(val_acc_averager) * 100.0, epoch) writer.add_scalar('data/val_iou (Meta)', float(val_iou_averager), epoch) # Update best saved model if val_iou_averager > trlog['max_iou']: print("New Best (Validation)") trlog['max_iou'] = val_iou_averager trlog['max_iou_epoch'] = epoch self.save_model('max_iou') # Save model every 2 epochs if epoch % 2 == 0: self.save_model('epoch' + str(epoch)) # Update the logs trlog['val_loss'].append(val_loss_averager) trlog['val_acc'].append(val_acc_averager) trlog['val_iou'].append(val_iou_averager) if epoch % 1 == 0: print('Running Time: {}, Estimated Time: {}'.format( timer.measure(), timer.measure(epoch / self.args.max_epoch))) print( 'Epoch:{}, Average Val Loss: {:.4f}, Average Val mIoU: {:.4f}' .format(epoch, val_loss_averager, val_iou_averager)) # Save log torch.save(trlog, osp.join(self.args.save_path, 'trlog')) print( '----------------------------------------------------------------------------------------------------------------------------------------------------------' ) writer.close() def eval(self): """The function for the meta-evaluate (test) phase.""" # Load the logs trlog = torch.load(osp.join(self.args.save_path, 'trlog')) # Load meta-test set self.test_set = Dataset('test', self.args) self.sampler = CategoriesSampler(self.test_set.labeln, self.args.num_batch, self.args.way + 1, self.args.train_query, self.args.test_query) self.loader = DataLoader(dataset=self.test_set, batch_sampler=self.sampler, num_workers=8, pin_memory=True) # Load model for meta-test phase if self.args.eval_weights is not None: self.model.load_state_dict( torch.load(self.args.eval_weights)['params']) else: self.model.load_state_dict( torch.load(osp.join(self.args.save_path, 'max_iou' + '.pth'))['params']) # Set model to eval mode self.model.eval() # Set accuracy(IoU) averager ave_acc = Averager() # Start meta-test K = self.args.way + 1 N = self.args.train_query Q = self.args.test_query count = 1 for i, batch in enumerate(self.loader, 1): if torch.cuda.is_available(): data, labels, _ = [_.cuda() for _ in batch] else: data = batch[0] labels = batch[1] p = K * N im_train, im_test = data[:p], data[p:] #Adjusting labels for each meta task labels = downlabel(labels, K) out_train, out_test = labels[:p], labels[p:] if (torch.cuda.is_available()): im_train = im_train.cuda() im_test = im_test.cuda() out_train = out_train.cuda() out_test = out_test.cuda() #Reshaping train set ouput Ytr = out_train.reshape(-1) Ytr = onehot(Ytr, K) #One hot encoding for loss Yte = out_test.reshape(out_test.shape[0], -1) if (torch.cuda.is_available()): Ytr = Ytr.cuda() Yte = Yte.cuda() # Output logits for model Gte = self.model(im_train, Ytr, im_test, Yte) GteT = torch.transpose(Gte, 1, 2) # Calculate meta-train accuracy self._reset_metrics() seg_metrics = eval_metrics(GteT, Yte, K) self._update_seg_metrics(*seg_metrics) pixAcc, mIoU, _ = self._get_seg_metrics(K).values() ave_acc.add(mIoU) #Saving Test Image, Ground Truth Image and Predicted Image for j in range(K * Q): x1 = im_test[j].detach().cpu() y1 = out_test[j].detach().cpu() z1 = GteT[j].detach().cpu() z1 = torch.argmax(z1, axis=0) m = int(math.sqrt(z1.shape[0])) z2 = z1.reshape(m, m) x = transforms.ToPILImage()(x1).convert("RGB") y = Image.fromarray(decode_segmap(y1, K)) z = Image.fromarray(decode_segmap(z2, K)) px = self.args.save_image_dir + str(count) + 'a.jpg' py = self.args.save_image_dir + str(count) + 'b.png' pz = self.args.save_image_dir + str(count) + 'c.png' x.save(px) y.save(py) z.save(pz) count = count + 1 # Test mIoU ave_acc = ave_acc.item() print("=============================================================") print('Average Test mIoU: {:.4f}'.format(ave_acc)) print("Images Saved!") print("=============================================================")
class PreTrainer(object): def __init__(self, args): # Set the folder to save the records and checkpoints log_base_dir = '../logs/' if not osp.exists(log_base_dir): os.mkdir(log_base_dir) pre_base_dir = osp.join(log_base_dir, 'pre') if not osp.exists(pre_base_dir): os.mkdir(pre_base_dir) save_path1 = '_'.join([args.dataset, args.model_type]) save_path2 = 'batchsize' + str(args.pre_batch_size) + '_lr' + str(args.pre_lr) + '_gamma' + str(args.pre_gamma) + '_step' + \ str(args.pre_step_size) + '_maxepoch' + str(args.pre_max_epoch) args.save_path = pre_base_dir + '/' + save_path1 + '_' + save_path2 ensure_path(args.save_path) # Set args to be shareable in the class self.args = args # Load pretrain set self.trainset = Dataset('train', self.args) self.train_loader = DataLoader(dataset=self.trainset, batch_size=args.pre_batch_size, shuffle=True, num_workers=8, pin_memory=True) # Load pre-val set self.valset = mDataset('val', self.args) self.val_sampler = CategoriesSampler( self.valset.labeln, self.args.num_batch, self.args.way, self.args.shot + self.args.val_query, self.args.shot) self.val_loader = DataLoader(dataset=self.valset, batch_sampler=self.val_sampler, num_workers=8, pin_memory=True) # Build pretrain model self.model = MtlLearner(self.args, mode='train') print(self.model) ''' if self.args.pre_init_weights is not None: self.model_dict = self.model.state_dict() pretrained_dict = torch.load(self.args.pre_init_weights)['params'] pretrained_dict = {'encoder.'+k: v for k, v in pretrained_dict.items()} pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in self.model_dict} print(pretrained_dict.keys()) self.model_dict.update(pretrained_dict) self.model.load_state_dict(self.model_dict) ''' self.FL = FocalLoss() self.CD = CE_DiceLoss() self.LS = LovaszSoftmax() # Set optimizer # Set optimizer self.optimizer = torch.optim.SGD([{'params': self.model.encoder.parameters(), 'lr': self.args.pre_lr}], \ momentum=self.args.pre_custom_momentum, nesterov=True, weight_decay=self.args.pre_custom_weight_decay) # Set learning rate scheduler self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.args.pre_step_size, \ gamma=self.args.pre_gamma) # Set model to GPU if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True self.model = self.model.cuda() def save_model(self, name): """The function to save checkpoints. Args: name: the name for saved checkpoint """ torch.save(dict(params=self.model.encoder.state_dict()), osp.join(self.args.save_path, name + '.pth')) def _reset_metrics(self): #self.batch_time = AverageMeter() #self.data_time = AverageMeter() #self.total_loss = AverageMeter() self.total_inter, self.total_union = 0, 0 self.total_correct, self.total_label = 0, 0 def _update_seg_metrics(self, correct, labeled, inter, union): self.total_correct += correct self.total_label += labeled self.total_inter += inter self.total_union += union def _get_seg_metrics(self, n_class): self.n_class = n_class pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label) IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) mIoU = IoU.mean() return { "Pixel_Accuracy": np.round(pixAcc, 3), "Mean_IoU": np.round(mIoU, 3), "Class_IoU": dict(zip(range(self.n_class), np.round(IoU, 3))) } def train(self): """The function for the pre-train phase.""" # Set the pretrain log trlog = {} trlog['args'] = vars(self.args) trlog['train_loss'] = [] trlog['val_loss'] = [] trlog['train_acc'] = [] trlog['val_acc'] = [] trlog['train_iou'] = [] trlog['val_iou'] = [] trlog['max_iou'] = 0.0 trlog['max_iou_epoch'] = 0 # Set the timer timer = Timer() # Set global count to zero global_count = 0 # Set tensorboardX writer = SummaryWriter(comment=self.args.save_path) # Start pretrain for epoch in range(1, self.args.pre_max_epoch + 1): # Update learning rate self.lr_scheduler.step() # Set the model to train mode self.model.train() self.model.mode = 'train' # Set averager classes to record training losses and accuracies train_loss_averager = Averager() train_acc_averager = Averager() train_iou_averager = Averager() # Using tqdm to read samples from train loader tqdm_gen = tqdm.tqdm(self.train_loader) for i, batch in enumerate(tqdm_gen, 1): # Update global count number global_count = global_count + 1 if torch.cuda.is_available(): data, label = [_.cuda() for _ in batch] else: data = batch[0] label = batch[1] # Output logits for model logits = self.model(data) # Calculate train loss # CD loss is modified in the whole project to incorporate ony Cross Entropy loss. Modify as per requirement. #loss = self.FL(logits, label) + self.CD(logits,label) + self.LS(logits,label) loss = self.CD(logits, label) # Calculate train accuracy self._reset_metrics() seg_metrics = eval_metrics(logits, label, self.args.num_classes) self._update_seg_metrics(*seg_metrics) pixAcc, mIoU, _ = self._get_seg_metrics( self.args.num_classes).values() # Add loss and accuracy for the averagers train_loss_averager.add(loss.item()) train_acc_averager.add(pixAcc) train_iou_averager.add(mIoU) # Print loss and accuracy till this step tqdm_gen.set_description( 'Epoch {}, Loss={:.4f} Acc={:.4f} IOU={:.4f}'.format( epoch, train_loss_averager.item(), train_acc_averager.item() * 100.0, train_iou_averager.item())) # Loss backwards and optimizer updates self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update the averagers train_loss_averager = train_loss_averager.item() train_acc_averager = train_acc_averager.item() train_iou_averager = train_iou_averager.item() writer.add_scalar('data/train_loss(Pre)', float(train_loss_averager), epoch) writer.add_scalar('data/train_acc(Pre)', float(train_acc_averager) * 100.0, epoch) writer.add_scalar('data/train_iou (Pre)', float(train_iou_averager), epoch) print( 'Epoch {}, Train: Loss={:.4f}, Acc={:.4f}, IoU={:.4f}'.format( epoch, train_loss_averager, train_acc_averager * 100.0, train_iou_averager)) # Start validation for this epoch, set model to eval mode self.model.eval() self.model.mode = 'val' # Set averager classes to record validation losses and accuracies val_loss_averager = Averager() val_acc_averager = Averager() val_iou_averager = Averager() # Print previous information if epoch % 1 == 0: print('Best Val Epoch {}, Best Val IoU={:.4f}'.format( trlog['max_iou_epoch'], trlog['max_iou'])) # Run validation for i, batch in enumerate(self.val_loader, 1): if torch.cuda.is_available(): data, labels, _ = [_.cuda() for _ in batch] else: data = batch[0] label = labels[0] p = self.args.way * self.args.shot data_shot, data_query = data[:p], data[p:] label_shot, label = labels[:p], labels[p:] par = data_shot, label_shot, data_query logits = self.model(par) # Calculate preval loss #loss = self.FL(logits, label) + self.CD(logits,label) + self.LS(logits,label) loss = self.CD(logits, label) # Calculate val accuracy self._reset_metrics() seg_metrics = eval_metrics(logits, label, self.args.way) self._update_seg_metrics(*seg_metrics) pixAcc, mIoU, _ = self._get_seg_metrics(self.args.way).values() val_loss_averager.add(loss.item()) val_acc_averager.add(pixAcc) val_iou_averager.add(mIoU) # Update validation averagers val_loss_averager = val_loss_averager.item() val_acc_averager = val_acc_averager.item() val_iou_averager = val_iou_averager.item() writer.add_scalar('data/val_loss(Pre)', float(val_loss_averager), epoch) writer.add_scalar('data/val_acc(Pre)', float(val_acc_averager) * 100.0, epoch) writer.add_scalar('data/val_iou (Pre)', float(val_iou_averager), epoch) # Print loss and accuracy for this epoch print('Epoch {}, Val: Loss={:.4f} Acc={:.4f} IoU={:.4f}'.format( epoch, val_loss_averager, val_acc_averager * 100.0, val_iou_averager)) # Update best saved model if val_iou_averager > trlog['max_iou']: trlog['max_iou'] = val_iou_averager trlog['max_iou_epoch'] = epoch print("model saved in max_iou") self.save_model('max_iou') # Save model every 10 epochs if epoch % 10 == 0: self.save_model('epoch' + str(epoch)) # Update the logs trlog['train_loss'].append(train_loss_averager) trlog['train_acc'].append(train_acc_averager) trlog['val_loss'].append(val_loss_averager) trlog['val_acc'].append(val_acc_averager) trlog['train_iou'].append(train_iou_averager) trlog['val_iou'].append(val_iou_averager) # Save log torch.save(trlog, osp.join(self.args.save_path, 'trlog')) if epoch % 1 == 0: print('Running Time: {}, Estimated Time: {}'.format( timer.measure(), timer.measure(epoch / self.args.max_epoch))) writer.close()
class PreTrainer(object): """The class that contains the code for the pretrain phase.""" def __init__(self, args): # Set the folder to save the records and checkpoints log_base_dir = './logs/' if not osp.exists(log_base_dir): os.mkdir(log_base_dir) pre_base_dir = osp.join(log_base_dir, 'pre') if not osp.exists(pre_base_dir): os.mkdir(pre_base_dir) save_path1 = '_'.join([args.dataset, args.model_type]) save_path2 = 'batchsize' + str(args.pre_batch_size) + '_lr' + str(args.pre_lr) + '_gamma' + str(args.pre_gamma) + '_step' + \ str(args.pre_step_size) + '_maxepoch' + str(args.pre_max_epoch) args.save_path = pre_base_dir + '/' + save_path1 + '_' + save_path2 ensure_path(args.save_path) # Set args to be shareable in the class self.args = args # Load pretrain set self.trainset = Dataset('train', self.args, train_aug=False) self.train_loader = DataLoader(dataset=self.trainset, batch_size=args.pre_batch_size, shuffle=True, num_workers=8, pin_memory=True) # Load meta-val set self.valset = Dataset('val', self.args) self.val_sampler = CategoriesSampler( self.valset.label, 20, self.args.way, self.args.shot + self.args.val_query) self.val_loader = DataLoader(dataset=self.valset, batch_sampler=self.val_sampler, num_workers=8, pin_memory=True) # Set pretrain class number num_class_pretrain = self.trainset.num_class # Build pretrain model self.model = MtlLearner(self.args, mode='pre', num_cls=num_class_pretrain) #self.model=self.model.float() # Set optimizer params = list(self.model.encoder.parameters()) + list( self.model.pre_fc.parameters()) self.optimizer = optim.Adam(params) # Set model to GPU if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True self.model = self.model.cuda() def save_model(self, name): """The function to save checkpoints. Args: name: the name for saved checkpoint """ torch.save(dict(params=self.model.encoder.state_dict()), osp.join(self.args.save_path, name + '.pth')) def train(self): """The function for the pre-train phase.""" # Set the pretrain log trlog = {} trlog['args'] = vars(self.args) trlog['train_loss'] = [] trlog['val_loss'] = [] trlog['train_acc'] = [] trlog['val_acc'] = [] trlog['max_acc'] = 0.0 trlog['max_acc_epoch'] = 0 # Set the timer timer = Timer() # Set global count to zero global_count = 0 # Set tensorboardX writer = SummaryWriter(comment=self.args.save_path) # Start pretrain for epoch in range(1, self.args.pre_max_epoch + 1): # Set the model to train mode print('Epoch {}'.format(epoch)) self.model.train() self.model.mode = 'pre' # Set averager classes to record training losses and accuracies train_loss_averager = Averager() train_acc_averager = Averager() # Using tqdm to read samples from train loader tqdm_gen = tqdm.tqdm(self.train_loader) #for i, batch in enumerate(self.train_loader): for i, batch in enumerate(tqdm_gen, 1): # Update global count number global_count = global_count + 1 if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] label = batch[1] if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) logits = self.model(data) loss = F.cross_entropy(logits, label) # Calculate train accuracy acc = count_acc(logits, label) # Write the tensorboardX records writer.add_scalar('data/loss', float(loss), global_count) writer.add_scalar('data/acc', float(acc), global_count) # Print loss and accuracy for this step train_loss_averager.add(loss.item()) train_acc_averager.add(acc) # Loss backwards and optimizer updates self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update the averagers train_loss_averager = train_loss_averager.item() train_acc_averager = train_acc_averager.item() # start the original evaluation self.model.eval() self.model.mode = 'origval' _, valid_results = self.val_orig(self.valset.X_val, self.valset.y_val) print('validation accuracy ', valid_results[0]) # Start validation for this epoch, set model to eval mode self.model.eval() self.model.mode = 'preval' # Set averager classes to record validation losses and accuracies val_loss_averager = Averager() val_acc_averager = Averager() # Generate the labels for test label = torch.arange(self.args.way).repeat(self.args.val_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) label_shot = torch.arange(self.args.way).repeat(self.args.shot) if torch.cuda.is_available(): label_shot = label_shot.type(torch.cuda.LongTensor) else: label_shot = label_shot.type(torch.LongTensor) # Run meta-validation for i, batch in enumerate(self.val_loader, 1): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] #data=data.float() p = self.args.shot * self.args.way data_shot, data_query = data[:p], data[p:] logits = self.model((data_shot, label_shot, data_query)) loss = F.cross_entropy(logits, label) acc = count_acc(logits, label) val_loss_averager.add(loss.item()) val_acc_averager.add(acc) # Update validation averagers val_loss_averager = val_loss_averager.item() val_acc_averager = val_acc_averager.item() # Write the tensorboardX records writer.add_scalar('data/val_loss', float(val_loss_averager), epoch) writer.add_scalar('data/val_acc', float(val_acc_averager), epoch) # Update best saved model if val_acc_averager > trlog['max_acc']: trlog['max_acc'] = val_acc_averager trlog['max_acc_epoch'] = epoch self.save_model('max_acc') # Save model every 10 epochs if epoch % 10 == 0: self.save_model('epoch' + str(epoch)) # Update the logs trlog['train_loss'].append(train_loss_averager) trlog['train_acc'].append(train_acc_averager) trlog['val_loss'].append(val_loss_averager) trlog['val_acc'].append(val_acc_averager) # Save log torch.save(trlog, osp.join(self.args.save_path, 'trlog')) if epoch % 10 == 0: print('Running Time: {}, Estimated Time: {}'.format( timer.measure(), timer.measure(epoch / self.args.max_epoch))) writer.close() def val_orig(self, X_val, y_val): predicted_loss = [] inputs = torch.from_numpy(X_val) labels = torch.FloatTensor(y_val * 1.0) inputs, labels = Variable(inputs), Variable(labels) results = [] predicted = [] self.model.eval() self.model.mode = 'origval' if torch.cuda.is_available(): inputs = inputs.type(torch.cuda.FloatTensor) else: inputs = inputs.type(torch.FloatTensor) predicted = self.model(inputs) predicted = predicted.data.cpu().numpy() Y = labels.data.numpy() predicted = np.argmax(predicted, axis=1) for param in ["acc", "auc", "recall", "precision", "fmeasure"]: if param == 'acc': results.append(accuracy_score(Y, np.round(predicted))) if param == "recall": results.append( recall_score(Y, np.round(predicted), average='micro')) if param == "fmeasure": precision = precision_score(Y, np.round(predicted), average='micro') recall = recall_score(Y, np.round(predicted), average='micro') results.append(2 * precision * recall / (precision + recall)) return predicted, results
class MetaTrainer(object): """The class that contains the code for the meta-train phase and meta-eval phase.""" def __init__(self, args): param = configs.__dict__[args.config]() args.shot = param.shot args.test = param.test args.debug = param.debug args.deconfound = param.deconfound args.meta_label = param.meta_label args.init_weights = param.init_weights self.test_iter = param.test_iter args.param = param pprint(vars(args)) # Set the folder to save the records and checkpoints log_base_dir = '/data2/yuezhongqi/Model/mtl/logs/' if not osp.exists(log_base_dir): os.mkdir(log_base_dir) meta_base_dir = osp.join(log_base_dir, 'meta') if not osp.exists(meta_base_dir): os.mkdir(meta_base_dir) save_path1 = '_'.join([args.dataset, args.model_type, 'MTL']) save_path2 = 'shot' + str(args.shot) + '_way' + str(args.way) + '_query' + str(args.train_query) + \ '_step' + str(args.step_size) + '_gamma' + str(args.gamma) + '_lr1' + str(args.meta_lr1) + '_lr2' + str(args.meta_lr2) + \ '_batch' + str(args.num_batch) + '_maxepoch' + str(args.max_epoch) + \ '_baselr' + str(args.base_lr) + '_updatestep' + str(args.update_step) + \ '_stepsize' + str(args.step_size) + '_' + args.meta_label args.save_path = meta_base_dir + '/' + save_path1 + '_' + save_path2 ensure_path(args.save_path) # Set args to be shareable in the class self.args = args # Load meta-train set self.trainset = Dataset('train', self.args, dataset=self.args.param.dataset, train_aug=False) num_workers = 8 if args.debug: num_workers = 0 self.train_sampler = CategoriesSampler( self.trainset.label, self.args.num_batch, self.args.way, self.args.shot + self.args.train_query) self.train_loader = DataLoader(dataset=self.trainset, batch_sampler=self.train_sampler, num_workers=num_workers, pin_memory=True) # Load meta-val set self.valset = Dataset('val', self.args, dataset=self.args.param.dataset, train_aug=False) self.val_sampler = CategoriesSampler( self.valset.label, self.test_iter, self.args.way, self.args.shot + self.args.val_query) self.val_loader = DataLoader(dataset=self.valset, batch_sampler=self.val_sampler, num_workers=num_workers, pin_memory=True) # Build meta-transfer learning model self.model = MtlLearner(self.args) # load pretrained model without FC classifier self.model.load_pretrain_weight(self.args.init_weights) ''' self.model_dict = self.model.state_dict() if self.args.init_weights is not None: pretrained_dict = torch.load(self.args.init_weights)['params'] else: pre_base_dir = osp.join(log_base_dir, 'pre') pre_save_path1 = '_'.join([args.dataset, args.model_type]) pre_save_path2 = 'batchsize' + str(args.pre_batch_size) + '_lr' + str(args.pre_lr) + '_gamma' + str(args.pre_gamma) + '_step' + \ str(args.pre_step_size) + '_maxepoch' + str(args.pre_max_epoch) pre_save_path = pre_base_dir + '/' + pre_save_path1 + '_' + pre_save_path2 pretrained_dict = torch.load(osp.join(pre_save_path, 'max_acc.pth'))['params'] pretrained_dict = {'encoder.'+k: v for k, v in pretrained_dict.items()} pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in self.model_dict} print(pretrained_dict.keys()) self.model_dict.update(pretrained_dict) self.model.load_state_dict(self.model_dict) ''' # Set model to GPU if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True self.model = self.model.cuda() if self.args.param.model == "wideres": print("Using Parallel") self.model.encoder = torch.nn.DataParallel( self.model.encoder).cuda() # Set optimizer self.optimizer = torch.optim.Adam( [{ 'params': filter(lambda p: p.requires_grad, self.model.encoder.parameters()) }, { 'params': self.model.base_learner.parameters(), 'lr': self.args.meta_lr2 }], lr=self.args.meta_lr1) # Set learning rate scheduler self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, step_size=self.args.step_size, gamma=self.args.gamma) if not self.args.deconfound: self.criterion = torch.nn.CrossEntropyLoss().cuda() else: self.criterion = torch.nn.NLLLoss().cuda() # Enable evaluation with Cross if args.cross: args.param.dataset = "cross" def write_output_message(self, message, file_name=None): if file_name is None: file_name = "results" # output_file = os.path.join(self.args.save_path, "results.txt") output_file = os.path.join("outputs", file_name + ".txt") with open(output_file, "a") as f: f.write(message + "\n") def save_model(self, name): """The function to save checkpoints. Args: name: the name for saved checkpoint """ torch.save(dict(params=self.model.state_dict()), osp.join(self.args.save_path, name + '.pth')) def train(self): """The function for the meta-train phase.""" # Set the meta-train log trlog = {} trlog['args'] = vars(self.args) trlog['train_loss'] = [] trlog['val_loss'] = [] trlog['train_acc'] = [] trlog['val_acc'] = [] trlog['max_acc'] = 0.0 trlog['max_acc_epoch'] = 0 # Set the timer timer = Timer() # Set global count to zero global_count = 0 # Set tensorboardX writer = SummaryWriter(comment=self.args.save_path) # Generate the labels for train set of the episodes label_shot = torch.arange(self.args.way).repeat(self.args.shot) if torch.cuda.is_available(): label_shot = label_shot.type(torch.cuda.LongTensor) else: label_shot = label_shot.type(torch.LongTensor) # Start meta-train for epoch in range(1, self.args.max_epoch + 1): # Update learning rate self.lr_scheduler.step() # Set the model to train mode self.model.train() # Set averager classes to record training losses and accuracies train_loss_averager = Averager() train_acc_averager = Averager() # Generate the labels for test set of the episodes during meta-train updates label = torch.arange(self.args.way).repeat(self.args.train_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) # Using tqdm to read samples from train loader tqdm_gen = tqdm.tqdm(self.train_loader) for i, batch in enumerate(tqdm_gen, 1): # Update global count number global_count = global_count + 1 if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] p = self.args.shot * self.args.way data_shot, data_query = data[:p], data[p:] # Output logits for model logits = self.model((data_shot, label_shot, data_query, False)) # Calculate meta-train loss loss = self.criterion(logits, label) # Calculate meta-train accuracy acc = count_acc(logits, label) # Write the tensorboardX records writer.add_scalar('data/loss', float(loss), global_count) writer.add_scalar('data/acc', float(acc), global_count) # Print loss and accuracy for this step tqdm_gen.set_description( 'Epoch {}, Loss={:.4f} Acc={:.4f}'.format( epoch, loss.item(), acc)) # Add loss and accuracy for the averagers train_loss_averager.add(loss.item()) train_acc_averager.add(acc) # Loss backwards and optimizer updates self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update the averagers train_loss_averager = train_loss_averager.item() train_acc_averager = train_acc_averager.item() # Start validation for this epoch, set model to eval mode self.model.eval() # Set averager classes to record validation losses and accuracies val_loss_averager = Averager() val_acc_averager = Averager() # Generate the labels for test set of the episodes during meta-val for this epoch label = torch.arange(self.args.way).repeat(self.args.val_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) # Print previous information if epoch % 10 == 0: print('Best Epoch {}, Best Val Acc={:.4f}'.format( trlog['max_acc_epoch'], trlog['max_acc'])) # Run meta-validation print_freq = int(self.test_iter / 5) if epoch > 0: for i, batch in enumerate(self.val_loader, 1): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] p = self.args.shot * self.args.way data_shot, data_query = data[:p], data[p:] logits = self.model( (data_shot, label_shot, data_query, True)) # loss = F.cross_entropy(logits, label) if not self.args.deconfound: loss = F.cross_entropy(logits, label) else: loss = F.nll_loss(logits, label) acc = count_acc(logits, label) val_loss_averager.add(loss.item()) val_acc_averager.add(acc) if i % print_freq == 0: # Update validation averagers val_loss_averager_item = val_loss_averager.item() val_acc_averager_item = val_acc_averager.item() # Write the tensorboardX records writer.add_scalar('data/val_loss', float(val_loss_averager_item), epoch) writer.add_scalar('data/val_acc', float(val_acc_averager_item), epoch) # Print loss and accuracy for this epoch print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format( epoch, val_loss_averager_item, val_acc_averager_item)) # Update validation averagers val_loss_averager = val_loss_averager.item() val_acc_averager = val_acc_averager.item() # Write the tensorboardX records writer.add_scalar('data/val_loss', float(val_loss_averager), epoch) writer.add_scalar('data/val_acc', float(val_acc_averager), epoch) # Print loss and accuracy for this epoch msg = 'Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format( epoch, val_loss_averager, val_acc_averager) print(msg) self.write_output_message(msg) # Update best saved model if val_acc_averager > trlog['max_acc']: trlog['max_acc'] = val_acc_averager trlog['max_acc_epoch'] = epoch self.save_model('max_acc') # Save model every 10 epochs if epoch % 10 == 0: self.save_model('epoch' + str(epoch)) # Update the logs trlog['train_loss'].append(train_loss_averager) trlog['train_acc'].append(train_acc_averager) trlog['val_loss'].append(val_loss_averager) trlog['val_acc'].append(val_acc_averager) # Save log torch.save(trlog, osp.join(self.args.save_path, 'trlog')) if epoch % 10 == 0: print('Running Time: {}, Estimated Time: {}'.format( timer.measure(), timer.measure(epoch / self.args.max_epoch))) writer.close() def eval(self): """The function for the meta-eval phase.""" # Load the logs # trlog = torch.load(osp.join(self.args.save_path, 'trlog')) num_workers = 8 if self.args.debug: num_workers = 0 self.test_iter = 2000 # Load meta-test set test_set = Dataset('test', self.args, dataset=self.args.param.dataset, train_aug=False) sampler = CategoriesSampler(test_set.label, self.test_iter, self.args.way, self.args.shot + self.args.val_query) loader = DataLoader(test_set, batch_sampler=sampler, num_workers=num_workers, pin_memory=True) # Set test accuracy recorder test_acc_record = np.zeros((self.test_iter, )) # Load model for meta-test phase if self.args.eval_weights is not None: self.model.load_state_dict( torch.load(self.args.eval_weights)['params']) else: # Load according to config file args = self.args base_path = "/data2/yuezhongqi/Model/ifsl/mtl" if args.param.dataset == "tiered": add_path = "tiered_" else: add_path = "" if args.param.model == "ResNet10": add_path += "resnet_" elif args.param.model == "wideres": add_path += "wrn_" elif "baseline" in args.config: add_path += "baseline_" else: add_path += "edsplit_" add_path += str(args.param.shot) self.add_path = add_path self.model.load_state_dict( torch.load(osp.join(base_path, add_path + '.pth'))['params']) # Set model to eval mode self.model.eval() # Set accuracy averager ave_acc = Averager() # Generate labels label = torch.arange(self.args.way).repeat(self.args.val_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) label_shot = torch.arange(self.args.way).repeat(self.args.shot) if torch.cuda.is_available(): label_shot = label_shot.type(torch.cuda.LongTensor) else: label_shot = label_shot.type(torch.LongTensor) hacc = Hacc() # Start meta-test for i, batch in enumerate(loader, 1): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] k = self.args.way * self.args.shot data_shot, data_query = data[:k], data[k:] logits = self.model((data_shot, label_shot, data_query, True)) acc = count_acc(logits, label) hardness, correct = get_hardness_correct(logits, label_shot, label, data_shot, data_query, self.model.pretrain) ave_acc.add(acc) hacc.add_data(hardness, correct) test_acc_record[i - 1] = acc if i % 100 == 0: #print('batch {}: {:.2f}({:.2f})'.format(i, ave_acc.item() * 100, acc * 100)) print("Average acc:{:.4f}, Average hAcc:{:.4f}".format( ave_acc.item(), hacc.get_topk_hard_acc())) # Modify add path to generate test case name: test_case_name = self.add_path if self.args.cross: test_case_name += "_cross" # Calculate the confidence interval, update the logs m, pm = compute_confidence_interval(test_acc_record) msg = test_case_name + ' Test Acc {:.4f} +- {:.4f}, hAcc {:.4f}'.format( ave_acc.item() * 100, pm * 100, hacc.get_topk_hard_acc()) print(msg) self.write_output_message(msg, test_case_name) if self.args.save_hacc: print("Saving hacc!") pickle.dump(hacc, open("hacc/" + test_case_name, "wb")) print('Test Acc {:.4f} + {:.4f}'.format(m, pm))
class MetaTrainer(object): def __init__(self, args): log_base_dir = './logs/' if not osp.exists(log_base_dir): os.mkdir(log_base_dir) meta_base_dir = osp.join(log_base_dir, 'meta') if not osp.exists(meta_base_dir): os.mkdir(meta_base_dir) save_path1 = '_'.join([args.dataset, args.model_type, 'MTL']) save_path2 = 'shot' + str(args.shot) + '_way' + str(args.way) + '_query' + str(args.train_query) + '_step' + str(args.step_size) + '_gamma' + str(args.gamma) + '_lr' + str(args.lr) + '_lrbase' + str(args.lr_base) + '_lrc' + str(args.lr_combination) + '_lrch' + str(args.lr_combination_hyperprior) + '_lrbs' + str(args.lr_basestep) + '_lrbsh' + str(args.lr_basestep_hyperprior) + '_batch' + str(args.num_batch) + '_maxepoch' + str(args.max_epoch) + '_csw' + str(args.hyperprior_combination_softweight) + '_cbsw' + str(args.hyperprior_basestep_softweight) + '_baselr' + str(args.base_lr) + '_updatestep' + str(args.update_step) + '_stepsize' + str(args.step_size) + '_' + args.label args.save_path = meta_base_dir + '/' + save_path1 + '_' + save_path2 ensure_path(args.save_path) self.args = args if self.args.dataset == 'MiniImageNet': from dataloader.mini_imagenet import MiniImageNet as Dataset elif self.args.dataset == 'TieredImageNet': from dataloader.tiered_imagenet import TieredImageNet as Dataset elif self.args.dataset == 'FC100': from dataloader.fewshotcifar import FewshotCifar as Dataset else: raise ValueError('Non-supported Dataset.') self.trainset = Dataset('train', self.args) self.train_sampler = CategoriesSampler(self.trainset.label, self.args.num_batch, self.args.way, self.args.shot + self.args.train_query) self.train_loader = DataLoader(dataset=self.trainset, batch_sampler=self.train_sampler, num_workers=8, pin_memory=True) self.valset = Dataset('val', self.args) self.val_sampler = CategoriesSampler(self.valset.label, 3000, self.args.way, self.args.shot + self.args.val_query) self.val_loader = DataLoader(dataset=self.valset, batch_sampler=self.val_sampler, num_workers=8, pin_memory=True) self.model = MtlLearner(self.args) new_para = filter(lambda p: p.requires_grad, self.model.encoder.parameters()) self.optimizer = torch.optim.Adam([{'params': new_para}, {'params': self.model.base_learner.parameters(), 'lr': self.args.lr_base}, {'params': self.model.get_hyperprior_combination_initialization_vars(), 'lr': self.args.lr_combination}, {'params': self.model.get_hyperprior_combination_mapping_vars(), 'lr': self.args.lr_combination_hyperprior}, {'params': self.model.get_hyperprior_basestep_initialization_vars(), 'lr': self.args.lr_basestep}, {'params': self.model.get_hyperprior_stepsize_mapping_vars(), 'lr': self.args.lr_basestep_hyperprior}], lr=self.args.lr) self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.args.step_size, gamma=self.args.gamma) self.model_dict = self.model.state_dict() if self.args.init_weights is not None: pretrained_dict = torch.load(self.args.init_weights)['params'] pretrained_dict = {'encoder.'+k: v for k, v in pretrained_dict.items()} pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in self.model_dict} print(pretrained_dict.keys()) self.model_dict.update(pretrained_dict) self.model.load_state_dict(self.model_dict) if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True self.model = self.model.cuda() def save_model(self, name): torch.save(dict(params=self.model.state_dict()), osp.join(self.args.save_path, name + '.pth')) def train(self): trlog = {} trlog['args'] = vars(self.args) trlog['train_loss'] = [] trlog['val_loss'] = [] trlog['train_acc'] = [] trlog['val_acc'] = [] trlog['max_acc'] = 0.0 trlog['max_acc_epoch'] = 0 timer = Timer() global_count = 0 writer = SummaryWriter(logdir=self.args.save_path) for epoch in range(1, self.args.max_epoch + 1): self.lr_scheduler.step() self.model.train() tl = Averager() ta = Averager() label = torch.arange(self.args.way).repeat(self.args.train_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) label_shot = torch.arange(self.args.way).repeat(self.args.shot) if torch.cuda.is_available(): label_shot = label_shot.type(torch.cuda.LongTensor) else: label_shot = label_shot.type(torch.LongTensor) tqdm_gen = tqdm.tqdm(self.train_loader) for i, batch in enumerate(tqdm_gen, 1): global_count = global_count + 1 if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] p = self.args.shot * self.args.way data_shot, data_query = data[:p], data[p:] logits, combination_list, basestep_list = self.model((data_shot, label_shot, data_query)) loss = F.cross_entropy(logits, label) acc = count_acc(logits, label) writer.add_scalar('data/loss', float(loss), global_count) writer.add_scalar('data/acc', float(acc), global_count) writer.add_scalar('combination_value/0', float(combination_list[0][0]), global_count) writer.add_scalar('combination_value/24', float(combination_list[24][0]), global_count) writer.add_scalar('combination_value/49', float(combination_list[49][0]), global_count) writer.add_scalar('combination_value/74', float(combination_list[74][0]), global_count) writer.add_scalar('combination_value/99', float(combination_list[99][0]), global_count) writer.add_scalar('basestep_value/0', float(basestep_list[0][0]), global_count) writer.add_scalar('basestep_value/24', float(basestep_list[24][0]), global_count) writer.add_scalar('basestep_value/49', float(basestep_list[49][0]), global_count) writer.add_scalar('basestep_value/74', float(basestep_list[74][0]), global_count) writer.add_scalar('basestep_value/99', float(basestep_list[99][0]), global_count) tqdm_gen.set_description('Epoch {}, Loss={:.4f} Acc={:.4f}'.format(epoch, loss.item(), acc)) tl.add(loss.item()) ta.add(acc) self.optimizer.zero_grad() loss.backward() self.optimizer.step() tl = tl.item() ta = ta.item() self.model.eval() vl = Averager() va = Averager() label = torch.arange(self.args.way).repeat(self.args.val_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) label_shot = torch.arange(self.args.way).repeat(self.args.shot) if torch.cuda.is_available(): label_shot = label_shot.type(torch.cuda.LongTensor) else: label_shot = label_shot.type(torch.LongTensor) print('Best Epoch {}, Best Val Acc={:.4f}'.format(trlog['max_acc_epoch'], trlog['max_acc'])) tqdm_gen1 = tqdm.tqdm(self.val_loader) for i, batch in enumerate(tqdm_gen1, 1): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] p = self.args.shot * self.args.way data_shot, data_query = data[:p], data[p:] logits, _, _ = self.model((data_shot, label_shot, data_query)) loss = F.cross_entropy(logits, label) acc = count_acc(logits, label) vl.add(loss.item()) va.add(acc) vl = vl.item() va = va.item() writer.add_scalar('data/val_loss', float(vl), epoch) writer.add_scalar('data/val_acc', float(va), epoch) print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format(epoch, vl, va)) if va > trlog['max_acc']: trlog['max_acc'] = va trlog['max_acc_epoch'] = epoch self.save_model('max_acc') if epoch % 10 == 0: self.save_model('epoch'+str(epoch)) trlog['train_loss'].append(tl) trlog['train_acc'].append(ta) trlog['val_loss'].append(vl) trlog['val_acc'].append(va) torch.save(trlog, osp.join(self.args.save_path, 'trlog')) self.save_model('epoch-last') writer.close()
class MetaTrainer(object): """The class that contains the code for the meta-train phase and meta-eval phase.""" def __init__(self, args): # Set the folder to save the records and checkpoints save_image_dir='../results/' if not osp.exists(save_image_dir): os.mkdir(save_image_dir) log_base_dir = '../logs/' if not osp.exists(log_base_dir): os.mkdir(log_base_dir) meta_base_dir = osp.join(log_base_dir, 'meta') if not osp.exists(meta_base_dir): os.mkdir(meta_base_dir) save_path1 = '_'.join([args.dataset, args.model_type, 'MTL']) save_path2 = 'shot' + str(args.shot) + '_way' + str(args.way) + '_query' + str(args.train_query) + \ '_step' + str(args.step_size) + '_gamma' + str(args.gamma) + '_lr1' + str(args.meta_lr1) + '_lr2' + str(args.meta_lr2) + \ '_batch' + str(args.num_batch) + '_maxepoch' + str(args.max_epoch) + \ '_baselr' + str(args.base_lr) + '_updatestep' + str(args.update_step) + \ '_stepsize' + str(args.step_size) + '_' + args.meta_label args.save_path = meta_base_dir + '/' + save_path1 + '_' + save_path2 args.save_image_dir=save_image_dir ensure_path(args.save_path) # Set args to be shareable in the class self.args = args # Load meta-train set self.trainset = mDataset('meta', self.args) self.train_sampler = CategoriesSampler(self.trainset.labeln, self.args.num_batch, self.args.way, self.args.shot + self.args.train_query,self.args.shot) self.train_loader = DataLoader(dataset=self.trainset, batch_sampler=self.train_sampler, num_workers=8, pin_memory=True) # Load meta-val set self.valset = mDataset('val', self.args) self.val_sampler = CategoriesSampler(self.valset.labeln, self.args.num_batch, self.args.way, self.args.shot + self.args.val_query,self.args.shot) self.val_loader = DataLoader(dataset=self.valset, batch_sampler=self.val_sampler, num_workers=8, pin_memory=True) # Build meta-transfer learning model self.model = MtlLearner(self.args) self.FL=FocalLoss() self.CD=CE_DiceLoss() self.LS=LovaszSoftmax() # Set optimizer self.optimizer = torch.optim.Adam([{'params': filter(lambda p: p.requires_grad, self.model.encoder.parameters())}, \ {'params': self.model.base_learner.parameters(), 'lr': self.args.meta_lr2}], lr=self.args.meta_lr1) # Set learning rate scheduler self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.args.step_size, gamma=self.args.gamma) # load pretrained model self.model_dict = self.model.state_dict() if self.args.init_weights is not None: pretrained_dict = torch.load(self.args.init_weights)['params'] else: pre_base_dir = osp.join(log_base_dir, 'pre') pre_save_path1 = '_'.join([args.dataset, args.model_type]) pre_save_path2 = 'batchsize' + str(args.pre_batch_size) + '_lr' + str(args.pre_lr) + '_gamma' + str(args.pre_gamma) + '_step' + \ str(args.pre_step_size) + '_maxepoch' + str(args.pre_max_epoch) pre_save_path = pre_base_dir + '/' + pre_save_path1 + '_' + pre_save_path2 pretrained_dict = torch.load(osp.join(pre_save_path, 'max_iou.pth'))['params'] pretrained_dict = {'encoder.'+k: v for k, v in pretrained_dict.items()} pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in self.model_dict} print(pretrained_dict.keys()) self.model_dict.update(pretrained_dict) self.model.load_state_dict(self.model_dict) # Set model to GPU if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True self.model = self.model.cuda() def _reset_metrics(self): #self.batch_time = AverageMeter() #self.data_time = AverageMeter() #self.total_loss = AverageMeter() self.total_inter, self.total_union = 0, 0 self.total_correct, self.total_label = 0, 0 def _update_seg_metrics(self, correct, labeled, inter, union): self.total_correct += correct self.total_label += labeled self.total_inter += inter self.total_union += union def _get_seg_metrics(self,n_class): self.n_class=n_class pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label) IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) mIoU = IoU.mean() return { "Pixel_Accuracy": np.round(pixAcc, 3), "Mean_IoU": np.round(mIoU, 3), "Class_IoU": dict(zip(range(self.n_class), np.round(IoU, 3))) } def save_model(self, name): """The function to save checkpoints. Args: name: the name for saved checkpoint """ torch.save(dict(params=self.model.state_dict()), osp.join(self.args.save_path, name + '.pth')) def train(self): """The function for the meta-train phase.""" # Set the meta-train log trlog = {} trlog['args'] = vars(self.args) trlog['train_loss'] = [] trlog['val_loss'] = [] trlog['train_acc'] = [] trlog['val_acc'] = [] trlog['train_iou'] = [] trlog['val_iou'] = [] trlog['max_iou'] = 0.0 trlog['max_iou_epoch'] = 0 # Set the timer timer = Timer() # Set global count to zero global_count = 0 # Set tensorboardX writer = SummaryWriter(comment=self.args.save_path) # Start meta-train for epoch in range(1, self.args.max_epoch + 1): # Update learning rate self.lr_scheduler.step() # Set the model to train mode self.model.train() # Set averager classes to record training losses and accuracies train_loss_averager = Averager() train_acc_averager = Averager() train_iou_averager = Averager() # Using tqdm to read samples from train loader tqdm_gen = tqdm.tqdm(self.train_loader) self._reset_metrics() for i, batch in enumerate(tqdm_gen, 1): # Update global count number global_count = global_count + 1 if torch.cuda.is_available(): data, labels,_ = [_.cuda() for _ in batch] else: data = batch[0] labels=batch[1] p = self.args.way*self.args.shot data_shot, data_query = data[:p], data[p:] label_shot,label=labels[:p],labels[p:] # Output logits for model par=data_shot, label_shot, data_query logits = self.model(par) # Calculate meta-train loss #loss = self.FL(logits, label) + self.CD(logits,label) + self.LS(logits,label) loss = self.CD(logits,label) # Calculate meta-train accuracy self._reset_metrics() seg_metrics = eval_metrics(logits, label, self.args.way) self._update_seg_metrics(*seg_metrics) pixAcc, mIoU, _ = self._get_seg_metrics(self.args.way).values() # Add loss and accuracy for the averagers train_loss_averager.add(loss.item()) train_acc_averager.add(pixAcc) train_iou_averager.add(mIoU) # Print loss and accuracy till this step tqdm_gen.set_description('Epoch {}, Loss={:.4f} Acc={:.4f} IoU={:.4f}'.format(epoch, train_loss_averager.item(), train_acc_averager.item()*100.0,train_iou_averager.item())) # Loss backwards and optimizer updates self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update the averagers train_loss_averager = train_loss_averager.item() train_acc_averager = train_acc_averager.item() train_iou_averager = train_iou_averager.item() writer.add_scalar('data/train_loss (Meta)', float(train_loss_averager), epoch) writer.add_scalar('data/train_acc (Meta)', float(train_acc_averager)*100.0, epoch) writer.add_scalar('data/train_iou (Meta)', float(train_iou_averager), epoch) # Start validation for this epoch, set model to eval mode self.model.eval() # Set averager classes to record validation losses and accuracies val_loss_averager = Averager() val_acc_averager = Averager() val_iou_averager = Averager() # Print previous information if epoch % 1 == 0: print('Best Val Epoch {}, Best Val IoU={:.4f}'.format(trlog['max_iou_epoch'], trlog['max_iou'])) # Run meta for i, batch in enumerate(self.val_loader, 1): if torch.cuda.is_available(): data, labels,_ = [_.cuda() for _ in batch] else: data = batch[0] labels=batch[1] p = self.args.way* self.args.shot data_shot, data_query = data[:p], data[p:] label_shot,label=labels[:p],labels[p:] par=data_shot, label_shot, data_query logits = self.model(par) # Calculate meta val loss #loss = self.FL(logits, label) + self.CD(logits,label) + self.LS(logits,label) loss = self.CD(logits,label) # Calculate meta-val accuracy self._reset_metrics() seg_metrics = eval_metrics(logits, label, self.args.way) self._update_seg_metrics(*seg_metrics) pixAcc, mIoU, _ = self._get_seg_metrics(self.args.way).values() val_loss_averager.add(loss.item()) val_acc_averager.add(pixAcc) val_iou_averager.add(mIoU) # Update validation averagers val_loss_averager = val_loss_averager.item() val_acc_averager = val_acc_averager.item() val_iou_averager = val_iou_averager.item() # Write the tensorboardX records writer.add_scalar('data/val_loss (Meta)', float(val_loss_averager), epoch) writer.add_scalar('data/val_acc (Meta)', float(val_acc_averager)*100.0, epoch) writer.add_scalar('data/val_iou (Meta)', float(val_iou_averager), epoch) # Print loss and accuracy for this epoch print('Epoch {}, Val, Loss={:.4f} Acc={:.4f} IoU={:.4f}'.format(epoch, val_loss_averager, val_acc_averager*100.0,val_iou_averager)) # Update best saved model if val_iou_averager > trlog['max_iou']: trlog['max_iou'] = val_iou_averager trlog['max_iou_epoch'] = epoch self.save_model('max_iou') # Save model every 10 epochs if epoch % 10 == 0: self.save_model('epoch'+str(epoch)) # Update the logs trlog['train_loss'].append(train_loss_averager) trlog['train_acc'].append(train_acc_averager) trlog['val_loss'].append(val_loss_averager) trlog['val_acc'].append(val_acc_averager) trlog['train_iou'].append(train_iou_averager) trlog['val_iou'].append(val_iou_averager) # Save log torch.save(trlog, osp.join(self.args.save_path, 'trlog')) if epoch % 1 == 0: print('Running Time: {}, Estimated Time: {}'.format(timer.measure(), timer.measure(epoch / self.args.max_epoch))) writer.close() def eval(self): """The function for the meta-evaluate (test) phase.""" # Load the logs trlog = torch.load(osp.join(self.args.save_path, 'trlog')) # Load meta-test set self.test_set = mDataset('test', self.args) self.sampler = CategoriesSampler(self.test_set.labeln, self.args.num_batch, self.args.way, self.args.teshot + self.args.test_query, self.args.teshot) self.loader = DataLoader(dataset=self.test_set, batch_sampler=self.sampler, num_workers=8, pin_memory=True) #self.loader = DataLoader(dataset=self.test_set,batch_size=10, shuffle=False, num_workers=8, pin_memory=True) # Set test accuracy recorder #test_acc_record = np.zeros((600,)) # Load model for meta-test phase if self.args.eval_weights is not None: self.model.load_state_dict(torch.load(self.args.eval_weights)['params']) else: self.model.load_state_dict(torch.load(osp.join(self.args.save_path, 'max_iou' + '.pth'))['params']) # Set model to eval mode self.model.eval() # Set accuracy averager ave_acc = Averager() # Start meta-test self._reset_metrics() count=1 for i, batch in enumerate(self.loader, 1): if torch.cuda.is_available(): data, labels,_ = [_.cuda() for _ in batch] else: data = batch[0] labels=batch[1] p = self.args.teshot*self.args.way data_shot, data_query = data[:p], data[p:] label_shot,label=labels[:p],labels[p:] logits = self.model((data_shot, label_shot, data_query)) seg_metrics = eval_metrics(logits, label, self.args.way) self._update_seg_metrics(*seg_metrics) pixAcc, mIoU, _ = self._get_seg_metrics(self.args.way).values() ave_acc.add(pixAcc) #test_acc_record[i-1] = acc #if i % 100 == 0: #print('batch {}: {Average Accuracy:.2f}({Pixel Accuracy:.2f} {IoU :.2f} )'.format(i, ave_acc.item() * 100.0, pixAcc * 100.0,mIoU)) #Saving Test Image, Ground Truth Image and Predicted Image for j in range(len(data_query)): x1 = data_query[j].detach().cpu() y1 = label[j].detach().cpu() z1 = logits[j].detach().cpu() x = transforms.ToPILImage()(x1).convert("RGB") y = transforms.ToPILImage()(y1 /(1.0*(self.args.way-1))).convert("LA") im = torch.tensor(np.argmax(np.array(z1),axis=0)/(1.0*(self.args.way-1))) im = im.type(torch.FloatTensor) z = transforms.ToPILImage()(im).convert("LA") px=self.args.save_image_dir+str(count)+'a.jpg' py=self.args.save_image_dir+str(count)+'b.png' pz=self.args.save_image_dir+str(count)+'c.png' x.save(px) y.save(py) z.save(pz) count=count+1
class MetaTrainer(object): """The class that contains the code for the meta-train phase and meta-eval phase.""" def __init__(self, args): # Set the folder to save the records and checkpoints log_base_dir = './logs/' if not osp.exists(log_base_dir): os.mkdir(log_base_dir) meta_base_dir = osp.join(log_base_dir, 'meta') if not osp.exists(meta_base_dir): os.mkdir(meta_base_dir) save_path1 = '_'.join([args.dataset, args.model_type, 'MTL']) save_path2 = 'shot' + str(args.shot) + '_way' + str(args.way) + '_query' + str(args.train_query) + \ '_step' + str(args.step_size) + '_gamma' + str(args.gamma) + '_lr1' + str(args.meta_lr1) + '_lr2' + str(args.meta_lr2) + \ '_batch' + str(args.num_batch) + '_maxepoch' + str(args.max_epoch) + \ '_baselr' + str(args.base_lr) + '_updatestep' + str(args.update_step) + \ '_stepsize' + str(args.step_size) + '_' + args.meta_label args.save_path = meta_base_dir + '/' + save_path1 + '_' + save_path2 ensure_path(args.save_path) # Set args to be shareable in the class self.args = args # Load meta-train set self.trainset = Dataset('train', self.args) self.train_sampler = CategoriesSampler( self.trainset.label, self.args.num_batch, self.args.way, self.args.shot + self.args.train_query) self.train_loader = DataLoader(dataset=self.trainset, batch_sampler=self.train_sampler, num_workers=args.num_workers, pin_memory=True) # Load meta-val set self.valset = Dataset('val', self.args) self.val_sampler = CategoriesSampler( self.valset.label, 600, self.args.way, self.args.shot + self.args.val_query) self.val_loader = DataLoader(dataset=self.valset, batch_sampler=self.val_sampler, num_workers=args.num_workers, pin_memory=True) # Build meta-transfer learning model self.model = MtlLearner(self.args,res="high" if (self.args.distill_id or self.args.high_res) else "low",multi_gpu=len(args.gpu.split(","))>1,\ crossAtt=self.args.cross_att) if self.args.distill_id: #self.teacher = MtlLearner(self.args,res="low") #self.teacher.load_state_dict(torch.load(args.distill_id)["params"]) self.teacher = MtlLearner(self.args, res="low", repVecNb=self.args.nb_parts_teach, multi_gpu=len(args.gpu.split(",")) > 1) bestTeach = "../models/{}/meta_{}_trial{}_max_acc.pth".format( self.args.exp_id, self.args.distill_id, self.args.best_trial_teach - 1) self.teacher.load_state_dict(torch.load(bestTeach)["params"]) # Set optimizer self.optimizer = torch.optim.Adam([{'params': filter(lambda p: p.requires_grad, self.model.encoder.parameters())}, \ {'params': self.model.base_learner.parameters(), 'lr': self.args.meta_lr2}], lr=self.args.meta_lr1) # Set learning rate scheduler self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, step_size=self.args.step_size, gamma=self.args.gamma) # load pretrained model without FC classifier self.model_dict = self.model.state_dict() if self.args.init_weights is not None: pretrained_dict = torch.load(self.args.init_weights)['params'] pretrained_dict = { 'encoder.' + k: v for k, v in pretrained_dict.items() } pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in self.model_dict } self.model_dict.update(pretrained_dict) self.model.load_state_dict(self.model_dict) # Set model to GPU if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True self.model = self.model.cuda() if self.args.distill_id: self.teacher = self.teacher.cuda() if self.args.cross_att: self.criterion = crossAttModule.CrossEntropyLoss() def crossAttLoss(self, ytest, cls_scores, labels_test, pids): loss1 = self.criterion(ytest, pids.view(-1)) loss2 = self.criterion(cls_scores, labels_test.view(-1)) loss = loss1 + 0.5 * loss2 return loss def one_hot(self, labels_train): """ Turn the labels_train to one-hot encoding. Args: labels_train: [batch_size, num_train_examples] Return: labels_train_1hot: [batch_size, num_train_examples, K] """ labels_train = labels_train.cpu() nKnovel = 1 + labels_train.max() labels_train_1hot_size = list(labels_train.size()) + [ nKnovel, ] labels_train_unsqueeze = labels_train.unsqueeze(dim=labels_train.dim()) labels_train_1hot = torch.zeros(labels_train_1hot_size).scatter_( len(labels_train_1hot_size) - 1, labels_train_unsqueeze, 1) return labels_train_1hot def save_model(self, name): """The function to save checkpoints. Args: name: the name for saved checkpoint """ #torch.save(dict(params=self.model.encoder.state_dict()), osp.join(self.args.save_path, name + '.pth')) torch.save( dict(params=self.model.state_dict()), "../models/{}/meta_{}_trial{}_{}.pth".format( self.args.exp_id, self.args.model_id, self.args.trial_number, name)) def train(self, trial): """The function for the meta-train phase.""" # Set the meta-train log trlog = {} trlog['args'] = vars(self.args) trlog['train_loss'] = [] trlog['val_loss'] = [] trlog['train_acc'] = [] trlog['val_acc'] = [] trlog['max_acc'] = 0.0 trlog['max_acc_epoch'] = 0 # Set the timer timer = Timer() # Set global count to zero global_count = 0 # Set tensorboardX writer = SummaryWriter(comment=self.args.save_path) # Generate the labels for train set of the episodes label_shot = torch.arange(self.args.way).repeat(self.args.shot) if torch.cuda.is_available(): label_shot = label_shot.type(torch.cuda.LongTensor) else: label_shot = label_shot.type(torch.LongTensor) worstClasses = [] # Start meta-train for epoch in range(1, self.args.max_epoch + 1): # Update learning rate self.lr_scheduler.step() # Set the model to train mode self.model.train() # Set averager classes to record training losses and accuracies train_loss_averager = Averager() train_acc_averager = Averager() # Generate the labels for test set of the episodes during meta-train updates label = torch.arange(self.args.way).repeat(self.args.train_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) # Using tqdm to read samples from train loader tqdm_gen = tqdm.tqdm(self.train_loader) for i, batch in enumerate(tqdm_gen, 1): # Update global count number global_count = global_count + 1 if torch.cuda.is_available(): data, targ = [_.cuda() for _ in batch] else: data, targ = batch p = self.args.shot * self.args.way data_shot, data_query = data[:p], data[p:] # Output logits for model if self.args.cross_att: label_one_hot = self.one_hot(label).to(label.device) ytest, cls_scores, logits = self.model( (data_shot, label_shot, data_query), ytest=label_one_hot) pids = label_shot loss = self.crossAttLoss(ytest, cls_scores, label, pids) logits = logits[0] else: logits = self.model((data_shot, label_shot, data_query)) # Calculate meta-train loss loss = F.cross_entropy(logits, label) if self.args.distill_id: teachLogits = self.teacher( (data_shot, label_shot, data_query)) kl = F.kl_div(F.log_softmax(logits / self.args.kl_temp, dim=1), F.softmax(teachLogits / self.args.kl_temp, dim=1), reduction="batchmean") loss = (kl * self.args.kl_interp * self.args.kl_temp * self.args.kl_temp + loss * (1 - self.args.kl_interp)) acc = count_acc(logits, label) # Write the tensorboardX records writer.add_scalar('data/loss', float(loss), global_count) writer.add_scalar('data/acc', float(acc), global_count) # Print loss and accuracy for this step tqdm_gen.set_description( 'Epoch {}, Loss={:.4f} Acc={:.4f}'.format( epoch, loss.item(), acc)) # Add loss and accuracy for the averagers train_loss_averager.add(loss.item()) train_acc_averager.add(acc) # Loss backwards and optimizer updates self.optimizer.zero_grad() loss.backward() self.optimizer.step() if self.args.hard_tasks: if len(worstClasses) == self.args.way: inds = self.train_sampler.hardBatch(worstClasses) batch = [self.trainset[i][0] for i in inds] data_shot, data_query = data[:p], data[p:] logits = self.model( (data_shot, label_shot, data_query)) loss = F.cross_entropy(logits, label) self.optimizer.zero_grad() loss.backward() self.optimizer.step() worstClasses = [] else: error_mat = (logits.argmax(dim=1) == label).view( self.args.train_query, self.args.way) worst = error_mat.float().mean(dim=0).argmin() worst_trueInd = targ[worst] worstClasses.append(worst_trueInd) # Update the averagers train_loss_averager = train_loss_averager.item() train_acc_averager = train_acc_averager.item() # Start validation for this epoch, set model to eval mode self.model.eval() # Set averager classes to record validation losses and accuracies val_loss_averager = Averager() val_acc_averager = Averager() # Generate the labels for test set of the episodes during meta-val for this epoch label = torch.arange(self.args.way).repeat(self.args.val_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) # Print previous information if epoch % 10 == 0: print('Best Epoch {}, Best Val Acc={:.4f}'.format( trlog['max_acc_epoch'], trlog['max_acc'])) # Run meta-validation for i, batch in enumerate(self.val_loader, 1): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] p = self.args.shot * self.args.way data_shot, data_query = data[:p], data[p:] if self.args.cross_att: label_one_hot = self.one_hot(label).to(label.device) ytest, cls_scores, logits = self.model( (data_shot, label_shot, data_query), ytest=label_one_hot) pids = label_shot loss = self.crossAttLoss(ytest, cls_scores, label, pids) logits = logits[0] else: logits = self.model((data_shot, label_shot, data_query)) loss = F.cross_entropy(logits, label) acc = count_acc(logits, label) val_loss_averager.add(loss.item()) val_acc_averager.add(acc) # Update validation averagers val_loss_averager = val_loss_averager.item() val_acc_averager = val_acc_averager.item() # Write the tensorboardX records writer.add_scalar('data/val_loss', float(val_loss_averager), epoch) writer.add_scalar('data/val_acc', float(val_acc_averager), epoch) # Print loss and accuracy for this epoch print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format( epoch, val_loss_averager, val_acc_averager)) # Update best saved model if val_acc_averager > trlog['max_acc']: trlog['max_acc'] = val_acc_averager trlog['max_acc_epoch'] = epoch self.save_model('max_acc') # Update the logs trlog['train_loss'].append(train_loss_averager) trlog['train_acc'].append(train_acc_averager) trlog['val_loss'].append(val_loss_averager) trlog['val_acc'].append(val_acc_averager) # Save log torch.save(trlog, osp.join(self.args.save_path, 'trlog')) if epoch % 10 == 0: print('Running Time: {}, Estimated Time: {}'.format( timer.measure(), timer.measure(epoch / self.args.max_epoch))) trial.report(val_acc_averager, epoch) writer.close() def eval(self, gradcam=False, rise=False, test_on_val=False): """The function for the meta-eval phase.""" # Load the logs if os.path.exists(osp.join(self.args.save_path, 'trlog')): trlog = torch.load(osp.join(self.args.save_path, 'trlog')) else: trlog = None torch.manual_seed(1) np.random.seed(1) # Load meta-test set test_set = Dataset('val' if test_on_val else 'test', self.args) sampler = CategoriesSampler(test_set.label, 600, self.args.way, self.args.shot + self.args.val_query) loader = DataLoader(test_set, batch_sampler=sampler, num_workers=8, pin_memory=True) # Set test accuracy recorder test_acc_record = np.zeros((600, )) # Load model for meta-test phase if self.args.eval_weights is not None: weights = self.addOrRemoveModule( self.model, torch.load(self.args.eval_weights)['params']) self.model.load_state_dict(weights) else: self.model.load_state_dict( torch.load(osp.join(self.args.save_path, 'max_acc' + '.pth'))['params']) # Set model to eval mode self.model.eval() # Set accuracy averager ave_acc = Averager() # Generate labels label = torch.arange(self.args.way).repeat(self.args.val_query) if torch.cuda.is_available(): label = label.type(torch.cuda.LongTensor) else: label = label.type(torch.LongTensor) label_shot = torch.arange(self.args.way).repeat(self.args.shot) if torch.cuda.is_available(): label_shot = label_shot.type(torch.cuda.LongTensor) else: label_shot = label_shot.type(torch.LongTensor) if gradcam: self.model.layer3 = self.model.encoder.layer3 model_dict = dict(type="resnet", arch=self.model, layer_name='layer3') grad_cam = GradCAM(model_dict, True) grad_cam_pp = GradCAMpp(model_dict, True) self.model.features = self.model.encoder guided = GuidedBackprop(self.model) if rise: self.model.layer3 = self.model.encoder.layer3 score_mod = ScoreCam(self.model) # Start meta-test for i, batch in enumerate(loader, 1): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] k = self.args.way * self.args.shot data_shot, data_query = data[:k], data[k:] if i % 5 == 0: suff = "_val" if test_on_val else "" if self.args.rep_vec or self.args.cross_att: print('batch {}: {:.2f}({:.2f})'.format( i, ave_acc.item() * 100, acc * 100)) if self.args.cross_att: label_one_hot = self.one_hot(label).to(label.device) _, _, logits, simMapQuer, simMapShot, normQuer, normShot = self.model( (data_shot, label_shot, data_query), ytest=label_one_hot, retSimMap=True) else: logits, simMapQuer, simMapShot, normQuer, normShot, fast_weights = self.model( (data_shot, label_shot, data_query), retSimMap=True) torch.save( simMapQuer, "../results/{}/{}_simMapQuer{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) torch.save( simMapShot, "../results/{}/{}_simMapShot{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) torch.save( data_query, "../results/{}/{}_dataQuer{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) torch.save( data_shot, "../results/{}/{}_dataShot{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) torch.save( normQuer, "../results/{}/{}_normQuer{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) torch.save( normShot, "../results/{}/{}_normShot{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) else: logits, normQuer, normShot, fast_weights = self.model( (data_shot, label_shot, data_query), retFastW=True, retNorm=True) torch.save( normQuer, "../results/{}/{}_normQuer{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) torch.save( normShot, "../results/{}/{}_normShot{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) if gradcam: print("Saving gradmaps", i) allMasks, allMasks_pp, allMaps = [], [], [] for l in range(len(data_query)): allMasks.append( grad_cam(data_query[l:l + 1], fast_weights, None)) allMasks_pp.append( grad_cam_pp(data_query[l:l + 1], fast_weights, None)) allMaps.append( guided.generate_gradients(data_query[l:l + 1], fast_weights)) allMasks = torch.cat(allMasks, dim=0) allMasks_pp = torch.cat(allMasks_pp, dim=0) allMaps = torch.cat(allMaps, dim=0) torch.save( allMasks, "../results/{}/{}_gradcamQuer{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) torch.save( allMasks_pp, "../results/{}/{}_gradcamppQuer{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) torch.save( allMaps, "../results/{}/{}_guidedQuer{}{}.th".format( self.args.exp_id, self.args.model_id, i, suff)) if rise: print("Saving risemaps", i) allScore = [] for l in range(len(data_query)): allScore.append( score_mod(data_query[l:l + 1], fast_weights)) else: if self.args.cross_att: label_one_hot = self.one_hot(label).to(label.device) _, _, logits = self.model( (data_shot, label_shot, data_query), ytest=label_one_hot) else: logits = self.model((data_shot, label_shot, data_query)) acc = count_acc(logits, label) ave_acc.add(acc) test_acc_record[i - 1] = acc # Calculate the confidence interval, update the logs m, pm = compute_confidence_interval(test_acc_record) if trlog is not None: print('Val Best Epoch {}, Acc {:.4f}, Test Acc {:.4f}'.format( trlog['max_acc_epoch'], trlog['max_acc'], ave_acc.item())) print('Test Acc {:.4f} + {:.4f}'.format(m, pm)) return m def addOrRemoveModule(self, net, weights): exKeyWei = None for key in weights: if key.find("encoder") != -1: exKeyWei = key break else: print(key) exKeyNet = None for key in net.state_dict(): if key.find("encoder") != -1: exKeyNet = key break print(exKeyWei, exKeyNet) if exKeyWei.find("module") != -1 and exKeyNet.find("module") == -1: #remove module newWeights = {} for param in weights: newWeights[param.replace("module.", "")] = weights[param] weights = newWeights if exKeyWei.find("module") == -1 and exKeyNet.find("module") != -1: #add module newWeights = {} for param in weights: if param.find("encoder") != -1: param_split = param.split(".") newParam = param_split[0] + "." + "module." + ".".join( param_split[1:]) newWeights[newParam] = weights[param] else: newWeights[param] = weights[param] weights = newWeights return weights