def run(self):
        self.device.use()

        self.setup()

        while True:
            job, data = self.pipe.recv()
            if job == 'finalize':
                self.device.device.synchronize()
                break
            if job == 'update':
                # For reducing memory
                self.model.cleargrads()

                model = self.model
                model.cleargrads()

                x = self.converter(self.iterator.next(), self.device_index)
                batch_size = len(x)
                images = x['image']
                viewpoints = x['viewpoint']
                xp = model.xp

                representation, query_images, query_viewpoints = encode_scene(
                    images, viewpoints, model, self.device_index)

                with self.reporter.scope({}):  # pass dummy observation
                    # Compute distribution parameterws
                    (z_t_param_array,
                     pixel_mean) = model.sample_z_and_x_params_from_posterior(
                         query_images, query_viewpoints, representation)

                    # Compute ELBO
                    (ELBO, bits_per_pixel, negative_log_likelihood,
                     kl_divergence) = estimate_ELBO(xp, query_images,
                                                    z_t_param_array,
                                                    pixel_mean,
                                                    self.pixel_log_sigma,
                                                    batch_size)

                    # Update parameters
                    loss = -ELBO

                loss.backward()
                del loss

                gg = gather_grads(self.model)
                nccl_data_type = _get_nccl_data_type(gg.dtype)
                null_stream = cuda.Stream.null
                self.comm.reduce(gg.data.ptr, gg.data.ptr, gg.size,
                                 nccl_data_type, nccl.NCCL_SUM, 0,
                                 null_stream.ptr)
                del gg
                self.model.cleargrads()
                gp = gather_params(self.model)
                nccl_data_type = _get_nccl_data_type(gp.dtype)
                self.comm.bcast(gp.data.ptr, gp.size, nccl_data_type, 0,
                                null_stream.ptr)
                scatter_params(self.model, gp)
                del gp
    def test_gather_scatter_params(self):
        cupy = cuda.cupy
        model0 = SimpleNet(dtype=self.dtype)
        model1 = SimpleNet(dtype=self.dtype)

        model0.to_gpu()
        model1.to_gpu()

        gp0 = mpu.gather_params(model0)
        mpu.scatter_params(model1, gp0)

        cupy.testing.assert_array_equal(model0.conv.W.data, model1.conv.W.data)
        cupy.testing.assert_array_equal(model0.conv.b.data, model1.conv.b.data)
        cupy.testing.assert_array_equal(model0.fc.W.data, model1.fc.W.data)
        cupy.testing.assert_array_equal(model0.fc.b.data, model1.fc.b.data)
    def test_gather_scatter_params(self):
        cupy = cuda.cupy
        model0 = SimpleNet(dtype=self.dtype)
        model1 = SimpleNet(dtype=self.dtype)

        model0.to_gpu()
        model1.to_gpu()

        gp0 = mpu.gather_params(model0)
        mpu.scatter_params(model1, gp0)

        cupy.testing.assert_array_equal(model0.conv.W.data, model1.conv.W.data)
        cupy.testing.assert_array_equal(model0.conv.b.data, model1.conv.b.data)
        cupy.testing.assert_array_equal(model0.fc.W.data, model1.fc.W.data)
        cupy.testing.assert_array_equal(model0.fc.b.data, model1.fc.b.data)
Exemplo n.º 4
0
    def run(self):
        from cupy.cuda import nccl
        dev = cuda.Device(self.device)
        dev.use()
        self.setup()
        gp = None
        while True:
            job, data = self.pipe.recv()
            if job == 'finalize':
                dev.synchronize()
                break
            if job == 'update':
                # For reducing memory
                self.model.cleargrads()

                batch = self.iterator.next()
                x = converter_kaldi(batch[0], self.reader)
                observation = {}
                with self.reporter.scope(observation):
                    loss = self.model(x)

                self.model.cleargrads()
                loss.backward()
                loss.unchain_backward()

                del loss

                gg = gather_grads(self.model)
                null_stream = cuda.Stream.null
                self.comm.reduce(gg.data.ptr, gg.data.ptr, gg.size,
                                 nccl.NCCL_FLOAT,
                                 nccl.NCCL_SUM, 0,
                                 null_stream.ptr)
                del gg
                self.model.cleargrads()
                gp = gather_params(self.model)
                self.comm.bcast(gp.data.ptr, gp.size,
                                nccl.NCCL_FLOAT, 0,
                                null_stream.ptr)
                scatter_params(self.model, gp)
                gp = None

                delete_feat(x)