Exemplo n.º 1
0
    def train(self, env):
        # Memory
        memory = ReplayBuffer(capacity=self.replay_size)

        # Training Loop
        total_numsteps = 0
        updates = 0

        for i_episode in itertools.count(1):
            episode_reward = 0
            episode_steps = 0
            done = False
            state = env.reset()

            while not done:
                if total_numsteps < self.start_steps:
                    action = env.action_space.sample()  # Sample random action
                else:
                    # Sample action from policy
                    action = self.select_action(state)

                if len(memory) > self.batch_size:
                    # Number of updates per step in environment
                    for i in range(self.updates_per_step):
                        # Update parameters of all the networks
                        q1_loss, q2_loss, policy_loss, alpha_loss = self.update_parameters(
                            memory, self.batch_size, updates)
                        updates += 1

                next_state, reward, done, _ = env.step(action)  # Step
                episode_steps += 1
                total_numsteps += 1
                episode_reward += reward

                if self.render:
                    env.render()

                # Ignore the "done" signal if it comes from hitting the time horizon.
                # (https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py)
                done = 0 if episode_steps == env._max_episode_steps else done

                memory.push(state, action, reward, next_state,
                            done)  # Append transition to memory

                state = next_state

            logger.info('UPDATE')
            logger.record_tabular('q1_loss', q1_loss)
            logger.record_tabular('q2_loss', q2_loss)
            logger.record_tabular('policy_loss', policy_loss)
            logger.record_tabular('alpha_loss', alpha_loss)
            logger.dump_tabular()

            logger.info('STATUS')
            logger.record_tabular('i_episode', i_episode)
            logger.record_tabular('episode_steps', episode_steps)
            logger.record_tabular('total_numsteps', total_numsteps)
            logger.record_tabular('episode_reward', episode_reward)
            logger.dump_tabular()

            if i_episode % 100 == 0:
                logger.info('SAVE')
                self.save_model('../saved/sac')

            if total_numsteps > self.num_steps:
                return
