Exemplo n.º 1
0
    def learn(self):
        logger.info("Training")
        n_steps = 0
        # best_success_rate = 0.

        for epoch in range(self.args.n_epochs):
            residual_losses = []
            for _ in range(self.args.n_cycles):
                # Collect trajectories
                self.controller.reconfigure_heuristic(self.get_residual)
                n_steps += self.collect_trajectories(
                    self.args.num_rollouts_per_mpi)
                # Update residual
                logger.debug("Updating")
                for _ in range(self.args.n_batches):
                    residual_loss = self._update_residual()
                    residual_losses.append(
                        residual_loss.detach().cpu().numpy())
                    logger.debug('Loss', residual_loss)

                self._update_target_network(self.residual_target,
                                            self.residual)

            success_rate = self.eval_agent()
            if MPI.COMM_WORLD.Get_rank() == 0:
                print(
                    '[{}] epoch is: {}, Num steps: {}, eval success rate is: {:.3f}'
                    .format(datetime.now(), epoch, n_steps, success_rate))
                logger.record_tabular('epoch', epoch)
                logger.record_tabular('n_steps', n_steps)
                logger.record_tabular('success_rate', success_rate)
                logger.record_tabular('residual_loss',
                                      np.mean(residual_losses))
                logger.dump_tabular()
Exemplo n.º 2
0
 def eval_agent(self):
     total_success_rate = []
     for _ in range(self.args.n_test_rollouts):
         per_success_rate = []
         self.env.reset()
         for _ in range(self.env_params['max_timesteps']):
             # convert the actions
             actions = np.zeros(self.env_params['action'])
             observation_new, _, _, info = self.env.step(actions)
             per_success_rate.append(info['is_success'])
         total_success_rate.append(per_success_rate)
     total_success_rate = np.array(total_success_rate)
     local_success_rate = np.mean(total_success_rate[:, -1])
     global_success_rate = MPI.COMM_WORLD.allreduce(
         local_success_rate, op=MPI.SUM)
     success_rate = global_success_rate / MPI.COMM_WORLD.Get_size()
     if MPI.COMM_WORLD.Get_rank() == 0:
         logger.record_tabular('success_rate', success_rate)
         logger.dump_tabular()
     return success_rate
Exemplo n.º 3
0
    def learn_online_in_real_world(self, max_timesteps=None):
        # If any pre-existing model is given, load it
        if self.args.load_dir:
            self.load()
        # Reset the environment
        observation = self.env.reset()

        # Count of total number of steps
        total_n_steps = 0
        while True:
            # Rollout for a few steps to collect some transitions
            n_steps, final_observation, done = self.online_rollout(observation)
            # Update the counter
            total_n_steps += n_steps
            # Check if we have reached the goal
            if done:
                break
            # Batch updates
            losses = []
            # for _ in range(self.args.planning_rollout_length):
            for _ in range(self.args.n_online_planning_updates):
                # Update the state-action value residual
                loss = self._update_state_action_value_residual()
                losses.append(loss)
                # Update the target network
                self._update_target_network(self.state_action_value_target_residual,
                                            self.state_action_value_residual)
            # Log
            logger.record_tabular('n_steps', total_n_steps)
            logger.record_tabular('residual_loss', np.mean(losses))
            # logger.dump_tabular()
            # Move to next iteration
            observation = copy.deepcopy(final_observation)

            if max_timesteps and total_n_steps >= max_timesteps:
                break

        return total_n_steps
