Exemplo n.º 1
0
    def backward(ctx, dz):
        x, weight, bias, mean, var = ctx.saved_tensors
        dz = dz.contiguous()
        if ctx.needs_input_grad[0]:
            dx = dz.new().resize_as_(dz)
        else:
            dx = None
        if ctx.needs_input_grad[1]:
            dweight = dz.new().resize_as_(mean).zero_()
        else:
            dweight = None
        if ctx.needs_input_grad[2]:
            dbias = dz.new().resize_as_(mean).zero_()
        else:
            dbias = None
        _check_contiguous(x, dz, weight, bias, mean, var)

        # 1. compute \sum(\frac{dJ}{dy_i}) and \sum(\frac{dJ}{dy_i}*\hat{x_i})
        num_features = mean.size(0)
        sum_dz = x.new().resize_(num_features)
        sum_dz_xhat = x.new().resize_(num_features)
        _check_contiguous(sum_dz, sum_dz_xhat)
        _lib_bn.syncbn_backward_xhat_cuda(
            dz, x, mean, var, sum_dz, sum_dz_xhat, ctx.eps)
        if ctx.is_master:
            sum_dzs, sum_dz_xhats = [sum_dz], [sum_dz_xhat]
            # master : gatther from slaves
            for _ in range(ctx.master_queue.maxsize):
                sum_dz_w, sum_dz_xhat_w = ctx.master_queue.get()
                ctx.master_queue.task_done()
                sum_dzs.append(sum_dz_w)
                sum_dz_xhats.append(sum_dz_xhat_w)
            # master : compute global stats
            sum_dz = comm.reduce_add(sum_dzs)
            sum_dz_xhat = comm.reduce_add(sum_dz_xhats)
            sum_dz /= ctx.N
            sum_dz_xhat /= ctx.N
            # master : broadcast global stats
            tensors = comm.broadcast_coalesced(
                (sum_dz, sum_dz_xhat), [mean.get_device()] + ctx.worker_ids)
            for ts, queue in zip(tensors[1:], ctx.worker_queues):
                queue.put(ts)
        else:
            # slave : send to master
            ctx.master_queue.put((sum_dz, sum_dz_xhat))
            # slave : get global stats
            sum_dz, sum_dz_xhat = ctx.worker_queue.get()
            ctx.worker_queue.task_done()

        # do batch norm backward
        _lib_bn.syncbn_backard_cuda(
            dz, x, weight if weight is not None else dz.new(),
            bias if bias is not None else dz.new(),
            mean, var, sum_dz, sum_dz_xhat,
            dx if dx is not None else dz.new(),
            dweight if dweight is not None else dz.new(),
            dbias if dbias is not None else dz.new(), ctx.eps)

        return dx, dweight, dbias, None, None, None, \
            None, None, None, None, None
Exemplo n.º 2
0
    def forward(cls, ctx, x, weight, bias, running_mean, running_var,
                extra, compute_stats=True, momentum=0.1, eps=1e-05):
        # Save context
        if extra is not None:
            cls._parse_extra(ctx, extra)
        ctx.compute_stats = compute_stats
        ctx.momentum = momentum
        ctx.eps = eps
        if ctx.compute_stats:
            N = _count_samples(x) * (ctx.master_queue.maxsize + 1)
            assert N > 1
            num_features = running_mean.size(0)
            # 1. compute sum(x) and sum(x^2)
            xsum = x.new().resize_(num_features)
            xsqsum = x.new().resize_(num_features)
            _check_contiguous(x, xsum, xsqsum)
            _lib_bn.syncbn_sum_sqsum_cuda(x.detach(), xsum, xsqsum)
            if ctx.is_master:
                xsums, xsqsums = [xsum], [xsqsum]
                # master : gatther all sum(x) and sum(x^2) from slaves
                for _ in range(ctx.master_queue.maxsize):
                    xsum_w, xsqsum_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    xsums.append(xsum_w)
                    xsqsums.append(xsqsum_w)
                xsum = comm.reduce_add(xsums)
                xsqsum = comm.reduce_add(xsqsums)
                mean = xsum / N
                sumvar = xsqsum - xsum * mean
                var = sumvar / N
                uvar = sumvar / (N - 1)
                # master : broadcast global mean, variance to all slaves
                tensors = comm.broadcast_coalesced(
                    (mean, uvar, var), [mean.get_device()] + ctx.worker_ids)
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                # slave : send sum(x) and sum(x^2) to master
                ctx.master_queue.put((xsum, xsqsum))
                # slave : get global mean and variance
                mean, uvar, var = ctx.worker_queue.get()
                ctx.worker_queue.task_done()

            # Update running stats
            running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
            running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * uvar)
            ctx.N = N
            ctx.save_for_backward(x, weight, bias, mean, var)
        else:
            mean, var = running_mean, running_var

        output = x.new().resize_as_(x)
        _check_contiguous(output, x, mean, var, weight, bias)
        # do batch norm forward
        _lib_bn.syncbn_forward_cuda(
            output, x, weight if weight is not None else x.new(),
            bias if bias is not None else x.new(), mean, var, ctx.eps)
        return output
Exemplo n.º 3
0
    def forward(ctx, *args):
        '''
        Args:
            args[0] (torch.Tensor): compute loss flag
            args[1] (torch.Tensor): fp16 flag
            args[2:num_splits + 2] (each is a torch.sparse.LongTensor): one-hot label parts, located in different gpus
            args[num_splits + 2:] (each is a torch.Tensor): fc logit parts, located in different gpus

        Returns:
            loss
        '''
        ctx.num_splits = (len(args) - 2) // 2
        ctx.compute_loss = args[0]
        ctx.fp16 = args[1]
        ctx.batch_size = args[2].size()[0]
        ctx.label_split = args[2:ctx.num_splits + 2]
        # for numerical stability
        max_list = []
        for arg in args[ctx.num_splits + 2:]:
            m, _ = torch.max(arg, dim=1, keepdim=True)
            max_list.append(m)
        mc = torch.cat(max_list, dim=1)
        m, _ = torch.max(mc, dim=1, keepdim=True)
        nargs = [
            arg - m.to(gpu_id)
            for gpu_id, arg in enumerate(args[ctx.num_splits + 2:])
        ]

        # get exp sum
        exp_logit_list = []
        exp_sum_list = []
        for gpu_id, narg in enumerate(nargs):
            exp_logit = torch.exp(narg)
            exp_logit_list.append(exp_logit)
            exp_sum = torch.sum(exp_logit, dim=1, keepdim=True)
            exp_sum_list.append(exp_sum)
        exp_sum_all = comm.reduce_add(exp_sum_list, 0)

        # compute softmax output
        softmax_list = []
        for gpu_id, narg in enumerate(nargs):
            softmax = exp_logit_list[gpu_id] / exp_sum_all.to(gpu_id)
            softmax_list.append(softmax)
        ctx.save_for_backward(*softmax_list)

        loss = torch.ones(1)
        if ctx.compute_loss:
            _loss_list = []
            for gpu_id, softmax in enumerate(softmax_list):
                idx = ctx.label_split[gpu_id]._indices()
                _loss = torch.zeros(ctx.batch_size).to(gpu_id)
                _loss.scatter_(dim=0, index=idx[0], src=softmax[tuple(idx)])
                _loss_list.append(_loss)
            _loss = comm.reduce_add(_loss_list, destination=0)
            log_loss = -torch.log(_loss)
            loss = torch.mean(log_loss)

        return loss
Exemplo n.º 4
0
def comm_all_reduce(inputs):
    # comm backend
    result = comm.reduce_add(inputs)
    results = []
    for i in range(len(inputs)):
        results.append(result.clone().cuda(i))
    return results
