def is_file_valid(model_path, save_file=False):
    assert model_path.endswith('.npy'), logger.error(
        'Invalid file provided {}'.format(model_path))
    if not save_file:
        assert os.path.exists(model_path), logger.error(
            'file not found: {}'.format(model_path))
    logger.info('[LOAD/SAVE] checkpoint path is {}'.format(model_path))
    def _traj_backward_pass(self, i_traj):
        """ @brief: do the backward pass. Note that everytime a back fails, we
            increase the traj_kl_eta, and recompute everything
        """
        finished = False
        kl_eta_multiplier = self._op_data[i_traj]['kl_eta_multiplier']
        num_trials = 0
        while not finished:
            traj_kl_eta = self._op_data[i_traj]['traj_kl_eta']
            self._set_cost_kl_penalty(i_traj, traj_kl_eta)  # the kl penalty

            finished = self._ilqr_data_wrapper.backward_pass(
                i_traj, self._op_data[i_traj]['traj_kl_eta'])

            # if failed, increase the kl_eta
            if not finished:
                # recalculate the kl penalty
                self._op_data[i_traj]['traj_kl_eta'] += kl_eta_multiplier
                kl_eta_multiplier *= 2.0

            num_trials += 1
            if num_trials > self.args.gps_max_backward_pass_trials or \
                    self._op_data[i_traj]['traj_kl_eta'] > 1e16:
                logger.error('Failed update')
                break
Exemple #3
0
        def reward_derivative(data_dict, target):
            num_data = len(data_dict['start_state'])
            if target == 'state':
                derivative_data = np.zeros(
                    [num_data, self._env_info['ob_size']], dtype=np.float
                )

                # the speed_reward part
                derivative_data[:, velocity_ob_pos] += 1.0

                # the height reward part
                derivative_data[:, height_ob_pos] += - 2.0 * height_coeff * \
                    (data_dict['start_state'][:, height_ob_pos] - target_height)

            elif target == 'action':
                derivative_data = np.zeros(
                    [num_data, self._env_info['action_size']], dtype=np.float
                )

                # the control reward part
                derivative_data[:, :] += - 2.0 * ctrl_coeff * \
                    data_dict['action'][:, :]

            elif target == 'state-state':
                derivative_data = np.zeros(
                    [num_data,
                     self._env_info['ob_size'], self._env_info['ob_size']],
                    dtype=np.float
                )

                # the height reward
                derivative_data[:, height_ob_pos, height_ob_pos] += \
                    - 2.0 * height_coeff

            elif target == 'action-state':
                derivative_data = np.zeros(
                    [num_data, self._env_info['action_size'],
                     self._env_info['ob_size']],
                    dtype=np.float
                )
            elif target == 'state-action':
                derivative_data = np.zeros(
                    [num_data, self._env_info['ob_size'],
                     self._env_info['action_size']],
                    dtype=np.float
                )

            elif target == 'action-action':
                derivative_data = np.zeros(
                    [num_data, self._env_info['action_size'],
                     self._env_info['action_size']],
                    dtype=np.float
                )
                for diagonal_id in range(self._env_info['action_size']):
                    derivative_data[:, diagonal_id, diagonal_id] += \
                        -2.0 * ctrl_coeff
            else:
                assert False, logger.error('Invalid target {}'.format(target))

            return derivative_data
Exemple #4
0
        def reward_derivative(data_dict, target):

            num_data = len(data_dict['start_state'])
            if target == 'state':
                derivative_data = np.zeros(
                    [num_data, self._env_info['ob_size']], dtype=np.float
                )

                # the speed_reward part
                derivative_data[:, 22] += (0.25 / 0.015)

                # quad_impact_cost
                # cfrc_ext = data_dict['start_state'][-84:]
                if self._env_name == 'gym_humanoid':
                    derivative_data[:, -84:] += \
                        - 1e-6 * data_dict['start_state'][:, -84:]

            elif target == 'action':
                derivative_data = np.zeros(
                    [num_data, self._env_info['action_size']], dtype=np.float
                )

                # the control reward part
                derivative_data[:, :] += - 0.2 * data_dict['action'][:, :]

            elif target == 'state-state':
                derivative_data = np.zeros(
                    [num_data,
                     self._env_info['ob_size'], self._env_info['ob_size']],
                    dtype=np.float
                )
                if self._env_name == 'gym_humanoid':
                    for diagonal_id in range(-84, 0):
                        derivative_data[:, diagonal_id, diagonal_id] += - 1e-6

            elif target == 'action-state':
                derivative_data = np.zeros(
                    [num_data, self._env_info['action_size'],
                     self._env_info['ob_size']],
                    dtype=np.float
                )
            elif target == 'state-action':
                derivative_data = np.zeros(
                    [num_data, self._env_info['ob_size'],
                     self._env_info['action_size']],
                    dtype=np.float
                )

            elif target == 'action-action':
                derivative_data = np.zeros(
                    [num_data, self._env_info['action_size'],
                     self._env_info['action_size']],
                    dtype=np.float
                )
                for diagonal_id in range(self._env_info['action_size']):
                    derivative_data[:, diagonal_id, diagonal_id] += -0.2
            else:
                assert False, logger.error('Invalid target {}'.format(target))

            return derivative_data