Exemplo n.º 4
0
    def learn_online_in_real_world(self, max_timesteps=None):
        # If any pre-existing model is given, load it
        if self.args.load_dir:
            self.load()

        # Reset the environment
        observation = self.env.reset()
        # Configure heuristic for controller
        self.controller.reconfigure_heuristic(
            lambda obs: get_state_value_residual(obs, self.preproc_inputs, self
                                                 .state_value_residual))
        # Configure dynamics for controller
        if self.args.agent == 'rts':
            self.controller.reconfigure_discrepancy(
                lambda obs, ac: get_discrepancy_neighbors(
                    obs, ac, self.construct_4d_point, self.kdtrees, self.args.
                    neighbor_radius))

        # Configure dynamics for controller
        if self.args.agent == 'mbpo' or self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp':
            self.controller.reconfigure_residual_dynamics(
                self.get_residual_dynamics)
        # Count of total number of steps
        total_n_steps = 0
        while True:
            obs = observation['observation']
            g = observation['desired_goal']
            qpos = observation['sim_state'].qpos
            qvel = observation['sim_state'].qvel
            # Get action from the controller
            ac, info = self.controller.act(observation)
            if self.args.agent == 'rts':
                assert self.controller.residual_dynamics_fn is None
            if self.args.agent == 'mbpo' or self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp':
                assert self.controller.discrepancy_fn is None
            # Get discrete action index
            ac_ind = self.env.discrete_actions[tuple(ac)]
            # Get the next observation
            next_observation, rew, _, _ = self.env.step(ac)
            # if np.array_equal(obs, next_observation['observation']):
            #     import ipdb
            #     ipdb.set_trace()
            # print('ACTION', ac)
            # print('VALUE PREDICTED', info['start_node_h'])
            # print('COST', -rew)
            if self.args.render:
                self.env.render()
            total_n_steps += 1
            # Check if we reached the goal
            if self.env.env._is_success(next_observation['achieved_goal'], g):
                print('REACHED GOAL!')
                break
            # Get the next obs
            obs_next = next_observation['observation']
            # Get the sim next obs
            set_sim_state_and_goal(self.planning_env, qpos.copy(), qvel.copy(),
                                   g.copy())
            next_observation_sim, _, _, _ = self.planning_env.step(ac)
            obs_sim_next = next_observation_sim['observation']
            # Store transition
            transition = [
                obs.copy(),
                g.copy(), ac_ind,
                qpos.copy(),
                qvel.copy(),
                obs_next.copy(),
                obs_sim_next.copy()
            ]
            dynamics_losses = []
            # RTS
            if self.args.agent == 'rts' and self._check_dynamics_transition(
                    transition):
                # print('DISCREPANCY IN DYNAMICS')
                self.memory.store_real_world_transition(transition)
                # # Fit model
                self._update_discrepancy_model()
            # MBPO
            elif self.args.agent == 'mbpo' or self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp':
                self.memory.store_real_world_transition(transition)
                # Update the dynamics
                if self.args.agent == 'mbpo':
                    for _ in range(self.args.n_online_planning_updates):
                        # Update dynamics
                        loss = self._update_batch_residual_dynamics()
                        dynamics_losses.append(loss.item())
                else:
                    loss = self._update_residual_dynamics()
                    dynamics_losses.append(loss)
            # # Plan in the model
            value_loss = self.plan_online_in_model(
                n_planning_updates=self.args.n_online_planning_updates,
                initial_observation=copy.deepcopy(observation))

            # Log
            logger.record_tabular('n_steps', total_n_steps)
            if self.args.agent == 'mbpo' or self.args.agent == 'mbpo_knn' or self.args.agent == 'mbpo_gp':
                logger.record_tabular('dynamics loss',
                                      np.mean(dynamics_losses))
            logger.record_tabular('residual_loss', value_loss)
            # logger.dump_tabular()
            # Move to next iteration
            observation = copy.deepcopy(next_observation)

            if max_timesteps and total_n_steps >= max_timesteps:
                break

        return total_n_steps
