def forward(ctx, input, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False):
        torch.cuda.nvtx.range_push("sync_BN_fw")
        input = input.contiguous()
        world_size = 0

        mean = None
        var_biased = None
        inv_std = None
        var = None
        out = None
        count = None
        if track_running_stats:
            if channel_last:
                count = int(input.numel()/input.size(-1))
                mean, var_biased = syncbn.welford_mean_var_c_last(input)
            else:
                count = int(input.numel()/input.size(1))
                mean, var_biased = syncbn.welford_mean_var(input)

            if torch.distributed.is_initialized():
                if not process_group:
                    process_group = torch.distributed.group.WORLD
                world_size = torch.distributed.get_world_size(process_group)
                mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device)
                var_all = torch.empty(world_size, var_biased.size(0), dtype=var_biased.dtype, device=var_biased.device)
                mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)]
                var_l = [var_all.narrow(0, i, 1) for i in range(world_size)]
                torch.distributed.all_gather(mean_l, mean, process_group)
                torch.distributed.all_gather(var_l, var_biased, process_group)
                mean, var, inv_std = syncbn.welford_parallel(mean_all, var_all, count, eps)
                # TODO(Jie): should do fp32 math instead!
            else:
                inv_std = 1.0 / torch.sqrt(var_biased + eps)
                var = var_biased * (count) / (count-1) 

            if count == 1 and world_size < 2:
                raise ValueError('Expected more than 1 value per channel when training, got input size{}'.format(input.size()))

            r_m_inc = mean if running_mean.dtype != torch.float16 else mean.half()
            r_v_inc = var if running_variance.dtype != torch.float16 else var.half()
            running_mean.data = running_mean.data * (1-momentum) + momentum*r_m_inc
            running_variance.data = running_variance.data * (1-momentum) + momentum*r_v_inc
        else:
            mean = running_mean.data
            inv_std = 1.0 / torch.sqrt(running_variance.data + eps)

        ctx.save_for_backward(input, weight, mean, inv_std)
        ctx.process_group = process_group
        ctx.channel_last = channel_last
        ctx.world_size = world_size

        if channel_last:
            out = syncbn.batchnorm_forward_c_last(input, mean, inv_std, weight, bias)
        else:
            out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias)

        torch.cuda.nvtx.range_pop()
        return out
    def forward(ctx,
                input,
                weight,
                bias,
                running_mean,
                running_variance,
                eps,
                track_running_stats=True,
                momentum=1.0):
        torch.cuda.nvtx.range_push("sync_BN_fw")
        input = input.contiguous()

        if track_running_stats:
            mean, var, var_biased = syncbn.welford_mean_var(input)

            if torch.distributed.is_initialized():
                world_size = torch.distributed.get_world_size()
                mean_all = torch.empty(world_size,
                                       mean.size(0),
                                       dtype=mean.dtype,
                                       device=mean.device)
                var_all = torch.empty(world_size,
                                      var.size(0),
                                      dtype=var.dtype,
                                      device=var.device)
                mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)]
                var_l = [var_all.narrow(0, i, 1) for i in range(world_size)]
                torch.distributed.all_gather(mean_l, mean)
                torch.distributed.all_gather(var_l, var_biased)
                mean, var, var_biased = syncbn.welford_parallel(
                    mean_all.transpose(1, 0).contiguous(),
                    var_all.transpose(1, 0).contiguous(),
                    int(input.numel() / input.size(1)))
                # TODO(Jie): should do fp32 math instead!

            r_m_inc = mean if running_mean.dtype != torch.float16 else mean.half(
            )
            r_v_inc = var if running_variance.dtype != torch.float16 else var.half(
            )
            running_mean.data = running_mean.data * (
                1 - momentum) + momentum * r_m_inc
            running_variance.data = running_variance.data * (
                1 - momentum) + momentum * r_v_inc
        else:
            mean = running_mean.data
            var_biased = running_var.data

        ctx.save_for_backward(input, weight, mean, var_biased)
        ctx.eps = eps

        out = syncbn.batchnorm_forward(input, mean, var_biased, weight, bias,
                                       eps)

        torch.cuda.nvtx.range_pop()
        return out
