Example #1
0
    def compute_gradients(self, loss, var_list, **kwargs):
        grads_and_vars = tf.train.AdamOptimizer.compute_gradients(
            self, loss, var_list, **kwargs)
        grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None]

        flat_grad = tf.concat(
            [tf.reshape(g, (-1, )) for g, v in grads_and_vars], axis=0)

        if Config.is_test_rank():
            flat_grad = tf.zeros_like(flat_grad)

        shapes = [v.shape.as_list() for g, v in grads_and_vars]
        sizes = [int(np.prod(s)) for s in shapes]

        num_tasks = self.comm.Get_size()
        buf = np.zeros(sum(sizes), np.float32)

        def _collect_grads(flat_grad):
            self.comm.Allreduce(flat_grad, buf, op=MPI.SUM)
            np.divide(buf, float(num_tasks) * self.train_frac, out=buf)
            return buf

        avg_flat_grad = tf.py_func(_collect_grads, [flat_grad], tf.float32)
        avg_flat_grad.set_shape(flat_grad.shape)
        avg_grads = tf.split(avg_flat_grad, sizes, axis=0)
        avg_grads_and_vars = [(tf.reshape(g, v.shape), v)
                              for g, (_, v) in zip(avg_grads, grads_and_vars)]

        return avg_grads_and_vars
Example #2
0
    def step_wait(self):
        self.buf_rew = np.zeros_like(self.buf_rew)
        self.buf_done = np.zeros_like(self.buf_done)

        lib.vec_wait(self.handle, self.buf_rgb, self.buf_render_rgb,
                     self.buf_rew, self.buf_done)

        obs_frames = self.buf_rgb.astype(np.float32)

        if Config.USE_BLACK_WHITE:
            obs_frames = np.mean(obs_frames.astype(np.float32),
                                 axis=-1).astype(np.float32)[..., None]

        if Config.is_test_rank():
            obs_frames = slice_spectrum(obs_frames, Config.TEST_SPECTRUM,
                                        Config.RADIUS)
        else:
            obs_frames = slice_spectrum(obs_frames, Config.TRAIN_SPECTRUM,
                                        Config.RADIUS)

        return obs_frames, self.buf_rew, self.buf_done, self.dummy_info