Exemplo n.º 5
0
    def learn(self):
        """
        train the network

        """
        logger.info("Training..")
        n_steps = 0
        best_success_rate = 0.
        # start to collect samples
        for epoch in range(self.args.n_epochs):
            actor_losses = []
            critic_losses = []
            switch_losses = []
            for _ in range(self.args.n_cycles):
                mb_obs, mb_ag, mb_g, mb_actions, mb_switch_actions = [], [], [], [], []
                for _ in range(self.args.num_rollouts_per_mpi):
                    # reset the rollouts
                    ep_obs, ep_ag, ep_g, ep_actions, ep_switch_actions = [], [], [], [], []
                    # reset the environment
                    observation = self.env.reset()
                    obs = observation['observation']
                    ag = observation['achieved_goal']
                    g = observation['desired_goal']
                    # start to collect samples
                    for _ in range(self.env_params['max_timesteps']):
                        n_steps += 1
                        with torch.no_grad():
                            input_tensor = self._preproc_inputs(obs, g)
                            pi = self.actor_network(input_tensor)
                            _, switch_actions_q_values = self.critic_switch_network(
                                input_tensor, pi)
                            switch_action = self._select_switch_actions(
                                switch_actions_q_values)
                        # feed the actions into the environment
                        if switch_action == 0:
                            # Hardcoded action
                            action = self.controller.act(observation)
                        else:
                            # Learned policy action
                            action = self._select_actions(pi)
                        observation_new, _, _, info = self.env.step(action)
                        obs_new = observation_new['observation']
                        ag_new = observation_new['achieved_goal']
                        # append rollouts
                        ep_obs.append(obs.copy())
                        ep_ag.append(ag.copy())
                        ep_g.append(g.copy())
                        ep_actions.append(action.copy())
                        ep_switch_actions.append(switch_action)
                        # re-assign the observation
                        obs = obs_new
                        ag = ag_new
                        observation = observation_new
                    ep_obs.append(obs.copy())
                    ep_ag.append(ag.copy())
                    mb_obs.append(ep_obs)
                    mb_ag.append(ep_ag)
                    mb_g.append(ep_g)
                    mb_actions.append(ep_actions)
                    mb_switch_actions.append(ep_switch_actions)
                # convert them into arrays
                mb_obs = np.array(mb_obs)
                mb_ag = np.array(mb_ag)
                mb_g = np.array(mb_g)
                mb_actions = np.array(mb_actions)
                mb_switch_actions = np.array(mb_switch_actions)
                # store the episodes
                self.buffer.store_episode(
                    [mb_obs, mb_ag, mb_g, mb_actions, mb_switch_actions])
                self._update_normalizer(
                    [mb_obs, mb_ag, mb_g, mb_actions, mb_switch_actions])
                for _ in range(self.args.n_batches):
                    # train the network
                    critic_loss, actor_loss, switch_loss = self._update_network(
                    )
                    actor_losses.append(actor_loss.detach().numpy())
                    critic_losses.append(critic_loss.detach().numpy())
                    switch_losses.append(switch_loss.detach().numpy())
                # soft update
                self._soft_update_target_network(self.actor_target_network,
                                                 self.actor_network)
                self._soft_update_target_network(
                    self.critic_switch_target_network,
                    self.critic_switch_network)
            # start to do the evaluation
            success_rate, prop_hardcoded = self._eval_agent()
            if MPI.COMM_WORLD.Get_rank() == 0:
                print(
                    '[{}] epoch is: {}, Num steps: {}, eval success rate is: {:.3f}'
                    .format(datetime.now(), epoch, n_steps, success_rate))
                logger.record_tabular('epoch', epoch)
                logger.record_tabular('n_steps', n_steps)
                logger.record_tabular('success_rate', success_rate)
                logger.record_tabular('prop_hardcoded', prop_hardcoded)
                logger.record_tabular('actor_loss', np.mean(actor_losses))
                logger.record_tabular('critic_loss', np.mean(critic_losses))
                logger.record_tabular('switch_loss', np.mean(switch_losses))
                logger.dump_tabular()
                if success_rate > best_success_rate:
                    logger.info("Better success rate... Saving policy")
                    torch.save([
                        self.o_norm.mean, self.o_norm.std, self.g_norm.mean,
                        self.g_norm.std,
                        self.actor_network.state_dict()
                    ], self.model_path + '/model.pt')
                    best_success_rate = success_rate
