def __init__(self):
        self.best_accuracy = 0.0

        # all data
        self.data_train = CUBDataset.get_data_all(Config.data_root)
        self.task_train = CUBDataset(self.data_train, Config.num_way,
                                     Config.num_shot)
        self.task_train_loader = DataLoader(self.task_train,
                                            Config.batch_size,
                                            True,
                                            num_workers=Config.num_workers)

        # model
        self.matching_net = RunnerTool.to_cuda(Config.matching_net)
        RunnerTool.to_cuda(self.matching_net.apply(RunnerTool.weights_init))
        self.norm = Normalize(2)

        # loss
        self.loss = RunnerTool.to_cuda(nn.MSELoss())

        # optim
        self.matching_net_optim = torch.optim.Adam(
            self.matching_net.parameters(), lr=Config.learning_rate)
        self.matching_net_scheduler = MultiStepLR(self.matching_net_optim,
                                                  Config.train_epoch_lr,
                                                  gamma=0.5)

        self.test_tool = FSLTestTool(self.matching_test,
                                     data_root=Config.data_root,
                                     num_way=Config.num_way,
                                     num_shot=Config.num_shot,
                                     episode_size=Config.episode_size,
                                     test_episode=Config.test_episode,
                                     transform=self.task_train.transform_test)
        pass
    def __init__(self):
        # all data
        self.data_train = OmniglotDataset.get_data_all(Config.data_root)
        self.task_train = OmniglotDataset(self.data_train, Config.num_way,
                                          Config.num_shot)
        self.task_train_loader = DataLoader(self.task_train,
                                            Config.batch_size,
                                            True,
                                            num_workers=Config.num_workers)

        # model
        self.matching_net = RunnerTool.to_cuda(
            MatchingNet(hid_dim=64, z_dim=64))
        RunnerTool.to_cuda(self.matching_net.apply(RunnerTool.weights_init))
        self.norm = Normalize(2)

        # loss
        self.loss = RunnerTool.to_cuda(nn.MSELoss())

        # optim
        self.matching_net_optim = torch.optim.SGD(
            self.matching_net.parameters(),
            lr=Config.learning_rate,
            momentum=0.9,
            weight_decay=5e-4)

        self.test_tool = FSLTestTool(self.matching_test,
                                     data_root=Config.data_root,
                                     num_way=Config.num_way_test,
                                     num_shot=Config.num_shot,
                                     episode_size=Config.episode_size,
                                     test_episode=Config.test_episode,
                                     transform=self.task_train.transform_test)
        pass
    def train(self):
        Tools.print()
        Tools.print("Training...")

        for epoch in range(1, 1 + Config.train_epoch):
            self.matching_net.train()

            Tools.print()
            all_loss = 0.0
            for task_data, task_labels, task_index in tqdm(
                    self.task_train_loader):
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)

                # 1 calculate features
                predicts = self.matching(task_data)

                # 2 loss
                loss = self.loss(predicts, task_labels)
                all_loss += loss.item()

                # 3 backward
                self.matching_net.zero_grad()
                loss.backward()
                self.matching_net_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f} lr:{}".format(
                epoch, all_loss / len(self.task_train_loader),
                self.matching_net_scheduler.get_last_lr()))

            self.matching_net_scheduler.step()
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                Tools.print()
                Tools.print("Test {} {} .......".format(
                    epoch, Config.model_name))
                self.matching_net.eval()

                val_accuracy = self.test_tool.val(episode=epoch,
                                                  is_print=True,
                                                  has_test=False)
                if val_accuracy > self.best_accuracy:
                    self.best_accuracy = val_accuracy
                    torch.save(self.matching_net.state_dict(), Config.mn_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass
    def __init__(self):
        self.adjust_learning_rate = Config.ic_adjust_learning_rate

        # data
        self.data_train = TieredImageNetICDataset.get_data_all(
            Config.data_root)
        self.tiered_imagenet_dataset = TieredImageNetICDataset(self.data_train)
        self.ic_train_loader = DataLoader(self.tiered_imagenet_dataset,
                                          Config.ic_batch_size,
                                          shuffle=True,
                                          num_workers=Config.num_workers)
        self.ic_train_loader_eval = DataLoader(self.tiered_imagenet_dataset,
                                               Config.ic_batch_size,
                                               shuffle=False,
                                               num_workers=Config.num_workers)

        # IC
        self.produce_class = ProduceClass(len(self.data_train),
                                          Config.ic_out_dim, Config.ic_ratio)
        self.produce_class.init()

        # model
        self.ic_model = RunnerTool.to_cuda(
            ICResNet(Config.ic_resnet,
                     low_dim=Config.ic_out_dim,
                     modify_head=Config.ic_modify_head))
        self.ic_model = RunnerTool.to_cuda(nn.DataParallel(self.ic_model))
        cudnn.benchmark = True
        RunnerTool.to_cuda(self.ic_model.apply(RunnerTool.weights_init))
        self.ic_loss = RunnerTool.to_cuda(nn.CrossEntropyLoss())

        self.ic_model_optim = torch.optim.SGD(self.ic_model.parameters(),
                                              lr=Config.ic_learning_rate,
                                              momentum=0.9,
                                              weight_decay=5e-4)

        # Eval
        self.test_tool_ic = ICTestTool(
            feature_encoder=None,
            ic_model=self.ic_model,
            data_root=Config.data_root,
            batch_size=Config.ic_batch_size,
            num_workers=Config.num_workers,
            ic_out_dim=Config.ic_out_dim,
            transform=self.ic_train_loader.dataset.transform_test,
            k=Config.ic_knn)
        pass
    def __init__(self):
        self.adjust_learning_rate = Config.adjust_learning_rate

        # all data
        self.data_train = OmniglotDataset.get_data_all(Config.data_root)
        self.task_train = OmniglotDataset(self.data_train, Config.num_way, Config.num_shot)
        self.task_train_loader = DataLoader(self.task_train, Config.batch_size, True, num_workers=Config.num_workers)

        # IC
        self.produce_class = ProduceClass(len(self.data_train), Config.ic_out_dim, Config.ic_ratio)
        self.produce_class.init()
        self.task_train.set_samples_class(self.produce_class.classes)

        # model
        self.matching_net = RunnerTool.to_cuda(Config.matching_net)
        self.ic_model = RunnerTool.to_cuda(ICResNet(low_dim=Config.ic_out_dim, encoder=Config.ic_net))
        self.norm = Normalize(2)
        if Config.multi_gpu:
            self.matching_net = RunnerTool.to_cuda(nn.DataParallel(self.matching_net))
            cudnn.benchmark = True
            pass
        # RunnerTool.to_cuda(self.matching_net.apply(RunnerTool.weights_init))
        # RunnerTool.to_cuda(self.ic_model.apply(RunnerTool.weights_init))

        # optim
        self.matching_net_optim = torch.optim.SGD(
            self.matching_net.parameters(), lr=Config.learning_rate, momentum=0.9, weight_decay=5e-4)
        self.ic_model_optim = torch.optim.SGD(
            self.ic_model.parameters(), lr=Config.learning_rate, momentum=0.9, weight_decay=5e-4)

        # loss
        self.ic_loss = RunnerTool.to_cuda(nn.CrossEntropyLoss())
        self.fsl_loss = RunnerTool.to_cuda(nn.MSELoss())

        # Eval
        self.test_tool_fsl = FSLTestTool(self.matching_test, data_root=Config.data_root,
                                         num_way=Config.num_way_test, num_shot=Config.num_shot,
                                         episode_size=Config.episode_size, test_episode=Config.test_episode,
                                         transform=self.task_train.transform_fsl_test)
        self.test_tool_ic = ICTestTool(feature_encoder=None, ic_model=self.ic_model,
                                       data_root=Config.data_root, batch_size=Config.batch_size,
                                       num_workers=Config.num_workers, ic_out_dim=Config.ic_out_dim,
                                       transform=self.task_train_loader.dataset.transform_ic_test, k=20)
        pass
 def eval(self):
     self.ic_model.eval()
     with torch.no_grad():
         ic_classes = np.zeros(shape=(len(self.tiered_imagenet_dataset), ),
                               dtype=np.int)
         for image, label, idx in tqdm(self.ic_train_loader_eval):
             ic_out_logits, ic_out_l2norm = self.ic_model(
                 RunnerTool.to_cuda(image))
             ic_classes[idx] = np.argmax(ic_out_l2norm.cpu().numpy(),
                                         axis=-1)
             pass
     return self.data_train, ic_classes
    def __init__(self, data_train, classes):
        # all data
        self.data_train = data_train
        self.task_train = TieredImageNetFSLDataset(self.data_train, classes,
                                                   Config.fsl_num_way,
                                                   Config.fsl_num_shot)
        self.task_train_loader = DataLoader(self.task_train,
                                            Config.fsl_batch_size,
                                            True,
                                            num_workers=Config.num_workers)

        # model
        self.matching_net = RunnerTool.to_cuda(Config.fsl_matching_net)
        self.matching_net = RunnerTool.to_cuda(
            nn.DataParallel(self.matching_net))
        cudnn.benchmark = True
        RunnerTool.to_cuda(self.matching_net.apply(RunnerTool.weights_init))
        self.norm = Normalize(2)

        # optim
        self.matching_net_optim = torch.optim.Adam(
            self.matching_net.parameters(), lr=Config.fsl_learning_rate)
        self.matching_net_scheduler = MultiStepLR(self.matching_net_optim,
                                                  Config.fsl_lr_schedule,
                                                  gamma=1 / 3)

        # loss
        self.fsl_loss = RunnerTool.to_cuda(nn.MSELoss())

        # Eval
        self.test_tool_fsl = FSLTestTool(
            self.matching_test,
            data_root=Config.data_root,
            num_way=Config.fsl_num_way,
            num_shot=Config.fsl_num_shot,
            episode_size=Config.fsl_episode_size,
            test_episode=Config.fsl_test_episode,
            transform=self.task_train.transform_test)
        pass
示例#8
0
    def __init__(self):
        self.best_accuracy = 0.0
        self.adjust_learning_rate = Config.adjust_learning_rate

        # all data
        self.data_train = TieredImageNetDataset.get_data_all(Config.data_root)
        self.task_train = TieredImageNetDataset(self.data_train, Config.num_way,
                                                Config.num_shot, load_data=Config.load_data)
        self.task_train_loader = DataLoader(self.task_train, Config.batch_size, True, num_workers=Config.num_workers)

        # IC
        self.produce_class = ProduceClass(len(self.data_train), Config.ic_out_dim, Config.ic_ratio)
        self.produce_class.init()
        self.task_train.set_samples_class(self.produce_class.classes)
        self.task_train.set_samples_feature(self.produce_class.features)

        # model
        self.matching_net = RunnerTool.to_cuda(Config.matching_net)
        self.ic_model = RunnerTool.to_cuda(ICResNet(low_dim=Config.ic_out_dim,
                                                    resnet=Config.resnet, modify_head=Config.modify_head))
        self.norm = Normalize(2)
        RunnerTool.to_cuda(self.matching_net.apply(RunnerTool.weights_init))
        RunnerTool.to_cuda(self.ic_model.apply(RunnerTool.weights_init))

        # optim
        self.matching_net_optim = torch.optim.SGD(
            self.matching_net.parameters(), lr=Config.learning_rate, momentum=0.9, weight_decay=5e-4)
        self.ic_model_optim = torch.optim.SGD(
            self.ic_model.parameters(), lr=Config.learning_rate, momentum=0.9, weight_decay=5e-4)

        # loss
        self.ic_loss = RunnerTool.to_cuda(nn.CrossEntropyLoss())
        self.fsl_loss = RunnerTool.to_cuda(nn.MSELoss())

        # Eval
        self.test_tool_fsl = FSLTestTool(self.matching_test, data_root=Config.data_root,
                                         num_way=Config.num_way, num_shot=Config.num_shot,
                                         episode_size=Config.episode_size, test_episode=Config.test_episode,
                                         transform=self.task_train.transform_test)
        self.test_tool_ic = ICTestTool(feature_encoder=None, ic_model=self.ic_model, data_root=Config.data_root,
                                       transform=self.task_train.transform_test, batch_size=Config.batch_size,
                                       num_workers=Config.num_workers, ic_out_dim=Config.ic_out_dim, k=Config.knn,
                                       load_data=Config.load_data)
        pass
    def train(self):
        Tools.print()
        Tools.print("Training...")

        # Init Update
        try:
            self.matching_net.eval()
            self.ic_model.eval()
            Tools.print("Init label {} .......")
            self.produce_class.reset()
            for task_data, task_labels, task_index, task_ok in tqdm(self.task_train_loader):
                ic_labels = RunnerTool.to_cuda(task_index[:, -1])
                task_data, task_labels = RunnerTool.to_cuda(task_data), RunnerTool.to_cuda(task_labels)
                ic_out_logits, ic_out_l2norm = self.ic_model(task_data[:, -1])
                self.produce_class.cal_label(ic_out_l2norm, ic_labels)
                pass
            Tools.print("Epoch: {}/{}".format(self.produce_class.count, self.produce_class.count_2))
        finally:
            pass

        best_accuracy = 0.0
        for epoch in range(1, 1 + Config.train_epoch):
            self.matching_net.train()
            self.ic_model.train()

            Tools.print()
            mn_lr= self.adjust_learning_rate(self.matching_net_optim, epoch,
                                             Config.first_epoch, Config.t_epoch, Config.learning_rate)
            ic_lr = self.adjust_learning_rate(self.ic_model_optim, epoch,
                                              Config.first_epoch, Config.t_epoch, Config.learning_rate)
            Tools.print('Epoch: [{}] mn_lr={} ic_lr={}'.format(epoch, mn_lr, ic_lr))

            self.produce_class.reset()
            Tools.print(self.task_train.classes)
            is_ok_total, is_ok_acc = 0, 0
            all_loss, all_loss_fsl, all_loss_ic = 0.0, 0.0, 0.0
            for task_data, task_labels, task_index, task_ok in tqdm(self.task_train_loader):
                ic_labels = RunnerTool.to_cuda(task_index[:, -1])
                task_data, task_labels = RunnerTool.to_cuda(task_data), RunnerTool.to_cuda(task_labels)

                ###########################################################################
                # 1 calculate features
                relations = self.matching(task_data)
                ic_out_logits, ic_out_l2norm = self.ic_model(task_data[:, -1])

                # 2
                ic_targets = self.produce_class.get_label(ic_labels)
                self.produce_class.cal_label(ic_out_l2norm, ic_labels)

                # 3 loss
                loss_fsl = self.fsl_loss(relations, task_labels)
                loss_ic = self.ic_loss(ic_out_logits, ic_targets)
                loss = loss_fsl + loss_ic
                all_loss += loss.item()
                all_loss_fsl += loss_fsl.item()
                all_loss_ic += loss_ic.item()

                # 4 backward
                self.ic_model.zero_grad()
                loss_ic.backward()
                self.ic_model_optim.step()

                self.matching_net.zero_grad()
                loss_fsl.backward()
                self.matching_net_optim.step()

                # is ok
                is_ok_acc += torch.sum(torch.cat(task_ok))
                is_ok_total += torch.prod(torch.tensor(torch.cat(task_ok).shape))
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f} fsl:{:.3f} ic:{:.3f} ok:{:.3f}({}/{})".format(
                epoch + 1, all_loss / len(self.task_train_loader),
                all_loss_fsl / len(self.task_train_loader), all_loss_ic / len(self.task_train_loader),
                int(is_ok_acc) / int(is_ok_total), is_ok_acc, is_ok_total, ))
            Tools.print("Train: [{}] {}/{}".format(epoch, self.produce_class.count, self.produce_class.count_2))
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                self.matching_net.eval()
                self.ic_model.eval()

                self.test_tool_ic.val(epoch=epoch)
                val_accuracy = self.test_tool_fsl.val(episode=epoch, is_print=True)

                if val_accuracy > best_accuracy:
                    best_accuracy = val_accuracy
                    torch.save(self.matching_net.state_dict(), Config.mn_dir)
                    torch.save(self.ic_model.state_dict(), Config.ic_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass
    def train(self):
        Tools.print()
        Tools.print("Training...")
        best_accuracy = 0.0

        for epoch in range(1, 1 + Config.fsl_train_epoch):
            self.matching_net.train()

            Tools.print()
            all_loss, is_ok_total, is_ok_acc = 0.0, 0, 0
            for task_data, task_labels, task_index, task_ok in tqdm(
                    self.task_train_loader):
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)

                ###########################################################################
                # 1 calculate features
                relations = self.matching(task_data)

                # 3 loss
                loss = self.fsl_loss(relations, task_labels)
                all_loss += loss.item()

                # 4 backward
                self.matching_net.zero_grad()
                loss.backward()
                self.matching_net_optim.step()

                # is ok
                is_ok_acc += torch.sum(torch.cat(task_ok))
                is_ok_total += torch.prod(
                    torch.tensor(torch.cat(task_ok).shape))
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f} ok:{:.3f}({}/{}) lr:{}".format(
                epoch, all_loss / len(self.task_train_loader),
                int(is_ok_acc) / int(is_ok_total), is_ok_acc, is_ok_total,
                self.matching_net_scheduler.get_last_lr()))
            self.matching_net_scheduler.step()
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.fsl_val_freq == 0:
                self.matching_net.eval()

                val_accuracy = self.test_tool_fsl.val(episode=epoch,
                                                      is_print=True,
                                                      has_test=False)

                if val_accuracy > best_accuracy:
                    best_accuracy = val_accuracy
                    torch.save(self.matching_net.state_dict(), Config.mn_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass
    def train(self):
        Tools.print()
        Tools.print("Training...")
        best_accuracy = 0.0

        # Init Update
        try:
            self.ic_model.eval()
            Tools.print("Init label {} .......")
            self.produce_class.reset()
            with torch.no_grad():
                for image, label, idx in tqdm(self.ic_train_loader):
                    image, idx = RunnerTool.to_cuda(image), RunnerTool.to_cuda(
                        idx)
                    ic_out_logits, ic_out_l2norm = self.ic_model(image)
                    self.produce_class.cal_label(ic_out_l2norm, idx)
                    pass
                pass
            Tools.print("Epoch: {}/{}".format(self.produce_class.count,
                                              self.produce_class.count_2))
        finally:
            pass

        for epoch in range(1, 1 + Config.ic_train_epoch):
            self.ic_model.train()

            Tools.print()
            ic_lr = self.adjust_learning_rate(self.ic_model_optim, epoch,
                                              Config.ic_first_epoch,
                                              Config.ic_t_epoch,
                                              Config.ic_learning_rate)
            Tools.print('Epoch: [{}] ic_lr={}'.format(epoch, ic_lr))

            all_loss = 0.0
            self.produce_class.reset()
            for image, label, idx in tqdm(self.ic_train_loader):
                image, label, idx = RunnerTool.to_cuda(
                    image), RunnerTool.to_cuda(label), RunnerTool.to_cuda(idx)

                ###########################################################################
                # 1 calculate features
                ic_out_logits, ic_out_l2norm = self.ic_model(image)

                # 2 calculate labels
                ic_targets = self.produce_class.get_label(idx)
                self.produce_class.cal_label(ic_out_l2norm, idx)

                # 3 loss
                loss = self.ic_loss(ic_out_logits, ic_targets)
                all_loss += loss.item()

                # 4 backward
                self.ic_model.zero_grad()
                loss.backward()
                self.ic_model_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f}".format(
                epoch, all_loss / len(self.ic_train_loader)))
            Tools.print("Train: [{}] {}/{}".format(epoch,
                                                   self.produce_class.count,
                                                   self.produce_class.count_2))
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.ic_val_freq == 0:
                self.ic_model.eval()

                val_accuracy = self.test_tool_ic.val(epoch=epoch)
                if val_accuracy > best_accuracy:
                    best_accuracy = val_accuracy
                    torch.save(self.ic_model.state_dict(), Config.ic_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass