Beispiel #1
0
    def distributed_sinkhornknopp(self, Q: torch.Tensor):
        """
        Apply the distributed sinknorn optimization on the scores matrix to
        find the assignments
        """
        with torch.no_grad():
            sum_Q = torch.sum(Q, dtype=Q.dtype)
            all_reduce_sum(sum_Q)
            Q /= sum_Q

            k = Q.shape[0]
            n = Q.shape[1]
            N = get_world_size() * Q.shape[1]

            # we follow the u, r, c and Q notations from
            # https://arxiv.org/abs/1911.05371
            r = torch.ones(k) / k
            c = torch.ones(n) / N

            if self.use_gpu:
                r = r.cuda(non_blocking=True)
                c = c.cuda(non_blocking=True)

            curr_sum = torch.sum(Q, dim=1, dtype=Q.dtype)
            all_reduce_sum(curr_sum)

            for _ in range(self.loss_config.num_iters):
                u = curr_sum
                Q *= (r / u).unsqueeze(1)
                Q *= (c / torch.sum(Q, dim=0, dtype=Q.dtype)).unsqueeze(0)
                curr_sum = torch.sum(Q, dim=1, dtype=Q.dtype)
                all_reduce_sum(curr_sum)
            return (
                Q /
                torch.sum(Q, dim=0, keepdim=True, dtype=Q.dtype)).t().float()
Beispiel #2
0
    def _build_momentum_network(self, task: tasks.ClassyTask) -> None:
        """
        Create the teacher: it is an exponential moving average of the student.
        """
        logging.info("Building momentum encoder")

        # - same architecture but do not apply stochastic depth
        task.config["MODEL"]["TRUNK"]["VISION_TRANSFORMERS"][
            "DROP_PATH_RATE"] = 0
        task.loss.momentum_teacher = build_model(task.config["MODEL"],
                                                 task.config["OPTIMIZER"])
        task.loss.momentum_teacher = nn.SyncBatchNorm.convert_sync_batchnorm(
            task.loss.momentum_teacher)
        task.loss.momentum_teacher.to(task.device)

        if get_world_size() > 1:
            task.loss.momentum_teacher = init_distributed_data_parallel_model(
                task.loss.momentum_teacher)

        # Restore an hypothetical checkpoint
        if task.loss.checkpoint is not None:
            task.loss.load_state_dict(task.loss.checkpoint)
        # Initialize from the model
        else:
            task.loss.momentum_teacher.load_state_dict(task.model.state_dict())
 def _get_sampler(self, epoch) -> "DistributedSampler":
     if self.split == "train":
         # For video model training, we don't necessarily want to use all possible
         # clips in the video in one training epoch. More often, we randomly
         # sample at most N clips per training video. In practice, N is often 1
         clip_sampler = RandomClipSampler(self.video_clips,
                                          self.clips_per_video)
     else:
         # For video model testing, we sample N evenly spaced clips per test
         # video. We will simply average predictions over them
         clip_sampler = UniformClipSampler(self.video_clips,
                                           self.clips_per_video)
     clip_sampler = MaxLengthClipSampler(clip_sampler,
                                         num_samples=self.num_samples)
     world_size = get_world_size()
     rank = get_rank()
     sampler = DistributedSampler(
         clip_sampler,
         num_replicas=world_size,
         rank=rank,
         shuffle=self.shuffle,
         group_size=self.clips_per_video,
     )
     sampler.set_epoch(epoch)
     return sampler
Beispiel #4
0
    def get_global_batchsize(self):
        """
        Get the global batch size, combined over all the replicas.

        Returns:
            The overall batch size of the dataset.
        """
        return self.get_batchsize_per_replica() * get_world_size()
Beispiel #5
0
 def _get_sampler(self, epoch):
     world_size = get_world_size()
     rank = get_rank()
     sampler = DistributedSampler(self,
                                  num_replicas=world_size,
                                  rank=rank,
                                  shuffle=self.shuffle)
     sampler.set_epoch(epoch)
     return sampler
Beispiel #6
0
    def update_center(self):
        """
        Update center used for teacher output.
        """
        batch_center = torch.sum(self.teacher_output, dim=0, keepdim=True)
        dist.all_reduce(batch_center)
        batch_center = batch_center / (len(self.teacher_output) * get_world_size())

        # ema update
        m = self.loss_config.ema_center
        self.center = self.center * m + batch_center * (1 - m)