Exemplo n.º 5
0
 def test_reduce_add(self):
     x = torch.randn(5, 5)
     y = torch.randn(5, 5)
     x_cuda = x.cuda(0)
     y_cuda = y.cuda(1)
     result = comm.reduce_add((x_cuda, y_cuda))
     self.assertEqual(result.get_device(), 0)
     self.assertEqual(result.cpu(), x + y)
Exemplo n.º 6
0
def reduce_average_params(net_s, net_w, device_s):
    params_s = list(net_s.parameters())
    params_w = [list(net.parameters()) for net in net_w]
    num_workers = len(net_w)
    for j, param_s in enumerate(params_s):
        param_w_sum = comm.reduce_add(
            [params_w[i][j] for i in range(num_workers)], device_s)
        param_s.data.mul_(0)
        param_s.data.add_(1 / num_workers, param_w_sum.data)
    del param_w_sum
Exemplo n.º 7
0
 def forward(ctx, *inputs):
     # ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
     ctx.target_gpus = [inputs[i][0].get_device() for i in range(len(inputs))] #hy modified it according to the output of EdgeDetectionReweightedLosses
     # inputs = sorted(inputs, key=lambda i: i.get_device())
     inputs = sorted(inputs, key=lambda i: i[0].get_device())
     outputs = []
     for i in range(3): #the len of the output tuple of EdgeDetectionReweightedLosses(loss_side5, loss_fuse, loss)
         temp = [inputs[j][i] for j in range(len(inputs))]
         outputs.append(comm.reduce_add(temp)/len(inputs))
     # return comm.reduce_add(inputs)
     return tuple(outputs)
Exemplo n.º 8
0
    def backward(ctx, dz):
        z, var, weight, bias = ctx.saved_tensors
        dz = dz.contiguous()

        # Undo activation
        _act_backward(ctx, z, dz)

        if ctx.training:
            edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)

            if ctx.is_master:
                edzs, eydzs = [edz], [eydz]
                for _ in range(len(ctx.worker_queues)):
                    edz_w, eydz_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    edzs.append(edz_w)
                    eydzs.append(eydz_w)

                edz = comm.reduce_add(edzs) / (ctx.master_queue.maxsize + 1)
                eydz = comm.reduce_add(eydzs) / (ctx.master_queue.maxsize + 1)

                tensors = comm.broadcast_coalesced(
                    (edz, eydz), [edz.get_device()] + ctx.worker_ids
                )
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                ctx.master_queue.put((edz, eydz))
                edz, eydz = ctx.worker_queue.get()
                ctx.worker_queue.task_done()
        else:
            edz = dz.new_zeros(dz.size(1))
            eydz = dz.new_zeros(dz.size(1))

        dx, dweight, dbias = _backend.backward(
            z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps
        )
        dweight = dweight if ctx.affine else None
        dbias = dbias if ctx.affine else None

        return dx, dweight, dbias, None, None, None, None, None, None, None, None
Exemplo n.º 9
0
    def backward(ctx, dz):
        x, weight, bias, mean, var = ctx.saved_tensors
        dz = dz.contiguous()

        # 1. compute \sum(\frac{dJ}{dy_i}) and \sum(\frac{dJ}{dy_i}*\hat{x_i})
        sum_dz, sum_dz_xhat = _backend.syncbn_backward_xhat(
            dz, x, mean, var, ctx.eps)
        if ctx.is_master:
            sum_dzs, sum_dz_xhats = [sum_dz], [sum_dz_xhat]
            # master : gatther from slaves
            for _ in range(ctx.master_queue.maxsize):
                sum_dz_w, sum_dz_xhat_w = ctx.master_queue.get()
                ctx.master_queue.task_done()
                sum_dzs.append(sum_dz_w)
                sum_dz_xhats.append(sum_dz_xhat_w)
            # master : compute global stats
            sum_dz = comm.reduce_add(sum_dzs)
            sum_dz_xhat = comm.reduce_add(sum_dz_xhats)
            sum_dz /= ctx.N
            sum_dz_xhat /= ctx.N
            # master : broadcast global stats
            tensors = comm.broadcast_coalesced(
                (sum_dz, sum_dz_xhat), [mean.get_device()] + ctx.worker_ids)
            for ts, queue in zip(tensors[1:], ctx.worker_queues):
                queue.put(ts)
        else:
            # slave : send to master
            ctx.master_queue.put((sum_dz, sum_dz_xhat))
            # slave : get global stats
            sum_dz, sum_dz_xhat = ctx.worker_queue.get()
            ctx.worker_queue.task_done()

        # do batch norm backward
        dx, dweight, dbias = _backend.syncbn_backward(dz, x, weight, bias,
                                                      mean, var, sum_dz,
                                                      sum_dz_xhat, ctx.affine,
                                                      ctx.eps)

        return dx, dweight, dbias, \
            None, None, None, None, None, None
