示例#1
0
    def train(self):
        """
        train function
        :return: None
        """

        num_supports, num_samples, query_edge_mask, evaluation_mask = \
            preprocessing(self.train_opt['num_ways'],
                          self.train_opt['num_shots'],
                          self.train_opt['num_queries'],
                          self.train_opt['batch_size'],
                          self.arg.device)

        # main training loop, batch size is the number of tasks
        for iteration, batch in enumerate(self.data_loader['train']()):
            # init grad
            self.optimizer.zero_grad()

            # set current step
            self.global_step += 1

            # initialize nodes and edges for dual graph model
            support_data, support_label, query_data, query_label, all_data, all_label_in_edge, node_feature_gd, \
            edge_feature_gp, edge_feature_gd = initialize_nodes_edges(batch,
                                                                      num_supports,
                                                                      self.tensors,
                                                                      self.train_opt['batch_size'],
                                                                      self.train_opt['num_queries'],
                                                                      self.train_opt['num_ways'],
                                                                      self.arg.device)

            # set as train mode
            self.enc_module.train()
            self.gnn_module.train()

            # use backbone encode image
            last_layer_data, second_last_layer_data = backbone_two_stage_initialization(
                all_data, self.enc_module)

            # run the DPGN model
            point_similarity, node_similarity_l2, distribution_similarities = self.gnn_module(
                second_last_layer_data, last_layer_data, node_feature_gd,
                edge_feature_gd, edge_feature_gp)

            # compute loss
            total_loss, query_node_cls_acc_generations, query_edge_loss_generations = \
                self.compute_train_loss_pred(all_label_in_edge,
                                             point_similarity,
                                             node_similarity_l2,
                                             query_edge_mask,
                                             evaluation_mask,
                                             num_supports,
                                             support_label,
                                             query_label,
                                             distribution_similarities)

            # back propagation & update
            total_loss.backward()
            self.optimizer.step()

            # adjust learning rate
            adjust_learning_rate(optimizers=[self.optimizer],
                                 lr=self.train_opt['lr'],
                                 iteration=self.global_step,
                                 dec_lr_step=self.train_opt['dec_lr'],
                                 lr_adj_base=self.train_opt['lr_adj_base'])

            # log training info
            if self.global_step % self.arg.log_step == 0:
                self.log.info(
                    'step : {}  train_edge_loss : {}  node_acc : {}'.format(
                        self.global_step, query_edge_loss_generations[-1],
                        query_node_cls_acc_generations[-1]))

            # evaluation
            if self.global_step % self.eval_opt['interval'] == 0:
                is_best = 0
                test_acc = self.eval(partition='test')
                if test_acc > self.test_acc:
                    is_best = 1
                    self.test_acc = test_acc
                    self.best_step = self.global_step

                # log evaluation info
                self.log.info('test_acc : {}         step : {} '.format(
                    test_acc, self.global_step))
                self.log.info('test_best_acc : {}    step : {}'.format(
                    self.test_acc, self.best_step))

                # save checkpoints (best and newest)
                save_checkpoint(
                    {
                        'iteration': self.global_step,
                        'enc_module_state_dict': self.enc_module.state_dict(),
                        'gnn_module_state_dict': self.gnn_module.state_dict(),
                        'test_acc': self.test_acc,
                        'optimizer': self.optimizer.state_dict(),
                    }, is_best,
                    os.path.join(self.arg.checkpoint_dir, self.arg.exp_name))
示例#2
0
    def eval(self, partition='test', log_flag=True):
        """
        evaluation function
        :param partition: which part of data is used
        :param log_flag: if log the evaluation info
        :return: None
        """

        num_supports, num_samples, query_edge_mask, evaluation_mask = preprocessing(
            self.eval_opt['num_ways'], self.eval_opt['num_shots'],
            self.eval_opt['num_queries'], self.eval_opt['batch_size'],
            self.arg.device)

        query_edge_loss_generations = []
        query_node_cls_acc_generations = []
        # main training loop, batch size is the number of tasks
        for current_iteration, batch in enumerate(
                self.data_loader[partition]()):

            # initialize nodes and edges for dual graph model
            support_data, support_label, query_data, query_label, all_data, all_label_in_edge, node_feature_gd, \
            edge_feature_gp, edge_feature_gd = initialize_nodes_edges(batch,
                                                                      num_supports,
                                                                      self.tensors,
                                                                      self.eval_opt['batch_size'],
                                                                      self.eval_opt['num_queries'],
                                                                      self.eval_opt['num_ways'],
                                                                      self.arg.device)

            # set as eval mode
            self.enc_module.eval()
            self.gnn_module.eval()

            last_layer_data, second_last_layer_data = backbone_two_stage_initialization(
                all_data, self.enc_module)

            # run the DPGN model
            point_similarity, _, _ = self.gnn_module(second_last_layer_data,
                                                     last_layer_data,
                                                     node_feature_gd,
                                                     edge_feature_gd,
                                                     edge_feature_gp)

            query_node_cls_acc_generations, query_edge_loss_generations = \
                self.compute_eval_loss_pred(query_edge_loss_generations,
                                            query_node_cls_acc_generations,
                                            all_label_in_edge,
                                            point_similarity,
                                            query_edge_mask,
                                            evaluation_mask,
                                            num_supports,
                                            support_label,
                                            query_label)

        # logging
        if log_flag:
            self.log.info('------------------------------------')
            self.log.info(
                'step : {}  {}_edge_loss : {}  {}_node_acc : {}'.format(
                    self.global_step, partition,
                    np.array(query_edge_loss_generations).mean(), partition,
                    np.array(query_node_cls_acc_generations).mean()))

            self.log.info(
                'evaluation: total_count=%d, accuracy: mean=%.2f%%, std=%.2f%%, ci95=%.2f%%'
                % (current_iteration,
                   np.array(query_node_cls_acc_generations).mean() * 100,
                   np.array(query_node_cls_acc_generations).std() * 100,
                   1.96 * np.array(query_node_cls_acc_generations).std() /
                   np.sqrt(float(len(
                       np.array(query_node_cls_acc_generations)))) * 100))
            self.log.info('------------------------------------')

        return np.array(query_node_cls_acc_generations).mean()