Exemple #5
0
            def discrete():
                if target == 'state':
                    derivative_data = np.zeros(
                        [num_data, self._env_info['ob_size']], dtype=np.float)
                    derivative_data[:, x_ob_pos] = 2 * state[:, 0]
                    derivative_data[:, y_ob_pos] = 2 * state[:, 1]
                    derivative_data[:, x_vel_pos] = 2 * state[:, 2]
                    derivative_data[:, y_vel_pos] = 2 * state[:, 3]
                    derivative_data[:, theta_pos] = 2 * state[:, 4]
                    derivative_data[:, contact_one_pos] = .1
                    derivative_data[:, contact_two_pos] = .1

                elif target == 'action':
                    derivative_data = np.zeros(
                        [num_data, self._env_info['action_size']],
                        dtype=np.float)
                    derivative_data[:,0] = 5/(1 + np.exp(-5*action[:,0])) * \
                                        (1 - 1/(1 + np.exp(-5*action[:,0])))
                    derivative_data[:, 1] = 5 / 12 * np.exp(5 * action[:, 1] -
                                                            2.5)

                elif target == 'state-state':
                    derivative_data = np.zeros([
                        num_data, self._env_info['ob_size'],
                        self._env_info['ob_size']
                    ],
                                               dtype=np.float)
                    derivative_data[:, x_ob_pos, x_ob_pos] = 2
                    derivative_data[:, y_ob_pos, y_ob_pos] = 2
                    derivative_data[:, x_vel_pos, x_vel_pos] = 2
                    derivative_data[:, y_vel_pos, y_vel_pos] = 2
                    derivative_data[:, theta_pos, theta_pos] = 2

                elif target == 'action-state':
                    derivative_data = np.zeros([
                        num_data, self._env_info['action_size'],
                        self._env_info['ob_size']
                    ],
                                               dtype=np.float)

                elif target == 'state-action':
                    derivative_data = np.zeros([
                        num_data, self._env_info['ob_size'],
                        self._env_info['action_size']
                    ],
                                               dtype=np.float)

                elif target == 'action-action':
                    derivative_data = np.zeros([
                        num_data, self._env_info['action_size'],
                        self._env_info['action_size']
                    ],
                                               dtype=np.float)

                else:
                    assert False, logger.error(
                        'Invalid target {}'.format(target))

                return derivative_data
Exemple #6
0
        def reward_derivative(data_dict, target):
            num_data = len(data_dict['start_state'])
            if target == 'state':
                derivative_data = np.zeros(
                    [num_data, self._env_info['ob_size']], dtype=np.float)

                # the xpos reward part
                derivative_data[:, xpos_ob_pos] += - 2.0 * xpos_coeff * \
                    (data_dict['start_state'][:, xpos_ob_pos])

                # the ypos reward part
                derivative_data[:, ypos_ob_pos] += - 2.0 * \
                    (data_dict['start_state'][:, ypos_ob_pos] - ypos_target)

            elif target == 'action':
                derivative_data = np.zeros(
                    [num_data, self._env_info['action_size']], dtype=np.float)

            elif target == 'state-state':
                derivative_data = np.zeros([
                    num_data, self._env_info['ob_size'],
                    self._env_info['ob_size']
                ],
                                           dtype=np.float)

                # the xpos reward
                derivative_data[:, xpos_ob_pos, xpos_ob_pos] += \
                    - 2.0 * xpos_coeff

                # the ypos reward
                derivative_data[:, ypos_ob_pos, ypos_ob_pos] += \
                    - 2.0

            elif target == 'action-state':
                derivative_data = np.zeros([
                    num_data, self._env_info['action_size'],
                    self._env_info['ob_size']
                ],
                                           dtype=np.float)
            elif target == 'state-action':
                derivative_data = np.zeros([
                    num_data, self._env_info['ob_size'],
                    self._env_info['action_size']
                ],
                                           dtype=np.float)

            elif target == 'action-action':
                derivative_data = np.zeros([
                    num_data, self._env_info['action_size'],
                    self._env_info['action_size']
                ],
                                           dtype=np.float)
            else:
                assert False, logger.error('Invalid target {}'.format(target))

            return derivative_data
