示例#1
0
    def backward(ctx, grad_ouput):
        x, ex, exs, gamma, beta = ctx.saved_tensors

        grad_gamma, grad_beta, grad_ex, grad_exs = \
                syncbn_gpu.batch_norm_collect_grad_statistics(x, grad_ouput, gamma, ex, exs, ctx.eps)

        if ctx.training:
            if ctx.sync:
                if ctx.is_master:
                    grad_ex, grad_exs = [grad_ex.unsqueeze(0)
                                         ], [grad_exs.unsqueeze(0)]
                    for _ in range(ctx.master_queue.maxsize):
                        grad_ex_w, grad_exs_w = ctx.master_queue.get()
                        ctx.master_queue.task_done()
                        grad_ex.append(grad_ex_w.unsqueeze(0))
                        grad_exs.append(grad_exs_w.unsqueeze(0))
                    grad_ex = comm.gather(grad_ex).mean(0)
                    grad_exs = comm.gather(grad_exs).mean(0)

                    tensors = comm.broadcast_coalesced(
                        (grad_ex, grad_exs),
                        [grad_ex.get_device()] + ctx.worker_ids)
                    for ts, queue in zip(tensors[1:], ctx.worker_queues):
                        queue.put(ts)
                else:
                    ctx.master_queue.put((grad_ex, grad_exs))
                    grad_ex, grad_exs = ctx.worker_queue.get()
                    ctx.worker_queue.task_done()

        grad_input = syncbn_gpu.batch_norm_input_backward(
            x, grad_ouput, gamma, ex, exs, grad_ex, grad_exs, ctx.eps)

        return grad_input, grad_gamma, grad_beta, None, None, None, None, None, None
示例#2
0
    def forward(cls, ctx, x, weight, bias, running_mean, running_var,
                extra, training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
        # Save context
        cls._parse_extra(ctx, extra)
        ctx.training = training
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.activation = activation
        ctx.slope = slope
        ctx.affine = weight is not None and bias is not None

        # Prepare inputs
        count = _count_samples(x) * (ctx.master_queue.maxsize + 1)
        x = x.contiguous()
        weight = weight.contiguous() if ctx.affine else x.new_empty(0)
        bias = bias.contiguous() if ctx.affine else x.new_empty(0)

        if ctx.training:
            mean, var = _backend.mean_var(x)

            if ctx.is_master:
                means, vars = [mean.unsqueeze(0)], [var.unsqueeze(0)]
                for _ in range(ctx.master_queue.maxsize):
                    mean_w, var_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    means.append(mean_w.unsqueeze(0))
                    vars.append(var_w.unsqueeze(0))

                means = comm.gather(means)
                vars = comm.gather(vars)

                mean = means.mean(0)
                var = (vars + (mean - means) ** 2).mean(0)

                tensors = comm.broadcast_coalesced((mean, var), [mean.get_device()] + ctx.worker_ids)
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                ctx.master_queue.put((mean, var))
                mean, var = ctx.worker_queue.get()
                ctx.worker_queue.task_done()

            # Update running stats
            running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
            running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1))

            # Mark in-place modified tensors
            ctx.mark_dirty(x, running_mean, running_var)
        else:
            mean, var = running_mean.contiguous(), running_var.contiguous()
            ctx.mark_dirty(x)

        # BN forward + activation
        _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
        _act_forward(ctx, x)

        # Output
        ctx.var = var
        ctx.save_for_backward(x, var, weight, bias)
        return x
示例#3
0
    def forward(cls, ctx, x, weight, bias, running_mean, running_var,
                extra, training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
        # Save context
        cls._parse_extra(ctx, extra)
        ctx.training = training
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.activation = activation
        ctx.slope = slope
        ctx.affine = weight is not None and bias is not None

        # Prepare inputs
        count = _count_samples(x) * (ctx.master_queue.maxsize + 1)
        x = x.contiguous()
        weight = weight.contiguous() if ctx.affine else x.new_empty(0)
        bias = bias.contiguous() if ctx.affine else x.new_empty(0)

        if ctx.training:
            mean, var = _backend.mean_var(x)

            if ctx.is_master:
                means, vars = [mean.unsqueeze(0)], [var.unsqueeze(0)]
                for _ in range(ctx.master_queue.maxsize):
                    mean_w, var_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    means.append(mean_w.unsqueeze(0))
                    vars.append(var_w.unsqueeze(0))

                means = comm.gather(means)
                vars = comm.gather(vars)

                mean = means.mean(0)
                var = (vars + (mean - means) ** 2).mean(0)

                tensors = comm.broadcast_coalesced((mean, var), [mean.get_device()] + ctx.worker_ids)
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                ctx.master_queue.put((mean, var))
                mean, var = ctx.worker_queue.get()
                ctx.worker_queue.task_done()

            # Update running stats
            running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
            running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1))

            # Mark in-place modified tensors
            ctx.mark_dirty(x, running_mean, running_var)
        else:
            mean, var = running_mean.contiguous(), running_var.contiguous()
            ctx.mark_dirty(x)

        # BN forward + activation
        _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
        _act_forward(ctx, x)

        # Output
        ctx.var = var
        ctx.save_for_backward(x, var, weight, bias)
        return x
示例#4
0
    def backward(ctx, dz):
        x, _ex, _exs, gamma, beta = ctx.saved_tensors
        dz = dz.contiguous()

        # BN backward
        dx, _dex, _dexs, dgamma, dbeta = _C.batchnorm_backward(
            dz, x, _ex, _exs, gamma, beta, ctx.eps)

        if ctx.training:
            if ctx.sync:
                if ctx.is_master:
                    _dex, _dexs = [_dex.unsqueeze(0)], [_dexs.unsqueeze(0)]
                    for _ in range(ctx.master_queue.maxsize):
                        _dex_w, _dexs_w = ctx.master_queue.get()
                        ctx.master_queue.task_done()
                        _dex.append(_dex_w.unsqueeze(0))
                        _dexs.append(_dexs_w.unsqueeze(0))

                    _dex = comm.gather(_dex).mean(0)
                    _dexs = comm.gather(_dexs).mean(0)

                    tensors = comm.broadcast_coalesced(
                        (_dex, _dexs), [_dex.get_device()] + ctx.worker_ids)
                    for ts, queue in zip(tensors[1:], ctx.worker_queues):
                        queue.put(ts)
                else:
                    ctx.master_queue.put((_dex, _dexs))
                    _dex, _dexs = ctx.worker_queue.get()
                    ctx.worker_queue.task_done()

            dx_ = _C.expectation_backward(x, _dex, _dexs)
            dx = dx + dx_

        return dx, dgamma, dbeta, None, None, None, None, None, None, None, None, None
示例#5
0
    def backward(ctx, dz):
        z, _ex, _exs, gamma, beta = ctx.saved_tensors
        dz = dz.contiguous()

        # Undo activation
        _act_backward(ctx, z, dz)

        # BN backward
        if dz.is_cuda:
            dx, _dex, _dexs, dgamma, dbeta = lib.gpu.batchnorm_inp_backward(
                dz, z, _ex, _exs, gamma, beta, ctx.eps
            )
        else:
            raise NotImplemented

        if ctx.training:
            if ctx.sync:
                if ctx.is_master:
                    _dex, _dexs = [_dex.unsqueeze(0)], [_dexs.unsqueeze(0)]
                    for _ in range(ctx.master_queue.maxsize):
                        _dex_w, _dexs_w = ctx.master_queue.get()
                        ctx.master_queue.task_done()
                        _dex.append(_dex_w.unsqueeze(0))
                        _dexs.append(_dexs_w.unsqueeze(0))

                    _dex = comm.gather(_dex).mean(0)
                    _dexs = comm.gather(_dexs).mean(0)

                    tensors = comm.broadcast_coalesced(
                        (_dex, _dexs), [_dex.get_device()] + ctx.worker_ids
                    )
                    for ts, queue in zip(tensors[1:], ctx.worker_queues):
                        queue.put(ts)
                else:
                    ctx.master_queue.put((_dex, _dexs))
                    _dex, _dexs = ctx.worker_queue.get()
                    ctx.worker_queue.task_done()

            if z.is_cuda:
                lib.gpu.expectation_inp_backward(
                    dx, z, _dex, _dexs, _ex, _exs, gamma, beta, ctx.eps
                )
            else:
                raise NotImplemented

        return (
            dx,
            dgamma,
            dbeta,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
        )
