def __init__(self):
        self.best_accuracy = 0.0
        self.adjust_learning_rate = Config.adjust_learning_rate

        # all data
        self.data_train = MiniImageNetDataset.get_data_all(Config.data_root)
        self.task_train = MiniImageNetDataset(self.data_train, Config.num_way,
                                              Config.num_shot)
        self.task_train_loader = DataLoader(self.task_train,
                                            Config.batch_size,
                                            shuffle=True,
                                            num_workers=Config.num_workers)

        # model
        self.feature_encoder = RunnerTool.to_cuda(Config.feature_encoder)
        self.relation_network = RunnerTool.to_cuda(Config.relation_network)
        RunnerTool.to_cuda(self.feature_encoder.apply(RunnerTool.weights_init))
        RunnerTool.to_cuda(self.relation_network.apply(
            RunnerTool.weights_init))

        # optim
        self.feature_encoder_optim = torch.optim.SGD(
            self.feature_encoder.parameters(),
            lr=Config.learning_rate,
            momentum=0.9,
            weight_decay=5e-4)
        self.relation_network_optim = torch.optim.SGD(
            self.relation_network.parameters(),
            lr=Config.learning_rate,
            momentum=0.9,
            weight_decay=5e-4)

        # loss
        self.loss = RunnerTool.to_cuda(nn.MSELoss())

        self.test_tool = TestTool(self.compare_fsl_test,
                                  data_root=Config.data_root,
                                  num_way=Config.num_way,
                                  num_shot=Config.num_shot,
                                  episode_size=Config.episode_size,
                                  test_episode=Config.test_episode,
                                  transform=self.task_train.transform_test)
        pass
