def _get_cum_sample_fetch_times(self, phase_type) -> Tuple[List[float], ...]: if not self.sample_fetch_times: return None sample_fetch_times = torch.Tensor(self.sample_fetch_times) max_sample_fetch_times = all_reduce_max(sample_fetch_times).tolist() cum_sample_fetch_times = list( accumulate( [self.state.cum_sample_fetch_time[phase_type]] + max_sample_fetch_times ) )[1:] self.state.cum_sample_fetch_time[phase_type] = cum_sample_fetch_times[-1] return cum_sample_fetch_times
def forward(self, scores: torch.Tensor, head_id: int): assert scores.shape[0] % self.num_crops == 0 bs = scores.shape[0] // self.num_crops total_loss = 0 n_term_loss = 0 # 2 big crops are normally used for the assignment for i, crop_id in enumerate(self.crops_for_assign): with torch.no_grad(): scores_this_crop = scores[bs * crop_id:bs * (crop_id + 1)] if self.use_queue: queue = getattr(self, "local_queue" + str(head_id))[i].clone() scores_this_crop = torch.cat((scores_this_crop, queue)) if self.use_double_prec: assignments = torch.exp(scores_this_crop.double() / np.float64(self.epsilon)).t() assignments = assignments.double() else: assignments = scores_this_crop / self.epsilon # use the log-sum-exp trick for numerical stability. M = torch.max(assignments) all_reduce_max(M) assignments -= M assignments = torch.exp(assignments).t() assignments = self.distributed_sinkhornknopp(assignments)[:bs] idx_crop_pred = np.delete(np.arange(self.num_crops), crop_id) loss = 0 for p in idx_crop_pred: if self.use_double_prec: loss -= torch.mean( torch.sum( assignments * self.log_softmax(scores[bs * p:bs * (p + 1)].double() / np.float64(self.temperature)), dim=1, dtype=assignments.dtype, )) else: loss -= torch.mean( torch.sum( assignments * self.log_softmax( scores[bs * p:bs * (p + 1)] / self.temperature), dim=1, dtype=assignments.dtype, )) loss /= len(idx_crop_pred) total_loss += loss n_term_loss += 1 # stop training if NaN appears and log the output to help debugging # TODO (prigoyal): extract the logic to be common for all losses # debug_state() method that all losses can override if torch.isnan(loss): logging.info( f"Infinite Loss or NaN. Loss value: {loss}, rank: {self.dist_rank}" ) scores_output_file = os.path.join( self.output_dir, "rank" + str(self.dist_rank) + "_scores" + str(i) + ".pth", ) assignments_out_file = os.path.join( self.output_dir, "rank" + str(self.dist_rank) + "_assignments" + str(i) + ".pth", ) with PathManager.open(scores_output_file, "wb") as fwrite: torch.save(scores, fwrite) with PathManager.open(assignments_out_file, "wb") as fwrite: torch.save(assignments, fwrite) logging.info( f"Saved the scores matrix to: {scores_output_file}") logging.info( f"Saved the assignment matrix to: {assignments_out_file}") total_loss /= n_term_loss return total_loss
def forward(self, scores: torch.Tensor, head_id: int): assert scores.shape[0] % self.num_crops == 0 bs = scores.shape[0] // self.num_crops total_loss = 0 n_term_loss = 0 # 2 big crops are normally used for the assignment for i, crop_id in enumerate(self.crops_for_assign): # Compute the target assignments, taking crop_id as the features # used to compute the codes to which other crops will be mapped with torch.no_grad(): scores_this_crop = scores[bs * crop_id:bs * (crop_id + 1)] # Add representations of the queue (this option is useful when # the batch size is small, to increase the number of samples # in sinkhornknopp to make equal repartition possible) if self.use_queue: queue = getattr(self, "local_queue" + str(head_id))[i].clone() scores_this_crop = torch.cat((scores_this_crop, queue)) # Divide by epsilon (which can be seen as a temperature which # helps to sharpen the distribution of the assignments) if self.use_double_prec: assignments = torch.exp(scores_this_crop.double() / np.float64(self.epsilon)).t() assignments = assignments.double() else: assignments = scores_this_crop / self.epsilon # use the log-sum-exp trick for numerical stability. M = torch.max(assignments) all_reduce_max(M) assignments -= M assignments = torch.exp(assignments).t() # Apply sinkhornknopp algorithm to divide equally the # assignment to each of the prototypes assignments = distributed_sinkhornknopp( Q=assignments, hard_assignment=self.num_iteration < self.temp_hard_assignment_iters, world_size=self.world_size, num_iter=self.nmb_sinkhornknopp_iters, use_gpu=self.use_gpu, use_double_prec=self.use_double_prec, ) assignments = assignments[:bs] # For each crop other than the one used as target assignment # compute the cross entropy between the target assigment and # the soft-max of the dot product of each crop to the prototypes loss = 0 idx_crop_pred = np.delete(np.arange(self.num_crops), crop_id) for p in idx_crop_pred: if self.use_double_prec: loss -= torch.mean( torch.sum( assignments * self.log_softmax(scores[bs * p:bs * (p + 1)].double() / np.float64(self.temperature)), dim=1, dtype=assignments.dtype, )) else: loss -= torch.mean( torch.sum( assignments * self.log_softmax( scores[bs * p:bs * (p + 1)] / self.temperature), dim=1, dtype=assignments.dtype, )) # Average of the contribution of each crop (we don't want and # increase in the number of crop to impact the loss magnitude # and force us to update the LR) loss /= len(idx_crop_pred) # Average the contribution of each swapped assignment (the # division by 'n_term_loss' is done at the end of the loop) # for the same reason as above total_loss += loss n_term_loss += 1 # Stop training if NaN appears and log the output to help debugging # TODO (prigoyal): extract the logic to be common for all losses # debug_state() method that all losses can override if torch.isnan(loss): logging.info( f"Infinite Loss or NaN. Loss value: {loss}, rank: {self.dist_rank}" ) scores_output_file = os.path.join( self.output_dir, "rank" + str(self.dist_rank) + "_scores" + str(i) + ".pth", ) assignments_out_file = os.path.join( self.output_dir, "rank" + str(self.dist_rank) + "_assignments" + str(i) + ".pth", ) with PathManager.open(scores_output_file, "wb") as fwrite: torch.save(scores, fwrite) with PathManager.open(assignments_out_file, "wb") as fwrite: torch.save(assignments, fwrite) logging.info( f"Saved the scores matrix to: {scores_output_file}") logging.info( f"Saved the assignment matrix to: {assignments_out_file}") total_loss /= n_term_loss return total_loss