Exemplo n.º 6
0
    def learn_online_in_real_world(self, max_timesteps=None):
        # If any pre-existing model is given, load it
        if self.args.load_dir:
            self.load()

        # Reset the environment
        observation = self.env.reset()
        # Configure heuristic for controller
        self.controller.reconfigure_heuristic(
            lambda obs: get_state_value_residual(obs, self.preproc_inputs, self
                                                 .state_value_residual))

        # Configure dynamics for controller
        self.controller.reconfigure_residual_dynamics(
            self.get_residual_dynamics)

        # Count total number of steps
        total_n_steps = 0
        while True:
            obs = observation['observation']
            g = observation['desired_goal']
            qpos = observation['sim_state'].qpos
            qvel = observation['sim_state'].qvel

            # Get action from controller
            ac, info = self.controller.act(observation)
            # Get discrete action index
            ac_ind = self.env.discrete_actions[tuple(ac)]
            # Get next observation
            next_observation, rew, _, _ = self.env.step(ac)
            # Increment counter
            total_n_steps += 1
            if self.env.env._is_success(next_observation['achieved_goal'], g):
                print('REACHED GOAL!')
                break
            if self.args.render:
                self.env.render()
            # Get next obs
            obs_next = next_observation['observation']
            # GEt sim next obs
            set_sim_state_and_goal(self.planning_env, qpos.copy(), qvel.copy(),
                                   g.copy())
            next_observation_sim, _, _, _ = self.planning_env.step(ac)
            obs_sim_next = next_observation_sim['observation']
            # Store transition in real world memory
            transition = [
                obs.copy(),
                g.copy(), ac_ind,
                qpos.copy(),
                qvel.copy(),
                obs_next.copy(),
                obs_sim_next.copy()
            ]
            self.memory.store_real_world_transition(transition)

            # Update the dynamics
            dynamics_losses = []
            for _ in range(self.args.n_online_planning_updates):
                # Update dynamics
                loss = self._update_residual_dynamics()
                dynamics_losses.append(loss.item())

            # Update state value residual
            value_loss = self.plan_online_in_model(
                self.args.n_online_planning_updates,
                initial_observation=copy.deepcopy(observation))
            # log
            logger.record_tabular('n_steps', total_n_steps)
            logger.record_tabular('dynamics_loss', np.mean(dynamics_losses))
            logger.record_tabular('residual_loss', value_loss)
            logger.dump_tabular()

            # Move to next iteration
            observation = copy.deepcopy(next_observation)

            if max_timesteps and total_n_steps >= max_timesteps:
                break

        return total_n_steps