Exemple #7
0
    def train_initial_policy(self, data_dict, replay_buffer, training_info={}):
        # get the validation set
        # Hack the policy val percentage to 0.1 for policy initialization.
        self.args.policy_val_percentage = 0.1

        new_data_id = list(range(len(data_dict['start_state'])))
        self._npr.shuffle(new_data_id)
        num_val = int(len(new_data_id) * self.args.policy_val_percentage)
        val_data = {
            key: data_dict[key][new_data_id][:num_val]
            for key in ['start_state', 'end_state', 'action']
        }

        # get the training set
        train_data = {
            key: data_dict[key][new_data_id][num_val:]
            for key in ['start_state', 'end_state', 'action']
        }

        for i_epoch in range(self.args.dagger_epoch):
            # get the number of batches
            num_batches = len(train_data['action']) // \
                self.args.initial_policy_bs
            # from util.common.fpdb import fpdb; fpdb().set_trace()
            assert num_batches > 0, logger.error('batch_size > data_set')
            avg_training_loss = []

            for i_batch in range(num_batches):
                # train for each sub batch
                feed_dict = {
                    self._input_ph[key]: train_data[key][
                        i_batch * self.args.initial_policy_bs:
                        (i_batch + 1) * self.args.initial_policy_bs
                    ] for key in ['start_state', 'action']
                }
                fetch_dict = {
                    'update_op': self._update_operator['initial_update_op'],
                    'train_loss': self._update_operator['initial_policy_loss']
                }

                training_stat = self._session.run(fetch_dict, feed_dict)
                avg_training_loss.append(training_stat['train_loss'])

            val_loss = self.eval(val_data)

            logger.info(
                '[dynamics at epoch {}]: Val Loss: {}, Train Loss: {}'.format(
                    i_epoch, val_loss, np.mean(avg_training_loss)
                )
            )

        training_stat['val_loss'] = val_loss
        training_stat['avg_train_loss'] = np.mean(avg_training_loss)
        return training_stat