Exemplo n.º 10
0
    def forward(self, *args):  # args is list of logit parts
        # for numerical stability
        max_list = []
        for arg in args:
            m, _ = torch.max(arg, dim=1, keepdim=True)
            max_list.append(m)
        mc = torch.cat(max_list, dim=1)
        m, _ = torch.max(mc, dim=1, keepdim=True)
        nargs = [arg - m.to(gpu_id) for gpu_id, arg in enumerate(args)]

        # get exp sum
        exp_logit_list = []
        exp_sum_list = []
        for gpu_id, narg in enumerate(nargs):
            exp_logit = torch.exp(narg)
            exp_logit_list.append(exp_logit)
            exp_sum = torch.sum(exp_logit, dim=1, keepdim=True)
            exp_sum_list.append(exp_sum)
        exp_sum_all = comm.reduce_add(exp_sum_list, 0)

        # compute softmax output
        softmax_list = []
        for gpu_id, narg in enumerate(nargs):
            softmax = exp_logit_list[gpu_id] / exp_sum_all.to(gpu_id)
            softmax_list.append(softmax)
        self.save_for_backward(*softmax_list)

        loss = torch.zeros(1)
        if self.compute_loss:
            _loss_list = []
            for gpu_id, softmax in enumerate(softmax_list):
                _loss = torch.sum(softmax * self.label_split[gpu_id], dim=1)
                _loss_list.append(_loss)
            _loss = comm.reduce_add(_loss_list, 0)
            log_loss = -torch.log(_loss)
            loss = torch.mean(log_loss)

        return loss
Exemplo n.º 11
0
    def backward(ctx, dz):
        z, var, weight, bias = ctx.saved_tensors
        dz = dz.contiguous()

        # Undo activation
        _act_backward(ctx, z, dz)

        if ctx.training:
            edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)

            if ctx.is_master:
                edzs, eydzs = [edz], [eydz]
                for _ in range(len(ctx.worker_queues)):
                    edz_w, eydz_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    edzs.append(edz_w)
                    eydzs.append(eydz_w)

                edz = comm.reduce_add(edzs) / (ctx.master_queue.maxsize + 1)
                eydz = comm.reduce_add(eydzs) / (ctx.master_queue.maxsize + 1)

                tensors = comm.broadcast_coalesced((edz, eydz), [edz.get_device()] + ctx.worker_ids)
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                ctx.master_queue.put((edz, eydz))
                edz, eydz = ctx.worker_queue.get()
                ctx.worker_queue.task_done()
        else:
            edz = dz.new_zeros(dz.size(1))
            eydz = dz.new_zeros(dz.size(1))

        dx, dweight, dbias = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
        dweight = dweight if ctx.affine else None
        dbias = dbias if ctx.affine else None

        return dx, dweight, dbias, None, None, None, None, None, None, None, None
