Example #1
0
    def forward(cls, ctx, x, weight, bias, running_mean, running_var,
                extra, training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
        # Save context
        cls._parse_extra(ctx, extra)
        ctx.training = training
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.activation = activation
        ctx.slope = slope
        ctx.affine = weight is not None and bias is not None

        # Prepare inputs
        count = _count_samples(x) * (ctx.master_queue.maxsize + 1)
        x = x.contiguous()
        weight = weight.contiguous() if ctx.affine else x.new_empty(0)
        bias = bias.contiguous() if ctx.affine else x.new_empty(0)

        if ctx.training:
            mean, var = _backend.mean_var(x)

            if ctx.is_master:
                means, vars = [mean.unsqueeze(0)], [var.unsqueeze(0)]
                for _ in range(ctx.master_queue.maxsize):
                    mean_w, var_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    means.append(mean_w.unsqueeze(0))
                    vars.append(var_w.unsqueeze(0))

                means = comm.gather(means)
                vars = comm.gather(vars)

                mean = means.mean(0)
                var = (vars + (mean - means) ** 2).mean(0)

                tensors = comm.broadcast_coalesced((mean, var), [mean.get_device()] + ctx.worker_ids)
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                ctx.master_queue.put((mean, var))
                mean, var = ctx.worker_queue.get()
                ctx.worker_queue.task_done()

            # Update running stats
            running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
            running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1))

            # Mark in-place modified tensors
            ctx.mark_dirty(x, running_mean, running_var)
        else:
            mean, var = running_mean.contiguous(), running_var.contiguous()
            ctx.mark_dirty(x)

        # BN forward + activation
        _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
        _act_forward(ctx, x)

        # Output
        ctx.var = var
        ctx.save_for_backward(x, var, weight, bias)
        return x
Example #2
0
 def forward(self, *inputs):
     if not all(input.is_cuda for input in inputs):
         raise TypeError('Broadcast function not implemented for CPU tensors')
     if len(inputs) == 0:
         return tuple()
     self.num_inputs = len(inputs)
     self.input_device = inputs[0].get_device()
     outputs = comm.broadcast_coalesced(inputs, self.target_gpus)
     return tuple([t for tensors in outputs for t in tensors])
Example #3
0
    def _sync_params(self):
        params = [p.data for p in self.module.parameters()]
        result = broadcast_coalesced(params, self.device_ids, self.broadcast_bucket_size)
        for tensors, module in zip(result[1:], self._module_copies[1:]):
            for tensor, param in zip(tensors, module.parameters()):
                param.data.set_(tensor)

        # cross-node buffer sync
        buffers = list(self.module._all_buffers())
        flat_buffers = _flatten_tensors(buffers)
        dist.broadcast(flat_buffers, 0)
        for buf, synced in zip(buffers, _unflatten_tensors(flat_buffers, buffers)):
            buf.copy_(synced)

        # intra-node buffer sync
        result = broadcast_coalesced(buffers, self.device_ids, self.broadcast_bucket_size)
        for tensors, module in zip(result[1:], self._module_copies[1:]):
            for tensor, buf in zip(tensors, module._all_buffers()):
                buf.set_(tensor)
Example #4
0
    def _sync_params(self):
        if len(self.device_ids) > 1:
            # intra-node parameter sync
            params = [p.data for p in self.module.parameters()]
            result = broadcast_coalesced(params, self.device_ids, self.broadcast_bucket_size)
            for tensors, module in zip(result[1:], self._module_copies[1:]):
                for tensor, param in zip(tensors, module.parameters()):
                    param.data.set_(tensor)

        # module buffer sync
        if self.broadcast_buffers:
            buffers = list(self.module._all_buffers())
            if len(buffers) > 0:
                # cross-node buffer sync
                self._dist_broadcast_coalesced(buffers, self.broadcast_bucket_size)

                if len(self.device_ids) > 1:
                    # intra-node buffer sync
                    result = broadcast_coalesced(buffers, self.device_ids, self.broadcast_bucket_size)
                    for tensors, module in zip(result[1:], self._module_copies[1:]):
                        for tensor, buf in zip(tensors, module._all_buffers()):
                            buf.set_(tensor)
Example #5
0
 def forward(ctx, target_gpus, *inputs):
     if not all(input.is_cuda for input in inputs):
         raise TypeError('Broadcast function not implemented for CPU tensors')
     ctx.target_gpus = target_gpus
     if len(inputs) == 0:
         return tuple()
     ctx.num_inputs = len(inputs)
     ctx.input_device = inputs[0].get_device()
     outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
     non_differentiables = []
     for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]):
         if not input_requires_grad:
             for output in outputs:
                 non_differentiables.append(output[idx])
     ctx.mark_non_differentiable(*non_differentiables)
     return tuple([t for tensors in outputs for t in tensors])
Example #6
0
def data_parallel(f, input, params, stats, mode, device_ids, output_device=None):
    assert isinstance(device_ids, list)
    if output_device is None:
        output_device = device_ids[0]

    if len(device_ids) == 1:
        return f(input, params, stats, mode)

    params_all = Broadcast.apply(device_ids, *params.values())
    params_replicas = [{k: params_all[i + j*len(params)] for i, k in enumerate(params.keys())}
                       for j in range(len(device_ids))]
    stats_replicas = [dict(zip(stats.keys(), p))
                      for p in comm.broadcast_coalesced(list(stats.values()), device_ids)]

    replicas = [partial(f, params=p, stats=s, mode=mode)
                for p, s in zip(params_replicas, stats_replicas)]
    inputs = scatter([input], device_ids)
    outputs = parallel_apply(replicas, inputs)
    return gather(outputs, output_device)
