Exemplo n.º 1
0
def setup_distributed(num_images=None):
    """Setup distributed related parameters."""
    # init distributed
    if FLAGS.use_distributed:
        udist.init_dist()
        FLAGS.batch_size = udist.get_world_size() * FLAGS.per_gpu_batch_size
        FLAGS._loader_batch_size = FLAGS.per_gpu_batch_size
        if FLAGS.bn_calibration:
            FLAGS._loader_batch_size_calib = \
                FLAGS.bn_calibration_per_gpu_batch_size
        FLAGS.data_loader_workers = round(
            FLAGS.data_loader_workers / udist.get_local_size()
        )  # Per_gpu_workers(the function will return the nearest integer
    else:
        count = torch.cuda.device_count()
        FLAGS.batch_size = count * FLAGS.per_gpu_batch_size
        FLAGS._loader_batch_size = FLAGS.batch_size
        if FLAGS.bn_calibration:
            FLAGS._loader_batch_size_calib = \
                FLAGS.bn_calibration_per_gpu_batch_size * count
    if hasattr(FLAGS, 'base_lr'):
        FLAGS.lr = FLAGS.base_lr * (FLAGS.batch_size / FLAGS.base_total_batch)
    if num_images:
        # NOTE: don't drop last batch, thus must use ceil, otherwise learning
        # rate will be negative
        # the smallest integer not less than x
        FLAGS._steps_per_epoch = math.ceil(num_images / FLAGS.batch_size)
Exemplo n.º 2
0
    def forward(self, input):
        if du.get_local_size() == 1 or not self.training:
            return super().forward(input)

        assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs"
        C = input.shape[1]
        mean = torch.mean(input, dim=[0, 2, 3, 4])
        meansqr = torch.mean(input * input, dim=[0, 2, 3, 4])

        vec = torch.cat([mean, meansqr], dim=0)
        vec = GroupGather.apply(
            vec, self.num_sync_devices,
            self.num_groups) * (1.0 / self.num_sync_devices)

        mean, meansqr = torch.split(vec, C)
        var = meansqr - mean * mean
        self.running_mean += self.momentum * \
            (mean.detach() - self.running_mean)
        self.running_var += self.momentum * (var.detach() - self.running_var)

        invstd = torch.rsqrt(var + self.eps)
        scale = self.weight * invstd
        bias = self.bias - mean * scale
        scale = scale.reshape(1, -1, 1, 1, 1)
        bias = bias.reshape(1, -1, 1, 1, 1)
        return input * scale + bias
Exemplo n.º 3
0
 def __init__(self, num_sync_devices, **args):
     """
     Naive version of Synchronized 3D BatchNorm.
     Args:
         num_sync_devices (int): number of device to sync.
         args (list): other arguments.
     """
     self.num_sync_devices = num_sync_devices
     if self.num_sync_devices > 0:
         assert du.get_local_size() % self.num_sync_devices == 0, (
             du.get_local_size(),
             self.num_sync_devices,
         )
         self.num_groups = du.get_local_size() // self.num_sync_devices
     else:
         self.num_sync_devices = du.get_local_size()
         self.num_groups = 1
     super(NaiveSyncBatchNorm3d, self).__init__(**args)
Exemplo n.º 4
0
def main():
    """Entry."""
    # init distributed
    global is_root_rank
    if FLAGS.use_distributed:
        udist.init_dist()
        FLAGS.batch_size = udist.get_world_size() * FLAGS.per_gpu_batch_size
        FLAGS._loader_batch_size = FLAGS.per_gpu_batch_size
        if FLAGS.bn_calibration:
            FLAGS._loader_batch_size_calib = FLAGS.bn_calibration_per_gpu_batch_size
        FLAGS.data_loader_workers = round(FLAGS.data_loader_workers /
                                          udist.get_local_size())
        is_root_rank = udist.is_master()
    else:
        count = torch.cuda.device_count()
        FLAGS.batch_size = count * FLAGS.per_gpu_batch_size
        FLAGS._loader_batch_size = FLAGS.batch_size
        if FLAGS.bn_calibration:
            FLAGS._loader_batch_size_calib = FLAGS.bn_calibration_per_gpu_batch_size * count
        is_root_rank = True
    FLAGS.lr = FLAGS.base_lr * (FLAGS.batch_size / FLAGS.base_total_batch)
    # NOTE: don't drop last batch, thus must use ceil, otherwise learning rate
    # will be negative
    FLAGS._steps_per_epoch = int(np.ceil(NUM_IMAGENET_TRAIN /
                                         FLAGS.batch_size))

    if is_root_rank:
        FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir,
                                       time.strftime("%Y%m%d-%H%M%S"))
        create_exp_dir(
            FLAGS.log_dir,
            FLAGS.config_path,
            blacklist_dirs=[
                'exp',
                '.git',
                'pretrained',
                'tmp',
                'deprecated',
                'bak',
            ],
        )
        setup_logging(FLAGS.log_dir)
        for k, v in _ENV_EXPAND.items():
            logging.info('Env var expand: {} to {}'.format(k, v))
        logging.info(FLAGS)

    set_random_seed(FLAGS.get('random_seed', 0))
    with SummaryWriterManager():
        train_val_test()
Exemplo n.º 5
0
    def backward(ctx, grad_output):
        """
        Perform backwarding, gathering the gradients across different process/ GPU
        group.
        """
        grad_output_list = [
            torch.zeros_like(grad_output) for k in range(du.get_local_size())
        ]
        dist.all_gather(
            grad_output_list,
            grad_output,
            async_op=False,
            group=du._LOCAL_PROCESS_GROUP,
        )

        grads = torch.stack(grad_output_list, dim=0)
        if ctx.num_groups > 1:
            rank = du.get_local_rank()
            group_idx = rank // ctx.num_sync_devices
            grads = grads[group_idx * ctx.num_sync_devices:(group_idx + 1) *
                          ctx.num_sync_devices]
        grads = torch.sum(grads, dim=0)
        return grads, None, None
Exemplo n.º 6
0
    def forward(ctx, input, num_sync_devices, num_groups):
        """
        Perform forwarding, gathering the stats across different process/ GPU
        group.
        """
        ctx.num_sync_devices = num_sync_devices
        ctx.num_groups = num_groups

        input_list = [
            torch.zeros_like(input) for k in range(du.get_local_size())
        ]
        dist.all_gather(input_list,
                        input,
                        async_op=False,
                        group=du._LOCAL_PROCESS_GROUP)

        inputs = torch.stack(input_list, dim=0)
        if num_groups > 1:
            rank = du.get_local_rank()
            group_idx = rank // num_sync_devices
            inputs = inputs[group_idx * num_sync_devices:(group_idx + 1) *
                            num_sync_devices]
        inputs = torch.sum(inputs, dim=0)
        return inputs