Ejemplo n.º 2
0
    def __init__(self):
        self.best_accuracy = 0.0

        # all data
        self.data_train = MiniImageNetDataset.get_data_all(Config.data_root)
        self.task_train = MiniImageNetDataset(self.data_train, Config.num_way,
                                              Config.num_shot)
        self.task_train_loader = DataLoader(self.task_train,
                                            Config.batch_size,
                                            True,
                                            num_workers=Config.num_workers)

        # model
        self.feature_encoder = RunnerTool.to_cuda(Config.feature_encoder)
        self.relation_network = RunnerTool.to_cuda(Config.relation_network)
        RunnerTool.to_cuda(self.feature_encoder.apply(RunnerTool.weights_init))
        RunnerTool.to_cuda(self.relation_network.apply(
            RunnerTool.weights_init))

        # optim
        self.feature_encoder_optim = torch.optim.Adam(
            self.feature_encoder.parameters(), lr=Config.learning_rate)
        self.feature_encoder_scheduler = StepLR(self.feature_encoder_optim,
                                                Config.train_epoch // 3,
                                                gamma=0.5)
        self.relation_network_optim = torch.optim.Adam(
            self.relation_network.parameters(), lr=Config.learning_rate)
        self.relation_network_scheduler = StepLR(self.relation_network_optim,
                                                 Config.train_epoch // 3,
                                                 gamma=0.5)

        # loss
        self.loss = RunnerTool.to_cuda(nn.MSELoss())

        self.test_tool = TestTool(self.compare_fsl_test,
                                  data_root=Config.data_root,
                                  num_way=Config.num_way,
                                  num_shot=Config.num_shot,
                                  episode_size=Config.episode_size,
                                  test_episode=Config.test_episode,
                                  transform=self.task_train.transform_test)
        pass
    def train(self):
        Tools.print()
        Tools.print("Training...")

        for epoch in range(1, 1 + Config.train_epoch):
            self.feature_encoder.train()
            self.relation_network.train()

            Tools.print()
            fe_lr = self.adjust_learning_rate(self.feature_encoder_optim,
                                              epoch, Config.first_epoch,
                                              Config.t_epoch,
                                              Config.learning_rate)
            rn_lr = self.adjust_learning_rate(self.relation_network_optim,
                                              epoch, Config.first_epoch,
                                              Config.t_epoch,
                                              Config.learning_rate)
            Tools.print('Epoch: [{}] fe_lr={} rn_lr={}'.format(
                epoch, fe_lr, rn_lr))

            all_loss = 0.0
            for task_data, task_labels, task_index in tqdm(
                    self.task_train_loader):
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)

                # 1 calculate features
                relations = self.compare_fsl(task_data)

                # 2 loss
                loss = self.loss(relations, task_labels)
                all_loss += loss.item()

                # 3 backward
                self.feature_encoder.zero_grad()
                self.relation_network.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    self.feature_encoder.parameters(), 0.5)
                torch.nn.utils.clip_grad_norm_(
                    self.relation_network.parameters(), 0.5)
                self.feature_encoder_optim.step()
                self.relation_network_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f}".format(
                epoch, all_loss / len(self.task_train_loader)))
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                self.feature_encoder.eval()
                self.relation_network.eval()

                Tools.print()
                Tools.print("Test {} .......".format(epoch))
                val_accuracy = self.test_tool.val(episode=epoch, is_print=True)
                if val_accuracy > self.best_accuracy:
                    self.best_accuracy = val_accuracy
                    torch.save(self.feature_encoder.state_dict(),
                               Config.fe_dir)
                    torch.save(self.relation_network.state_dict(),
                               Config.rn_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass
Ejemplo n.º 4
0
    def train(self):
        Tools.print()
        Tools.print("Training...")

        # Init Update
        try:
            self.feature_encoder.eval()
            self.relation_network.eval()
            self.ic_model.eval()
            Tools.print("Init label {} .......")
            self.produce_class.reset()
            for task_data, task_labels, task_index in tqdm(
                    self.task_train_loader):
                ic_labels = RunnerTool.to_cuda(task_index[:, -1])
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)
                relations, query_features = self.compare_fsl(task_data)
                ic_out_logits, ic_out_l2norm = self.ic_model(query_features)
                self.produce_class.cal_label(ic_out_l2norm, ic_labels)
                pass
            Tools.print("Epoch: {}/{}".format(self.produce_class.count,
                                              self.produce_class.count_2))
        finally:
            pass

        for epoch in range(Config.train_epoch):
            self.feature_encoder.train()
            self.relation_network.train()
            self.ic_model.train()

            Tools.print()
            fe_lr = self.adjust_learning_rate(self.feature_encoder_optim,
                                              epoch, Config.first_epoch,
                                              Config.t_epoch,
                                              Config.learning_rate)
            rn_lr = self.adjust_learning_rate(self.relation_network_optim,
                                              epoch, Config.first_epoch,
                                              Config.t_epoch,
                                              Config.learning_rate)
            ic_lr = self.adjust_learning_rate(self.ic_model_optim, epoch,
                                              Config.first_epoch,
                                              Config.t_epoch,
                                              Config.learning_rate)
            Tools.print('Epoch: [{}] fe_lr={} rn_lr={} ic_lr={}'.format(
                epoch, fe_lr, rn_lr, ic_lr))

            self.produce_class.reset()
            all_loss, all_loss_fsl, all_loss_ic = 0.0, 0.0, 0.0
            for task_data, task_labels, task_index in tqdm(
                    self.task_train_loader):
                ic_labels = RunnerTool.to_cuda(task_index[:, -1])
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)

                ###########################################################################
                # 1 calculate features
                relations, query_features = self.compare_fsl(task_data)
                ic_out_logits, ic_out_l2norm = self.ic_model(query_features)

                # 2
                ic_targets = self.produce_class.get_label(ic_labels)
                self.produce_class.cal_label(ic_out_l2norm, ic_labels)

                # 3 loss
                loss_fsl = self.fsl_loss(relations,
                                         task_labels) * Config.loss_fsl_ratio
                loss_ic = self.ic_loss(ic_out_logits,
                                       ic_targets) * Config.loss_ic_ratio
                loss = loss_fsl + loss_ic
                all_loss += loss.item()
                all_loss_fsl += loss_fsl.item()
                all_loss_ic += loss_ic.item()

                # 4 backward
                self.feature_encoder.zero_grad()
                self.relation_network.zero_grad()
                self.ic_model.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    self.feature_encoder.parameters(), 0.5)
                torch.nn.utils.clip_grad_norm_(
                    self.relation_network.parameters(), 0.5)
                torch.nn.utils.clip_grad_norm_(self.ic_model.parameters(), 0.5)
                self.feature_encoder_optim.step()
                self.relation_network_optim.step()
                self.ic_model_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f} fsl:{:.3f} ic:{:.3f}".format(
                epoch + 1, all_loss / len(self.task_train_loader),
                all_loss_fsl / len(self.task_train_loader),
                all_loss_ic / len(self.task_train_loader)))
            Tools.print("Train: [{}] {}/{}".format(epoch,
                                                   self.produce_class.count,
                                                   self.produce_class.count_2))
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                self.feature_encoder.eval()
                self.relation_network.eval()
                self.ic_model.eval()

                self.test_tool_ic.val(epoch=epoch)
                val_accuracy = self.test_tool_fsl.val(episode=epoch,
                                                      is_print=True)

                if val_accuracy > self.best_accuracy:
                    self.best_accuracy = val_accuracy
                    torch.save(self.feature_encoder.state_dict(),
                               Config.fe_dir)
                    torch.save(self.relation_network.state_dict(),
                               Config.rn_dir)
                    torch.save(self.ic_model.state_dict(), Config.ic_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass
Ejemplo n.º 5
0
    def __init__(self):
        self.best_accuracy = 0.0
        self.adjust_learning_rate = Config.adjust_learning_rate

        # all data
        self.data_train = MiniImageNetDataset.get_data_all(Config.data_root)
        self.task_train = MiniImageNetDataset(self.data_train, Config.num_way,
                                              Config.num_shot)
        self.task_train_loader = DataLoader(self.task_train,
                                            Config.batch_size,
                                            True,
                                            num_workers=Config.num_workers)

        # IC
        self.produce_class = ProduceClass(len(self.data_train),
                                          Config.ic_out_dim, Config.ic_ratio)
        self.produce_class.init()
        self.task_train.set_samples_class(self.produce_class.classes)

        # model
        self.feature_encoder = RunnerTool.to_cuda(Config.feature_encoder)
        self.relation_network = RunnerTool.to_cuda(Config.relation_network)
        self.ic_model = RunnerTool.to_cuda(
            ICModel(in_dim=Config.ic_in_dim, out_dim=Config.ic_out_dim))

        RunnerTool.to_cuda(self.feature_encoder.apply(RunnerTool.weights_init))
        RunnerTool.to_cuda(self.relation_network.apply(
            RunnerTool.weights_init))
        RunnerTool.to_cuda(self.ic_model.apply(RunnerTool.weights_init))

        # optim
        self.feature_encoder_optim = torch.optim.SGD(
            self.feature_encoder.parameters(),
            lr=Config.learning_rate,
            momentum=0.9,
            weight_decay=5e-4)
        self.relation_network_optim = torch.optim.SGD(
            self.relation_network.parameters(),
            lr=Config.learning_rate,
            momentum=0.9,
            weight_decay=5e-4)
        self.ic_model_optim = torch.optim.SGD(self.ic_model.parameters(),
                                              lr=Config.learning_rate,
                                              momentum=0.9,
                                              weight_decay=5e-4)

        # loss
        self.ic_loss = RunnerTool.to_cuda(nn.CrossEntropyLoss())
        self.fsl_loss = RunnerTool.to_cuda(nn.MSELoss())

        # Eval
        self.test_tool_fsl = TestTool(self.compare_fsl_test,
                                      data_root=Config.data_root,
                                      num_way=Config.num_way,
                                      num_shot=Config.num_shot,
                                      episode_size=Config.episode_size,
                                      test_episode=Config.test_episode,
                                      transform=self.task_train.transform_test)
        self.test_tool_ic = ICTestTool(feature_encoder=self.feature_encoder,
                                       ic_model=self.ic_model,
                                       data_root=Config.data_root,
                                       batch_size=Config.batch_size,
                                       num_workers=Config.num_workers,
                                       ic_out_dim=Config.ic_out_dim)
        pass
    def __init__(self):
        self.best_accuracy = 0.0

        # all data
        self.data_train = MiniImageNetDataset.get_data_all(Config.data_root)
        self.task_train = MiniImageNetDataset(self.data_train, Config.num_way,
                                              Config.num_shot)
        self.task_train_loader = DataLoader(self.task_train,
                                            Config.batch_size,
                                            True,
                                            num_workers=Config.num_workers)

        # IC
        self.produce_class = ProduceClass(len(self.data_train),
                                          Config.ic_out_dim, Config.ic_ratio)
        self.produce_class.init()

        # model
        self.feature_encoder = RunnerTool.to_cuda(Config.feature_encoder)
        self.relation_network = RunnerTool.to_cuda(Config.relation_network)
        self.ic_model = RunnerTool.to_cuda(Config.ic_model)

        RunnerTool.to_cuda(self.feature_encoder.apply(RunnerTool.weights_init))
        RunnerTool.to_cuda(self.relation_network.apply(
            RunnerTool.weights_init))
        RunnerTool.to_cuda(self.ic_model.apply(RunnerTool.weights_init))

        # optim
        self.feature_encoder_optim = torch.optim.Adam(
            self.feature_encoder.parameters(), lr=Config.learning_rate)
        self.relation_network_optim = torch.optim.Adam(
            self.relation_network.parameters(), lr=Config.learning_rate)
        self.ic_model_optim = torch.optim.Adam(self.ic_model.parameters(),
                                               lr=Config.learning_rate)

        self.feature_encoder_scheduler = StepLR(self.feature_encoder_optim,
                                                Config.train_epoch // 3,
                                                gamma=0.5)
        self.relation_network_scheduler = StepLR(self.relation_network_optim,
                                                 Config.train_epoch // 3,
                                                 gamma=0.5)
        self.ic_model_scheduler = StepLR(self.ic_model_optim,
                                         Config.train_epoch // 3,
                                         gamma=0.5)

        # loss
        self.ic_loss = RunnerTool.to_cuda(nn.CrossEntropyLoss())
        self.fsl_loss = RunnerTool.to_cuda(nn.MSELoss())

        # Eval
        self.test_tool_fsl = TestTool(self.compare_fsl_test,
                                      data_root=Config.data_root,
                                      num_way=Config.num_way,
                                      num_shot=Config.num_shot,
                                      episode_size=Config.episode_size,
                                      test_episode=Config.test_episode,
                                      transform=self.task_train.transform_test)
        self.test_tool_ic = ICTestTool(feature_encoder=self.feature_encoder,
                                       ic_model=self.ic_model,
                                       data_root=Config.data_root,
                                       batch_size=Config.batch_size,
                                       num_workers=Config.num_workers,
                                       ic_out_dim=Config.ic_out_dim)
        pass
Ejemplo n.º 7
0
 def __init__(self):
     self.feature_encoder = RunnerTool.to_cuda(Config.feature_encoder)
     self.relation_network = RunnerTool.to_cuda(Config.relation_network)
     pass