Example #7
0
    def backward(ctx, dz):
        x, _ex, _exs, gamma, beta = ctx.saved_tensors
        dz = dz.contiguous()

        # BN backward
        if dz.is_cuda:
            dx, _dex, _dexs, dgamma, dbeta = lib.gpu.batchnorm_backward(dz, x, _ex, _exs, gamma, beta, ctx.eps)
        else:
            raise NotImplemented

        if ctx.training:
            if ctx.sync:
                if ctx.is_master:
                    _dex, _dexs = [_dex.unsqueeze(0)], [_dexs.unsqueeze(0)]
                    for _ in range(ctx.master_queue.maxsize):
                        _dex_w, _dexs_w = ctx.master_queue.get()
                        ctx.master_queue.task_done()
                        _dex.append(_dex_w.unsqueeze(0))
                        _dexs.append(_dexs_w.unsqueeze(0))

                    _dex = comm.gather(_dex).mean(0)
                    _dexs = comm.gather(_dexs).mean(0)

                    tensors = comm.broadcast_coalesced((_dex, _dexs), [_dex.get_device()] + ctx.worker_ids)
                    for ts, queue in zip(tensors[1:], ctx.worker_queues):
                        queue.put(ts)
                else:
                    ctx.master_queue.put((_dex, _dexs))
                    _dex, _dexs = ctx.worker_queue.get()
                    ctx.worker_queue.task_done()

            if x.is_cuda:
                dx_ = lib.gpu.expectation_backward(x, _dex, _dexs)
            else:
                raise NotImplemented
            dx = dx + dx_

        return dx, dgamma, dbeta, None, None, None, None, None, None, None, None, None
Example #8
0
    def update_replicas(self, parallel=False):

        params = [
            para.data for para in filter_para_grad(self.module.parameters())
        ]

        if len(params) > 0:
            param_copies = tuple([
                t for tensors in comm.broadcast_coalesced(
                    params, self.device_ids) for t in tensors
            ])

            # currently, pytorch broadcast binds parameters between self.nets[0] and self.module, so the following line ensures correctness but less efficient
            #for module, param_copy in zip(self.nets, [param_copies[i:i + len(params)] for i in range(0, len(param_copies), len(params))]):
            for module, param_copy in zip(self.nets[1:], [
                    param_copies[i:i + len(params)]
                    for i in range(len(params), len(param_copies), len(params))
            ]):
                for mp, para in zip(filter_para_grad(module.parameters()),
                                    param_copy):
                    mp.data, mp.grad = para, None

        self.ngradev = 0
Example #9
0
 def forward_nonsparse(ctx, target_gpus, *inputs, context_update=True):
     if not all(input.is_cuda for input in inputs):
         raise TypeError(
             'Broadcast function not implemented for CPU tensors')
     target_gpus = list(
         map(lambda x: _get_device_index(x, True), target_gpus))
     if context_update:
         ctx.target_gpus = target_gpus
     if len(inputs) == 0:
         return tuple()
     if context_update:
         ctx.num_inputs = len(inputs)
         ctx.input_device = inputs[0].get_device()
     outputs = comm.broadcast_coalesced(inputs, target_gpus)
     non_differentiables = []
     if context_update:
         for idx, input_requires_grad in enumerate(
                 ctx.needs_input_grad[1:]):
             if not input_requires_grad:
                 for output in outputs:
                     non_differentiables.append(output[idx])
         ctx.mark_non_differentiable(*non_differentiables)
     return tuple([t for tensors in outputs for t in tensors])
Example #10
0
    def backward(ctx, dz):
        z, _ex, _exs, gamma, beta = ctx.saved_tensors
        dz = dz.contiguous()

        # Undo activation
        _act_backward(ctx, z, dz)

        # BN backward
        dx, _dex, _dexs, dgamma, dbeta = \
            encoding_ext.batchnorm_inp_backward(dz, z, _ex, _exs, gamma, beta, ctx.eps)

        if ctx.training:
            if ctx.sync:
                if ctx.is_master:
                    _dex, _dexs = [_dex.unsqueeze(0)], [_dexs.unsqueeze(0)]
                    for _ in range(ctx.master_queue.maxsize):
                        _dex_w, _dexs_w = ctx.master_queue.get()
                        ctx.master_queue.task_done()
                        _dex.append(_dex_w.unsqueeze(0))
                        _dexs.append(_dexs_w.unsqueeze(0))

                    _dex = comm.gather(_dex).mean(0)
                    _dexs = comm.gather(_dexs).mean(0)

                    tensors = comm.broadcast_coalesced(
                        (_dex, _dexs), [_dex.get_device()] + ctx.worker_ids)
                    for ts, queue in zip(tensors[1:], ctx.worker_queues):
                        queue.put(ts)
                else:
                    ctx.master_queue.put((_dex, _dexs))
                    _dex, _dexs = ctx.worker_queue.get()
                    ctx.worker_queue.task_done()

            encoding_ext.expectation_inp_backward(dx, z, _dex, _dexs, _ex,
                                                  _exs, gamma, beta, ctx.eps)

        return dx, dgamma, dbeta, None, None, None, None, None, None, None, None, None
Example #11
0
    def backward(ctx, dz):
        z, var, weight, bias = ctx.saved_tensors
        dz = dz.contiguous()

        # Undo activation
        _act_backward(ctx, z, dz)

        if ctx.training:
            edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)

            if ctx.is_master:
                edzs, eydzs = [edz], [eydz]
                for _ in range(len(ctx.worker_queues)):
                    edz_w, eydz_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    edzs.append(edz_w)
                    eydzs.append(eydz_w)

                edz = comm.reduce_add(edzs) / (ctx.master_queue.maxsize + 1)
                eydz = comm.reduce_add(eydzs) / (ctx.master_queue.maxsize + 1)

                tensors = comm.broadcast_coalesced((edz, eydz), [edz.get_device()] + ctx.worker_ids)
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                ctx.master_queue.put((edz, eydz))
                edz, eydz = ctx.worker_queue.get()
                ctx.worker_queue.task_done()
        else:
            edz = dz.new_zeros(dz.size(1))
            eydz = dz.new_zeros(dz.size(1))

        dx, dweight, dbias = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
        dweight = dweight if ctx.affine else None
        dbias = dbias if ctx.affine else None

        return dx, dweight, dbias, None, None, None, None, None, None, None, None