Exemple #8
0
        def reward_derivative(data_dict, target):
            y_ob_pos = 0
            x_ob_pos = 1
            thetadot_ob_pos = 2
            num_data = len(data_dict['start_state'])
            if target == 'state':
                derivative_data = np.zeros(
                    [num_data, self._env_info['ob_size']], dtype=np.float)
                derivative_data[:, y_ob_pos] += -1
                derivative_data[:, x_ob_pos] += \
                    -0.1 * np.sign(data_dict['start_state'][:, x_ob_pos])
                derivative_data[:, thetadot_ob_pos] += \
                    -0.2 * data_dict['start_state'][:, 2]

            elif target == 'action':
                derivative_data = np.zeros(
                    [num_data, self._env_info['action_size']], dtype=np.float)
                derivative_data[:, :] = -.002 * data_dict['action'][:, :]

            elif target == 'state-state':
                derivative_data = np.zeros([
                    num_data, self._env_info['ob_size'],
                    self._env_info['ob_size']
                ],
                                           dtype=np.float)
                derivative_data[:, thetadot_ob_pos, thetadot_ob_pos] += -0.2

            elif target == 'action-state':
                derivative_data = np.zeros([
                    num_data, self._env_info['action_size'],
                    self._env_info['ob_size']
                ],
                                           dtype=np.float)

            elif target == 'state-action':
                derivative_data = np.zeros([
                    num_data, self._env_info['ob_size'],
                    self._env_info['action_size']
                ],
                                           dtype=np.float)

            elif target == 'action-action':
                derivative_data = np.zeros([
                    num_data, self._env_info['action_size'],
                    self._env_info['action_size']
                ],
                                           dtype=np.float)
                for diagonal_id in range(self._env_info['action_size']):
                    derivative_data[:, diagonal_id, diagonal_id] += -0.002

            else:
                assert False, logger.error('Invalid target {}'.format(target))
            return derivative_data
    def __init__(self, sess, summary_name, enable=True,
                 scalar_var_list=dict(), summary_dir=None):
        super(self.__class__, self).__init__(sess, summary_name, enable=enable,
                                             summary_dir=summary_dir)
        if not self.enable:
            return
        assert type(scalar_var_list) == dict, logger.error(
            'We only take the dict where the name is given as the key')

        if len(scalar_var_list) > 0:
            self.summary_list = []
            for name, var in scalar_var_list.items():
                self.summary_list.append(tf.summary.scalar(name, var))
            self.summary = tf.summary.merge(self.summary_list)
        def reward_derivative(data_dict, target):
            num_data = len(data_dict['start_state'])
            if target == 'state':
                derivative_data = np.zeros(
                    [num_data, self._env_info['ob_size']], dtype=np.float)
                derivative_data[:, 0] = -0.02 * data_dict['start_state'][:, 0]
                derivative_data[:, 2] = -np.sin(data_dict['start_state'][:, 2])

            elif target == 'action':
                derivative_data = np.zeros(
                    [num_data, self._env_info['action_size']], dtype=np.float)

            elif target == 'state-state':
                derivative_data = np.zeros([
                    num_data, self._env_info['ob_size'],
                    self._env_info['ob_size']
                ],
                                           dtype=np.float)
                derivative_data[:, 0, 0] = -0.02
                derivative_data[:, 2,
                                2] = -np.cos(data_dict['start_state'][:, 2])

            elif target == 'action-state':
                derivative_data = np.zeros([
                    num_data, self._env_info['action_size'],
                    self._env_info['ob_size']
                ],
                                           dtype=np.float)

            elif target == 'state-action':
                derivative_data = np.zeros([
                    num_data, self._env_info['ob_size'],
                    self._env_info['action_size']
                ],
                                           dtype=np.float)

            elif target == 'action-action':
                derivative_data = np.zeros([
                    num_data, self._env_info['action_size'],
                    self._env_info['action_size']
                ],
                                           dtype=np.float)

            else:
                assert False, logger.error('Invalid target {}'.format(target))

            return derivative_data
def load_expert_data(traj_data_name, traj_episode_num):
    # the start of the training
    traj_base_dir = init_path.get_abs_base_dir()

    if not traj_data_name.endswith('.npy'):
        traj_data_name = traj_data_name + '.npy'
    data_dir = os.path.join(traj_base_dir, traj_data_name)

    assert os.path.exists(data_dir), \
        logger.error('Invalid path: {}'.format(data_dir))
    expert_trajectory = np.load(data_dir, encoding="latin1")

    # choose only the top trajectories
    if len(expert_trajectory) > traj_episode_num:
        logger.warning('Using only %d trajs out of %d trajs' %
                       (traj_episode_num, len(expert_trajectory)))
    expert_trajectory = expert_trajectory[:min(traj_episode_num,
                                               len(expert_trajectory))]
    return expert_trajectory