Exemplo n.º 7
0
    def learn(self):
        logger.info("Training")
        # ILC loop
        # 1. Train the model on real environment transitions
        # 2. Plan in the model to get the optimal policy, and the direction of improvement
        # 3. Do line search on the real environment to find the right step size
        initial_residual_parameters = copy.deepcopy(self.residual.state_dict())
        for epoch in range(self.args.n_epochs):
            # 0. Fix start and goals
            self.populate_sim_states_and_goals()
            # 1. Plan in the model to get the optimal policy
            logger.info("Improving policy in the model")
            residual_losses = []
            for _ in range(self.args.n_cycles):
                # Collect trajectories
                self.controller.reconfigure_heuristic(self.get_residual)
                self.controller.reconfigure_dynamics(
                    self.get_dynamics_residual)
                self.collect_trajectories(
                    self.args.num_rollouts_per_mpi)
                # Update residual
                logger.info("Updating")
                for _ in range(self.args.n_batches):
                    residual_loss = self._update_residual()
                    residual_losses.append(
                        residual_loss.detach().cpu().numpy())
                    logger.info('Residual Loss', residual_loss.item())

                self._update_target_network(
                    self.residual_target, self.residual)

            if not self.args.planning:
                # Get the direction of improvement
                logger.info("Computing direction of improvement")
                final_residual_parameters = copy.deepcopy(
                    self.residual.state_dict())
                gradient = {}
                for key in initial_residual_parameters.keys():
                    gradient[key] = final_residual_parameters[key] - \
                        initial_residual_parameters[key]

                # 2. Line search in the real world
                logger.info("Line search in the real world")
                logger.info("Evaluating initial policy in the real world")
                initial_real_value_estimate = self.evaluate_real_world(
                    initial_residual_parameters)
                logger.info("Initial cost-to-go", initial_real_value_estimate)
                alpha = 1.0
                while True:
                    logger.info("Alpha", alpha)
                    current_residual_parameters = {}
                    for key in initial_residual_parameters.keys():
                        current_residual_parameters[key] = initial_residual_parameters[key] + \
                            alpha * gradient[key]

                    current_real_value_estimate = self.evaluate_real_world(
                        current_residual_parameters)
                    logger.info("Current cost-to-go",
                                current_real_value_estimate)

                    if current_real_value_estimate < initial_real_value_estimate:
                        # Cost to go decreased - found an alpha
                        logger.info("Initial cost-to-go", initial_real_value_estimate,
                                    "Final cost-to-go", current_real_value_estimate)
                        initial_real_value_estimate = current_real_value_estimate
                        initial_residual_parameters = copy.deepcopy(
                            current_residual_parameters)
                        break
                    else:
                        # Decrease alpha
                        alpha *= 0.5

                    if alpha < self.args.alpha_threshold:
                        # If alpha is really really small
                        # Don't update the residual
                        logger.info(
                            "Alpha really small. Not updating residual")
                        logger.info("Best cost-to-go so far",
                                    initial_real_value_estimate)
                        break

                # Assign chosen residual parameters for the residual
                self.residual.load_state_dict(initial_residual_parameters)
                self.residual_target.load_state_dict(
                    initial_residual_parameters)

            logger.info("Evaluating")
            success_rate = self.eval_agent()

            if not self.args.planning:
                # 3. Train model on real world transitions collected so far
                logger.info("Training model residual using real world samples")
                model_losses = []
                for _ in range(self.args.n_model_batches):
                    model_loss = self._update_model()
                    model_losses.append(model_loss.detach().cpu().numpy())
                    logger.info('Model Loss', model_loss.item())
            else:
                model_losses = [0]

            if MPI.COMM_WORLD.Get_rank() == 0:
                print('[{}] epoch is: {}, Num planning steps: {}, Num real steps: {}, eval success rate is: {:.3f}'.format(
                    datetime.now(), epoch, self.n_planning_steps, self.n_real_steps, success_rate))
                logger.record_tabular('epoch', epoch)
                logger.record_tabular('n_planning_steps',
                                      self.n_planning_steps)
                logger.record_tabular('n_real_steps', self.n_real_steps)
                logger.record_tabular('success_rate', success_rate)
                logger.record_tabular(
                    'residual_loss', np.mean(residual_losses))
                logger.record_tabular('model_loss', np.mean(model_losses))
                # logger.record_tabular(
                #     'cost-to-go', initial_real_value_estimate)
                logger.dump_tabular()