Example #12
0
    def backward(ctx, dz):
        z, var, weight, bias = ctx.saved_tensors
        dz = dz.contiguous()

        # Undo activation
        _act_backward(ctx, z, dz)

        if ctx.training:
            edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)

            if ctx.is_master:
                edzs, eydzs = [edz], [eydz]
                for _ in range(len(ctx.worker_queues)):
                    edz_w, eydz_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    edzs.append(edz_w)
                    eydzs.append(eydz_w)

                edz = comm.reduce_add(edzs) / (ctx.master_queue.maxsize + 1)
                eydz = comm.reduce_add(eydzs) / (ctx.master_queue.maxsize + 1)

                tensors = comm.broadcast_coalesced((edz, eydz), [edz.get_device()] + ctx.worker_ids)
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                ctx.master_queue.put((edz, eydz))
                edz, eydz = ctx.worker_queue.get()
                ctx.worker_queue.task_done()
        else:
            edz = dz.new_zeros(dz.size(1))
            eydz = dz.new_zeros(dz.size(1))

        dx, dweight, dbias = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
        dweight = dweight if ctx.affine else None
        dbias = dbias if ctx.affine else None

        return dx, dweight, dbias, None, None, None, None, None, None, None, None
Example #13
0
    def test_broadcast_coalesced(self):
        numel = 5
        num_bytes = numel * 8
        tensors = [
            torch.randn(numel).long().cuda(),
            torch.randn(numel).cuda(),
            torch.randn(numel).long().cuda(),
            torch.randn(numel).long().cuda(),
            torch.randn(numel * 2).int().cuda(),  # int is 2x shorter
            torch.randn(numel).cuda(),
        ]

        b_tensors = [comm.broadcast(t, (0, 1)) for t in tensors]
        for (_, bt), t in zip(b_tensors, tensors):
            self.assertEqual(bt.get_device(), 1)
            self.assertEqual(bt, t)
            self.assertIsInstance(bt, type(t))

        bc_tensors = comm.broadcast_coalesced(tensors, (0, 1), buffer_size=num_bytes * 5 // 2)
        bc_tensors_t = list(zip(*bc_tensors))
        self.assertEqual(b_tensors, bc_tensors_t)
        for (_, bt), (_, bct) in zip(b_tensors, bc_tensors_t):
            self.assertEqual(bt.get_device(), bct.get_device())
            self.assertIsInstance(bct, type(bt))
Example #14
0
    def update_replicas(self):

        if self.optm_splt is None:
            if self.is_contiguous_parameters:
                params = [
                    para.data
                    for para in get_all_contiguous_parameters_m(self.module)
                ]
                param_copies = comm.broadcast_coalesced(
                    params, self.device_ids)
                with torch.no_grad():
                    for module, param_copy in zip(self.nets[1:],
                                                  param_copies[1:]):
                        for mp, para in zip(
                                get_all_contiguous_parameters_m(module),
                                param_copy):
                            mp.data.copy_(para)
                            mp.grad.zero_()
            else:
                params = [
                    para.data
                    for para in filter_para_grad(self.module.parameters())
                ]
                param_copies = comm.broadcast_coalesced(
                    params, self.device_ids)
                # currently, pytorch broadcast binds parameters between self.nets[0] and self.module, so the following line ensures correctness but less efficient
                #for module, param_copy in zip(self.nets, param_copies):
                for module, param_copy in zip(self.nets[1:], param_copies[1:]):
                    for mp, para in zip(filter_para_grad(module.parameters()),
                                        param_copy):
                        mp.data, mp.grad = para, None
        else:
            for i, (
                    net,
                (
                    lind,
                    rind,
                ),
            ) in enumerate(zip(self.nets, self.optm_splt)):
                _dev_params = [
                    para.data
                    for para in get_contiguous_parameters_m(net, index=i)
                ] if self.is_contiguous_parameters else [
                    para.data for para in range_parameter_iter(
                        net, lind, rind, func=filter_para_grad_iter)
                ]
                if i > 0:
                    _devices = self.device_ids[:]
                    _devices.insert(0, _devices.pop(i))
                else:
                    _devices = self.device_ids
                _dev_param_copies = comm.broadcast_coalesced(
                    _dev_params, _devices)
                if i > 0:
                    _dev_param_copies.insert(i, _dev_param_copies.pop(0))
                    for pc, _dpc in zip(param_copies, _dev_param_copies):
                        pc.extend(_dpc)
                else:
                    param_copies = _dev_param_copies
            if self.is_contiguous_parameters:
                with torch.no_grad():
                    for module, param_copy in zip(self.nets, param_copies):
                        for mp, para in zip(
                                get_all_contiguous_parameters_m(module),
                                param_copy):
                            mp.data.copy_(para)
                            mp.grad.zero_()
            else:
                for module, param_copy in zip(self.nets, param_copies):
                    for mp, para in zip(filter_para_grad(module.parameters()),
                                        param_copy):
                        mp.data, mp.grad = para, None

        self.ngradev = 0
    def backward(ctx, dz):
        z, weight, bias, running_mean, running_var = ctx.saved_tensors
        dz = dz.contiguous()

        # Undo activation
        _act_backward(ctx, z, dz)

        if ctx.needs_input_grad[0]:
            dx = dz.new().resize_as_(dz)
        else:
            dx = None

        if ctx.needs_input_grad[1]:
            dweight = dz.new().resize_as_(running_mean).zero_()
        else:
            dweight = None

        if ctx.needs_input_grad[2]:
            dbias = dz.new().resize_as_(running_mean).zero_()
        else:
            dbias = None

        if ctx.training:
            edz = dz.new().resize_as_(running_mean)
            eydz = dz.new().resize_as_(running_mean)
            _check_contiguous(z, dz, weight, bias, edz, eydz)
            _check(_ext.bn_edz_eydz_cuda, z, dz,
                   weight if weight is not None else dz.new(),
                   bias if bias is not None else dz.new(), edz, eydz, ctx.eps)

            if ctx.is_master:
                edzs, eydzs = [edz], [eydz]
                for _ in range(len(ctx.worker_queues)):
                    edz_w, eydz_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    edzs.append(edz_w)
                    eydzs.append(eydz_w)

                edz = comm.reduce_add(edzs) / (ctx.master_queue.maxsize + 1)
                eydz = comm.reduce_add(eydzs) / (ctx.master_queue.maxsize + 1)

                tensors = comm.broadcast_coalesced(
                    (edz, eydz), [edz.get_device()] + ctx.worker_ids)
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                ctx.master_queue.put((edz, eydz))
                edz, eydz = ctx.worker_queue.get()
                ctx.worker_queue.task_done()
        else:
            edz = dz.new().resize_as_(running_mean).zero_()
            eydz = dz.new().resize_as_(running_mean).zero_()

        _check_contiguous(dz, z, ctx.var, weight, bias, edz, eydz, dx, dweight,
                          dbias)
        _check(_ext.bn_backard_cuda, dz, z, ctx.var,
               weight if weight is not None else dz.new(),
               bias if bias is not None else dz.new(), edz, eydz,
               dx if dx is not None else dz.new(),
               dweight if dweight is not None else dz.new(),
               dbias if dbias is not None else dz.new(), ctx.eps)

        del ctx.var

        return dx, dweight, dbias, None, None, None, None, None, None, None, None
Example #16
0
    def forward(ctx,
                x,
                weight,
                bias,
                running_mean,
                running_var,
                extra,
                compute_stats=True,
                momentum=0.1,
                eps=1e-05):
        def _parse_extra(ctx, extra):
            ctx.is_master = extra["is_master"]
            if ctx.is_master:
                ctx.master_queue = extra["master_queue"]
                ctx.worker_queues = extra["worker_queues"]
                ctx.worker_ids = extra["worker_ids"]
            else:
                ctx.master_queue = extra["master_queue"]
                ctx.worker_queue = extra["worker_queue"]

        # Save context
        if extra is not None:
            _parse_extra(ctx, extra)
        ctx.compute_stats = compute_stats
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.affine = weight is not None and bias is not None
        if ctx.compute_stats:
            N = _count_samples(x) * (ctx.master_queue.maxsize + 1)
            assert N > 1
            # 1. compute sum(x) and sum(x^2)
            xsum, xsqsum = _backend.syncbn_sum_sqsum(x.detach())
            if ctx.is_master:
                xsums, xsqsums = [xsum], [xsqsum]
                # master : gatther all sum(x) and sum(x^2) from slaves
                for _ in range(ctx.master_queue.maxsize):
                    xsum_w, xsqsum_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    xsums.append(xsum_w)
                    xsqsums.append(xsqsum_w)
                xsum = comm.reduce_add(xsums)
                xsqsum = comm.reduce_add(xsqsums)
                mean = xsum / N
                sumvar = xsqsum - xsum * mean
                var = sumvar / N
                uvar = sumvar / (N - 1)
                # master : broadcast global mean, variance to all slaves
                tensors = comm.broadcast_coalesced(
                    (mean, uvar, var), [mean.get_device()] + ctx.worker_ids)
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                # slave : send sum(x) and sum(x^2) to master
                ctx.master_queue.put((xsum, xsqsum))
                # slave : get global mean and variance
                mean, uvar, var = ctx.worker_queue.get()
                ctx.worker_queue.task_done()

            # Update running stats
            running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
            running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * uvar)
            ctx.N = N
            ctx.save_for_backward(x, weight, bias, mean, var)
        else:
            mean, var = running_mean, running_var

        # do batch norm forward
        z = _backend.syncbn_forward(x, weight, bias, mean, var, ctx.affine,
                                    ctx.eps)
        return z
