def __init__(self):
        self.best_accuracy = 0.0
        self.adjust_learning_rate = Config.adjust_learning_rate

        # all data
        self.data_train = MiniImageNetDataset.get_data_all(Config.data_root)
        self.task_train = MiniImageNetDataset(self.data_train, Config.num_way,
                                              Config.num_shot)
        self.task_train_loader = DataLoader(self.task_train,
                                            Config.batch_size,
                                            shuffle=True,
                                            num_workers=Config.num_workers)

        # model
        self.proto_net = RunnerTool.to_cuda(Config.proto_net)
        RunnerTool.to_cuda(self.proto_net.apply(RunnerTool.weights_init))
        self.loss_ce = RunnerTool.to_cuda(nn.CrossEntropyLoss())
        self.loss_mse = RunnerTool.to_cuda(nn.MSELoss())

        # optim
        self.proto_net_optim = torch.optim.SGD(self.proto_net.parameters(),
                                               lr=Config.learning_rate,
                                               momentum=0.9,
                                               weight_decay=5e-4)

        self.test_tool = TestTool(self.proto_test,
                                  data_root=Config.data_root,
                                  num_way=Config.num_way,
                                  num_shot=Config.num_shot,
                                  episode_size=Config.episode_size,
                                  test_episode=Config.test_episode,
                                  transform=self.task_train.transform_test)
        pass
Пример #2
0
    def __init__(self):
        self.best_accuracy = 0.0

        # all data
        self.data_train = MiniImageNetDataset.get_data_all(Config.data_root)
        self.task_train = MiniImageNetDataset(self.data_train, Config.num_way,
                                              Config.num_shot)
        self.task_train_loader = DataLoader(self.task_train,
                                            Config.batch_size,
                                            shuffle=True,
                                            num_workers=Config.num_workers)

        # model
        self.proto_net = RunnerTool.to_cuda(Config.proto_net)
        RunnerTool.to_cuda(self.proto_net.apply(RunnerTool.weights_init))

        # optim
        self.proto_net_optim = torch.optim.Adam(self.proto_net.parameters(),
                                                lr=Config.learning_rate)
        self.proto_net_scheduler = StepLR(self.proto_net_optim,
                                          Config.train_epoch // 3,
                                          gamma=0.5)

        self.test_tool = TestTool(self.proto_test,
                                  data_root=Config.data_root,
                                  num_way=Config.num_way,
                                  num_shot=Config.num_shot,
                                  episode_size=Config.episode_size,
                                  test_episode=Config.test_episode,
                                  transform=self.task_train.transform_test)
        pass
    def train(self):
        Tools.print()
        Tools.print("Training...")

        for epoch in range(1, 1 + Config.train_epoch):
            self.proto_net.train()

            Tools.print()
            all_loss = 0.0
            self.adjust_learning_rate(epoch=epoch)
            Tools.print("{:6} lr:{}".format(
                epoch, self.proto_net_optim.param_groups[0]["lr"]))
            for task_data, task_labels, task_index in tqdm(
                    self.task_train_loader):
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)

                # 1 calculate features
                log_p_y = self.proto(task_data)

                # 2 loss
                loss = -(log_p_y * task_labels).sum() / task_labels.sum()
                all_loss += loss.item()

                # 3 backward
                self.proto_net.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.proto_net.parameters(),
                                               0.5)
                self.proto_net_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f}".format(
                epoch, all_loss / len(self.task_train_loader)))
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                Tools.print()
                Tools.print("Test {} {} .......".format(
                    epoch, Config.model_name))

                self.proto_net.eval()

                val_accuracy = self.test_tool.val(episode=epoch, is_print=True)
                if val_accuracy > self.best_accuracy:
                    self.best_accuracy = val_accuracy
                    torch.save(self.proto_net.state_dict(), Config.pn_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass
    def __init__(self):
        self.best_accuracy = 0.0
        self.adjust_learning_rate = Config.adjust_learning_rate

        # all data
        self.data_train = MiniImageNetDataset.get_data_all(Config.data_root)
        self.task_train = MiniImageNetDataset(self.data_train, Config.num_way,
                                              Config.num_shot)
        self.task_train_loader = DataLoader(self.task_train,
                                            Config.batch_size,
                                            True,
                                            num_workers=Config.num_workers)

        # IC
        self.produce_class = ProduceClass(len(self.data_train),
                                          Config.ic_out_dim, Config.ic_ratio)
        self.produce_class.init()
        self.task_train.set_samples_class(self.produce_class.classes)
        self.task_train.set_samples_feature(self.produce_class.features)

        # model
        self.proto_net = RunnerTool.to_cuda(Config.proto_net)
        self.ic_model = RunnerTool.to_cuda(ICResNet(low_dim=Config.ic_out_dim))

        RunnerTool.to_cuda(self.proto_net.apply(RunnerTool.weights_init))
        RunnerTool.to_cuda(self.ic_model.apply(RunnerTool.weights_init))

        # optim
        self.proto_net_optim = torch.optim.SGD(self.proto_net.parameters(),
                                               lr=Config.learning_rate,
                                               momentum=0.9,
                                               weight_decay=5e-4)
        self.ic_model_optim = torch.optim.SGD(self.ic_model.parameters(),
                                              lr=Config.learning_rate,
                                              momentum=0.9,
                                              weight_decay=5e-4)

        # loss
        self.ic_loss = RunnerTool.to_cuda(nn.CrossEntropyLoss())

        # Eval
        self.test_tool_fsl = TestTool(self.proto_test,
                                      data_root=Config.data_root,
                                      num_way=Config.num_way,
                                      num_shot=Config.num_shot,
                                      episode_size=Config.episode_size,
                                      test_episode=Config.test_episode,
                                      transform=self.task_train.transform_test)
        self.test_tool_ic = ICTestTool(feature_encoder=None,
                                       ic_model=self.ic_model,
                                       data_root=Config.data_root,
                                       batch_size=Config.batch_size,
                                       num_workers=Config.num_workers,
                                       ic_out_dim=Config.ic_out_dim)
        pass
    def __init__(self):
        self.best_accuracy = 0.0

        # all data
        self.data_train = MiniImageNetDataset.get_data_all(Config.data_root)
        self.task_train = MiniImageNetDataset(self.data_train, Config.num_way, Config.num_shot)
        self.task_train_loader = DataLoader(self.task_train, Config.batch_size, True, num_workers=Config.num_workers)

        # IC
        self.produce_class = ProduceClass(len(self.data_train), Config.ic_out_dim, Config.ic_ratio)
        self.produce_class.init()

        # model
        self.proto_net = RunnerTool.to_cuda(Config.proto_net)
        self.ic_model = RunnerTool.to_cuda(Config.ic_proto_net)

        RunnerTool.to_cuda(self.proto_net.apply(RunnerTool.weights_init))
        RunnerTool.to_cuda(self.ic_model.apply(RunnerTool.weights_init))

        # optim
        self.proto_net_optim = torch.optim.Adam(self.proto_net.parameters(), lr=Config.learning_rate)
        self.ic_model_optim = torch.optim.Adam(self.ic_model.parameters(), lr=Config.learning_rate)

        self.proto_net_scheduler = StepLR(self.proto_net_optim, Config.train_epoch // 3, gamma=0.5)
        self.ic_model_scheduler = StepLR(self.ic_model_optim, Config.train_epoch // 3, gamma=0.5)

        # loss
        self.ic_loss = RunnerTool.to_cuda(nn.CrossEntropyLoss())

        # Eval
        self.test_tool_fsl = TestTool(self.proto_test, data_root=Config.data_root,
                                      num_way=Config.num_way, num_shot=Config.num_shot,
                                      episode_size=Config.episode_size, test_episode=Config.test_episode,
                                      transform=self.task_train.transform_test)
        self.test_tool_ic = ICTestTool(feature_encoder=self.proto_net, ic_model=self.ic_model,
                                       data_root=Config.data_root, batch_size=Config.batch_size,
                                       num_workers=Config.num_workers, ic_out_dim=Config.ic_out_dim)
        pass
    def train(self):
        Tools.print()
        Tools.print("Training...")

        for epoch in range(1, 1 + Config.train_epoch):
            self.proto_net.train()

            Tools.print()
            all_loss = 0.0
            pn_lr = self.adjust_learning_rate(self.proto_net_optim, epoch,
                                              Config.first_epoch,
                                              Config.t_epoch,
                                              Config.learning_rate)
            Tools.print('Epoch: [{}] pn_lr={}'.format(epoch, pn_lr))

            for task_data, task_labels, task_index in tqdm(
                    self.task_train_loader):
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)

                # 1 calculate features
                dists = self.proto(task_data)

                # 2 loss
                if Config.is_mse:
                    targets = -(task_labels - 1)
                    loss = self.loss_mse(dists, targets)
                else:
                    targets = torch.argmax(task_labels,
                                           dim=1) // Config.num_shot
                    loss = self.loss_ce(-dists, targets)
                all_loss += loss.item()

                # 3 backward
                self.proto_net.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.proto_net.parameters(),
                                               0.5)
                self.proto_net_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f}".format(
                epoch, all_loss / len(self.task_train_loader)))
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                Tools.print()
                Tools.print("Test {} {} .......".format(
                    epoch, Config.model_name))

                self.proto_net.eval()

                val_accuracy = self.test_tool.val(episode=epoch, is_print=True)
                if val_accuracy > self.best_accuracy:
                    self.best_accuracy = val_accuracy
                    torch.save(self.proto_net.state_dict(), Config.pn_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass
    def train(self):
        Tools.print()
        Tools.print("Training...")

        # Init Update
        # try:
        #     self.proto_net.eval()
        #     self.ic_model.eval()
        #     Tools.print("Init label {} .......")
        #     self.produce_class.reset()
        #     with torch.no_grad():
        #         for task_data, task_labels, task_index in tqdm(self.task_train_loader):
        #             ic_labels = RunnerTool.to_cuda(task_index[:, -1])
        #             task_data, task_labels = RunnerTool.to_cuda(task_data), RunnerTool.to_cuda(task_labels)
        #             log_p_y, query_features = self.proto(task_data)
        #             ic_out_logits, ic_out_l2norm = self.ic_model(query_features)
        #             self.produce_class.cal_label(ic_out_l2norm, ic_labels)
        #             pass
        #         pass
        #     Tools.print("Epoch: {}/{}".format(self.produce_class.count, self.produce_class.count_2))
        # finally:
        #     pass

        for epoch in range(Config.train_epoch):
            self.proto_net.train()
            self.ic_model.train()

            Tools.print()
            self.produce_class.reset()
            all_loss, all_loss_fsl, all_loss_ic = 0.0, 0.0, 0.0
            for task_data, task_labels, task_index in tqdm(self.task_train_loader):
                ic_labels = RunnerTool.to_cuda(task_index[:, -1])
                task_data, task_labels = RunnerTool.to_cuda(task_data), RunnerTool.to_cuda(task_labels)

                ###########################################################################
                # 1 calculate features
                log_p_y, query_features = self.proto(task_data)
                ic_out_logits, ic_out_l2norm = self.ic_model(query_features)

                # 2
                ic_targets = self.produce_class.get_label(ic_labels)
                self.produce_class.cal_label(ic_out_l2norm, ic_labels)

                # 3 loss
                loss_fsl = -(log_p_y * task_labels).sum() / task_labels.sum() * Config.loss_fsl_ratio
                loss_ic = self.ic_loss(ic_out_logits, ic_targets) * Config.loss_ic_ratio
                loss = loss_fsl + loss_ic
                all_loss += loss.item()
                all_loss_fsl += loss_fsl.item()
                all_loss_ic += loss_ic.item()

                # 4 backward
                self.proto_net.zero_grad()
                self.ic_model.zero_grad()
                loss.backward()
                # torch.nn.utils.clip_grad_norm_(self.proto_net.parameters(), 0.5)
                # torch.nn.utils.clip_grad_norm_(self.ic_model.parameters(), 0.5)
                self.proto_net_optim.step()
                self.ic_model_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f} fsl:{:.3f} ic:{:.3f} lr:{}".format(
                epoch + 1, all_loss / len(self.task_train_loader), all_loss_fsl / len(self.task_train_loader),
                all_loss_ic / len(self.task_train_loader), self.proto_net_scheduler.get_last_lr()))
            Tools.print("Train: [{}] {}/{}".format(epoch, self.produce_class.count, self.produce_class.count_2))
            self.proto_net_scheduler.step()
            self.ic_model_scheduler.step()
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                self.proto_net.eval()
                self.ic_model.eval()

                self.test_tool_ic.val(epoch=epoch)
                val_accuracy = self.test_tool_fsl.val(episode=epoch, is_print=True)

                if val_accuracy > self.best_accuracy:
                    self.best_accuracy = val_accuracy
                    torch.save(self.proto_net.state_dict(), Config.pn_dir)
                    torch.save(self.ic_model.state_dict(), Config.ic_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass
    def train(self):
        Tools.print()
        Tools.print("Training...")

        # Init Update
        if False:
            self.ic_model.eval()
            Tools.print("Init label {} .......")
            self.produce_class.reset()
            for task_data, task_labels, task_index, task_ok in tqdm(
                    self.task_train_loader):
                ic_labels = RunnerTool.to_cuda(task_index[:, -1])
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)
                ic_out_logits, ic_out_l2norm = self.ic_model(task_data[:, -1])
                self.produce_class.cal_label(ic_out_l2norm, ic_labels)
                pass
            Tools.print("Epoch: {}/{}".format(self.produce_class.count,
                                              self.produce_class.count_2))
            pass

        for epoch in range(1, 1 + Config.train_epoch):
            self.proto_net.train()
            self.ic_model.train()

            Tools.print()
            pn_lr = self.adjust_learning_rate(self.proto_net_optim, epoch,
                                              Config.first_epoch,
                                              Config.t_epoch,
                                              Config.learning_rate)
            ic_lr = self.adjust_learning_rate(self.ic_model_optim, epoch,
                                              Config.first_epoch,
                                              Config.t_epoch,
                                              Config.learning_rate)
            Tools.print('Epoch: [{}] pn_lr={} ic_lr={}'.format(
                epoch, pn_lr, ic_lr))

            self.produce_class.reset()
            Tools.print(self.task_train.classes)
            is_ok_total, is_ok_acc = 0, 0
            all_loss, all_loss_fsl, all_loss_ic = 0.0, 0.0, 0.0
            for task_data, task_labels, task_index, task_ok in tqdm(
                    self.task_train_loader):
                ic_labels = RunnerTool.to_cuda(task_index[:, -1])
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)

                ###########################################################################
                # 1 calculate features
                log_p_y = self.proto(task_data)
                ic_out_logits, ic_out_l2norm = self.ic_model(task_data[:, -1])

                # 2
                ic_targets = self.produce_class.get_label(ic_labels)
                self.produce_class.cal_label(ic_out_l2norm, ic_labels)

                # 3 loss
                loss_fsl = -(log_p_y * task_labels).sum() / task_labels.sum()
                loss_ic = self.ic_loss(ic_out_logits, ic_targets)
                loss = loss_fsl * Config.loss_fsl_ratio + loss_ic * Config.loss_ic_ratio
                all_loss += loss.item()
                all_loss_fsl += loss_fsl.item()
                all_loss_ic += loss_ic.item()

                # 4 backward
                if Config.train_ic:
                    self.ic_model.zero_grad()
                    loss_ic.backward()
                    self.ic_model_optim.step()
                    pass

                self.proto_net.zero_grad()
                loss_fsl.backward()
                self.proto_net_optim.step()

                # is ok
                is_ok_acc += torch.sum(torch.cat(task_ok))
                is_ok_total += torch.prod(
                    torch.tensor(torch.cat(task_ok).shape))
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print(
                "{:6} loss:{:.3f} fsl:{:.3f} ic:{:.3f} ok:{:.3f}({}/{})".
                format(
                    epoch,
                    all_loss / len(self.task_train_loader),
                    all_loss_fsl / len(self.task_train_loader),
                    all_loss_ic / len(self.task_train_loader),
                    int(is_ok_acc) / int(is_ok_total),
                    is_ok_acc,
                    is_ok_total,
                ))
            Tools.print("Train: [{}] {}/{}".format(epoch,
                                                   self.produce_class.count,
                                                   self.produce_class.count_2))
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                self.proto_net.eval()
                self.ic_model.eval()

                self.test_tool_ic.val(epoch=epoch)
                val_accuracy = self.test_tool_fsl.val(episode=epoch,
                                                      is_print=True)

                if val_accuracy > self.best_accuracy:
                    self.best_accuracy = val_accuracy
                    torch.save(self.proto_net.state_dict(), Config.pn_dir)
                    torch.save(self.ic_model.state_dict(), Config.ic_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass
 def __init__(self):
     self.proto_net = RunnerTool.to_cuda(Config.proto_net)
     pass