Exemplo n.º 12
0
    def forward(ctx,
                x,
                weight,
                bias,
                running_mean,
                running_var,
                extra,
                compute_stats=True,
                momentum=0.1,
                eps=1e-05):
        def _parse_extra(ctx, extra):
            ctx.is_master = extra["is_master"]
            if ctx.is_master:
                ctx.master_queue = extra["master_queue"]
                ctx.worker_queues = extra["worker_queues"]
                ctx.worker_ids = extra["worker_ids"]
            else:
                ctx.master_queue = extra["master_queue"]
                ctx.worker_queue = extra["worker_queue"]

        # Save context
        if extra is not None:
            _parse_extra(ctx, extra)
        ctx.compute_stats = compute_stats
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.affine = weight is not None and bias is not None
        if ctx.compute_stats:
            N = _count_samples(x) * (ctx.master_queue.maxsize + 1)
            assert N > 1
            # 1. compute sum(x) and sum(x^2)
            xsum, xsqsum = _backend.syncbn_sum_sqsum(x.detach())
            if ctx.is_master:
                xsums, xsqsums = [xsum], [xsqsum]
                # master : gatther all sum(x) and sum(x^2) from slaves
                for _ in range(ctx.master_queue.maxsize):
                    xsum_w, xsqsum_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    xsums.append(xsum_w)
                    xsqsums.append(xsqsum_w)
                xsum = comm.reduce_add(xsums)
                xsqsum = comm.reduce_add(xsqsums)
                mean = xsum / N
                sumvar = xsqsum - xsum * mean
                var = sumvar / N
                uvar = sumvar / (N - 1)
                # master : broadcast global mean, variance to all slaves
                tensors = comm.broadcast_coalesced(
                    (mean, uvar, var), [mean.get_device()] + ctx.worker_ids)
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                # slave : send sum(x) and sum(x^2) to master
                ctx.master_queue.put((xsum, xsqsum))
                # slave : get global mean and variance
                mean, uvar, var = ctx.worker_queue.get()
                ctx.worker_queue.task_done()

            # Update running stats
            running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
            running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * uvar)
            ctx.N = N
            ctx.save_for_backward(x, weight, bias, mean, var)
        else:
            mean, var = running_mean, running_var

        # do batch norm forward
        z = _backend.syncbn_forward(x, weight, bias, mean, var, ctx.affine,
                                    ctx.eps)
        return z
Exemplo n.º 13
0
    def backward(ctx, dz):
        z, weight, bias, running_mean, running_var = ctx.saved_tensors
        dz = dz.contiguous()

        # Undo activation
        _act_backward(ctx, z, dz)

        if ctx.needs_input_grad[0]:
            dx = dz.new().resize_as_(dz)
        else:
            dx = None

        if ctx.needs_input_grad[1]:
            dweight = dz.new().resize_as_(running_mean).zero_()
        else:
            dweight = None

        if ctx.needs_input_grad[2]:
            dbias = dz.new().resize_as_(running_mean).zero_()
        else:
            dbias = None

        if ctx.training:
            edz = dz.new().resize_as_(running_mean)
            eydz = dz.new().resize_as_(running_mean)
            _check_contiguous(z, dz, weight, bias, edz, eydz)
            _check(_ext.bn_edz_eydz_cuda, z, dz,
                   weight if weight is not None else dz.new(),
                   bias if bias is not None else dz.new(), edz, eydz, ctx.eps)

            if ctx.is_master:
                edzs, eydzs = [edz], [eydz]
                for _ in range(len(ctx.worker_queues)):
                    edz_w, eydz_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    edzs.append(edz_w)
                    eydzs.append(eydz_w)

                edz = comm.reduce_add(edzs) / (ctx.master_queue.maxsize + 1)
                eydz = comm.reduce_add(eydzs) / (ctx.master_queue.maxsize + 1)

                tensors = comm.broadcast_coalesced(
                    (edz, eydz), [edz.get_device()] + ctx.worker_ids)
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                ctx.master_queue.put((edz, eydz))
                edz, eydz = ctx.worker_queue.get()
                ctx.worker_queue.task_done()
        else:
            edz = dz.new().resize_as_(running_mean).zero_()
            eydz = dz.new().resize_as_(running_mean).zero_()

        _check_contiguous(dz, z, ctx.var, weight, bias, edz, eydz, dx, dweight,
                          dbias)
        _check(_ext.bn_backard_cuda, dz, z, ctx.var,
               weight if weight is not None else dz.new(),
               bias if bias is not None else dz.new(), edz, eydz,
               dx if dx is not None else dz.new(),
               dweight if dweight is not None else dz.new(),
               dbias if dbias is not None else dz.new(), ctx.eps)

        del ctx.var

        return dx, dweight, dbias, None, None, None, None, None, None, None, None
