示例#1
0
 def _sync_from_torch(self):
     for child in self.children():
         child._sync_from_torch()
     for name, param in self.module.named_parameters(recurse=False):
         getattr(self, name).array = tensor.asarray(param)
     for name, buffer in self.module.named_buffers(recurse=False):
         setattr(self, name, tensor.asarray(buffer))
示例#2
0
    def __init__(self, module):
        super().__init__()
        self._module = module

        with self.init_scope():
            for name, child in module.named_children():
                setattr(self, name, TorchModule(child))
            for name, param in module.named_parameters(recurse=False):
                ch_param = chainer.Parameter(tensor.asarray(param))
                setattr(self, name, ch_param)
                # Gradient computed at PyTorch side is automatically
                # synchronized to Chainer side with this hook.
                param.register_hook(_get_grad_setter(ch_param))
            for name, buffer in module.named_buffers(recurse=False):
                self.add_persistent(name, tensor.asarray(buffer))
def test_asarray_multi_gpu():
    if torch.cuda.device_count() < 2:
        pytest.skip('Not enough GPUs')
    t = torch.arange(5, dtype=torch.float32, device='cuda:1')
    a = tensor.asarray(t)
    assert isinstance(a, cupy.ndarray)
    with cupy.cuda.Device(1):
        a += 1
        numpy.testing.assert_array_equal(a.get(), t.cpu().numpy())
示例#4
0
    def __init__(self, module):
        super().__init__()
        self._module = module

        with self.init_scope():
            for name, child in module.named_children():
                if name == 'module':
                    # DataParallel objects have the model stored as `module`
                    # causing a conflict.
                    name = 'wrapped_module'
                setattr(self, name, TorchModule(child))
            for name, param in module.named_parameters(recurse=False):
                ch_param = chainer.Parameter(tensor.asarray(param))
                setattr(self, name, ch_param)
                # Gradient computed at PyTorch side is automatically
                # synchronized to Chainer side with this hook.
                param.register_hook(_get_grad_setter(ch_param))
            for name, buffer in module.named_buffers(recurse=False):
                self.add_persistent(name, tensor.asarray(buffer))
    def set_model(self):
        chainer_model = MLP()
        chainer_model.to_device(self.device)
        self.device.use()

        dummy_input = self.train_dataset[0][0]
        dummy_input = chainer.Variable(tensor.asarray(dummy_input))
        dummy_input.to_device(self.device)
        chainer_model(dummy_input)
        # dummy_input = iter(self.train_loader).next()
        # dummy_input = chainer_prepare_batch(dummy_input, self.device)
        # chainer_model(dummy_input[0][0])

        self.model = cpm.LinkAsTorchModel(chainer_model)
示例#6
0
def collate_to_array(batch):
    data = torch.utils.data._utils.collate.default_collate(batch)
    return [tensor.asarray(x) for x in data]
示例#7
0
 def hook(grad):
     param.grad = tensor.asarray(grad)
def test_asarray_cpu():
    t = torch.arange(5, dtype=torch.float32)
    a = tensor.asarray(t)
    assert isinstance(a, numpy.ndarray)
    a += 1
    numpy.testing.assert_array_equal(a, t.numpy())
def test_asarray_empty_gpu():
    t = torch.tensor([], dtype=torch.float32, device='cuda')
    a = tensor.asarray(t)
def test_asarray_empty_cpu():
    t = torch.tensor([], dtype=torch.float32)
    a = tensor.asarray(t)
def test_asarray_gpu():
    t = torch.arange(5, dtype=torch.float32, device='cuda')
    a = tensor.asarray(t)
    assert isinstance(a, cupy.ndarray)
    a += 1
    numpy.testing.assert_array_equal(a.get(), t.cpu().numpy())
def test_asarray_empty_gpu():
    t = torch.tensor([], dtype=torch.float32, device='cuda')
    a = tensor.asarray(t)
    assert isinstance(a, cupy.ndarray)
def test_asarray_empty_cpu():
    t = torch.tensor([], dtype=torch.float32)
    a = tensor.asarray(t)
    assert isinstance(a, numpy.ndarray)