示例#1
0
    def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000):
        grads = []
        for param_name, param in self.module.named_parameters():
            if param.grad is not None:
                grad_data = param.grad.data
                if self.sparse_gradients_enabled(
                ) and param_name in self.csr_tensor_module_names:
                    grads.append(CSRTensor(grad_data))
                else:
                    grads.append(grad_data)

        split_buckets = split_half_float_double_csr(grads)

        for i, bucket_tuple in enumerate(split_buckets):
            bucket_type, bucket = bucket_tuple
            if bucket_type == CSRTensor.type():
                self.csr_allreduce_no_retain(bucket)
            else:
                self.allreduce_no_retain(bucket, numel_per_bucket=elements_per_buffer)
示例#2
0
def split_half_float_double_csr(tensors):
    dtypes = [
        "torch.cuda.HalfTensor", "torch.cuda.FloatTensor",
        "torch.cuda.DoubleTensor",
        CSRTensor.type()
    ]
    buckets = []
    for i, dtype in enumerate(dtypes):
        bucket = [t for t in tensors if t.type() == dtype]
        if bucket:
            buckets.append((dtype, bucket))
    return buckets
示例#3
0
def test_csr_addition_different():
    row_count = 10
    random.seed(1234)

    x = torch.ones(1, 5)
    for i in range(row_count - 1):
        if random.random() > 0.75:
            x = torch.cat([x, torch.ones(1, 5)])
        else:
            x = torch.cat([x, torch.zeros(1, 5)])
    dense_x = x.clone()
    cx = CSRTensor(x)

    y = torch.ones(1, 5)
    for i in range(row_count - 1):
        if random.random() > 0.75:
            y = torch.cat([y, torch.ones(1, 5)])
        else:
            y = torch.cat([y, torch.zeros(1, 5)])
    dense_y = y.clone()
    cy = CSRTensor(y)

    dense_sum = dense_x + dense_y
    cx.add(cy)

    assert torch.all(dense_sum == cx.to_dense())
示例#4
0
def test_csr_addition_self():
    row_count = 10
    random.seed(1234)

    x = torch.ones(1, 5)
    for i in range(row_count - 1):
        if random.random() > 0.75:
            x = torch.cat([x, torch.ones(1, 5)])
        else:
            x = torch.cat([x, torch.zeros(1, 5)])
    dense_x = x.clone()
    cx = CSRTensor(x)

    assert torch.all(dense_x == cx.to_dense())

    cx.add(cx)
    assert torch.all(dense_x + dense_x == cx.to_dense())