Example #17
0
 def backward(ctx, *inputs):
     inputs = [i.data for i in inputs]
     inputs = [inputs[i:i + ctx.num_inputs] for i in range(0, len(inputs), ctx.num_inputs)]
     results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
     outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
     return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])
    def forward(cls,
                ctx,
                x,
                weight,
                bias,
                running_mean,
                running_var,
                extra,
                training=True,
                momentum=0.1,
                eps=1e-05,
                activation=ACT_LEAKY_RELU,
                slope=0.01):
        # Save context
        cls._parse_extra(ctx, extra)
        ctx.training = training
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.activation = activation
        ctx.slope = slope

        n = _count_samples(x) * (ctx.master_queue.maxsize + 1)

        if ctx.training:
            mean = x.new().resize_(1, running_mean.size(0))
            var = x.new().resize_(1, running_var.size(0))
            _check_contiguous(x, mean, var)
            _check(_ext.bn_mean_var_cuda, x, mean, var)

            if ctx.is_master:
                means, vars = [mean], [var]
                for _ in range(ctx.master_queue.maxsize):
                    mean_w, var_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    means.append(mean_w)
                    vars.append(var_w)

                means = comm.gather(means)
                vars = comm.gather(vars)

                mean = means.mean(0)
                var = (vars + (mean - means)**2).mean(0)

                tensors = comm.broadcast_coalesced(
                    (mean, var), [mean.get_device()] + ctx.worker_ids)
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                ctx.master_queue.put((mean, var))
                mean, var = ctx.worker_queue.get()
                ctx.worker_queue.task_done()

            # Update running stats
            running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
            running_var.mul_(
                (1 - ctx.momentum)).add_(ctx.momentum * var * n / (n - 1))
        else:
            mean, var = running_mean, running_var

        _check_contiguous(x, mean, var, weight, bias)
        _check(_ext.bn_forward_cuda, x, mean, var,
               weight if weight is not None else x.new(),
               bias if bias is not None else x.new(), x, x, ctx.eps)

        # Activation
        _act_forward(ctx, x)

        # Output
        ctx.var = var
        ctx.save_for_backward(x, weight, bias, running_mean, running_var)
        ctx.mark_dirty(x)
        return x
    def backward(ctx, dz):
        z, weight, bias, running_mean, running_var = ctx.saved_tensors
        dz = dz.contiguous()

        # Undo activation
        _act_backward(ctx, z, dz)

        if ctx.needs_input_grad[0]:
            dx = dz.new().resize_as_(dz)
        else:
            dx = None

        if ctx.needs_input_grad[1]:
            dweight = dz.new().resize_as_(running_mean).zero_()
        else:
            dweight = None

        if ctx.needs_input_grad[2]:
            dbias = dz.new().resize_as_(running_mean).zero_()
        else:
            dbias = None

        if ctx.training:
            edz = dz.new().resize_as_(running_mean)
            eydz = dz.new().resize_as_(running_mean)
            _check_contiguous(z, dz, weight, bias, edz, eydz)
            _check(_ext.bn_edz_eydz_cuda,
                   z, dz,
                   weight if weight is not None else dz.new(),
                   bias if bias is not None else dz.new(),
                   edz, eydz, ctx.eps)

            if ctx.is_master:
                edzs, eydzs = [edz], [eydz]
                for _ in range(len(ctx.worker_queues)):
                    edz_w, eydz_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    edzs.append(edz_w)
                    eydzs.append(eydz_w)

                edz = comm.reduce_add(edzs) / (ctx.master_queue.maxsize + 1)
                eydz = comm.reduce_add(eydzs) / (ctx.master_queue.maxsize + 1)

                tensors = comm.broadcast_coalesced((edz, eydz), [edz.get_device()] + ctx.worker_ids)
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                ctx.master_queue.put((edz, eydz))
                edz, eydz = ctx.worker_queue.get()
                ctx.worker_queue.task_done()
        else:
            edz = dz.new().resize_as_(running_mean).zero_()
            eydz = dz.new().resize_as_(running_mean).zero_()

        _check_contiguous(dz, z, ctx.var, weight, bias, edz, eydz, dx, dweight, dbias)
        _check(_ext.bn_backard_cuda,
               dz, z, ctx.var,
               weight if weight is not None else dz.new(),
               bias if bias is not None else dz.new(),
               edz, eydz,
               dx if dx is not None else dz.new(),
               dweight if dweight is not None else dz.new(),
               dbias if dbias is not None else dz.new(),
               ctx.eps)

        del ctx.var

        return dx, dweight, dbias, None, None, None, None, None, None, None, None
    def forward(cls, ctx, x, weight, bias, running_mean, running_var,
                extra, training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
        # Save context
        cls._parse_extra(ctx, extra)
        ctx.training = training
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.activation = activation
        ctx.slope = slope

        n = _count_samples(x) * (ctx.master_queue.maxsize + 1)

        if ctx.training:
            mean = x.new().resize_(1, running_mean.size(0))
            var = x.new().resize_(1, running_var.size(0))
            _check_contiguous(x, mean, var)
            _check(_ext.bn_mean_var_cuda, x, mean, var)

            if ctx.is_master:
                means, vars = [mean], [var]
                for _ in range(ctx.master_queue.maxsize):
                    mean_w, var_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    means.append(mean_w)
                    vars.append(var_w)

                means = comm.gather(means)
                vars = comm.gather(vars)

                mean = means.mean(0)
                var = (vars + (mean - means) ** 2).mean(0)

                tensors = comm.broadcast_coalesced((mean, var), [mean.get_device()] + ctx.worker_ids)
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                ctx.master_queue.put((mean, var))
                mean, var = ctx.worker_queue.get()
                ctx.worker_queue.task_done()

            # Update running stats
            running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
            running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * n / (n - 1))
        else:
            mean, var = running_mean, running_var

        _check_contiguous(x, mean, var, weight, bias)
        _check(_ext.bn_forward_cuda,
               x, mean, var,
               weight if weight is not None else x.new(),
               bias if bias is not None else x.new(),
               x, x, ctx.eps)

        # Activation
        _act_forward(ctx, x)

        # Output
        ctx.var = var
        ctx.save_for_backward(x, weight, bias, running_mean, running_var)
        ctx.mark_dirty(x)
        return x
