def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000): grads = [] for param_name, param in self.module.named_parameters(): if param.grad is not None: grad_data = param.grad.data if self.sparse_gradients_enabled( ) and param_name in self.csr_tensor_module_names: grads.append(CSRTensor(grad_data)) else: grads.append(grad_data) split_buckets = split_half_float_double_csr(grads) for i, bucket_tuple in enumerate(split_buckets): bucket_type, bucket = bucket_tuple if bucket_type == CSRTensor.type(): self.csr_allreduce_no_retain(bucket) else: self.allreduce_no_retain(bucket, numel_per_bucket=elements_per_buffer)
def split_half_float_double_csr(tensors): dtypes = [ "torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", CSRTensor.type() ] buckets = [] for i, dtype in enumerate(dtypes): bucket = [t for t in tensors if t.type() == dtype] if bucket: buckets.append((dtype, bucket)) return buckets
def test_csr_addition_different(): row_count = 10 random.seed(1234) x = torch.ones(1, 5) for i in range(row_count - 1): if random.random() > 0.75: x = torch.cat([x, torch.ones(1, 5)]) else: x = torch.cat([x, torch.zeros(1, 5)]) dense_x = x.clone() cx = CSRTensor(x) y = torch.ones(1, 5) for i in range(row_count - 1): if random.random() > 0.75: y = torch.cat([y, torch.ones(1, 5)]) else: y = torch.cat([y, torch.zeros(1, 5)]) dense_y = y.clone() cy = CSRTensor(y) dense_sum = dense_x + dense_y cx.add(cy) assert torch.all(dense_sum == cx.to_dense())
def test_csr_addition_self(): row_count = 10 random.seed(1234) x = torch.ones(1, 5) for i in range(row_count - 1): if random.random() > 0.75: x = torch.cat([x, torch.ones(1, 5)]) else: x = torch.cat([x, torch.zeros(1, 5)]) dense_x = x.clone() cx = CSRTensor(x) assert torch.all(dense_x == cx.to_dense()) cx.add(cx) assert torch.all(dense_x + dense_x == cx.to_dense())