示例#6
0
    def forward(cls, ctx, x, gamma, beta, running_mean, running_var,
                extra, sync=True, training=True, momentum=0.1, eps=1e-05,
                activation="none", slope=0.01):
        # save context
        cls._parse_extra(ctx, extra)
        ctx.sync = sync
        ctx.training = training
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.activation = activation
        ctx.slope = slope
        assert activation == 'none'

        # continous inputs
        x = x.contiguous()
        gamma = gamma.contiguous()
        beta = beta.contiguous()

        if ctx.training:
            _ex, _exs = _C.expectation_forward(x)

            if ctx.sync:
                if ctx.is_master:
                    _ex, _exs = [_ex.unsqueeze(0)], [_exs.unsqueeze(0)]
                    for _ in range(ctx.master_queue.maxsize):
                        _ex_w, _exs_w = ctx.master_queue.get()
                        ctx.master_queue.task_done()
                        _ex.append(_ex_w.unsqueeze(0))
                        _exs.append(_exs_w.unsqueeze(0))

                    _ex = comm.gather(_ex).mean(0)
                    _exs = comm.gather(_exs).mean(0)

                    tensors = comm.broadcast_coalesced((_ex, _exs), [_ex.get_device()] + ctx.worker_ids)
                    for ts, queue in zip(tensors[1:], ctx.worker_queues):
                        queue.put(ts)
                else:
                    ctx.master_queue.put((_ex, _exs))
                    _ex, _exs = ctx.worker_queue.get()
                    ctx.worker_queue.task_done()

            # Update running stats
            _var = _exs - _ex ** 2
            running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
            running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)

            # Mark in-place modified tensors
            ctx.mark_dirty(running_mean, running_var)
        else:
            _ex, _var = running_mean.contiguous(), running_var.contiguous()
            _exs = _var + _ex ** 2

        # BN forward
        y = _C.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)

        # Output
        ctx.save_for_backward(x, _ex, _exs, gamma, beta)
        return y
 def forward(ctx, target_device, dim, *inputs):
     assert all(map(lambda i: i.is_cuda, inputs))
     ctx.target_device = target_device
     ctx.dim = dim
     ctx.input_gpus = tuple(map(lambda i: i.get_device(), inputs))
     ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs))
     return comm.gather(inputs, ctx.dim, ctx.target_device)
示例#8
0
 def forward(ctx, target_device, dim, *inputs):
     assert all(map(lambda i: i.is_cuda, inputs))
     ctx.target_device = target_device
     ctx.dim = dim
     ctx.input_gpus = tuple(map(lambda i: i.get_device(), inputs))
     ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs))
     return comm.gather(inputs, ctx.dim, ctx.target_device)
示例#9
0
 def forward(ctx, target_device, dim, *inputs):
     #print("DEBUG GATHER")
     #from IPython import embed; embed()
     assert all(map(lambda i: i.is_cuda, inputs))
     ctx.target_device = target_device
     ctx.dim = dim
     ctx.input_gpus = tuple(map(lambda i: i.get_device(), inputs))
     ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs))
     return comm.gather(inputs, ctx.dim, ctx.target_device)
示例#10
0
    def backward(ctx, dz):
        z, var, weight, bias = ctx.saved_tensors
        dz = dz.contiguous()

        # Undo activation
        _act_backward(ctx, z, dz)

        if ctx.training:
            edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine,
                                          ctx.eps)
            edz_local = edz.clone()
            eydz_local = eydz.clone()

            if ctx.is_master:
                edzs, eydzs = [edz.unsqueeze(0)], [eydz.unsqueeze(0)]
                for _ in range(len(ctx.worker_queues)):
                    edz_w, eydz_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    edzs.append(edz_w.unsqueeze(0))
                    eydzs.append(eydz_w.unsqueeze(0))

                edz = comm.gather(edzs)
                eydz = comm.gather(eydzs)
                edz = edz.mean(0)
                eydz = eydz.mean(0)

                tensors = comm.broadcast_coalesced(
                    (edz, eydz), [edz.get_device()] + ctx.worker_ids)
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                ctx.master_queue.put((edz, eydz))
                edz, eydz = ctx.worker_queue.get()
                ctx.worker_queue.task_done()
        else:
            edz = dz.new_zeros(dz.size(1))
            eydz = dz.new_zeros(dz.size(1))

        dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine,
                               ctx.eps)
        dweight = eydz_local * weight.sign() if ctx.affine else None
        dbias = edz_local if ctx.affine else None

        return dx, dweight, dbias, None, None, None, None, None, None, None, None
示例#11
0
 def forward(ctx, target_device, dim, *inputs):
     assert all(map(lambda i: i.is_cuda, inputs))
     ctx.target_device = target_device
     ctx.dim = dim
     ctx.input_gpus = tuple(map(lambda i: i.get_device(), inputs))
     if all(t.dim() == 0 for t in inputs) and dim == 0:
         inputs = tuple(t.view(1) for t in inputs)
         warnings.warn('Was asked to gather along dimension 0, but all '
                       'input tensors were scalars; will instead unsqueeze '
                       'and return a vector.')
         ctx.unsqueezed_scalar = True
     else:
         ctx.unsqueezed_scalar = False
     ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs))
     return comm.gather(inputs, ctx.dim, ctx.target_device)