Exemplo n.º 8
0
    def learn(self):
        """
        train the network

        """
        logger.info("Training..")
        n_steps = 0
        best_success_rate = 0.
        prev_actor_losses = [0.0]
        actor_losses = [0.0]
        critic_losses = []
        original_actor_lr = self.args.lr_actor
        coin_flipping = False
        # start to collect samples
        for epoch in range(self.args.n_epochs):
            # If residual, then account for burn-in period by monitoring the decrement in loss
            if (epoch == 0
                    or abs(np.mean(actor_losses) - np.mean(prev_actor_losses))
                    > self.args.threshold):
                # Do not update actor, just update critic
                logger.info('Only training critic')
                self.change_actor_lr(0.0)
                coin_flipping = True
            else:
                # Update actor as well
                self.change_actor_lr(original_actor_lr)
                coin_flipping = False

            prev_actor_losses = actor_losses
            actor_losses = []
            critic_losses = []
            for _ in range(self.args.n_cycles):
                mb_obs, mb_ag, mb_g, mb_actions, mb_f = [], [], [], [], []
                for _ in range(self.args.num_rollouts_per_mpi):
                    # reset the rollouts
                    ep_obs, ep_ag, ep_g, ep_actions, ep_f = [], [], [], [], []
                    # reset the environment
                    observation = self.env.reset()
                    obs = observation['observation']
                    ag = observation['achieved_goal']
                    g = observation['desired_goal']
                    f = self.env.extract_features(obs, g)
                    # start to collect samples
                    for t in range(self.env_params['max_timesteps']):
                        n_steps += 1
                        with torch.no_grad():
                            input_tensor = self._preproc_inputs(obs, g)
                            pi = self.actor_network(input_tensor)
                            action = self._select_actions(pi, coin_flipping)
                        # feed the actions into the environment
                        observation_new, _, _, info = self.env.step(action)
                        obs_new = observation_new['observation']
                        ag_new = observation_new['achieved_goal']
                        # append rollouts
                        ep_obs.append(obs.copy())
                        ep_ag.append(ag.copy())
                        ep_g.append(g.copy())
                        ep_actions.append(action.copy())
                        ep_f.append(f.copy())
                        # re-assign the observation
                        obs = obs_new
                        ag = ag_new
                        f = self.env.extract_features(obs, g)
                    ep_obs.append(obs.copy())
                    ep_ag.append(ag.copy())
                    ep_f.append(f.copy())
                    mb_obs.append(ep_obs)
                    mb_ag.append(ep_ag)
                    mb_g.append(ep_g)
                    mb_actions.append(ep_actions)
                    mb_f.append(ep_f)
                # convert them into arrays
                mb_obs = np.array(mb_obs)
                mb_ag = np.array(mb_ag)
                mb_g = np.array(mb_g)
                mb_actions = np.array(mb_actions)
                mb_f = np.array(mb_f)
                # store the episodes
                self.buffer.store_episode(
                    [mb_obs, mb_ag, mb_g, mb_actions, mb_f])
                self._update_normalizer(
                    [mb_obs, mb_ag, mb_g, mb_actions, mb_f])
                for _ in range(self.args.n_batches):
                    # train the network
                    critic_loss, actor_loss = self._update_network()
                    actor_losses.append(actor_loss.detach().numpy())
                    critic_losses.append(critic_loss.detach().numpy())
                # soft update
                self._soft_update_target_network(self.actor_target_network,
                                                 self.actor_network)
                self._soft_update_target_network(self.critic_target_network,
                                                 self.critic_network)
            # start to do the evaluation
            success_rate = self._eval_agent()
            if MPI.COMM_WORLD.Get_rank() == 0:
                print(
                    '[{}] epoch is: {}, Num steps: {}, eval success rate is: {:.3f}'
                    .format(datetime.now(), epoch, n_steps, success_rate))
                logger.record_tabular('epoch', epoch)
                logger.record_tabular('n_steps', n_steps)
                logger.record_tabular('success_rate', success_rate)
                logger.record_tabular('actor_loss', np.mean(actor_losses))
                logger.record_tabular('critic_loss', np.mean(critic_losses))
                logger.dump_tabular()
                if success_rate > best_success_rate:
                    logger.info("Better success rate... Saving policy")
                    torch.save([
                        self.o_norm.mean, self.o_norm.std, self.g_norm.mean,
                        self.g_norm.std,
                        self.actor_network.state_dict()
                    ], self.model_path + '/model.pt')
                    best_success_rate = success_rate
