def _get_cum_sample_fetch_times(self, phase_type) -> Tuple[List[float], ...]:
        if not self.sample_fetch_times:
            return None

        sample_fetch_times = torch.Tensor(self.sample_fetch_times)
        max_sample_fetch_times = all_reduce_max(sample_fetch_times).tolist()
        cum_sample_fetch_times = list(
            accumulate(
                [self.state.cum_sample_fetch_time[phase_type]] + max_sample_fetch_times
            )
        )[1:]
        self.state.cum_sample_fetch_time[phase_type] = cum_sample_fetch_times[-1]
        return cum_sample_fetch_times
Example #2
0
    def forward(self, scores: torch.Tensor, head_id: int):
        assert scores.shape[0] % self.num_crops == 0
        bs = scores.shape[0] // self.num_crops

        total_loss = 0
        n_term_loss = 0

        # 2 big crops are normally used for the assignment
        for i, crop_id in enumerate(self.crops_for_assign):
            with torch.no_grad():
                scores_this_crop = scores[bs * crop_id:bs * (crop_id + 1)]
                if self.use_queue:
                    queue = getattr(self,
                                    "local_queue" + str(head_id))[i].clone()
                    scores_this_crop = torch.cat((scores_this_crop, queue))
                if self.use_double_prec:
                    assignments = torch.exp(scores_this_crop.double() /
                                            np.float64(self.epsilon)).t()
                    assignments = assignments.double()
                else:
                    assignments = scores_this_crop / self.epsilon
                    # use the log-sum-exp trick for numerical stability.
                    M = torch.max(assignments)
                    all_reduce_max(M)
                    assignments -= M
                    assignments = torch.exp(assignments).t()
                assignments = self.distributed_sinkhornknopp(assignments)[:bs]
                idx_crop_pred = np.delete(np.arange(self.num_crops), crop_id)

            loss = 0
            for p in idx_crop_pred:
                if self.use_double_prec:
                    loss -= torch.mean(
                        torch.sum(
                            assignments *
                            self.log_softmax(scores[bs * p:bs *
                                                    (p + 1)].double() /
                                             np.float64(self.temperature)),
                            dim=1,
                            dtype=assignments.dtype,
                        ))
                else:
                    loss -= torch.mean(
                        torch.sum(
                            assignments * self.log_softmax(
                                scores[bs * p:bs *
                                       (p + 1)] / self.temperature),
                            dim=1,
                            dtype=assignments.dtype,
                        ))
            loss /= len(idx_crop_pred)
            total_loss += loss
            n_term_loss += 1

            # stop training if NaN appears and log the output to help debugging
            # TODO (prigoyal): extract the logic to be common for all losses
            # debug_state() method that all losses can override
            if torch.isnan(loss):
                logging.info(
                    f"Infinite Loss or NaN. Loss value: {loss}, rank: {self.dist_rank}"
                )
                scores_output_file = os.path.join(
                    self.output_dir,
                    "rank" + str(self.dist_rank) + "_scores" + str(i) + ".pth",
                )
                assignments_out_file = os.path.join(
                    self.output_dir,
                    "rank" + str(self.dist_rank) + "_assignments" + str(i) +
                    ".pth",
                )
                with PathManager.open(scores_output_file, "wb") as fwrite:
                    torch.save(scores, fwrite)
                with PathManager.open(assignments_out_file, "wb") as fwrite:
                    torch.save(assignments, fwrite)
                logging.info(
                    f"Saved the scores matrix to: {scores_output_file}")
                logging.info(
                    f"Saved the assignment matrix to: {assignments_out_file}")
        total_loss /= n_term_loss
        return total_loss