Exemplo n.º 3
0
bias_t = type_tensor(bias)

inp_r = ref_tensor(inp.transpose(1, 0, 2, 3).reshape(feature_size, -1))
inp2_r = ref_tensor(inp)
weight_r = ref_tensor(weight).view(-1, 1, 1)
bias_r = ref_tensor(bias).view(-1, 1, 1)

grad_output_t = type_tensor(grad)

m = inp_r.mean(1)
b_v = inp_r.var(1, unbiased=False)
unb_v = inp_r.var(1, unbiased=True)

eps = 1e-5

mean, var_biased = syncbn.welford_mean_var(inp_t)
inv_std = 1.0 / torch.sqrt(var_biased + eps)

bn = torch.nn.BatchNorm2d(feature_size).cuda()
bn.momentum = 1.0
bn.weight.data = weight_t.clone()
bn.bias.data = bias_t.clone()
if args.fp16:
    bn.half()
if args.fp64:
    bn.double()
inp_bn = inp_t.clone().requires_grad_()
grad_bn = grad_output_t.clone().detach()
out_bn = bn(inp_bn)
out_bn.backward(grad_bn)
bn_opt = optim.SGD(bn.parameters(), lr=1.0)
    def forward(ctx,
                input,
                z,
                weight,
                bias,
                running_mean,
                running_variance,
                eps,
                track_running_stats=True,
                momentum=1.0,
                process_group=None,
                channel_last=False,
                fuse_relu=False):
        input = input.contiguous()
        world_size = 0

        mean = None
        var_biased = None
        inv_std = None
        var = None
        out = None
        count = None
        if track_running_stats:
            if channel_last:
                count = int(input.numel() / input.size(-1))
                mean, var_biased = syncbn.welford_mean_var_c_last(input)
                num_channels = input.size(-1)
            else:
                count = int(input.numel() / input.size(1))
                mean, var_biased = syncbn.welford_mean_var(input)
                num_channels = input.size(1)

            if torch.distributed.is_initialized():
                if not process_group:
                    process_group = torch.distributed.group.WORLD
                device = mean.device
                world_size = torch.distributed.get_world_size(process_group)

                count_t = torch.empty(1, dtype=mean.dtype,
                                      device=mean.device).fill_(count)
                combined = torch.cat(
                    [mean.view(-1),
                     var_biased.view(-1), count_t], dim=0)
                combined_list = [
                    torch.empty_like(combined) for k in range(world_size)
                ]
                torch.distributed.all_gather(combined_list, combined,
                                             process_group)
                combined = torch.stack(combined_list, dim=0)
                mean_all, invstd_all, count_all = torch.split(combined,
                                                              num_channels,
                                                              dim=1)
                count_all = count_all.view(-1)
                mean, var, inv_std = syncbn.welford_parallel(
                    mean_all, invstd_all, count_all.to(torch.int32), eps)
            else:
                device = mean.device
                count_all = torch.cuda.IntTensor([count], device=device)
                inv_std = 1.0 / torch.sqrt(var_biased + eps)
                var = var_biased * (count) / (count - 1)

            if count == 1 and world_size < 2:
                raise ValueError(
                    'Expected more than 1 value per channel when training, got input size{}'
                    .format(input.size()))

            r_m_inc = mean if running_mean.dtype != torch.float16 else mean.half(
            )
            r_v_inc = var if running_variance.dtype != torch.float16 else var.half(
            )
            running_mean.data = running_mean.data * (
                1 - momentum) + momentum * r_m_inc
            running_variance.data = running_variance.data * (
                1 - momentum) + momentum * r_v_inc
        else:
            mean = running_mean.data
            inv_std = 1.0 / torch.sqrt(running_variance.data + eps)

        ctx.save_for_backward(input, weight, mean, inv_std, z, bias,
                              count_all.to(torch.int32))
        ctx.process_group = process_group
        ctx.channel_last = channel_last
        ctx.world_size = world_size
        ctx.fuse_relu = fuse_relu

        if channel_last:
            out = syncbn.batchnorm_forward_c_last(input, z, mean, inv_std,
                                                  weight, bias, fuse_relu)
        else:
            out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias)

        return out