Exemplo n.º 9
0
    def learn_offline_in_model(self):
        if not self.args.offline:
            warnings.warn('SHOULD NOT BE USED ONLINE')

        best_success_rate = 0.0
        n_steps = 0
        for epoch in range(self.args.n_epochs):
            # Reset the environment
            observation = self.env.reset()
            obs = observation['observation']
            g = observation['desired_goal']
            for _ in range(self.env_params['offline_max_timesteps']):
                # Get action
                ac, info = self.controller.act(observation)
                ac_ind = self.env.discrete_actions[tuple(ac)]
                # Get the next observation and reward from the environment
                observation_new, rew, _, _ = self.env.step(ac)
                n_steps += 1
                obs_new = observation_new['observation']
                # Store the transition in memory
                self.memory.store_real_world_transition(
                    [obs, g, ac_ind, obs_new], sim=False)
                observation = copy.deepcopy(observation_new)
                obs = obs_new.copy()

            # Update state value residual from model rollouts
            transitions = self.memory.sample_real_world_memory(
                batch_size=self.args.n_cycles)
            losses = []
            model_losses = []
            for i in range(self.args.n_cycles):
                observation = {}
                observation['observation'] = transitions['obs'][i].copy()
                observation['achieved_goal'] = transitions['obs'][i][:3].copy()
                observation['desired_goal'] = transitions['g'][i].copy()
                # Collect model rollouts

                self.collect_internal_model_trajectories(
                    num_rollouts=1,
                    rollout_length=self.env_params['offline_max_timesteps'],
                    initial_observations=[observation])
                # Update state value residuals
                for _ in range(self.args.n_batches):
                    state_value_residual_loss = self._update_state_value_residual(
                    ).item()
                    losses.append(state_value_residual_loss)
                self._update_target_network(self.state_value_target_residual,
                                            self.state_value_residual)

                # Update dynamics model
                for _ in range(self.args.n_batches):
                    loss = self._update_learned_dynamics_model().item()
                    model_losses.append(loss)

            # Evaluate agent in the model
            mean_success_rate, mean_return = self.eval_agent_in_model()
            # Check if this is a better residual
            if mean_success_rate > best_success_rate:
                best_success_rate = mean_success_rate
                print('Best success rate so far', best_success_rate)
                if self.args.save_dir is not None:
                    print('Saving residual')
                    self.save(epoch, best_success_rate)

            # log
            logger.record_tabular('epoch', epoch)
            logger.record_tabular('n_steps', n_steps)
            logger.record_tabular('success_rate', mean_success_rate)
            logger.record_tabular('return', mean_return)
            logger.record_tabular('state_value_residual_loss', np.mean(losses))
            logger.record_tabular('dynamics_loss', np.mean(model_losses))
            logger.dump_tabular()
