def train(self): """ train function :return: None """ num_supports, num_samples, query_edge_mask, evaluation_mask = \ preprocessing(self.train_opt['num_ways'], self.train_opt['num_shots'], self.train_opt['num_queries'], self.train_opt['batch_size'], self.arg.device) # main training loop, batch size is the number of tasks for iteration, batch in enumerate(self.data_loader['train']()): # init grad self.optimizer.zero_grad() # set current step self.global_step += 1 # initialize nodes and edges for dual graph model support_data, support_label, query_data, query_label, all_data, all_label_in_edge, node_feature_gd, \ edge_feature_gp, edge_feature_gd = initialize_nodes_edges(batch, num_supports, self.tensors, self.train_opt['batch_size'], self.train_opt['num_queries'], self.train_opt['num_ways'], self.arg.device) # set as train mode self.enc_module.train() self.gnn_module.train() # use backbone encode image last_layer_data, second_last_layer_data = backbone_two_stage_initialization( all_data, self.enc_module) # run the DPGN model point_similarity, node_similarity_l2, distribution_similarities = self.gnn_module( second_last_layer_data, last_layer_data, node_feature_gd, edge_feature_gd, edge_feature_gp) # compute loss total_loss, query_node_cls_acc_generations, query_edge_loss_generations = \ self.compute_train_loss_pred(all_label_in_edge, point_similarity, node_similarity_l2, query_edge_mask, evaluation_mask, num_supports, support_label, query_label, distribution_similarities) # back propagation & update total_loss.backward() self.optimizer.step() # adjust learning rate adjust_learning_rate(optimizers=[self.optimizer], lr=self.train_opt['lr'], iteration=self.global_step, dec_lr_step=self.train_opt['dec_lr'], lr_adj_base=self.train_opt['lr_adj_base']) # log training info if self.global_step % self.arg.log_step == 0: self.log.info( 'step : {} train_edge_loss : {} node_acc : {}'.format( self.global_step, query_edge_loss_generations[-1], query_node_cls_acc_generations[-1])) # evaluation if self.global_step % self.eval_opt['interval'] == 0: is_best = 0 test_acc = self.eval(partition='test') if test_acc > self.test_acc: is_best = 1 self.test_acc = test_acc self.best_step = self.global_step # log evaluation info self.log.info('test_acc : {} step : {} '.format( test_acc, self.global_step)) self.log.info('test_best_acc : {} step : {}'.format( self.test_acc, self.best_step)) # save checkpoints (best and newest) save_checkpoint( { 'iteration': self.global_step, 'enc_module_state_dict': self.enc_module.state_dict(), 'gnn_module_state_dict': self.gnn_module.state_dict(), 'test_acc': self.test_acc, 'optimizer': self.optimizer.state_dict(), }, is_best, os.path.join(self.arg.checkpoint_dir, self.arg.exp_name))
def eval(self, partition='test', log_flag=True): """ evaluation function :param partition: which part of data is used :param log_flag: if log the evaluation info :return: None """ num_supports, num_samples, query_edge_mask, evaluation_mask = preprocessing( self.eval_opt['num_ways'], self.eval_opt['num_shots'], self.eval_opt['num_queries'], self.eval_opt['batch_size'], self.arg.device) query_edge_loss_generations = [] query_node_cls_acc_generations = [] # main training loop, batch size is the number of tasks for current_iteration, batch in enumerate( self.data_loader[partition]()): # initialize nodes and edges for dual graph model support_data, support_label, query_data, query_label, all_data, all_label_in_edge, node_feature_gd, \ edge_feature_gp, edge_feature_gd = initialize_nodes_edges(batch, num_supports, self.tensors, self.eval_opt['batch_size'], self.eval_opt['num_queries'], self.eval_opt['num_ways'], self.arg.device) # set as eval mode self.enc_module.eval() self.gnn_module.eval() last_layer_data, second_last_layer_data = backbone_two_stage_initialization( all_data, self.enc_module) # run the DPGN model point_similarity, _, _ = self.gnn_module(second_last_layer_data, last_layer_data, node_feature_gd, edge_feature_gd, edge_feature_gp) query_node_cls_acc_generations, query_edge_loss_generations = \ self.compute_eval_loss_pred(query_edge_loss_generations, query_node_cls_acc_generations, all_label_in_edge, point_similarity, query_edge_mask, evaluation_mask, num_supports, support_label, query_label) # logging if log_flag: self.log.info('------------------------------------') self.log.info( 'step : {} {}_edge_loss : {} {}_node_acc : {}'.format( self.global_step, partition, np.array(query_edge_loss_generations).mean(), partition, np.array(query_node_cls_acc_generations).mean())) self.log.info( 'evaluation: total_count=%d, accuracy: mean=%.2f%%, std=%.2f%%, ci95=%.2f%%' % (current_iteration, np.array(query_node_cls_acc_generations).mean() * 100, np.array(query_node_cls_acc_generations).std() * 100, 1.96 * np.array(query_node_cls_acc_generations).std() / np.sqrt(float(len( np.array(query_node_cls_acc_generations)))) * 100)) self.log.info('------------------------------------') return np.array(query_node_cls_acc_generations).mean()