Example #21
0
def replicate(network, devices, no_gradient=False):
    def clear_gradient(para):
        para.grad = None
        return para

    num_replicas = len(devices)

    params = [clear_gradient(para) for para in network.parameters()
              ] if no_gradient else list(network.parameters())
    param_indices = {param: idx for idx, param in enumerate(params)}
    param_copies = comm.broadcast_coalesced(params, devices)

    buffers = list(network.buffers())
    buffer_indices = {buf: idx for idx, buf in enumerate(buffers)}
    buffer_copies = comm.broadcast_coalesced(buffers, devices)

    modules = list(network.modules())
    module_copies = [[] for device in devices]
    module_indices = {}
    scriptmodule_skip_attr = {
        "_parameters", "_buffers", "_modules", "forward", "_c"
    }

    for i, module in enumerate(modules):
        module_indices[module] = i
        for j in range(num_replicas):
            if isinstance(module, ScriptModule):
                # we have to initialize ScriptModule properly so that
                # it works with pybind11
                replica = ScriptModule()

                attribute_names = set(entry[0]
                                      for entry in module._c._get_attributes())

                keys = set(module.__dict__.keys()
                           ) - scriptmodule_skip_attr - attribute_names
                for key in keys:
                    if not isinstance(module.__dict__[key], ScriptMethod):
                        replica.__dict__[key] = module.__dict__[key]
                for name, the_type, value in module._c._get_attributes():
                    if not name in module._buffers.keys():
                        replica._c._register_attribute(name, the_type, value)
            else:
                replica = module.__new__(type(module))
                replica.__dict__ = module.__dict__.copy()
                replica._parameters = replica._parameters.copy()
                replica._buffers = replica._buffers.copy()
                replica._modules = replica._modules.copy()

            module_copies[j].append(replica)

    for i, module in enumerate(modules):
        for key, child in module._modules.items():
            if child is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = None
            else:
                module_idx = module_indices[child]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = module_copies[j][module_idx]
        for key, param in module._parameters.items():
            if param is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = None
            else:
                param_idx = param_indices[param]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = param_copies[j][
                        param_idx].requires_grad_(param.requires_grad)
        for key, buf in module._buffers.items():
            if buf is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = None
            else:
                buffer_idx = buffer_indices[buf]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = buffer_copies[j][buffer_idx]

    for j in range(num_replicas):
        for i, module in enumerate(modules):
            if isinstance(module, ScriptModule):
                replica = module_copies[j][i]
                for method_name in module._c._method_names():
                    replica._c.clone_method(module._c, method_name)

    return [module_copies[j][0] for j in range(num_replicas)]
