Esempio n. 1
0
 def forward(ctx, x):
     if x.is_cuda:
         ex, ex2 = gpu.expectation_forward(x)
     else:
         raise NotImplemented
     ctx.save_for_backward(x)
     return ex, ex2
Esempio n. 2
0
    def forward(cls,
                ctx,
                x,
                gamma,
                beta,
                running_mean,
                running_var,
                extra,
                sync=True,
                training=True,
                momentum=0.1,
                eps=1e-05,
                activation="none",
                slope=0.01):
        # save context
        cls._parse_extra(ctx, extra)
        ctx.sync = sync
        ctx.training = training
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.activation = activation
        ctx.slope = slope
        assert activation == 'none'

        # continous inputs
        x = x.contiguous()
        gamma = gamma.contiguous()
        beta = beta.contiguous()

        if ctx.training:
            if x.is_cuda:
                _ex, _exs = gpu.expectation_forward(x)
            else:
                raise NotImplemented

            if ctx.sync:
                if ctx.is_master:
                    _ex, _exs = [_ex.unsqueeze(0)], [_exs.unsqueeze(0)]
                    for _ in range(ctx.master_queue.maxsize):
                        _ex_w, _exs_w = ctx.master_queue.get()
                        ctx.master_queue.task_done()
                        _ex.append(_ex_w.unsqueeze(0))
                        _exs.append(_exs_w.unsqueeze(0))

                    _ex = comm.gather(_ex).mean(0)
                    _exs = comm.gather(_exs).mean(0)

                    tensors = comm.broadcast_coalesced(
                        (_ex, _exs), [_ex.get_device()] + ctx.worker_ids)
                    for ts, queue in zip(tensors[1:], ctx.worker_queues):
                        queue.put(ts)
                else:
                    ctx.master_queue.put((_ex, _exs))
                    _ex, _exs = ctx.worker_queue.get()
                    ctx.worker_queue.task_done()

            # Update running stats
            _var = _exs - _ex**2
            running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
            running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)

            # Mark in-place modified tensors
            ctx.mark_dirty(running_mean, running_var)
        else:
            _ex, _var = running_mean.contiguous(), running_var.contiguous()
            _exs = _var + _ex**2

        # BN forward + activation
        if x.is_cuda:
            y = gpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
        else:
            y = cpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)

        # Output
        ctx.save_for_backward(x, _ex, _exs, gamma, beta)

        ctx.mark_non_differentiable(running_mean, running_var)
        return y, running_mean, running_var
    def forward(ctx, x, gamma, beta, running_mean, running_var, eps, momentum,
                training, process_group):
        x = x.contiguous()
        ctx.training = training
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.process_group = process_group

        if not ctx.training:
            _ex, _var = running_mean.contiguous(), running_var.contiguous()
            _exs = _var + _ex**2
            if x.is_cuda:
                y = gpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
            else:
                y = cpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
            ctx.save_for_backward(x, _ex, _exs, gamma, beta)
            return y

        size = x.numel() // x.size(1)
        if size == 1:
            raise ValueError(
                'Expected more than 1 value per channel when training, got input size {}'
                .format(size))

        if x.is_cuda:
            _ex, _exs = gpu.expectation_forward(x)
        else:
            raise NotImplemented

        count = torch.Tensor([1]).to(x.device)
        count_all_reduce = torch.distributed.all_reduce(count,
                                                        group=process_group,
                                                        async_op=True)
        _ex_all_reduce = torch.distributed.all_reduce(_ex,
                                                      group=process_group,
                                                      async_op=True)
        _exs_all_reduce = torch.distributed.all_reduce(_exs,
                                                       group=process_group,
                                                       async_op=True)

        count_all_reduce.wait()
        _ex_all_reduce.wait()
        _exs_all_reduce.wait()

        _ex = _ex / count
        _exs = _exs / count

        # Update running stats
        _var = _exs - _ex**2
        running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
        running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)

        # Mark in-place modified tensors
        ctx.mark_dirty(running_mean, running_var)

        # BN forward + activation
        if x.is_cuda:
            y = gpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
        else:
            y = cpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)

        ctx.save_for_backward(x, _ex, _exs, gamma, beta)
        return y