Exemplo n.º 10
0
    def learn(self):
        """
        train the network

        """
        logger.info("Training..")
        n_psdp_iters = 0
        n_steps = 0
        best_success_rate = 0.
        epoch = 0
        success_rate = self._eval_agent()
        if MPI.COMM_WORLD.Get_rank() == 0:
            print(
                '[{}] epoch is: {}, Num steps: {}, eval success rate is: {:.3f}'
                .format(datetime.now(), epoch, n_steps, success_rate))
        # start to collect samples
        #assert self.args.n_cycles == self.T, "Number of cycles should be equal to horizon"
        #actor_losses, prev_actor_losses = [0.], [0.]
        #critic_losses, prev_critic_losses = [0.], [0.]
        current_t = self.T
        for epoch in range(self.args.n_epochs):
            # TODO: Burn-in critic?
            #prev_actor_losses = actor_losses
            #prev_critic_losses = critic_losses
            actor_losses = []
            critic_losses = []
            if epoch % 10 == 0:
                current_t = current_t - 1
                logger.info("Training residual policy at time step {}".format(
                    current_t))
            # TODO: Update actors one at a time by monitoring corresponding critic loss
            # Once the critic has been sufficiently trained, then we can start training the actor
            # at that time-step before moving onto the next time-step
            for _ in range(self.args.n_cycles):
                # current_t -= 1
                # if (current_t + 1) % 10 == 0:
                #     logger.info(
                #         "Training residual policy at time step {}".format(current_t))
                # for current_t in range(self.T-1, -1, -1):
                mb_obs, mb_ag, mb_g, mb_actions = [], [], [], []
                for _ in range(self.args.num_rollouts_per_mpi):
                    # reset the rollouts
                    ep_obs, ep_ag, ep_g, ep_actions = [], [], [], []
                    # reset the environment
                    observation = self.env.reset()
                    obs = observation['observation']
                    ag = observation['achieved_goal']
                    g = observation['desired_goal']
                    # start to collect samples
                    for t in range(self.env_params['max_timesteps']):
                        n_steps += 1
                        with torch.no_grad():
                            input_tensor = self._preproc_inputs(obs, g)
                            if t == current_t:
                                # Use untrained residual policy
                                pi = self.actor_networks[t](input_tensor)
                                action = self._select_actions(pi)
                            # elif t > current_t:
                            else:
                                # Use current trained policy
                                # If it has not been trained, it will predict zeros as
                                # a result of our initialization
                                pi = self.actor_networks[t](input_tensor)
                                action = pi.cpu().numpy().squeeze()
                        # feed the actions into the environment
                        observation_new, _, _, info = self.env.step(action)
                        obs_new = observation_new['observation']
                        ag_new = observation_new['achieved_goal']
                        # append rollouts
                        ep_obs.append(obs.copy())
                        ep_ag.append(ag.copy())
                        ep_g.append(g.copy())
                        ep_actions.append(action.copy())
                        # re-assign the observation
                        obs = obs_new
                        ag = ag_new
                    ep_obs.append(obs.copy())
                    ep_ag.append(ag.copy())
                    mb_obs.append(ep_obs)
                    mb_ag.append(ep_ag)
                    mb_g.append(ep_g)
                    mb_actions.append(ep_actions)
                # convert them into arrays
                mb_obs = np.array(mb_obs)
                mb_ag = np.array(mb_ag)
                mb_g = np.array(mb_g)
                mb_actions = np.array(mb_actions)
                # store the episodes
                self.buffer.store_episode([mb_obs, mb_ag, mb_g, mb_actions])
                self._update_normalizer([mb_obs, mb_ag, mb_g, mb_actions],
                                        current_t)
                for _ in range(self.args.n_batches):
                    # train the network
                    critic_loss, actor_loss = self._update_network(current_t)
                    critic_losses.append(critic_loss.detach().numpy())
                    actor_losses.append(actor_loss.detach().numpy())
                # soft update
                # self._soft_update_target_network(
                #     self.actor_target_networks[current_t], self.actor_networks[current_t])
            # FIX: No target network updates
            # self._soft_update_target_network(
            #     self.critic_target_network, self.critic_network)
            # self._hard_update_target_network(
            #     self.critic_target_network, self.critic_network)
            # start to do the evaluation
            success_rate = self._eval_agent()
            if MPI.COMM_WORLD.Get_rank() == 0:
                print(
                    '[{}] epoch is: {}, Current time step : {}, Num steps: {}, eval success rate is: {:.3f}'
                    .format(datetime.now(), epoch, current_t, n_steps,
                            success_rate))
                logger.record_tabular('epoch', epoch)
                logger.record_tabular('n_steps', n_steps)
                logger.record_tabular('success_rate', success_rate)
                logger.record_tabular('actor_loss', np.mean(actor_losses))
                logger.record_tabular('critic_loss', np.mean(critic_losses))
                logger.dump_tabular()
                if success_rate > best_success_rate:
                    logger.info("Better success rate... Saving policy")
                    # torch.save([self.o_norm.mean, self.o_norm.std, self.g_norm.mean, self.g_norm.std, self.actor_network.state_dict()],
                    #            self.model_path + '/model.pt')
                    torch.save([
                        self.o_norm.mean, self.o_norm.std, self.g_norm.mean,
                        self.g_norm.std
                    ] + [
                        self.actor_networks[t].state_dict()
                        for t in range(self.T)
                    ], self.model_path + '/model.pt')
                    best_success_rate = success_rate