Example #3
0
    def forward(self, scores: torch.Tensor, head_id: int):
        assert scores.shape[0] % self.num_crops == 0
        bs = scores.shape[0] // self.num_crops

        total_loss = 0
        n_term_loss = 0

        # 2 big crops are normally used for the assignment
        for i, crop_id in enumerate(self.crops_for_assign):

            # Compute the target assignments, taking crop_id as the features
            # used to compute the codes to which other crops will be mapped
            with torch.no_grad():
                scores_this_crop = scores[bs * crop_id:bs * (crop_id + 1)]

                # Add representations of the queue (this option is useful when
                # the batch size is small, to increase the number of samples
                # in sinkhornknopp to make equal repartition possible)
                if self.use_queue:
                    queue = getattr(self,
                                    "local_queue" + str(head_id))[i].clone()
                    scores_this_crop = torch.cat((scores_this_crop, queue))

                # Divide by epsilon (which can be seen as a temperature which
                # helps to sharpen the distribution of the assignments)
                if self.use_double_prec:
                    assignments = torch.exp(scores_this_crop.double() /
                                            np.float64(self.epsilon)).t()
                    assignments = assignments.double()
                else:
                    assignments = scores_this_crop / self.epsilon
                    # use the log-sum-exp trick for numerical stability.
                    M = torch.max(assignments)
                    all_reduce_max(M)
                    assignments -= M
                    assignments = torch.exp(assignments).t()

                # Apply sinkhornknopp algorithm to divide equally the
                # assignment to each of the prototypes
                assignments = distributed_sinkhornknopp(
                    Q=assignments,
                    hard_assignment=self.num_iteration <
                    self.temp_hard_assignment_iters,
                    world_size=self.world_size,
                    num_iter=self.nmb_sinkhornknopp_iters,
                    use_gpu=self.use_gpu,
                    use_double_prec=self.use_double_prec,
                )
                assignments = assignments[:bs]

            # For each crop other than the one used as target assignment
            # compute the cross entropy between the target assigment and
            # the soft-max of the dot product of each crop to the prototypes
            loss = 0
            idx_crop_pred = np.delete(np.arange(self.num_crops), crop_id)
            for p in idx_crop_pred:
                if self.use_double_prec:
                    loss -= torch.mean(
                        torch.sum(
                            assignments *
                            self.log_softmax(scores[bs * p:bs *
                                                    (p + 1)].double() /
                                             np.float64(self.temperature)),
                            dim=1,
                            dtype=assignments.dtype,
                        ))
                else:
                    loss -= torch.mean(
                        torch.sum(
                            assignments * self.log_softmax(
                                scores[bs * p:bs *
                                       (p + 1)] / self.temperature),
                            dim=1,
                            dtype=assignments.dtype,
                        ))

            # Average of the contribution of each crop (we don't want and
            # increase in the number of crop to impact the loss magnitude
            # and force us to update the LR)
            loss /= len(idx_crop_pred)

            # Average the contribution of each swapped assignment (the
            # division by 'n_term_loss' is done at the end of the loop)
            # for the same reason as above
            total_loss += loss
            n_term_loss += 1

            # Stop training if NaN appears and log the output to help debugging
            # TODO (prigoyal): extract the logic to be common for all losses
            # debug_state() method that all losses can override
            if torch.isnan(loss):
                logging.info(
                    f"Infinite Loss or NaN. Loss value: {loss}, rank: {self.dist_rank}"
                )
                scores_output_file = os.path.join(
                    self.output_dir,
                    "rank" + str(self.dist_rank) + "_scores" + str(i) + ".pth",
                )
                assignments_out_file = os.path.join(
                    self.output_dir,
                    "rank" + str(self.dist_rank) + "_assignments" + str(i) +
                    ".pth",
                )
                with PathManager.open(scores_output_file, "wb") as fwrite:
                    torch.save(scores, fwrite)
                with PathManager.open(assignments_out_file, "wb") as fwrite:
                    torch.save(assignments, fwrite)
                logging.info(
                    f"Saved the scores matrix to: {scores_output_file}")
                logging.info(
                    f"Saved the assignment matrix to: {assignments_out_file}")

        total_loss /= n_term_loss
        return total_loss