Exemplo n.º 2
0
def main():
    time_str = time.strftime("%Y%m%d-%H%M%S")
    print('time_str: ', time_str)

    exp_count = 0

    if args.experiment == 'a|s':
        direc_name_ = '_'.join([args.env, args.experiment])
    else:
        direc_name_ = '_'.join(
            [args.env, args.experiment, 'bp2VAE',
             str(args.bp2VAE)])

    direc_name_exist = True

    while direc_name_exist:
        exp_count += 1
        direc_name = '/'.join([direc_name_, str(exp_count)])
        direc_name_exist = os.path.exists(direc_name)

    try:
        os.makedirs(direc_name)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise

    if args.tensorboard_dir is None:
        logger = Logger('/'.join([direc_name, time_str]))
    else:
        logger = Logger(args.tensorboard_dir)

    env = gym.make(args.env)

    if args.wrapper:
        if args.video_dir is None:
            args.video_dir = '/'.join([direc_name, 'videos'])
        env = gym.wrappers.Monitor(env, args.video_dir, force=True)

    print('observation_space: ', env.observation_space)
    print('action_space: ', env.action_space)
    env.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.experiment == 'a|s':
        dim_x = env.observation_space.shape[0]
    elif args.experiment == 'a|z(s)' or args.experiment == 'a|z(s, s_next)' or \
            args.experiment == 'a|z(a_prev, s, s_next)':
        dim_x = args.z_dim

    policy = ActorCritic(input_size=dim_x,
                         hidden1_size=3 * dim_x,
                         hidden2_size=6 * dim_x,
                         action_size=env.action_space.n)

    if args.use_cuda:
        Tensor = torch.cuda.FloatTensor
        torch.cuda.manual_seed_all(args.seed)
        policy.cuda()
    else:
        Tensor = torch.FloatTensor

    policy_optimizer = optim.Adam(policy.parameters(), lr=args.policy_lr)

    if args.experiment != 'a|s':
        from util import ReplayBuffer, vae_loss_function

        dim_s = env.observation_space.shape[0]

        if args.experiment == 'a|z(s)' or args.experiment == 'a|z(s, s_next)':
            from model import VAE
            vae = VAE(input_size=dim_s,
                      hidden1_size=3 * args.z_dim,
                      hidden2_size=args.z_dim)

        elif args.experiment == 'a|z(a_prev, s, s_next)':
            from model import CVAE
            vae = CVAE(input_size=dim_s,
                       class_size=1,
                       hidden1_size=3 * args.z_dim,
                       hidden2_size=args.z_dim)

        if args.use_cuda:
            vae.cuda()
        vae_optimizer = optim.Adam(vae.parameters(), lr=args.vae_lr)

        if args.experiment == 'a|z(s)':
            from util import Transition_S2S as Transition
        elif args.experiment == 'a|z(s, s_next)' or args.experiment == 'a|z(a_prev, s, s_next)':
            from util import Transition_S2SNext as Transition

        buffer = ReplayBuffer(args.buffer_capacity, Transition)

        update_vae = True

    if args.experiment == 'a|s':
        from util import Record_S
    elif args.experiment == 'a|z(s)':
        from util import Record_S2S
    elif args.experiment == 'a|z(s, s_next)' or args.experiment == 'a|z(a_prev, s, s_next)':
        from util import Record_S2SNext

    def train_actor_critic(n):
        saved_info = policy.saved_info

        R = 0
        cum_returns_ = []

        for r in policy.rewards[::-1]:
            R = r + args.gamma * R
            cum_returns_.insert(0, R)

        cum_returns = Tensor(cum_returns_)
        cum_returns = (cum_returns - cum_returns.mean()) \
                      / (cum_returns.std() + np.finfo(np.float32).eps)
        cum_returns = Variable(cum_returns, requires_grad=False).unsqueeze(1)

        batch_info = SavedInfo(*zip(*saved_info))
        batch_log_prob = torch.cat(batch_info.log_prob)
        batch_value = torch.cat(batch_info.value)

        batch_adv = cum_returns - batch_value
        policy_loss = -torch.sum(batch_log_prob * batch_adv)
        value_loss = F.smooth_l1_loss(batch_value,
                                      cum_returns,
                                      size_average=False)

        policy_optimizer.zero_grad()
        total_loss = policy_loss + value_loss
        total_loss.backward()
        policy_optimizer.step()

        if args.use_cuda:
            logger.scalar_summary('value_loss', value_loss.data.cpu()[0], n)
            logger.scalar_summary('policy_loss', policy_loss.data.cpu()[0], n)

            all_value_loss.append(value_loss.data.cpu()[0])
            all_policy_loss.append(policy_loss.data.cpu()[0])
        else:
            logger.scalar_summary('value_loss', value_loss.data[0], n)
            logger.scalar_summary('policy_loss', policy_loss.data[0], n)

            all_value_loss.append(value_loss.data[0])
            all_policy_loss.append(policy_loss.data[0])

        del policy.rewards[:]
        del policy.saved_info[:]

    if args.experiment != 'a|s':

        def train_vae(n):

            train_times = (n // args.vae_update_frequency -
                           1) * args.vae_update_times

            for i in range(args.vae_update_times):
                train_times += 1

                sample = buffer.sample(args.batch_size)
                batch = Transition(*zip(*sample))
                state_batch = torch.cat(batch.state)

                if args.experiment == 'a|z(s)':
                    recon_batch, mu, log_var = vae.forward(state_batch)

                    mse_loss, kl_loss = vae_loss_function(
                        recon_batch,
                        state_batch,
                        mu,
                        log_var,
                        logger,
                        train_times,
                        kl_discount=args.kl_weight,
                        mode=args.experiment)

                elif args.experiment == 'a|z(s, s_next)' or args.experiment == 'a|z(a_prev, s, s_next)':
                    next_state_batch = Variable(torch.cat(batch.next_state),
                                                requires_grad=False)
                    predicted_batch, mu, log_var = vae.forward(state_batch)
                    mse_loss, kl_loss = vae_loss_function(
                        predicted_batch,
                        next_state_batch,
                        mu,
                        log_var,
                        logger,
                        train_times,
                        kl_discount=args.kl_weight,
                        mode=args.experiment)

                vae_loss = mse_loss + kl_loss

                vae_optimizer.zero_grad()
                vae_loss.backward()
                vae_optimizer.step()

                logger.scalar_summary('vae_loss', vae_loss.data[0],
                                      train_times)
                all_vae_loss.append(vae_loss.data[0])
                all_mse_loss.append(mse_loss.data[0])
                all_kl_loss.append(kl_loss.data[0])

    # To store cum_reward, value_loss and policy_loss from each episode
    all_cum_reward = []
    all_last_hundred_average = []
    all_value_loss = []
    all_policy_loss = []

    if args.experiment != 'a|s':
        # Store each vae_loss calculated
        all_vae_loss = []
        all_mse_loss = []
        all_kl_loss = []

    for episode in count(1):
        done = False
        state_ = torch.Tensor([env.reset()])
        cum_reward = 0

        if args.experiment == 'a|z(a_prev, s, s_next)':
            action = random.randint(0, 2)
            state_, reward, done, info = env.step(action)
            cum_reward += reward
            state_ = torch.Tensor([np.append(state_, action)])

        while not done:
            if args.experiment == 'a|s':
                state = Variable(state_, requires_grad=False)
            elif args.experiment == 'a|z(s)' or args.experiment == 'a|z(s, s_next)' \
                    or args.experiment == 'a|z(a_prev, s, s_next)':
                state_ = Variable(state_, requires_grad=False)
                mu, log_var = vae.encode(state_)

                if args.bp2VAE and update_vae:
                    state = vae.reparametrize(mu, log_var)
                else:
                    state = vae.reparametrize(mu, log_var).detach()

            action_ = policy.select_action(state)

            if args.use_cuda:
                action = action_.cpu()[0, 0]
            else:
                action = action_[0, 0]

            next_state_, reward, done, info = env.step(action)
            next_state_ = torch.Tensor([next_state_])
            cum_reward += reward

            if args.render:
                env.render()

            policy.rewards.append(reward)

            if args.experiment == 'a|z(s)':
                buffer.push(state_)
            elif args.experiment == 'a|z(s, s_next)' or args.experiment == 'a|z(a_prev, s, s_next)':
                if not done:
                    buffer.push(state_, next_state_)

            if args.experiment == 'a|z(a_prev, s, s_next)':
                next_state_ = torch.cat(
                    [next_state_, torch.Tensor([action])], 1)

            state_ = next_state_

        train_actor_critic(episode)
        last_hundred_average = sum(all_cum_reward[-100:]) / 100

        logger.scalar_summary('cum_reward', cum_reward, episode)
        logger.scalar_summary('last_hundred_average', last_hundred_average,
                              episode)

        all_cum_reward.append(cum_reward)
        all_last_hundred_average.append(last_hundred_average)

        if update_vae:
            if args.experiment != 'a|s' and episode % args.vae_update_frequency == 0:
                assert len(buffer) >= args.batch_size
                train_vae(episode)

            if len(all_vae_loss) > 1000:
                if abs(
                        sum(all_vae_loss[-500:]) / 500 -
                        sum(all_vae_loss[-1000:-500]) /
                        500) < args.vae_update_threshold:
                    update_vae = False

        if episode % args.log_interval == 0:
            print(
                'Episode {}\tLast cum return: {:5f}\t100-episodes average cum return: {:.2f}'
                .format(episode, cum_reward, last_hundred_average))

        if episode > args.num_episodes:
            print("100-episodes average cum return is now {} and "
                  "the last episode runs to {} time steps!".format(
                      last_hundred_average, cum_reward))
            env.close()
            torch.save(policy, '/'.join([direc_name, 'model']))

            if args.experiment == 'a|s':
                record = Record_S(
                    policy_loss=all_policy_loss,
                    value_loss=all_value_loss,
                    cum_reward=all_cum_reward,
                    last_hundred_average=all_last_hundred_average)
            elif args.experiment == 'a|z(s)':
                record = Record_S2S(
                    policy_loss=all_policy_loss,
                    value_loss=all_value_loss,
                    cum_reward=all_cum_reward,
                    last_hundred_average=all_last_hundred_average,
                    mse_recon_loss=all_mse_loss,
                    kl_loss=all_kl_loss,
                    vae_loss=all_vae_loss)
            elif args.experiment == 'a|z(s, s_next)' or args.experiment == 'a|z(a_prev, s, s_next)':
                record = Record_S2SNext(
                    policy_loss=all_policy_loss,
                    value_loss=all_value_loss,
                    cum_reward=all_cum_reward,
                    last_hundred_average=all_last_hundred_average,
                    mse_pred_loss=all_mse_loss,
                    kl_loss=all_kl_loss,
                    vae_loss=all_vae_loss)

            pickle.dump(record, open('/'.join([direc_name, 'record']), 'wb'))

            break