Exemple #12
0
        def reward_derivative(data_dict, target):
            """
            y_1_pos = 0
            x_1_pos = 1
            y_2_pos = 2
            x_2_pos = 3
            """
            num_data = len(data_dict['start_state'])
            if target == 'state':
                derivative_data = np.zeros(
                    [num_data, self._env_info['ob_size']], dtype=np.float)
                derivative_data[:, 0] += -1.0 - data_dict['start_state'][:, 2]
                derivative_data[:, 1] += data_dict['start_state'][:, 3]
                derivative_data[:, 2] += -data_dict['start_state'][:, 0]
                derivative_data[:, 3] += data_dict['start_state'][:, 1]

            elif target == 'action':
                derivative_data = np.zeros(
                    [num_data, self._env_info['action_size']], dtype=np.float)

            elif target == 'state-state':
                derivative_data = np.zeros([
                    num_data, self._env_info['ob_size'],
                    self._env_info['ob_size']
                ],
                                           dtype=np.float)
                derivative_data[:, 0, 2] += -1.0
                derivative_data[:, 1, 3] += 1.0
                derivative_data[:, 2, 0] += -1.0
                derivative_data[:, 3, 1] += 1.0

            elif target == 'action-state':
                derivative_data = np.zeros([
                    num_data, self._env_info['action_size'],
                    self._env_info['ob_size']
                ],
                                           dtype=np.float)

            elif target == 'state-action':
                derivative_data = np.zeros([
                    num_data, self._env_info['ob_size'],
                    self._env_info['action_size']
                ],
                                           dtype=np.float)

            elif target == 'action-action':
                derivative_data = np.zeros([
                    num_data, self._env_info['action_size'],
                    self._env_info['action_size']
                ],
                                           dtype=np.float)

            else:
                assert False, logger.error('Invalid target {}'.format(target))

            if self._env_name == 'gym_acrobot':
                return derivative_data

            elif self._env_name == 'gym_acrobot_sparse':
                return np.zeros_like(derivative_data)

            else:
                raise ValueError("invalid env name")
    def run(self):
        self._build_model()

        while True:
            next_task = self._task_queue.get(block=True)

            if next_task[0] == parallel_util.WORKER_PLANNING:
                # collect rollouts
                plan = self._plan(next_task[1])
                self._task_queue.task_done()
                self._result_queue.put(plan)

            elif next_task[0] == parallel_util.WORKER_PLAYING:
                # collect rollouts
                traj_episode = self._play(next_task[1])
                self._task_queue.task_done()
                self._result_queue.put(traj_episode)

            elif next_task[0] == parallel_util.WORKER_RATE_ACTIONS:
                # predict reward of a sequence of action
                reward = self._rate_action(next_task[1])
                self._task_queue.task_done()
                self._result_queue.put(reward)

            elif next_task[0] == parallel_util.WORKER_GET_MODEL:
                # collect the gradients
                data_id = next_task[1]['data_id']

                if next_task[1]['type'] == 'dynamics_derivative':
                    model_data = self._dynamics_derivative(
                        next_task[1]['data_dict'], next_task[1]['target'])
                elif next_task[1]['type'] == 'reward_derivative':
                    model_data = self._reward_derivative(
                        next_task[1]['data_dict'], next_task[1]['target'])
                elif next_task[1]['type'] == 'forward_model':
                    # get the next state
                    model_data = self._dynamics(next_task[1]['data_dict'])
                    model_data.update(self._reward(next_task[1]['data_dict']))
                    if next_task[1]['end_of_traj']:
                        # get the start reward for the initial state
                        model_data['end_reward'] = self._reward({
                            'start_state':
                            model_data['end_state'],
                            'action':
                            next_task[1]['data_dict']['action'] * 0.0
                        })['reward']
                else:
                    assert False

                self._task_queue.task_done()
                self._result_queue.put({
                    'data': model_data,
                    'data_id': data_id
                })

            elif next_task[0] == parallel_util.AGENT_SET_WEIGHTS:
                # set parameters of the actor policy
                self._set_weights(next_task[1])
                time.sleep(0.001)  # yield the process
                self._task_queue.task_done()

            elif next_task[0] == parallel_util.END_ROLLOUT_SIGNAL or \
                    next_task[0] == parallel_util.END_SIGNAL:
                # kill all the thread
                logger.info("kill message for worker {}".format(
                    self._worker_id))
                # logger.info("kill message for worker")
                self._task_queue.task_done()
                break
            else:
                logger.error('Invalid task type {}'.format(next_task[0]))
        return
 def get_tf_summary(self):
     assert self.summary is not None, logger.error(
         'tf summary not defined, call the summary object separately')
     return self.summary
    def _train_linear_gaussian_with_prior(self, data_dict):
        # parse the data back into episode to fit separate gaussians for each
        # timesteps
        assert (np.array(
            data_dict['episode_length']
        ) == data_dict['episode_length'][0]).all(), logger.error(
            'gps cannot handle cases where length of episode is not consistent'
        )

        episode_length = data_dict['episode_length'][0]
        num_episode = len(data_dict['action']) / episode_length

        # the linear coeff by x and u, the constant f_c
        dynamics_results = {'fm': [], 'fv': [], 'dyn_covar': []}
        # 'f_x': [], 'f_u': [], 'f_c': [], 'f_cov': [],
        # 'raw_f_xf_u': [], 'x0_mean': [], 'x0_cov': []
        dynamics_results['episode_length'] = episode_length

        # fit the init state NOTE: this is different from the code base though
        dynamics_results['x0mu'], dynamics_results['x0sigma'] = \
            self._fit_init_state(data_dict)

        # the normalization data
        '''
        whitening_stats = data_dict['whitening_stats']
        inv_sigma_x = np.diag(1.0 / whitening_stats['state']['std'])
        sigma_x = np.diag(whitening_stats['state']['std'])
        mu_x = whitening_stats['state']['mean']
        # mu_delta = whitening_stats['diff_state']['mean']
        # sigma_delta = np.diag(whitening_stats['diff_state']['std'])
        '''

        for i_pos in range(episode_length):
            i_pos_data_id = i_pos + \
                np.array(range(num_episode)) * episode_length
            train_data = np.concatenate([
                data_dict['start_state'][i_pos_data_id],
                data_dict['action'][i_pos_data_id],
                data_dict['end_state'][i_pos_data_id]
            ],
                                        axis=1)

            # get the gmm posterior
            pos_mean, pos_cov = gps_utils.get_gmm_posterior(
                self._gmm, self._gmm_weights, train_data)

            # fit a new linear gaussian dynamics (using the posterior as prior)
            i_dynamics_result = gps_utils.linear_gauss_dynamics_fit_with_prior(
                train_data, pos_mean, pos_cov, self._NIW_prior['m'],
                self._NIW_prior['n0'], self.args.gps_dynamics_cov_reg,
                self._action_size, self._observation_size)

            # unnormalize the data. we get the dynamics of the
            # p(x_t+1 - x_t | norm(x_t), u_t). Recover the original data
            '''
            i_dynamics_result['f_x'] = \
                sigma_x.dot(i_dynamics_result['f_x']).dot(inv_sigma_x)
            i_dynamics_result['f_u'] = sigma_x.dot(i_dynamics_result['f_u'])
            i_dynamics_result['f_c'] = sigma_x.dot(i_dynamics_result['f_c']) + \
                mu_x - i_dynamics_result['f_x'].dot(mu_x)
            i_dynamics_result['f_cov'] = \
                sigma_x.dot(i_dynamics_result['f_cov']).dot(sigma_x.T)
            '''

            dynamics_results['fm'].append(i_dynamics_result['raw_f_xf_u'])
            dynamics_results['fv'].append(i_dynamics_result['f_c'])
            dynamics_results['dyn_covar'].append(i_dynamics_result['f_cov'])
            '''
            for key in i_dynamics_result:
                dynamics_results[key].append(i_dynamics_result[key])
            '''
            '''
            from mbbl.util.common.vis_debug import vis_dynamics
            vis_dynamics(self.args, self._observation_size, self._action_size,
                         i_pos_data_id, data_dict, i_dynamics_result, 'state')
            vis_dynamics(self.args, self._observation_size, self._action_size,
                         i_pos_data_id, data_dict, i_dynamics_result, 'const')
            vis_dynamics(self.args, self._observation_size, self._action_size,
                         i_pos_data_id, data_dict, i_dynamics_result, 'action')
            '''

        for key in ['fm', 'fv', 'dyn_covar']:
            dynamics_results[key] = np.array(dynamics_results[key])

        return dynamics_results
