コード例 #1
0
    def __init__(self):
        self.best_accuracy = 0.0
        self.adjust_learning_rate = Config.adjust_learning_rate

        # all data
        self.data_train = MiniImageNetDataset.get_data_all(Config.data_root)
        self.task_train = MiniImageNetDataset(self.data_train, Config.num_way, Config.num_shot)
        self.task_train_loader = DataLoader(self.task_train, Config.batch_size, True, num_workers=Config.num_workers)

        # model
        self.matching_net = RunnerTool.to_cuda(Config.matching_net)
        RunnerTool.to_cuda(self.matching_net.apply(RunnerTool.weights_init))
        self.norm = Normalize(2)

        # loss
        self.loss = RunnerTool.to_cuda(nn.MSELoss())

        # optim
        self.matching_net_optim = torch.optim.SGD(
            self.matching_net.parameters(), lr=Config.learning_rate, momentum=0.9, weight_decay=5e-4)

        self.test_tool = TestTool(self.matching_test, data_root=Config.data_root,
                                  num_way=Config.num_way_test,  num_shot=Config.num_shot,
                                  episode_size=Config.episode_size, test_episode=Config.test_episode,
                                  transform=self.task_train.transform_test)
        pass
コード例 #2
0
 def __init__(self, resnet, low_dim=512, modify_head=False):
     super().__init__()
     self.resnet = resnet(num_classes=low_dim)
     self.l2norm = Normalize(2)
     if modify_head:
         self.resnet.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
         pass
     pass
コード例 #3
0
    def __init__(self):
        self.adjust_learning_rate = Config.adjust_learning_rate

        # all data
        self.data_train = MiniImageNetDataset.get_data_all(Config.data_root)
        self.task_train = MiniImageNetDataset(self.data_train, Config.num_way,
                                              Config.num_shot)
        self.task_train_loader = DataLoader(self.task_train,
                                            Config.batch_size,
                                            True,
                                            num_workers=Config.num_workers)

        # IC
        self.produce_class = ProduceClass(len(self.data_train),
                                          Config.ic_out_dim, Config.ic_ratio)
        self.produce_class.init()
        self.task_train.set_samples_class(self.produce_class.classes)

        # model
        self.matching_net = RunnerTool.to_cuda(Config.matching_net)
        self.norm = Normalize(2)
        if Config.multi_gpu:
            self.matching_net = RunnerTool.to_cuda(
                nn.DataParallel(self.matching_net))
            cudnn.benchmark = True
            pass
        RunnerTool.to_cuda(self.matching_net.apply(RunnerTool.weights_init))

        # optim
        self.matching_net_optim = torch.optim.SGD(
            self.matching_net.parameters(),
            lr=Config.learning_rate,
            momentum=0.9,
            weight_decay=5e-4)

        # loss
        self.ic_loss = RunnerTool.to_cuda(nn.CrossEntropyLoss())
        self.fsl_loss = RunnerTool.to_cuda(nn.MSELoss())

        # Eval
        self.test_tool_fsl = TestTool(self.matching_test,
                                      data_root=Config.data_root,
                                      num_way=Config.num_way,
                                      num_shot=Config.num_shot,
                                      episode_size=Config.episode_size,
                                      test_episode=Config.test_episode,
                                      transform=self.task_train.transform_test)
        self.test_tool_ic = ICTestTool(feature_encoder=None,
                                       ic_model=self.matching_net,
                                       data_root=Config.data_root,
                                       batch_size=Config.batch_size,
                                       num_workers=Config.num_workers,
                                       ic_out_dim=Config.ic_out_dim)
        pass
コード例 #4
0
    def __init__(self):
        self.best_accuracy = 0.0

        # all data
        self.data_train = MiniImageNetDataset.get_data_all(Config.data_root)
        self.task_train = MiniImageNetDataset(self.data_train, Config.num_way,
                                              Config.num_shot)
        self.task_train_loader = DataLoader(self.task_train,
                                            Config.batch_size,
                                            True,
                                            num_workers=Config.num_workers)

        # model
        self.matching_net = RunnerTool.to_cuda(Config.matching_net)
        RunnerTool.to_cuda(self.matching_net.apply(RunnerTool.weights_init))
        self.norm = Normalize(2)

        self.has_norm = Config.has_norm
        self.has_softmax = Config.has_softmax

        # loss
        self.loss = RunnerTool.to_cuda(nn.MSELoss())

        # optim
        self.matching_net_optim = torch.optim.Adam(
            self.matching_net.parameters(), lr=Config.learning_rate)
        self.matching_net_scheduler = StepLR(self.matching_net_optim,
                                             Config.train_epoch // 3,
                                             gamma=0.5)

        self.test_tool = TestTool(self.matching_test,
                                  data_root=Config.data_root,
                                  num_way=Config.num_way,
                                  num_shot=Config.num_shot,
                                  episode_size=Config.episode_size,
                                  test_episode=Config.test_episode,
                                  transform=self.task_train.transform_test)
        pass
 def __init__(self, encoder, low_dim=512):
     super().__init__()
     self.encoder = encoder
     self.fc = nn.Linear(self.encoder.out_dim, low_dim)
     self.l2norm = Normalize(2)
     pass