Example #22
0
def replicate(network, devices, detach=False):
    from ._functions import Broadcast

    if not _replicatable_module(network):
        raise RuntimeError("Cannot replicate network where python modules are "
                           "childrens of ScriptModule")

    devices = list(map(lambda x: _get_device_index(x, True), devices))
    num_replicas = len(devices)

    params = list(network.parameters())
    param_indices = {param: idx for idx, param in enumerate(params)}
    param_copies = Broadcast.apply(devices, *params)
    if len(params) > 0:
        param_copies = [
            param_copies[i:i + len(params)]
            for i in range(0, len(param_copies), len(params))
        ]

    buffers = list(network.buffers())
    buffer_indices = {buf: idx for idx, buf in enumerate(buffers)}
    buffer_copies = comm.broadcast_coalesced(buffers, devices)

    modules = list(network.modules())
    module_copies = [[] for device in devices]
    module_indices = {}
    scriptmodule_skip_attr = {"_parameters", "_buffers", "_modules"}

    for i, module in enumerate(modules):
        module_indices[module] = i
        for j in range(num_replicas):
            if _is_script_module(module):
                # we have to initialize ScriptModule properly so that
                # it works with pybind11
                replica = _init_script_module()
                keys = set(module.__dict__.keys()) - scriptmodule_skip_attr
                for key in keys:
                    replica.__dict__[key] = module.__dict__[key]
            else:
                replica = module.__new__(type(module))
                replica.__dict__ = module.__dict__.copy()
                replica._parameters = replica._parameters.copy()
                replica._buffers = replica._buffers.copy()
                replica._modules = replica._modules.copy()

            module_copies[j].append(replica)

    for i, module in enumerate(modules):
        for key, child in module._modules.items():
            if child is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = None
            else:
                module_idx = module_indices[child]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = module_copies[j][module_idx]
        for key, param in module._parameters.items():
            if param is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = None
            else:
                param_idx = param_indices[param]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = param_copies[j][param_idx].detach() \
                        if detach else param_copies[j][param_idx]
        for key, buf in module._buffers.items():
            if buf is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = None
            else:
                buffer_idx = buffer_indices[buf]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = buffer_copies[j][buffer_idx]

    for j in range(num_replicas):
        _copy_scriptmodule_methods(modules, module_copies[j], module_indices)

    return [module_copies[j][0] for j in range(num_replicas)]
Example #23
0
def replicate_model(network, devices, no_gradient=False):
    def clear_gradient(para):
        if no_gradient:
            para.grad = None
        return para

    num_replicas = len(devices)

    params = [clear_gradient(para) for para in network.parameters()
              ] if no_gradient else list(network.parameters())
    param_indices = {param: idx for idx, param in enumerate(params)}
    param_copies = tuple([
        t for tensors in comm.broadcast_coalesced(params, devices)
        for t in tensors
    ])
    if len(params) > 0:
        param_copies = [
            param_copies[i:i + len(params)]
            for i in range(0, len(param_copies), len(params))
        ]

    buffers = list(network.buffers())
    buffer_indices = {buf: idx for idx, buf in enumerate(buffers)}
    buffer_copies = comm.broadcast_coalesced(buffers, devices)

    modules = list(network.modules())
    module_copies = [[] for device in devices]
    module_indices = {}

    for i, module in enumerate(modules):
        module_indices[module] = i
        for j in range(num_replicas):
            replica = module.__new__(type(module))
            replica.__dict__ = module.__dict__.copy()
            replica._parameters = replica._parameters.copy()
            replica._buffers = replica._buffers.copy()
            replica._modules = replica._modules.copy()
            module_copies[j].append(replica)

    for i, module in enumerate(modules):
        for key, child in module._modules.items():
            if child is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = None
            else:
                module_idx = module_indices[child]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = module_copies[j][module_idx]
        for key, param in module._parameters.items():
            if param is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = None
            else:
                param_idx = param_indices[param]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = param_copies[j][
                        param_idx].requires_grad_(param.requires_grad)
        for key, buf in module._buffers.items():
            if buf is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = None
            else:
                buffer_idx = buffer_indices[buf]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = buffer_copies[j][buffer_idx]

    return [module_copies[j][0] for j in range(num_replicas)]
Example #24
0
def replicate(network, devices):
    from ._functions import Broadcast

    devices = tuple(devices)
    num_replicas = len(devices)

    params = list(network.parameters())
    param_indices = {param: idx for idx, param in enumerate(params)}
    param_copies = Broadcast(devices)(*params)
    if len(params) > 0:
        param_copies = [param_copies[i:i + len(params)]
                        for i in range(0, len(param_copies), len(params))]

    buffers = list(network._all_buffers())
    buffer_indices = {buf: idx for idx, buf in enumerate(buffers)}
    buffer_copies = comm.broadcast_coalesced(buffers, devices)

    modules = list(network.modules())
    module_copies = [[] for device in devices]
    module_indices = {}

    for i, module in enumerate(modules):
        module_indices[module] = i
        for j in range(num_replicas):
            replica = module.__new__(type(module))
            replica.__dict__ = module.__dict__.copy()
            replica._parameters = replica._parameters.copy()
            replica._buffers = replica._buffers.copy()
            replica._modules = replica._modules.copy()
            module_copies[j].append(replica)

    for i, module in enumerate(modules):
        for key, child in module._modules.items():
            if child is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = None
            else:
                module_idx = module_indices[child]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = module_copies[j][module_idx]
        for key, param in module._parameters.items():
            if param is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = None
            else:
                param_idx = param_indices[param]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = param_copies[j][param_idx]
        for key, buf in module._buffers.items():
            if buf is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = None
            else:
                buffer_idx = buffer_indices[buf]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = buffer_copies[j][buffer_idx]

    return [module_copies[j][0] for j in range(num_replicas)]
Example #25
0
    def forward(ctx,
                x,
                gamma,
                beta,
                running_mean,
                running_var,
                extra,
                training=True,
                momentum=0.1,
                eps=1e-5,
                sync=True):
        def parse_extra(ctx, extra):
            ctx.is_master = extra["is_master"]
            if ctx.is_master:
                ctx.master_queue = extra["master_queue"]
                ctx.worker_queues = extra["worker_queues"]
                ctx.worker_ids = extra["worker_ids"]
            else:
                ctx.master_queue = extra["master_queue"]
                ctx.worker_queue = extra["worker_queue"]

        parse_extra(ctx, extra)
        ctx.training = training
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.sync = sync

        if ctx.training:
            ex, exs = syncbn_gpu.batch_norm_collect_statistics(x)

            if ctx.sync:
                if ctx.is_master:
                    ex, exs = [ex.unsqueeze(0)], [exs.unsqueeze(0)]
                    for _ in range(ctx.master_queue.maxsize):
                        ex_w, exs_w = ctx.master_queue.get()
                        ctx.master_queue.task_done()
                        ex.append(ex_w.unsqueeze(0))
                        exs.append(exs_w.unsqueeze(0))
                    ex = comm.gather(ex).mean(0)
                    exs = comm.gather(exs).mean(0)

                    tensors = comm.broadcast_coalesced(
                        (ex, exs), [ex.get_device()] + ctx.worker_ids)
                    for ts, queue in zip(tensors[1:], ctx.worker_queues):
                        queue.put(ts)
                else:
                    ctx.master_queue.put((ex, exs))
                    ex, exs = ctx.worker_queue.get()
                    ctx.worker_queue.task_done()

            var = exs - ex**2
            running_mean.mul_(1 - ctx.momentum).add_(ctx.momentum * ex)
            running_var.mul_(1 - ctx.momentum).add_(ctx.momentum * var)

            ctx.mark_dirty(running_mean, running_var)

            y = syncbn_gpu.batch_norm_transform_input(x, gamma, beta, ex, exs,
                                                      ctx.eps)

            ctx.save_for_backward(x, ex, exs, gamma, beta)

        return y
