def _log(self, dataset): timeit.stop('total') dataset.log() logger.dump_tabular(print_func=logger.info) logger.debug('') for line in str(timeit).split('\n'): logger.debug(line) timeit.reset() timeit.start('total')
def _gather_rollouts_cross_entropy(self, policy, num_rollouts): dataset = utils.Dataset() for _ in range(num_rollouts): state = self._env.reset() done = False t = 0 while not done: if self._render: timeit.start('render') self._env.render() timeit.stop('render') timeit.start('get action') action = policy.get_action_cross_entropy(state) timeit.stop('get action') timeit.start('env step') next_state, reward, done, _ = self._env.step(action) timeit.stop('env step') done = done or (t >= self._max_rollout_length) dataset.add(state, action, next_state, reward, done) state = next_state t += 1 return dataset
def _debug_rollout_and_record(self, policy, num_rollouts): dataset = utils.Dataset() for r_num in range(num_rollouts): state = self._env.reset() done = False t = 0 states = [state] pred_states = [state] while not done: if self._render: timeit.start('render') self._env.render() timeit.stop('render') timeit.start('get action') action, next_state_pred = policy.get_action(state, True) timeit.stop('get action') timeit.start('env step') next_state, reward, done, _ = self._env.step(action) timeit.stop('env step') done = done or (t >= self._max_rollout_length) dataset.add(state, action, next_state, reward, done) state = next_state t += 1 pred_states.append(next_state_pred) states.append(next_state) states = np.array(states) pred_states = np.array(pred_states) self._debug_plot_states(states, pred_states, r_num) return dataset
def _train_policy(self, dataset): """ Train the model-based policy implementation details: (a) Train for self._training_epochs number of epochs (b) The dataset.random_iterator(...) method will iterate through the dataset once in a random order (c) Use self._training_batch_size for iterating through the dataset (d) Keep track of the loss values by appending them to the losses array """ timeit.start('train policy') losses = [] ### PROBLEM 1 ### YOUR CODE HERE # raise NotImplementedError for _ in range(self._training_epochs): current_batches = dataset.random_iterator( self._training_batch_size) while True: state, action, next_state, reward, _ = \ next(current_batches, [None] * 5) if state is None: break reward = np.atleast_2d(reward).T loss = self._policy.train_step(state, action, next_state, reward) losses.append(loss) logger.record_tabular('TrainingLossStart', losses[0]) logger.record_tabular('TrainingLossFinal', losses[-1]) timeit.stop('train policy') plt.figure() plt.plot(losses) plt.savefig(os.path.join(logger.dir, 'training.png'))