def setup_distributed(num_images=None): """Setup distributed related parameters.""" # init distributed if FLAGS.use_distributed: udist.init_dist() FLAGS.batch_size = udist.get_world_size() * FLAGS.per_gpu_batch_size FLAGS._loader_batch_size = FLAGS.per_gpu_batch_size if FLAGS.bn_calibration: FLAGS._loader_batch_size_calib = \ FLAGS.bn_calibration_per_gpu_batch_size FLAGS.data_loader_workers = round( FLAGS.data_loader_workers / udist.get_local_size() ) # Per_gpu_workers(the function will return the nearest integer else: count = torch.cuda.device_count() FLAGS.batch_size = count * FLAGS.per_gpu_batch_size FLAGS._loader_batch_size = FLAGS.batch_size if FLAGS.bn_calibration: FLAGS._loader_batch_size_calib = \ FLAGS.bn_calibration_per_gpu_batch_size * count if hasattr(FLAGS, 'base_lr'): FLAGS.lr = FLAGS.base_lr * (FLAGS.batch_size / FLAGS.base_total_batch) if num_images: # NOTE: don't drop last batch, thus must use ceil, otherwise learning # rate will be negative # the smallest integer not less than x FLAGS._steps_per_epoch = math.ceil(num_images / FLAGS.batch_size)
def forward(self, input): if du.get_local_size() == 1 or not self.training: return super().forward(input) assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs" C = input.shape[1] mean = torch.mean(input, dim=[0, 2, 3, 4]) meansqr = torch.mean(input * input, dim=[0, 2, 3, 4]) vec = torch.cat([mean, meansqr], dim=0) vec = GroupGather.apply( vec, self.num_sync_devices, self.num_groups) * (1.0 / self.num_sync_devices) mean, meansqr = torch.split(vec, C) var = meansqr - mean * mean self.running_mean += self.momentum * \ (mean.detach() - self.running_mean) self.running_var += self.momentum * (var.detach() - self.running_var) invstd = torch.rsqrt(var + self.eps) scale = self.weight * invstd bias = self.bias - mean * scale scale = scale.reshape(1, -1, 1, 1, 1) bias = bias.reshape(1, -1, 1, 1, 1) return input * scale + bias
def __init__(self, num_sync_devices, **args): """ Naive version of Synchronized 3D BatchNorm. Args: num_sync_devices (int): number of device to sync. args (list): other arguments. """ self.num_sync_devices = num_sync_devices if self.num_sync_devices > 0: assert du.get_local_size() % self.num_sync_devices == 0, ( du.get_local_size(), self.num_sync_devices, ) self.num_groups = du.get_local_size() // self.num_sync_devices else: self.num_sync_devices = du.get_local_size() self.num_groups = 1 super(NaiveSyncBatchNorm3d, self).__init__(**args)
def main(): """Entry.""" # init distributed global is_root_rank if FLAGS.use_distributed: udist.init_dist() FLAGS.batch_size = udist.get_world_size() * FLAGS.per_gpu_batch_size FLAGS._loader_batch_size = FLAGS.per_gpu_batch_size if FLAGS.bn_calibration: FLAGS._loader_batch_size_calib = FLAGS.bn_calibration_per_gpu_batch_size FLAGS.data_loader_workers = round(FLAGS.data_loader_workers / udist.get_local_size()) is_root_rank = udist.is_master() else: count = torch.cuda.device_count() FLAGS.batch_size = count * FLAGS.per_gpu_batch_size FLAGS._loader_batch_size = FLAGS.batch_size if FLAGS.bn_calibration: FLAGS._loader_batch_size_calib = FLAGS.bn_calibration_per_gpu_batch_size * count is_root_rank = True FLAGS.lr = FLAGS.base_lr * (FLAGS.batch_size / FLAGS.base_total_batch) # NOTE: don't drop last batch, thus must use ceil, otherwise learning rate # will be negative FLAGS._steps_per_epoch = int(np.ceil(NUM_IMAGENET_TRAIN / FLAGS.batch_size)) if is_root_rank: FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir, time.strftime("%Y%m%d-%H%M%S")) create_exp_dir( FLAGS.log_dir, FLAGS.config_path, blacklist_dirs=[ 'exp', '.git', 'pretrained', 'tmp', 'deprecated', 'bak', ], ) setup_logging(FLAGS.log_dir) for k, v in _ENV_EXPAND.items(): logging.info('Env var expand: {} to {}'.format(k, v)) logging.info(FLAGS) set_random_seed(FLAGS.get('random_seed', 0)) with SummaryWriterManager(): train_val_test()
def backward(ctx, grad_output): """ Perform backwarding, gathering the gradients across different process/ GPU group. """ grad_output_list = [ torch.zeros_like(grad_output) for k in range(du.get_local_size()) ] dist.all_gather( grad_output_list, grad_output, async_op=False, group=du._LOCAL_PROCESS_GROUP, ) grads = torch.stack(grad_output_list, dim=0) if ctx.num_groups > 1: rank = du.get_local_rank() group_idx = rank // ctx.num_sync_devices grads = grads[group_idx * ctx.num_sync_devices:(group_idx + 1) * ctx.num_sync_devices] grads = torch.sum(grads, dim=0) return grads, None, None
def forward(ctx, input, num_sync_devices, num_groups): """ Perform forwarding, gathering the stats across different process/ GPU group. """ ctx.num_sync_devices = num_sync_devices ctx.num_groups = num_groups input_list = [ torch.zeros_like(input) for k in range(du.get_local_size()) ] dist.all_gather(input_list, input, async_op=False, group=du._LOCAL_PROCESS_GROUP) inputs = torch.stack(input_list, dim=0) if num_groups > 1: rank = du.get_local_rank() group_idx = rank // num_sync_devices inputs = inputs[group_idx * num_sync_devices:(group_idx + 1) * num_sync_devices] inputs = torch.sum(inputs, dim=0) return inputs