def forward(self, input): if du.get_local_size() == 1 or not self.training: return super().forward(input) assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs" C = input.shape[1] mean = torch.mean(input, dim=[0, 2, 3]) meansqr = torch.mean(input * input, dim=[0, 2, 3]) vec = torch.cat([mean, meansqr], dim=0) vec = GroupGather.apply(vec, self.num_sync_devices, self.num_groups) * ( 1.0 / self.num_sync_devices ) mean, meansqr = torch.split(vec, C) var = meansqr - mean * mean self.running_mean += self.momentum * (mean.detach() - self.running_mean) self.running_var += self.momentum * (var.detach() - self.running_var) invstd = torch.rsqrt(var + self.eps) scale = self.weight * invstd bias = self.bias - mean * scale scale = scale.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1) return input * scale + bias
def backward(ctx, grad_output): """ Perform backwarding, gathering the gradients across different process/ GPU group. """ grad_output_list = [ torch.zeros_like(grad_output) for k in range(du.get_local_size()) ] dist.all_gather( grad_output_list, grad_output, async_op=False, group=du._LOCAL_PROCESS_GROUP, ) grads = torch.stack(grad_output_list, dim=0) if ctx.num_groups > 1: rank = du.get_local_rank() group_idx = rank // ctx.num_sync_devices grads = grads[ group_idx * ctx.num_sync_devices : (group_idx + 1) * ctx.num_sync_devices ] grads = torch.sum(grads, dim=0) return grads, None, None
def forward(ctx, input, num_sync_devices, num_groups): """ Perform forwarding, gathering the stats across different process/ GPU group. """ ctx.num_sync_devices = num_sync_devices ctx.num_groups = num_groups input_list = [ torch.zeros_like(input) for k in range(du.get_local_size()) ] dist.all_gather( input_list, input, async_op=False, group=du._LOCAL_PROCESS_GROUP ) inputs = torch.stack(input_list, dim=0) if num_groups > 1: rank = du.get_local_rank() group_idx = rank // num_sync_devices inputs = inputs[ group_idx * num_sync_devices : (group_idx + 1) * num_sync_devices ] inputs = torch.sum(inputs, dim=0) return inputs
def __init__(self, num_sync_devices, **args): """ Naive version of Synchronized 2D BatchNorm. Args: num_sync_devices (int): number of device to sync. args (list): other arguments. """ self.num_sync_devices = num_sync_devices if self.num_sync_devices > 0: assert du.get_local_size() % self.num_sync_devices == 0, ( du.get_local_size(), self.num_sync_devices, ) self.num_groups = du.get_local_size() // self.num_sync_devices else: self.num_sync_devices = du.get_local_size() self.num_groups = 1 super(NaiveSyncBatchNorm2d, self).__init__(**args)
def _simclr_precompute_pos_neg_mask_multi(self): # computed once at the beginning of training distributed = self.cfg.CONTRASTIVE.SIMCLR_DIST_ON if distributed: total_images = self.cfg.TRAIN.BATCH_SIZE * self.cfg.NUM_SHARDS world_size = du.get_world_size() rank = du.get_rank() else: total_images = self.cfg.TRAIN.BATCH_SIZE world_size = du.get_local_size() rank = du.get_local_rank() local_orig_images = total_images // world_size local_crops = local_orig_images * self.num_crops pos_temps = [] for d in np.arange(self.num_crops): pos_temp, neg_temp = [], [] for i in range(world_size): if i == rank: pos = np.eye(local_crops, k=d * local_orig_images) + np.eye( local_crops, k=-local_crops + d * local_orig_images) neg = np.ones((local_crops, local_crops)) else: pos = np.zeros((local_crops, local_crops)) neg = np.zeros((local_crops, local_crops)) pos_temp.append(pos) neg_temp.append(neg) pos_temps.append(np.hstack(pos_temp)) neg_temp = np.hstack(neg_temp) pos_mask = [] for i in range(self.num_crops - 1): pos_mask.append(torch.from_numpy(pos_temps[1 + i])) neg_mask = torch.from_numpy(neg_temp - sum(pos_temps)) if self.num_gpus: for i in range(len(pos_mask)): pos_mask[i] = pos_mask[i].cuda(non_blocking=True) neg_mask = neg_mask.cuda(non_blocking=True) self.pos_mask, self.neg_mask = pos_mask, neg_mask
def _batch_shuffle(self, x): if len(x) == 2: another_crop = True else: another_crop = False if another_crop: x, x_crop = x[0], x[1] else: x = x[0] world_size = self.cfg.NUM_GPUS * self.cfg.NUM_SHARDS if self.num_gpus > 1: if self.cfg.CONTRASTIVE.LOCAL_SHUFFLE_BN: x = du.cat_all_gather(x, local=True) if another_crop: x_crop = du.cat_all_gather(x_crop, local=True) world_size = du.get_local_size() gpu_idx = du.get_local_rank() else: x = du.cat_all_gather(x) if another_crop: x_crop = du.cat_all_gather(x_crop) gpu_idx = torch.distributed.get_rank() idx_randperm = torch.randperm(x.shape[0]).cuda() if self.num_gpus > 1: torch.distributed.broadcast(idx_randperm, src=0) else: gpu_idx = 0 idx_randperm = idx_randperm.view(world_size, -1) x = x[idx_randperm[gpu_idx, :]] if another_crop: x_crop = x_crop[idx_randperm[gpu_idx, :]] idx_restore = torch.argsort(idx_randperm.view(-1)) idx_restore = idx_restore.view(world_size, -1) if another_crop: return [x, x_crop], idx_restore else: return [x], idx_restore