示例#12
0
 def forward(ctx, target_device, dim, *inputs):
     assert all(map(lambda i: i.is_cuda, inputs))
     ctx.target_device = target_device
     ctx.dim = dim
     ctx.input_gpus = tuple(map(lambda i: i.get_device(), inputs))
     if all(t.dim() == 0 for t in inputs) and dim == 0:
         inputs = tuple(t.view(1) for t in inputs)
         warnings.warn('Was asked to gather along dimension 0, but all '
                       'input tensors were scalars; will instead unsqueeze '
                       'and return a vector.')
         ctx.unsqueezed_scalar = True
     else:
         ctx.unsqueezed_scalar = False
     ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs))
     return comm.gather(inputs, ctx.dim, ctx.target_device)
示例#13
0
 def gather_map_simple(outputs, target_device):
     output = outputs[0]
     for out in outputs[1:]:
         for k, v in out.items():
             if isinstance(v, list):
                 output[k].extend(v)
             elif isinstance(v, torch.Tensor):
                 #                    output[k]=torch.cat((output[k], v), dim=0)
                 output[k] = comm.gather((output[k], v), 0, target_device)
             elif isinstance(v, np.ndarray):
                 output[k] = np.append(output[k], v, axis=0)
             elif isinstance(v, int) or isinstance(v, float):
                 output[k] += v
             else:
                 raise TypeError
     return output
示例#14
0
    def _test_gather(self, dim):
        if torch.cuda.device_count() < 2:
            raise unittest.SkipTest("only one GPU detected")
        x = torch.randn(2, 5).cuda(0)
        y = torch.randn(2, 5).cuda(1)
        result = comm.gather((x, y), dim)

        expected_size = list(x.size())
        expected_size[dim] += y.size(dim)
        expected_size = torch.Size(expected_size)
        self.assertEqual(result.get_device(), 0)
        self.assertEqual(result.size(), expected_size)

        index = [slice(None, None), slice(None, None)]
        index[dim] = slice(0, x.size(dim))
        self.assertEqual(result[tuple(index)], x)
        index[dim] = slice(x.size(dim), x.size(dim) + y.size(dim))
        self.assertEqual(result[tuple(index)], y)
示例#15
0
 def backward(self, grad_output):
     return comm.gather(grad_output, self.dim, self.input_device)
示例#16
0
 def forward(self, *inputs):
     assert all(map(lambda i: i.is_cuda, inputs))
     self.input_gpus = tuple(map(lambda i: i.get_device(), inputs))
     self.input_sizes = tuple(map(lambda i: i.size(self.dim), inputs))
     return comm.gather(inputs, self.dim, self.target_device)
示例#17
0
    def forward(cls,
                ctx,
                x,
                weight,
                bias,
                running_mean,
                running_var,
                extra,
                training=True,
                momentum=0.1,
                eps=1e-05,
                activation=ACT_LEAKY_RELU,
                slope=0.01):
        # Save context
        cls._parse_extra(ctx, extra)
        ctx.training = training
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.activation = activation
        ctx.slope = slope

        n = _count_samples(x) * (ctx.master_queue.maxsize + 1)

        if ctx.training:
            mean = x.new().resize_(1, running_mean.size(0))
            var = x.new().resize_(1, running_var.size(0))
            _check_contiguous(x, mean, var)
            _check(_ext.bn_mean_var_cuda, x, mean, var)

            if ctx.is_master:
                means, vars = [mean], [var]
                for _ in range(ctx.master_queue.maxsize):
                    mean_w, var_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    means.append(mean_w)
                    vars.append(var_w)

                means = comm.gather(means)
                vars = comm.gather(vars)

                mean = means.mean(0)
                var = (vars + (mean - means)**2).mean(0)

                tensors = comm.broadcast_coalesced(
                    (mean, var), [mean.get_device()] + ctx.worker_ids)
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                ctx.master_queue.put((mean, var))
                mean, var = ctx.worker_queue.get()
                ctx.worker_queue.task_done()

            # Update running stats
            running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
            running_var.mul_(
                (1 - ctx.momentum)).add_(ctx.momentum * var * n / (n - 1))
        else:
            mean, var = running_mean, running_var

        _check_contiguous(x, mean, var, weight, bias)
        _check(_ext.bn_forward_cuda, x, mean, var,
               weight if weight is not None else x.new(),
               bias if bias is not None else x.new(), x, x, ctx.eps)

        # Activation
        _act_forward(ctx, x)

        # Output
        ctx.var = var
        ctx.save_for_backward(x, weight, bias, running_mean, running_var)
        ctx.mark_dirty(x)
        return x
示例#18
0
 def forward(self, *inputs):
     assert all(map(lambda i: i.is_cuda, inputs))
     self.input_gpus = tuple(map(lambda i: i.get_device(), inputs))
     self.input_sizes = tuple(map(lambda i: i.size(self.dim), inputs))
     return comm.gather(inputs, self.dim, self.target_device)
示例#19
0
 def backward(self, *grad_output):
     return comm.gather(grad_output, self.dim, self.input_device)
    def forward(cls, ctx, x, weight, bias, running_mean, running_var,
                extra, training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
        # Save context
        cls._parse_extra(ctx, extra)
        ctx.training = training
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.activation = activation
        ctx.slope = slope

        n = _count_samples(x) * (ctx.master_queue.maxsize + 1)

        if ctx.training:
            mean = x.new().resize_(1, running_mean.size(0))
            var = x.new().resize_(1, running_var.size(0))
            _check_contiguous(x, mean, var)
            _check(_ext.bn_mean_var_cuda, x, mean, var)

            if ctx.is_master:
                means, vars = [mean], [var]
                for _ in range(ctx.master_queue.maxsize):
                    mean_w, var_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    means.append(mean_w)
                    vars.append(var_w)

                means = comm.gather(means)
                vars = comm.gather(vars)

                mean = means.mean(0)
                var = (vars + (mean - means) ** 2).mean(0)

                tensors = comm.broadcast_coalesced((mean, var), [mean.get_device()] + ctx.worker_ids)
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                ctx.master_queue.put((mean, var))
                mean, var = ctx.worker_queue.get()
                ctx.worker_queue.task_done()

            # Update running stats
            running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
            running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * n / (n - 1))
        else:
            mean, var = running_mean, running_var

        _check_contiguous(x, mean, var, weight, bias)
        _check(_ext.bn_forward_cuda,
               x, mean, var,
               weight if weight is not None else x.new(),
               bias if bias is not None else x.new(),
               x, x, ctx.eps)

        # Activation
        _act_forward(ctx, x)

        # Output
        ctx.var = var
        ctx.save_for_backward(x, weight, bias, running_mean, running_var)
        ctx.mark_dirty(x)
        return x
示例#21
0
    def forward(ctx,
                x,
                gamma,
                beta,
                running_mean,
                running_var,
                extra,
                training=True,
                momentum=0.1,
                eps=1e-5,
                sync=True):
        def parse_extra(ctx, extra):
            ctx.is_master = extra["is_master"]
            if ctx.is_master:
                ctx.master_queue = extra["master_queue"]
                ctx.worker_queues = extra["worker_queues"]
                ctx.worker_ids = extra["worker_ids"]
            else:
                ctx.master_queue = extra["master_queue"]
                ctx.worker_queue = extra["worker_queue"]

        parse_extra(ctx, extra)
        ctx.training = training
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.sync = sync

        if ctx.training:
            ex, exs = syncbn_gpu.batch_norm_collect_statistics(x)

            if ctx.sync:
                if ctx.is_master:
                    ex, exs = [ex.unsqueeze(0)], [exs.unsqueeze(0)]
                    for _ in range(ctx.master_queue.maxsize):
                        ex_w, exs_w = ctx.master_queue.get()
                        ctx.master_queue.task_done()
                        ex.append(ex_w.unsqueeze(0))
                        exs.append(exs_w.unsqueeze(0))
                    ex = comm.gather(ex).mean(0)
                    exs = comm.gather(exs).mean(0)

                    tensors = comm.broadcast_coalesced(
                        (ex, exs), [ex.get_device()] + ctx.worker_ids)
                    for ts, queue in zip(tensors[1:], ctx.worker_queues):
                        queue.put(ts)
                else:
                    ctx.master_queue.put((ex, exs))
                    ex, exs = ctx.worker_queue.get()
                    ctx.worker_queue.task_done()

            var = exs - ex**2
            running_mean.mul_(1 - ctx.momentum).add_(ctx.momentum * ex)
            running_var.mul_(1 - ctx.momentum).add_(ctx.momentum * var)

            ctx.mark_dirty(running_mean, running_var)

            y = syncbn_gpu.batch_norm_transform_input(x, gamma, beta, ex, exs,
                                                      ctx.eps)

            ctx.save_for_backward(x, ex, exs, gamma, beta)

        return y