Example #26
0
def replicate(network, devices, detach=False):
    from ._functions import Broadcast

    devices = tuple(devices)
    num_replicas = len(devices)

    params = list(network.parameters())
    param_indices = {param: idx for idx, param in enumerate(params)}
    param_copies = Broadcast.apply(devices, *params)
    if len(params) > 0:
        param_copies = [
            param_copies[i:i + len(params)]
            for i in range(0, len(param_copies), len(params))
        ]

    buffers = list(network.buffers())
    buffer_indices = {buf: idx for idx, buf in enumerate(buffers)}
    buffer_copies = comm.broadcast_coalesced(buffers, devices)

    modules = list(network.modules())
    module_copies = [[] for device in devices]
    module_indices = {}

    for i, module in enumerate(modules):
        module_indices[module] = i
        for j in range(num_replicas):
            replica = module.__new__(type(module))
            replica.__dict__ = module.__dict__.copy()
            replica._parameters = replica._parameters.copy()
            replica._buffers = replica._buffers.copy()
            replica._modules = replica._modules.copy()
            module_copies[j].append(replica)

    for i, module in enumerate(modules):
        for key, child in module._modules.items():
            if child is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = None
            else:
                module_idx = module_indices[child]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = module_copies[j][module_idx]
        for key, param in module._parameters.items():
            if param is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = None
            else:
                param_idx = param_indices[param]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = param_copies[j][param_idx].detach() \
                        if detach else param_copies[j][param_idx]
        for key, buf in module._buffers.items():
            if buf is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = None
            else:
                buffer_idx = buffer_indices[buf]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = buffer_copies[j][buffer_idx]

    return [module_copies[j][0] for j in range(num_replicas)]
Example #27
0
    def forward(cls,
                ctx,
                x,
                weight,
                bias,
                running_mean,
                running_var,
                extra,
                training=True,
                momentum=0.1,
                eps=1e-05,
                activation=ACT_LEAKY_RELU,
                slope=0.01):
        # Save context
        cls._parse_extra(ctx, extra)
        ctx.training = training
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.activation = activation
        ctx.slope = slope
        ctx.affine = weight is not None and bias is not None

        # Prepare inputs
        count = _count_samples(x) * (ctx.master_queue.maxsize + 1)
        x = x.contiguous()
        weight = weight.contiguous() if ctx.affine else x.new_empty(0)
        bias = bias.contiguous() if ctx.affine else x.new_empty(0)

        if ctx.training:
            mean, var = _backend.mean_var(x)

            if ctx.is_master:
                means, vars = [mean.unsqueeze(0)], [var.unsqueeze(0)]
                for _ in range(ctx.master_queue.maxsize):
                    mean_w, var_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    means.append(mean_w.unsqueeze(0))
                    vars.append(var_w.unsqueeze(0))

                means = comm.gather(means)
                vars = comm.gather(vars)

                mean = means.mean(0)
                var = (vars + (mean - means)**2).mean(0)

                tensors = comm.broadcast_coalesced(
                    (mean, var), [mean.get_device()] + ctx.worker_ids)
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                ctx.master_queue.put((mean, var))
                mean, var = ctx.worker_queue.get()
                ctx.worker_queue.task_done()

            # Update running stats
            running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
            running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var *
                                                      count / (count - 1))

            # Mark in-place modified tensors
            ctx.mark_dirty(x, running_mean, running_var)
        else:
            mean, var = running_mean.contiguous(), running_var.contiguous()
            ctx.mark_dirty(x)

        # BN forward + activation
        _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
        _act_forward(ctx, x)

        # Output
        ctx.var = var
        ctx.save_for_backward(x, var, weight, bias)
        return x
Example #28
0
    def update_replicas_para(self):

        if self.optm_splt is None:
            if self.is_contiguous_parameters:
                params = [
                    para.data
                    for para in get_all_contiguous_parameters_m(self.module)
                ]
                param_copies = comm.broadcast_coalesced(
                    params, self.device_ids)
                with torch.no_grad():
                    for module, param_copy in zip(self.nets[1:],
                                                  param_copies[1:]):
                        for mp, para in zip(
                                get_all_contiguous_parameters_m(module),
                                param_copy):
                            mp.data.copy_(para)
            else:
                params = [
                    para.data
                    for para in filter_para_grad(self.module.parameters())
                ]
                param_copies = comm.broadcast_coalesced(
                    params, self.device_ids)
                for module, param_copy in zip(self.nets[1:], param_copies[1:]):
                    for mp, para in zip(filter_para_grad(module.parameters()),
                                        param_copy):
                        mp.data = para
        else:
            for i, (
                    net,
                (
                    lind,
                    rind,
                ),
            ) in enumerate(zip(self.nets, self.optm_splt)):
                _dev_params = [
                    para.data
                    for para in get_contiguous_parameters_m(net, index=i)
                ] if self.is_contiguous_parameters else [
                    para.data for para in range_parameter_iter(
                        net, lind, rind, func=filter_para_grad_iter)
                ]
                if i > 0:
                    _devices = self.device_ids[:]
                    _devices.insert(0, _devices.pop(i))
                else:
                    _devices = self.device_ids
                _dev_param_copies = comm.broadcast_coalesced(
                    _dev_params, _devices)
                if i > 0:
                    _dev_param_copies.insert(i, _dev_param_copies.pop(0))
                    for pc, _dpc in zip(param_copies, _dev_param_copies):
                        pc.extend(_dpc)
                else:
                    param_copies = _dev_param_copies
                if self.is_contiguous_parameters:
                    with torch.no_grad():
                        for module, param_copy in zip(self.nets, param_copies):
                            for mp, para in zip(
                                    get_all_contiguous_parameters_m(module),
                                    param_copy):
                                mp.data.copy_(para)
                else:
                    for module, param_copy in zip(self.nets, param_copies):
                        for mp, para in zip(
                                filter_para_grad(module.parameters()),
                                param_copy):
                            mp.data = para

        self.ngradev = 0
