def __init__(self):
        # all data
        self.data_train = MyDataset.get_data_split(Config.data_root,
                                                   split="train")
        self.task_train = RandomAndCssDataset(self.data_train, Config.num_way,
                                              Config.num_shot,
                                              Config.transform_train,
                                              Config.transform_test)
        self.task_train_loader = DataLoader(self.task_train,
                                            Config.batch_size,
                                            True,
                                            num_workers=Config.num_workers)

        # model
        self.net = RunnerTool.to_cuda(Config.net)
        RunnerTool.to_cuda(self.net.apply(RunnerTool.weights_init))
        self.norm = Normalize(2)

        # optim
        self.loss = RunnerTool.to_cuda(nn.MSELoss())
        self.net_optim = torch.optim.SGD(self.net.parameters(),
                                         lr=Config.learning_rate,
                                         momentum=0.9,
                                         weight_decay=5e-4)

        # Eval
        self.test_tool_fsl = FSLTestTool(
            self.matching_test,
            data_root=Config.data_root,
            num_way=Config.num_way_test,
            num_shot=Config.num_shot,
            episode_size=Config.episode_size,
            test_episode=Config.test_episode,
            transform=self.task_train.transform_test)
        pass
Ejemplo n.º 2
0
    def __init__(self, config):
        self.config = config

        # model
        self.norm = Normalize(2)
        self.matching_net = RunnerTool.to_cuda(self.config.matching_net)
        pass
Ejemplo n.º 3
0
    def __init__(self):
        self.adjust_learning_rate = Config.ic_adjust_learning_rate

        # data
        self.data_train = TieredImageNetICDataset.get_data_all(
            Config.data_root)
        self.tiered_imagenet_dataset = TieredImageNetICDataset(self.data_train)
        self.ic_train_loader = DataLoader(self.tiered_imagenet_dataset,
                                          Config.ic_batch_size,
                                          shuffle=True,
                                          num_workers=Config.num_workers)
        self.ic_train_loader_eval = DataLoader(self.tiered_imagenet_dataset,
                                               Config.ic_batch_size,
                                               shuffle=False,
                                               num_workers=Config.num_workers)

        # IC
        self.produce_class = ProduceClass(len(self.data_train),
                                          Config.ic_out_dim, Config.ic_ratio)
        self.produce_class.init()

        # model
        self.ic_model = RunnerTool.to_cuda(
            ICResNet(low_dim=Config.ic_out_dim, encoder=Config.ic_net))
        RunnerTool.to_cuda(self.ic_model.apply(RunnerTool.weights_init))
        self.ic_loss = RunnerTool.to_cuda(nn.CrossEntropyLoss())

        self.ic_model_optim = torch.optim.SGD(self.ic_model.parameters(),
                                              lr=Config.ic_learning_rate,
                                              momentum=0.9,
                                              weight_decay=5e-4)

        # Eval
        self.test_tool_ic = ICTestTool(
            feature_encoder=None,
            ic_model=self.ic_model,
            data_root=Config.data_root,
            batch_size=Config.ic_batch_size,
            num_workers=Config.num_workers,
            ic_out_dim=Config.ic_out_dim,
            transform=self.ic_train_loader.dataset.transform_test,
            k=Config.ic_knn)
        pass
Ejemplo n.º 4
0
 def eval(self):
     self.ic_model.eval()
     with torch.no_grad():
         ic_classes = np.zeros(shape=(len(self.tiered_imagenet_dataset), ),
                               dtype=np.int)
         for image, label, idx in tqdm(self.ic_train_loader_eval):
             ic_out_logits, ic_out_l2norm = self.ic_model(
                 RunnerTool.to_cuda(image))
             ic_classes[idx] = np.argmax(ic_out_l2norm.cpu().numpy(),
                                         axis=-1)
             pass
     return self.data_train, ic_classes
Ejemplo n.º 5
0
    def __init__(self, data_train, classes):
        # all data
        self.data_train = data_train
        self.task_train = TieredImageNetFSLDataset(self.data_train, classes,
                                                   Config.fsl_num_way,
                                                   Config.fsl_num_shot)
        self.task_train_loader = DataLoader(self.task_train,
                                            Config.fsl_batch_size,
                                            True,
                                            num_workers=Config.num_workers)

        # model
        self.matching_net = RunnerTool.to_cuda(Config.fsl_matching_net)
        self.matching_net = RunnerTool.to_cuda(
            nn.DataParallel(self.matching_net))
        cudnn.benchmark = True
        RunnerTool.to_cuda(self.matching_net.apply(RunnerTool.weights_init))
        self.norm = Normalize(2)

        # optim
        self.matching_net_optim = torch.optim.Adam(
            self.matching_net.parameters(), lr=Config.fsl_learning_rate)
        self.matching_net_scheduler = MultiStepLR(self.matching_net_optim,
                                                  Config.fsl_lr_schedule,
                                                  gamma=1 / 3)

        # loss
        self.fsl_loss = RunnerTool.to_cuda(nn.MSELoss())

        # Eval
        self.test_tool_fsl = FSLTestTool(
            self.matching_test,
            data_root=Config.data_root,
            num_way=Config.fsl_num_way,
            num_shot=Config.fsl_num_shot,
            episode_size=Config.fsl_episode_size,
            test_episode=Config.fsl_test_episode,
            transform=self.task_train.transform_test)
        pass