Exemple #16
0
    def train(self, data_dict, replay_buffer, training_info={}):
        # update the whitening stats of the network
        self._set_whitening_var(data_dict['whitening_stats'])
        self._debug_it += 1

        # get the validation set
        new_data_id = list(range(len(data_dict['start_state'])))
        self._npr.shuffle(new_data_id)
        num_val = min(
            int(len(new_data_id) * self.args.dynamics_val_percentage),
            self.args.dynamics_val_max_size)
        val_data = {
            key: data_dict[key][new_data_id][:num_val]
            for key in ['start_state', 'end_state', 'action']
        }

        # get the training set
        replay_train_data = replay_buffer.get_all_data()
        train_data = {
            key: np.concatenate([
                data_dict[key][new_data_id][num_val:], replay_train_data[key]
            ])
            for key in ['start_state', 'end_state', 'action']
        }

        for i_epochs in range(self.args.dynamics_epochs):
            # get the number of batches
            num_batches = len(train_data['action']) // \
                self.args.dynamics_batch_size
            # from util.common.fpdb import fpdb; fpdb().set_trace()
            assert num_batches > 0, logger.error('batch_size > data_set')
            avg_training_loss = []

            for i_batch in range(num_batches):
                # train for each sub batch
                feed_dict = {
                    self._input_ph[key]:
                    train_data[key][i_batch *
                                    self.args.dynamics_batch_size:(i_batch +
                                                                   1) *
                                    self.args.dynamics_batch_size]
                    for key in ['start_state', 'end_state', 'action']
                }
                fetch_dict = {
                    'update_op': self._update_operator['update_op'],
                    'train_loss': self._update_operator['loss']
                }

                training_stat = self._session.run(fetch_dict, feed_dict)
                avg_training_loss.append(training_stat['train_loss'])

            val_loss = self.eval(val_data)

            logger.info('[dynamics]: Val Loss: {}, Train Loss: {}'.format(
                val_loss, np.mean(avg_training_loss)))
            '''
            if self._debug_it > 20:
                for i in range(2):
                    test_feed_dict = {
                        self._input_ph[key]: feed_dict[self._input_ph[key]][[i]]
                        for key in ['start_state', 'end_state', 'action']
                    }
                    t_end_state = self._session.run(self._tensor['pred_output'], test_feed_dict)
                    print('train_loss', training_stat)
                    # print('gt', test_feed_dict[self._input_ph['end_state']])
                    # print('pred', t_end_state)
                    diff = t_end_state - test_feed_dict[self._input_ph['end_state']]
                    print(i, 'pred_diff', diff, np.abs(diff).max())
                    print('end-start diff',
                          test_feed_dict[self._input_ph['end_state']] -
                          test_feed_dict[self._input_ph['start_state']])
                    print('pred_end-start diff',
                          t_end_state,
                          test_feed_dict[self._input_ph['start_state']])

                from util.common.fpdb import fpdb; fpdb().set_trace()
            '''
        training_stat['val_loss'] = val_loss
        training_stat['avg_train_loss'] = np.mean(avg_training_loss)
        return training_stat
    def train(self, data_dict, replay_buffer, training_info={}):
        self._set_whitening_var(data_dict['whitening_stats'])
        self._debug_it += 1

        # get the validation set
        new_data_id = list(range(len(data_dict['start_state'])))
        self._npr.shuffle(new_data_id)
        num_val = min(
            int(len(new_data_id) * self.args.dynamics_val_percentage),
            self.args.dynamics_val_max_size)
        val_data = {
            key: data_dict[key][new_data_id][:num_val]
            for key in ['start_state', 'end_state', 'action']
        }

        # get the training set
        replay_train_data = replay_buffer.get_all_data()
        train_data = {
            key: np.concatenate([
                data_dict[key][new_data_id][num_val:], replay_train_data[key]
            ])
            for key in ['start_state', 'end_state', 'action']
        }

        # training loop
        for ep_i in range(self.args.dynamics_epochs):

            num_batches = len(train_data['action']) // \
                    self.args.dynamics_batch_size
            assert num_batches > 0, logger.error('batch_size > data_set')
            avg_training_loss = []

            for i_batch in range(num_batches):
                idx_start = i_batch * self.args.dynamics_batch_size
                idx_end = (i_batch + 1) * self.args.dynamics_batch_size

                feed_dict = {
                    self._input_ph[key]: train_data[key][idx_start:idx_end]
                    for key in ['start_state', 'end_state', 'action']
                }
                feed_dict[
                    self._input_ph['keep_prob']] = self.args.ggnn_keep_prob
                if self.args.d_output:
                    discrete_label = self._process_next_ob(
                        train_data['start_state'][idx_start:idx_end],
                        train_data['end_state'][idx_start:idx_end])
                    feed_dict[
                        self._input_ph['discretize_label_ph']] = discrete_label

                fetch_dict = {
                    'update_op': self._update_operator['update_op'],
                    'train_loss': self._update_operator['loss']
                }

                training_stat = self._session.run(fetch_dict, feed_dict)
                avg_training_loss.append(training_stat['train_loss'])

            val_loss = self.eval(val_data)
            logger.info('[gnn dynamics]: val loss {}, trn loss: {}'.format(
                val_loss, np.mean(avg_training_loss)))

        training_stat['val_loss'] = val_loss
        training_stat['avg_train_loss'] = np.mean(avg_training_loss)

        return training_stat