Example #29
0
def replicate(network, devices):

    devices = tuple(devices)
    num_replicas = len(devices)

    params = list(network.parameters())
    param_indices = {param: idx for idx, param in enumerate(params)}
    param_copies = broadcast_coalesced(params, devices)

    buffers = list(network._all_buffers())
    buffer_indices = {buf: idx for idx, buf in enumerate(buffers)}
    buffer_copies = broadcast_coalesced(buffers, devices)

    modules = list(network.modules())
    module_copies = [[] for device in devices]
    module_indices = {}

    for i, module in enumerate(modules):
        module_indices[module] = i
        for j in range(num_replicas):
            replica = module.__new__(type(module))
            replica.__dict__ = module.__dict__.copy()
            replica._parameters = replica._parameters.copy()
            replica._buffers = replica._buffers.copy()
            replica._modules = replica._modules.copy()
            module_copies[j].append(replica)

    for i, module in enumerate(modules):
        for key, child in module._modules.items():
            if child is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = None
            else:
                module_idx = module_indices[child]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._modules[key] = module_copies[j][module_idx]
        for key, param in module._parameters.items():
            if param is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = None
            else:
                param_idx = param_indices[param]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._parameters[key] = param_copies[j][param_idx]
                    replica._parameters[
                        key].requires_grad = param.requires_grad
        for key, buf in module._buffers.items():
            if buf is None:
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = None
            else:
                buffer_idx = buffer_indices[buf]
                for j in range(num_replicas):
                    replica = module_copies[j][i]
                    replica._buffers[key] = buffer_copies[j][buffer_idx]

    return [module_copies[j][0] for j in range(num_replicas)]
Example #30
0
    def forward(cls,
                ctx,
                x,
                gamma,
                beta,
                running_mean,
                running_var,
                extra,
                sync=True,
                training=True,
                momentum=0.1,
                eps=1e-05,
                activation="none",
                slope=0.01):
        # save context
        cls._parse_extra(ctx, extra)
        ctx.sync = sync
        ctx.training = training
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.activation = activation
        ctx.slope = slope
        assert activation == 'none'

        # continous inputs
        x = x.contiguous()
        gamma = gamma.contiguous()
        beta = beta.contiguous()

        if ctx.training:
            _ex, _exs = _C.expectation_forward(x)

            if ctx.sync:
                if ctx.is_master:
                    _ex, _exs = [_ex.unsqueeze(0)], [_exs.unsqueeze(0)]
                    for _ in range(ctx.master_queue.maxsize):
                        _ex_w, _exs_w = ctx.master_queue.get()
                        ctx.master_queue.task_done()
                        _ex.append(_ex_w.unsqueeze(0))
                        _exs.append(_exs_w.unsqueeze(0))

                    _ex = comm.gather(_ex).mean(0)
                    _exs = comm.gather(_exs).mean(0)

                    tensors = comm.broadcast_coalesced(
                        (_ex, _exs), [_ex.get_device()] + ctx.worker_ids)
                    for ts, queue in zip(tensors[1:], ctx.worker_queues):
                        queue.put(ts)
                else:
                    ctx.master_queue.put((_ex, _exs))
                    _ex, _exs = ctx.worker_queue.get()
                    ctx.worker_queue.task_done()

            # Update running stats
            _var = _exs - _ex**2
            running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
            running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)

            # Mark in-place modified tensors
            ctx.mark_dirty(running_mean, running_var)
        else:
            _ex, _var = running_mean.contiguous(), running_var.contiguous()
            _exs = _var + _ex**2

        # BN forward
        y = _C.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)

        # Output
        ctx.save_for_backward(x, _ex, _exs, gamma, beta)
        return y
Example #31
0
    def forward(cls,
                ctx,
                x,
                weight,
                bias,
                running_mean,
                running_var,
                extra,
                compute_stats=True,
                momentum=0.1,
                eps=1e-05):
        # Save context
        if extra is not None:
            cls._parse_extra(ctx, extra)
        ctx.compute_stats = compute_stats
        ctx.momentum = momentum
        ctx.eps = eps
        if ctx.compute_stats:
            N = _count_samples(x) * (ctx.master_queue.maxsize + 1)
            assert N > 1
            num_features = running_mean.size(0)
            # 1. compute sum(x) and sum(x^2)
            xsum = x.new().resize_(num_features)
            xsqsum = x.new().resize_(num_features)
            _check_contiguous(x, xsum, xsqsum)
            _lib_bn.syncbn_sum_sqsum_cuda(x.detach(), xsum, xsqsum)
            if ctx.is_master:
                xsums, xsqsums = [xsum], [xsqsum]
                # master : gatther all sum(x) and sum(x^2) from slaves
                for _ in range(ctx.master_queue.maxsize):
                    xsum_w, xsqsum_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    xsums.append(xsum_w)
                    xsqsums.append(xsqsum_w)
                xsum = comm.reduce_add(xsums)
                xsqsum = comm.reduce_add(xsqsums)
                mean = xsum / N
                sumvar = xsqsum - xsum * mean
                var = sumvar / N
                uvar = sumvar / (N - 1)
                # master : broadcast global mean, variance to all slaves
                tensors = comm.broadcast_coalesced(
                    (mean, uvar, var), [mean.get_device()] + ctx.worker_ids)
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                # slave : send sum(x) and sum(x^2) to master
                ctx.master_queue.put((xsum, xsqsum))
                # slave : get global mean and variance
                mean, uvar, var = ctx.worker_queue.get()
                ctx.worker_queue.task_done()

            # Update running stats
            running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
            running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * uvar)
            ctx.N = N
            ctx.save_for_backward(x, weight, bias, mean, var)
        else:
            mean, var = running_mean, running_var

        output = x.new().resize_as_(x)
        _check_contiguous(output, x, mean, var, weight, bias)
        # do batch norm forward
        _lib_bn.syncbn_forward_cuda(output, x,
                                    weight if weight is not None else x.new(),
                                    bias if bias is not None else x.new(),
                                    mean, var, ctx.eps)
        return output