Ejemplo n.º 6
0
    def __init__(self, config):
        self.config = config

        # model
        self.norm = Normalize(2)
        self.matching_net = RunnerTool.to_cuda(self.config.matching_net)

        # check
        if self.config.is_check:
            return

        # Eval
        self.test_tool_fsl = FSLTestTool(self.matching_test,
                                         data_root=self.config.data_root,
                                         num_way=self.config.num_way,
                                         num_shot=self.config.num_shot,
                                         episode_size=self.config.episode_size,
                                         test_episode=self.config.test_episode,
                                         transform=self.config.transform_test,
                                         txt_path=self.config.log_file)
        pass
Ejemplo n.º 7
0
    def __init__(self, config):
        self.config = config

        # all data
        self.data_train = MyDataset.get_data_split(
            self.config.data_root, split=MyDataset.dataset_split_train)
        self.task_train = TrainDataset(
            self.data_train,
            self.config.num_way,
            self.config.num_shot,
            transform_train_ic=self.config.transform_train_ic,
            transform_train_fsl=self.config.transform_train_fsl,
            transform_test=self.config.transform_test)
        self.task_train_loader = DataLoader(
            self.task_train,
            self.config.batch_size,
            True,
            num_workers=self.config.num_workers)

        # IC
        self.produce_class = ProduceClass(len(self.data_train),
                                          self.config.ic_out_dim,
                                          self.config.ic_ratio)
        self.produce_class.init()
        self.task_train.set_samples_class(self.produce_class.classes)

        # model
        self.norm = Normalize(2)
        self.matching_net = RunnerTool.to_cuda(self.config.matching_net)
        self.ic_model = RunnerTool.to_cuda(
            ICResNet(low_dim=self.config.ic_out_dim,
                     resnet=self.config.resnet,
                     modify_head=self.config.modify_head))
        RunnerTool.to_cuda(self.matching_net.apply(RunnerTool.weights_init))
        RunnerTool.to_cuda(self.ic_model.apply(RunnerTool.weights_init))

        # optim
        self.matching_net_optim = torch.optim.SGD(
            self.matching_net.parameters(),
            lr=self.config.learning_rate,
            momentum=0.9,
            weight_decay=5e-4)
        self.ic_model_optim = torch.optim.SGD(self.ic_model.parameters(),
                                              lr=self.config.learning_rate,
                                              momentum=0.9,
                                              weight_decay=5e-4)

        # loss
        self.ic_loss = RunnerTool.to_cuda(nn.CrossEntropyLoss())
        self.fsl_loss = RunnerTool.to_cuda(nn.MSELoss())

        # Eval
        self.test_tool_fsl = FSLTestTool(
            self.matching_test,
            data_root=self.config.data_root,
            num_way=self.config.num_way,
            num_shot=self.config.num_shot,
            episode_size=self.config.episode_size,
            test_episode=self.config.test_episode,
            transform=self.task_train.transform_test,
            txt_path=self.config.log_file)
        self.test_tool_ic = ICTestTool(
            feature_encoder=None,
            ic_model=self.ic_model,
            data_root=self.config.data_root,
            batch_size=self.config.batch_size,
            num_workers=self.config.num_workers,
            ic_out_dim=self.config.ic_out_dim,
            transform=self.task_train.transform_test,
            txt_path=self.config.log_file)
        pass
