def build(self, model):
        model.eval()
        with torch.no_grad():
            for label, batch in enumerate(self.dataloader):
                if label % 100 == 0:
                    logger.info("building knn sampler: {} labels".format(
                        label * self.batch_size // self.each_class))
                input, target = batch
                output = model(input)
                output = output.reshape(-1, self.each_class, self.dim)
                cur_label_num = len(output)
                for i in range(cur_label_num):
                    # num_each_class x dim
                    output_norm = F.normalize(output[i])
                    one_label = int(target[i * self.each_class].item())
                    self.features[one_label] = output_norm.mean(0)
                    loss = 1.0 - torch.matmul(output_norm,
                                              output_norm.t()).min().item()
                    self.losses[one_label] = max(loss, 1e-6)

            self.distance = torch.matmul(self.features, self.features.t())
            _, self.knn = self.distance.topk(self.knn_num, 1, True, True)
            self.knn = self.knn.tolist()
            logger.info("initialize knn: {}".format(
                beautify_info(self.knn[:10])))
Beispiel #2
0
 def from_config(cls, cfg, mode="train"):
     if (mode not in cfg) or (not cfg[mode]):
         return None
     self = cls(cfg, mode)
     pin_memory = cfg["pin_memory"]
     data_loader = DataLoader(self, batch_size=self.batch_size, num_workers=self.num_workers,
                              pin_memory=pin_memory, sampler=self.sampler, drop_last=(mode != "test"))
     logger.info("data loader is: {}".format(beautify_info(data_loader)))
     return data_loader
Beispiel #3
0
 def from_config(cls, cfg, mode="train"):
     if (mode not in cfg) or (not cfg[mode]):
         return None
     self = cls(cfg, mode)
     logger.info("===mode-{}===\ncollate_fn: {}\nsampler:{}".format(
         mode, beautify_info(self.collate_fn), beautify_info(self.sampler)))
     is_shuffle = True
     if mode != "train" or self.sampler is not None:
         is_shuffle = False
     pin_memory = cfg["pin_memory"]
     data_loader = DataLoader(self,
                              batch_size=self.batch_size,
                              shuffle=is_shuffle,
                              num_workers=self.num_workers,
                              collate_fn=self.collate_fn,
                              pin_memory=pin_memory,
                              sampler=self.sampler,
                              drop_last=(mode != "test"))
     logger.info("data loader is: {}".format(beautify_info(data_loader)))
     return data_loader
 def __iter__(self):
     selected_labels = self._sample()
     if self.p_steps == 0:
         logger.info("selected labels: {}".format(
             beautify_info(selected_labels[:self.target_each_batch])))
     sample_list = []
     for sl in selected_labels:
         indices = self.label2index[sl]
         if len(indices) > self.each_class:
             selected_index = np.random.choice(indices,
                                               size=self.each_class,
                                               replace=False)
         else:
             selected_index = np.random.choice(indices,
                                               size=self.each_class,
                                               replace=True)
         sample_list.extend(list(selected_index))
     return iter(sample_list)
Beispiel #5
0
 def __repr__(self):
     return beautify_info(self)
Beispiel #6
0
def main():
    config = create_config()
    init_logger(config["log_file"])
    logger.info("create configure successfully: {}".format(
        beautify_info(config)))
    runner(config)