Beispiel #7
0
 def __init__(self, cfg: AttrDict, data_source: str, path: str, split: str,
              dataset_name: str):
     super(AirstoreDataset, self).__init__(
         queue_size=cfg["DATA"][split]["BATCHSIZE_PER_REPLICA"])
     self.pathmanager = create_path_manager()
     self.cfg = cfg
     self.batch_size = cfg["DATA"][split]["BATCHSIZE_PER_REPLICA"]
     self.airstore_uri = path
     self.split = split
     self.epoch = 0
     self.start_iter = 0
     self.enable_queue_dataset = cfg["DATA"][
         self.split]["ENABLE_QUEUE_DATASET"]
     self.global_rank = get_rank()
     self.global_world_size = get_world_size()
     self._iterator = None
Beispiel #8
0
    def _get_sampler(self, epoch: int):
        """
        Return a :class:`torch.utils.data.sampler.Sampler` to sample the data.

        This is used to distribute the data across the replicas. If shuffling
        is enabled, every epoch will have a different shuffle.

        Args:
            epoch: The epoch being fetched.

        Returns:
            A sampler which tells the data loader which sample to load next.
        """
        world_size = get_world_size()
        rank = get_rank()
        sampler = DistributedSampler(self,
                                     num_replicas=world_size,
                                     rank=rank,
                                     shuffle=self.shuffle)
        sampler.set_epoch(epoch)
        return sampler
Beispiel #9
0
    def __init__(self, loss_config: AttrDict):
        super().__init__()

        self.loss_config = loss_config
        size_dataset = self.loss_config.num_train_samples
        size_memory_per_process = int(
            math.ceil(size_dataset * 1.0 / get_world_size()))

        if self.loss_config.DROP_LAST:
            size_memory_per_process -= (size_memory_per_process %
                                        self.loss_config.BATCHSIZE_PER_REPLICA)

        self.nmb_mbs = len(self.loss_config.memory_params.crops_for_mb)
        self.nmb_heads = len(self.loss_config.num_clusters)
        self.num_clusters = self.loss_config.num_clusters
        self.embedding_dim = self.loss_config.memory_params.embedding_dim
        self.crops_for_mb = self.loss_config.memory_params.crops_for_mb
        self.nmb_unique_idx = self.loss_config.BATCHSIZE_PER_REPLICA
        self.num_crops = self.loss_config.num_crops
        self.temperature = self.loss_config.temperature
        self.nmb_kmeans_iters = self.loss_config.kmeans_iters
        self.start_idx = 0

        self.register_buffer(
            "local_memory_embeddings",
            torch.zeros(self.nmb_mbs, size_memory_per_process,
                        self.embedding_dim),
        )
        self.register_buffer("local_memory_index",
                             torch.zeros(size_memory_per_process).long())
        self.register_buffer(
            "assignments",
            -100 * torch.ones(self.nmb_heads, size_dataset).long())
        for i, k in enumerate(self.loss_config.num_clusters):
            self.register_buffer("centroids" + str(i),
                                 torch.rand(k, self.embedding_dim))

        self.cross_entropy_loss = nn.CrossEntropyLoss(ignore_index=-100)
Beispiel #10
0
    def __init__(
        self,
        temperature: float,
        crops_for_assign: List[int],
        num_crops: int,
        num_iters: int,
        epsilon: float,
        use_double_prec: bool,
        num_prototypes: List[int],
        local_queue_length: int,
        embedding_dim: int,
        temp_hard_assignment_iters: int,
        output_dir: str,
    ):
        super(SwAVCriterion, self).__init__()

        self.use_gpu = get_cuda_device_index() > -1

        self.temperature = temperature
        self.crops_for_assign = crops_for_assign
        self.num_crops = num_crops
        self.nmb_sinkhornknopp_iters = num_iters
        self.epsilon = epsilon
        self.use_double_prec = use_double_prec
        self.num_prototypes = num_prototypes
        self.nmb_heads = len(self.num_prototypes)
        self.embedding_dim = embedding_dim
        self.temp_hard_assignment_iters = temp_hard_assignment_iters
        self.local_queue_length = local_queue_length
        self.dist_rank = get_rank()
        self.world_size = get_world_size()
        self.log_softmax = nn.LogSoftmax(dim=1).cuda()
        self.softmax = nn.Softmax(dim=1).cuda()
        self.register_buffer("num_iteration", torch.zeros(1, dtype=int))
        self.use_queue = False
        if local_queue_length > 0:
            self.initialize_queue()
        self.output_dir = output_dir
Beispiel #11
0
 def get_global_batchsize(self):
     """
     The global batch size across all the trainers
     """
     return self.get_batchsize_per_replica() * get_world_size()