Ejemplo n.º 8
0
    def train(self):
        Tools.print()
        best_accuracy = 0.0
        Tools.print("Training...", txt_path=self.config.log_file)

        # Init Update
        try:
            self.matching_net.eval()
            self.ic_model.eval()
            Tools.print("Init label {} .......", txt_path=self.config.log_file)
            self.produce_class.reset()
            for task_data, task_labels, task_index, task_ok in tqdm(
                    self.task_train_loader):
                ic_labels = RunnerTool.to_cuda(task_index[:, -1])
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)
                ic_out_logits, ic_out_l2norm = self.ic_model(task_data[:, -1])
                self.produce_class.cal_label(ic_out_l2norm, ic_labels)
                pass
            Tools.print("Epoch: {}/{}".format(self.produce_class.count,
                                              self.produce_class.count_2),
                        txt_path=self.config.log_file)
        finally:
            pass

        for epoch in range(1, 1 + self.config.train_epoch):
            self.matching_net.train()
            self.ic_model.train()

            Tools.print()
            mn_lr = self.config.adjust_learning_rate(self.matching_net_optim,
                                                     epoch,
                                                     self.config.first_epoch,
                                                     self.config.t_epoch,
                                                     self.config.learning_rate)
            ic_lr = self.config.adjust_learning_rate(self.ic_model_optim,
                                                     epoch,
                                                     self.config.first_epoch,
                                                     self.config.t_epoch,
                                                     self.config.learning_rate)
            Tools.print('Epoch: [{}] mn_lr={} ic_lr={}'.format(
                epoch, mn_lr, ic_lr),
                        txt_path=self.config.log_file)

            self.produce_class.reset()
            Tools.print(self.task_train.classes)
            is_ok_total, is_ok_acc = 0, 0
            all_loss, all_loss_fsl, all_loss_ic = 0.0, 0.0, 0.0
            for task_data, task_labels, task_index, task_ok in tqdm(
                    self.task_train_loader):
                ic_labels = RunnerTool.to_cuda(task_index[:, -1])
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)

                ###########################################################################
                # 1 calculate features
                relations = self.matching(task_data)
                ic_out_logits, ic_out_l2norm = self.ic_model(task_data[:, -1])

                # 2
                ic_targets = self.produce_class.get_label(ic_labels)
                self.produce_class.cal_label(ic_out_l2norm, ic_labels)

                # 3 loss
                loss_fsl = self.fsl_loss(relations, task_labels)
                loss_ic = self.ic_loss(ic_out_logits, ic_targets)
                loss = loss_fsl * self.config.loss_fsl_ratio + loss_ic * self.config.loss_ic_ratio
                all_loss += loss.item()
                all_loss_fsl += loss_fsl.item()
                all_loss_ic += loss_ic.item()

                # 4 backward
                self.ic_model.zero_grad()
                loss_ic.backward()
                self.ic_model_optim.step()

                self.matching_net.zero_grad()
                loss_fsl.backward()
                self.matching_net_optim.step()

                # is ok
                is_ok_acc += torch.sum(torch.cat(task_ok))
                is_ok_total += torch.prod(
                    torch.tensor(torch.cat(task_ok).shape))
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print(
                "{:6} loss:{:.3f} fsl:{:.3f} ic:{:.3f} ok:{:.3f}({}/{})".
                format(epoch + 1, all_loss / len(self.task_train_loader),
                       all_loss_fsl / len(self.task_train_loader),
                       all_loss_ic / len(self.task_train_loader),
                       int(is_ok_acc) / int(is_ok_total), is_ok_acc,
                       is_ok_total),
                txt_path=self.config.log_file)
            Tools.print("Train: [{}] {}/{}".format(epoch,
                                                   self.produce_class.count,
                                                   self.produce_class.count_2),
                        txt_path=self.config.log_file)
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % self.config.val_freq == 0:
                self.matching_net.eval()
                self.ic_model.eval()

                self.test_tool_ic.val(epoch=epoch)
                val_accuracy = self.test_tool_fsl.val(episode=epoch,
                                                      is_print=True)

                if val_accuracy > best_accuracy:
                    best_accuracy = val_accuracy
                    torch.save(self.matching_net.state_dict(),
                               Tools.new_dir(self.config.mn_dir))
                    torch.save(self.ic_model.state_dict(),
                               Tools.new_dir(self.config.ic_dir))
                    Tools.print("Save networks for epoch: {}".format(epoch),
                                txt_path=self.config.log_file)
                    pass
                pass
            ###########################################################################
            pass

        pass