Exemplo n.º 14
0
 def forward(ctx, *inputs):
     ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
     return comm.reduce_add(inputs)
Exemplo n.º 15
0
 def forward(ctx, *inputs):
     ctx.save_for_backward(*inputs)
     if len(inputs) == 1:
         return inputs[0]
     return comm.reduce_add(inputs)
Exemplo n.º 16
0
    def backward(ctx, dz):
        z, weight, bias, running_mean, running_var = ctx.saved_tensors
        dz = dz.contiguous()

        # Undo activation
        _act_backward(ctx, z, dz)

        if ctx.needs_input_grad[0]:
            dx = dz.new().resize_as_(dz)
        else:
            dx = None

        if ctx.needs_input_grad[1]:
            dweight = dz.new().resize_as_(running_mean).zero_()
        else:
            dweight = None

        if ctx.needs_input_grad[2]:
            dbias = dz.new().resize_as_(running_mean).zero_()
        else:
            dbias = None

        if ctx.training:
            edz = dz.new().resize_as_(running_mean)
            eydz = dz.new().resize_as_(running_mean)
            _check_contiguous(z, dz, weight, bias, edz, eydz)
            _check(_ext.bn_edz_eydz_cuda,
                   z, dz,
                   weight if weight is not None else dz.new(),
                   bias if bias is not None else dz.new(),
                   edz, eydz, ctx.eps)

            if ctx.is_master:
                edzs, eydzs = [edz], [eydz]
                for _ in range(len(ctx.worker_queues)):
                    edz_w, eydz_w = ctx.master_queue.get()
                    ctx.master_queue.task_done()
                    edzs.append(edz_w)
                    eydzs.append(eydz_w)

                edz = comm.reduce_add(edzs) / (ctx.master_queue.maxsize + 1)
                eydz = comm.reduce_add(eydzs) / (ctx.master_queue.maxsize + 1)

                tensors = comm.broadcast_coalesced((edz, eydz), [edz.get_device()] + ctx.worker_ids)
                for ts, queue in zip(tensors[1:], ctx.worker_queues):
                    queue.put(ts)
            else:
                ctx.master_queue.put((edz, eydz))
                edz, eydz = ctx.worker_queue.get()
                ctx.worker_queue.task_done()
        else:
            edz = dz.new().resize_as_(running_mean).zero_()
            eydz = dz.new().resize_as_(running_mean).zero_()

        _check_contiguous(dz, z, ctx.var, weight, bias, edz, eydz, dx, dweight, dbias)
        _check(_ext.bn_backard_cuda,
               dz, z, ctx.var,
               weight if weight is not None else dz.new(),
               bias if bias is not None else dz.new(),
               edz, eydz,
               dx if dx is not None else dz.new(),
               dweight if dweight is not None else dz.new(),
               dbias if dbias is not None else dz.new(),
               ctx.eps)

        del ctx.var

        return dx, dweight, dbias, None, None, None, None, None, None, None, None
Exemplo n.º 17
0
 def forward(ctx, *inputs):
     ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
     inputs = sorted(inputs, key=lambda i: i.get_device())
     return comm.reduce_add(inputs)
Exemplo n.º 18
0
 def backward(self, *grad_output):
     return comm.reduce_add(grad_output, self.input_device)