Ejemplo n.º 9
0
    def train(self):
        Tools.print()
        Tools.print("Training...")
        best_accuracy = 0.0

        for epoch in range(1, 1 + Config.fsl_train_epoch):
            self.matching_net.train()

            Tools.print()
            all_loss, is_ok_total, is_ok_acc = 0.0, 0, 0
            for task_data, task_labels, task_index, task_ok in tqdm(
                    self.task_train_loader):
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)

                ###########################################################################
                # 1 calculate features
                relations = self.matching(task_data)

                # 3 loss
                loss = self.fsl_loss(relations, task_labels)
                all_loss += loss.item()

                # 4 backward
                self.matching_net.zero_grad()
                loss.backward()
                self.matching_net_optim.step()

                # is ok
                is_ok_acc += torch.sum(torch.cat(task_ok))
                is_ok_total += torch.prod(
                    torch.tensor(torch.cat(task_ok).shape))
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f} ok:{:.3f}({}/{}) lr:{}".format(
                epoch, all_loss / len(self.task_train_loader),
                int(is_ok_acc) / int(is_ok_total), is_ok_acc, is_ok_total,
                self.matching_net_scheduler.get_last_lr()))
            self.matching_net_scheduler.step()
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.fsl_val_freq == 0:
                self.matching_net.eval()

                val_accuracy = self.test_tool_fsl.val(episode=epoch,
                                                      is_print=True,
                                                      has_test=False)

                if val_accuracy > best_accuracy:
                    best_accuracy = val_accuracy
                    torch.save(self.matching_net.state_dict(), Config.mn_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass
Ejemplo n.º 10
0
    def train(self):
        Tools.print()
        Tools.print("Training...")
        best_accuracy = 0.0

        # Init Update
        try:
            self.ic_model.eval()
            Tools.print("Init label {} .......")
            self.produce_class.reset()
            with torch.no_grad():
                for image, label, idx in tqdm(self.ic_train_loader):
                    image, idx = RunnerTool.to_cuda(image), RunnerTool.to_cuda(
                        idx)
                    ic_out_logits, ic_out_l2norm = self.ic_model(image)
                    self.produce_class.cal_label(ic_out_l2norm, idx)
                    pass
                pass
            Tools.print("Epoch: {}/{}".format(self.produce_class.count,
                                              self.produce_class.count_2))
        finally:
            pass

        for epoch in range(1, 1 + Config.ic_train_epoch):
            self.ic_model.train()

            Tools.print()
            ic_lr = self.adjust_learning_rate(self.ic_model_optim, epoch,
                                              Config.ic_first_epoch,
                                              Config.ic_t_epoch,
                                              Config.ic_learning_rate)
            Tools.print('Epoch: [{}] ic_lr={}'.format(epoch, ic_lr))

            all_loss = 0.0
            self.produce_class.reset()
            for image, label, idx in tqdm(self.ic_train_loader):
                image, label, idx = RunnerTool.to_cuda(
                    image), RunnerTool.to_cuda(label), RunnerTool.to_cuda(idx)

                ###########################################################################
                # 1 calculate features
                ic_out_logits, ic_out_l2norm = self.ic_model(image)

                # 2 calculate labels
                ic_targets = self.produce_class.get_label(idx)
                self.produce_class.cal_label(ic_out_l2norm, idx)

                # 3 loss
                loss = self.ic_loss(ic_out_logits, ic_targets)
                all_loss += loss.item()

                # 4 backward
                self.ic_model.zero_grad()
                loss.backward()
                self.ic_model_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.4f}".format(
                epoch, all_loss / len(self.ic_train_loader)))
            Tools.print("Train: [{}] {}/{}".format(epoch,
                                                   self.produce_class.count,
                                                   self.produce_class.count_2))
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.ic_val_freq == 0:
                self.ic_model.eval()

                val_accuracy = self.test_tool_ic.val(epoch=epoch)
                if val_accuracy > best_accuracy:
                    best_accuracy = val_accuracy
                    torch.save(self.ic_model.state_dict(), Config.ic_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass
    def train(self):
        Tools.print()
        Tools.print("Training...")
        best_accuracy = 0.0

        for epoch in range(1, 1 + Config.train_epoch):
            self.net.train()

            Tools.print()
            all_loss = 0.0
            net_lr = Config.adjust_learning_rate(self.net_optim, epoch,
                                                 Config.first_epoch,
                                                 Config.t_epoch,
                                                 Config.learning_rate)
            Tools.print('Epoch: [{}] net_lr={}'.format(epoch, net_lr))
            for task_data, task_labels, task_index in tqdm(
                    self.task_train_loader):
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)

                ###########################################################################
                # 1 calculate features
                relations = self.matching(task_data)

                # 2 loss
                loss = self.loss(relations, task_labels)
                all_loss += loss.item()

                # 3 backward
                self.net.zero_grad()
                loss.backward()
                self.net_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f}".format(
                epoch, all_loss / len(self.task_train_loader)))
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                Tools.print()
                Tools.print("Test {} {} .......".format(
                    epoch, Config.model_name))

                self.net.eval()

                val_accuracy = self.test_tool_fsl.val(episode=epoch,
                                                      is_print=True)
                if val_accuracy > best_accuracy:
                    best_accuracy = val_accuracy
                    torch.save(self.net.state_dict(), Config.net_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass