コード例 #1
0
def main():
    args = get_args()
    

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.set_num_threads(1)

    device = torch.device(args.device)
    utils.cleanup_log_dir(args.log_dir)
    env_make = make_pybullet_env(args.task, log_dir=args.log_dir, frame_skip=args.frame_skip)
    envs = make_vec_envs(env_make, args.num_processes, args.log_dir, device, args.frame_stack)
    actor_critic = MetaPolicy(envs.observation_space, envs.action_space)
    loss_writer = LossWriter(args.log_dir, fieldnames= ('V_loss','action_loss','meta_action_loss','meta_value_loss','meta_loss', 'loss'))

    if args.restart_model:
        actor_critic.load_state_dict(
            torch.load(args.restart_model, map_location=device))

    actor_critic.to(device)

    agent = MetaPPO(
        actor_critic, args.clip_param, args.ppo_epoch,
        args.num_mini_batch, args.value_loss_coef,
        args.entropy_coef, lr=args.lr, eps=args.eps,
        max_grad_norm=args.max_grad_norm)

    obs = envs.reset()
    rollouts = RolloutStorage(args.num_steps, args.num_processes, obs, envs.action_space, actor_critic.recurrent_hidden_state_size)
    rollouts.to(device)  # they live in GPU, converted to torch from the env wrapper

    start = time.time()
    num_updates = int(args.num_env_steps) // args.num_steps // args.num_processes

    for j in range(num_updates):

        ppo_rollout(args.num_steps, envs, actor_critic, rollouts)

        value_loss, meta_value_loss, action_loss, meta_action_loss, loss, meta_loss = ppo_update(
            agent, actor_critic, rollouts, args.use_gae, args.gamma, args.gae_lambda)
        
        loss_writer.write_row({'V_loss': value_loss.item(), 'action_loss': action_loss.item(), 'meta_action_loss':meta_action_loss.item(),'meta_value_loss':meta_value_loss.item(),'meta_loss': meta_loss.item(), 'loss': loss.item()} )
        
        if (j % args.save_interval == 0 or j == num_updates - 1) and args.log_dir != "":
            ppo_save_model(actor_critic, os.path.join(args.log_dir, "model.state_dict"), j)

        if j % args.log_interval == 0:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            s = "Update {}, num timesteps {}, FPS {} \n".format(
                j, total_num_steps, int(total_num_steps / (time.time() - start)))
            s += "Loss {:.5f}, meta loss {:.5f}, value_loss {:.5f}, meta_value_loss {:.5f}, action_loss {:.5f}, meta action loss {:.5f}".format(
                loss.item(), meta_loss.item(), value_loss.item(), meta_value_loss.item(), action_loss.item(), meta_action_loss.item())
            print(s, flush=True)
コード例 #2
0
def evaluate(actor_critic, ob_rms, env_name, seed, num_processes, eval_log_dir,
             device):
    eval_envs = make_vec_envs(env_name, seed + num_processes, num_processes,
                              None, eval_log_dir, device, True)

    vec_norm = utils.get_vec_normalize(eval_envs)
    if vec_norm is not None:
        vec_norm.eval()
        vec_norm.ob_rms = ob_rms

    eval_episode_rewards = []

    obs = eval_envs.reset()
    eval_recurrent_hidden_states = torch.zeros(
        num_processes, actor_critic.recurrent_hidden_state_size, device=device)
    eval_masks = torch.zeros(num_processes, 1, device=device)

    while len(eval_episode_rewards) < 10:
        with torch.no_grad():
            _, action, _, eval_recurrent_hidden_states = actor_critic.act(
                obs,
                eval_recurrent_hidden_states,
                eval_masks,
                deterministic=True)

        # Obser reward and next obs
        obs, _, done, infos = eval_envs.step(action)

        eval_masks = torch.tensor([[0.0] if done_ else [1.0]
                                   for done_ in done],
                                  dtype=torch.float32,
                                  device=device)

        for info in infos:
            if 'episode' in info.keys():
                eval_episode_rewards.append(info['episode']['r'])

    eval_envs.close()

    print(" Evaluation using {} episodes: mean reward {:.5f}\n".format(
        len(eval_episode_rewards), np.mean(eval_episode_rewards)))
コード例 #3
0
parser.add_argument('--cnn',
                    default='Fixup',
                    help='Type of cnn. Options are CNN,Impala,Fixup,State')
parser.add_argument('--state-stack',
                    type=int,
                    default=4,
                    help='Number of steps to stack in states')
parser.add_argument('--task',
                    default='HalfCheetahPyBulletEnv-v0',
                    help='which of the pybullet task')

args = parser.parse_args()
args.det = not args.non_det
device = torch.device(args.device)
env_make = make_pybullet_env(args.task, frame_skip=args.frame_skip)
env = make_vec_envs(env_make, 1, None, device, args.frame_stack)

env.render(mode="human")
base_kwargs = {'recurrent': args.recurrent_policy}
actor_critic = MetaPolicy(env.observation_space, env.action_space)
if args.load_model:
    actor_critic.load_state_dict(
        torch.load(args.load_model, map_location=device))
actor_critic.to(device)

recurrent_hidden_states = torch.zeros(
    1, actor_critic.recurrent_hidden_state_size).to(device)
masks = torch.zeros(1, 1).to(device)

obs = env.reset()
fig = plt.figure()
def train_policy_embedding():
    """
    Script for training the dynamics (or environment) embeddings 
    using a transformer. 

    References:
    https://github.com/jadore801120/attention-is-all-you-need-pytorch/
    """
    args = get_args()
    os.environ['OMP_NUM_THREADS'] = '1'

    # Useful Variables
    best_eval_loss = sys.maxsize
    device = args.device
    if device != 'cpu':
        torch.cuda.empty_cache()

    # Create the Environment
    env = make_vec_envs(args, device)

    names = []
    for e in range(args.num_envs):
        for s in range(args.num_seeds):
            names.append('ppo.{}.env{}.seed{}.pt'.format(args.env_name, e, s))

    all_policies = []
    for name in names:
        actor_critic = Policy(env.observation_space.shape,
                              env.action_space,
                              base_kwargs={'recurrent': False})
        actor_critic.to(device)
        model = os.path.join(args.save_dir, name)
        actor_critic.load_state_dict(torch.load(model))
        all_policies.append(actor_critic)

    encoder_dim = args.num_attn_heads * args.policy_attn_head_dim
    enc_input_size = env.observation_space.shape[0] + env.action_space.shape[0]

    # Initialize the Transformer encoder and decoders
    encoder = embedding_networks.make_encoder_oh(enc_input_size, N=args.num_layers, \
                                                 d_model=encoder_dim, h=args.num_attn_heads, \
                                                 dropout=args.dropout, d_emb=args.policy_embedding_dim)

    decoder = Policy(tuple(
        [env.observation_space.shape[0] + args.policy_embedding_dim]),
                     env.action_space,
                     base_kwargs={'recurrent': False})

    embedding_networks.init_weights(encoder)
    embedding_networks.init_weights(decoder)

    encoder.train()
    decoder.train()

    encoder.to(device)
    decoder.to(device)

    # Loss and Optimizer
    criterion = nn.MSELoss(reduction='sum')

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=args.lr_policy)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=args.lr_policy)

    # Create the Environment
    env_sampler = env_utils.EnvSamplerEmb(env, all_policies, args)

    # Collect Train Data
    src_batch = []
    tgt_batch = []
    state_batch = []
    mask_batch = []
    mask_batch_all = []

    train_policies = [i for i in range(int(3 / 4 * args.num_envs))]
    train_envs = [i for i in range(int(3 / 4 * args.num_envs))]
    # For each policy in our dataset
    for pi in train_policies:
        # For each environment in our dataset
        for env in train_envs:
            # Sample a number of trajectories for this (policy, env) pair
            for _ in range(args.num_eps_policy):
                state_batch_t, tgt_batch_t, src_batch_t, mask_batch_t,\
                    mask_batch_all_t = env_sampler.sample_policy_data(\
                    policy_idx=pi, env_idx=env)
                state_batch.extend(state_batch_t)
                tgt_batch.extend(tgt_batch_t)
                src_batch.extend(src_batch_t)
                mask_batch.extend(mask_batch_t)
                mask_batch_all.extend(mask_batch_all_t)

    src_batch = torch.stack(src_batch)
    tgt_batch = torch.stack(tgt_batch).squeeze(1)
    state_batch = torch.stack(state_batch).squeeze(1)
    mask_batch = torch.stack(mask_batch)
    mask_batch_all = torch.stack(mask_batch_all)
    num_samples_train = src_batch.shape[0]

    # Collect Eval Data
    src_batch_eval = []
    tgt_batch_eval = []
    state_batch_eval = []
    mask_batch_eval = []
    mask_batch_all_eval = []

    eval_policies = [i for i in range(int(3 / 4 * args.num_envs))]
    eval_envs = [i for i in range(int(3 / 4 * args.num_envs))]
    # For each policy in our dataset
    for pi in eval_policies:
        # For each environment in our dataset
        for env in eval_envs:
            # Sample a number of trajectories for this (policy, env) pair
            for _ in range(args.num_eps_policy):
                state_batch_t, tgt_batch_t, src_batch_t, mask_batch_t, \
                    mask_batch_all_t = env_sampler.sample_policy_data(\
                    policy_idx=pi, env_idx=env)

                state_batch_eval.extend(state_batch_t)
                tgt_batch_eval.extend(tgt_batch_t)
                src_batch_eval.extend(src_batch_t)
                mask_batch_eval.extend(mask_batch_t)
                mask_batch_all_eval.extend(mask_batch_all_t)

    src_batch_eval = torch.stack(src_batch_eval).detach()
    tgt_batch_eval = torch.stack(tgt_batch_eval).squeeze(1).detach()
    state_batch_eval = torch.stack(state_batch_eval).squeeze(1).detach()
    mask_batch_eval = torch.stack(mask_batch_eval).detach()
    mask_batch_all_eval = torch.stack(mask_batch_all_eval).detach()
    num_samples_eval = src_batch_eval.shape[0]

    # Training Loop
    for epoch in range(args.num_epochs_emb + 1):
        encoder.train()
        decoder.train()

        indices = [i for i in range(num_samples_train)]
        random.shuffle(indices)
        total_counts = 0
        total_loss = 0
        num_correct_actions = 0
        for nmb in range(0, len(indices), args.policy_batch_size):
            indices_mb = indices[nmb:nmb + args.policy_batch_size]

            source = src_batch[indices_mb].to(device)
            target = tgt_batch[indices_mb].to(device)
            state = state_batch[indices_mb].to(device).float()
            mask = mask_batch[indices_mb].to(device)
            mask_all = mask_batch_all[indices_mb].squeeze(2).unsqueeze(1).to(
                device)

            embedding = encoder(source.detach().to(device),
                                mask_all.detach().to(device))
            embedding = F.normalize(embedding, p=2, dim=1)

            state *= mask.to(device)
            embedding *= mask.to(device)
            recurrent_hidden_state = torch.zeros(
                args.policy_batch_size,
                decoder.recurrent_hidden_state_size,
                device=device,
                requires_grad=True).float()
            mask_dec = torch.zeros(args.policy_batch_size,
                                   1,
                                   device=device,
                                   requires_grad=True).float()
            emb_state_input = torch.cat((embedding, state.to(device)),
                                        dim=1).to(device)
            action = decoder(emb_state_input, recurrent_hidden_state, mask_dec)
            action *= mask.to(device)
            target *= mask

            loss = criterion(action, target.to(device))
            total_loss += loss.item()
            total_counts += len(indices_mb)

            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()
            loss.backward()
            encoder_optimizer.step()
            decoder_optimizer.step()

        if epoch % args.log_interval == 0:
            avg_loss = total_loss / total_counts
            print("\n# Epoch %d: Train Loss = %.6f " % (epoch + 1, avg_loss))

        # Evaluation
        encoder.eval()
        decoder.eval()

        indices_eval = [i for i in range(num_samples_eval)]
        total_counts_eval = 0
        total_loss_eval = 0
        num_correct_actions_eval = 0
        for nmb in range(0, len(indices_eval), args.policy_batch_size):

            indices_mb_eval = indices_eval[nmb:nmb + args.policy_batch_size]

            source_eval = src_batch_eval[indices_mb_eval].to(device).detach()
            target_eval = tgt_batch_eval[indices_mb_eval].to(device).detach()
            state_eval = state_batch_eval[indices_mb_eval].float().to(
                device).detach()
            mask_eval = mask_batch_eval[indices_mb_eval].to(device).detach()
            mask_all_eval = mask_batch_all_eval[indices_mb_eval].squeeze(
                2).unsqueeze(1).to(device).detach()

            embedding_eval = encoder(
                source_eval.detach().to(device),
                mask_all_eval.detach().to(device)).detach()
            embedding_eval = F.normalize(embedding_eval, p=2, dim=1).detach()

            state_eval *= mask_eval.to(device).detach()
            embedding_eval *= mask_eval.to(device).detach()
            recurrent_hidden_state_eval = torch.zeros(
                args.policy_batch_size,
                decoder.recurrent_hidden_state_size,
                device='cpu').float()
            mask_dec_eval = torch.zeros(args.policy_batch_size,
                                        1,
                                        device='cpu').float()
            emb_state_input_eval = torch.cat(
                (embedding_eval, state_eval.to(device)), dim=1)
            action_eval = decoder(emb_state_input_eval,
                                  recurrent_hidden_state_eval,
                                  mask_dec_eval,
                                  deterministic=True)
            action_eval *= mask_eval.to(device)
            target_eval *= mask_eval

            loss_eval = criterion(action_eval, target_eval.to(device))

            total_loss_eval += loss_eval.item()
            total_counts_eval += len(indices_mb_eval)

        avg_loss_eval = total_loss_eval / total_counts_eval

        # Save the models
        if avg_loss_eval <= best_eval_loss:
            best_eval_loss = avg_loss_eval
            pdvf_utils.save_model("policy-encoder.", encoder, encoder_optimizer, \
                             epoch, args, args.env_name, save_dir=args.save_dir_policy_embedding)
            pdvf_utils.save_model("policy-decoder.", decoder, decoder_optimizer, \
                             epoch, args, args.env_name, save_dir=args.save_dir_policy_embedding)

        if epoch % args.log_interval == 0:
            print("# Epoch %d: Eval Loss = %.6f " % (epoch + 1, avg_loss_eval))
コード例 #5
0
def train_pdvf():
    '''
    Train the Policy-Dynamics Value Function of PD-VF
    which estimates the return for a family of policies 
    in a family of environments with varying dynamics.

    To do this, it trains a network conditioned on an initial state,
    a (learned) policy embedding, and a (learned) dynamics embedding
    and outputs an estimate of the cumulative reward of the 
    corresponding policy in the given environment. 
    '''
    args = get_args()
    os.environ['OMP_NUM_THREADS'] = '1'

    device = args.device
    if device != 'cpu':
        torch.cuda.empty_cache()

    # Create the environment
    envs = make_vec_envs(args, device)

    if args.seed:
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    names = []
    for e in range(args.num_envs):
        for s in range(args.num_seeds):
            names.append('ppo.{}.env{}.seed{}.pt'.format(args.env_name, e, s))

    all_policies = []
    for name in names:
        actor_critic = Policy(envs.observation_space.shape,
                              envs.action_space,
                              base_kwargs={'recurrent': False})
        actor_critic.to(device)
        model = os.path.join(args.save_dir, name)
        actor_critic.load_state_dict(torch.load(model))
        all_policies.append(actor_critic)

    # Load the collected interaction episodes for each agent
    policy_encoder, policy_decoder = pdvf_utils.load_policy_model(args, envs)
    env_encoder = pdvf_utils.load_dynamics_model(args, envs)

    policy_decoder.train()
    decoder_optimizer = optim.Adam(policy_decoder.parameters(),
                                   lr=args.lr_policy)
    decoder_optimizer2 = optim.Adam(policy_decoder.parameters(),
                                    lr=args.lr_policy)
    decoder_network = {'policy_decoder': policy_decoder, \
                        'decoder_optimizer': decoder_optimizer, \
                        'decoder_optimizer2': decoder_optimizer2}

    # Instantiate the PD-VF, Optimizer and Loss
    args.use_l2_loss = True
    value_net = PDVF(envs.observation_space.shape[0],
                     args.dynamics_embedding_dim,
                     args.hidden_dim_pdvf,
                     args.policy_embedding_dim,
                     device=device).to(device)
    optimizer = optim.Adam(value_net.parameters(),
                           lr=args.lr_pdvf,
                           eps=args.eps)
    optimizer2 = optim.Adam(value_net.parameters(),
                            lr=args.lr_pdvf,
                            eps=args.eps)
    network = {
        'net': value_net,
        'optimizer': optimizer,
        'optimizer2': optimizer2
    }
    value_net.train()

    train_policies = [i for i in range(int(3 / 4 * args.num_envs))]
    train_envs = [i for i in range(int(3 / 4 * args.num_envs))]
    eval_envs = [i for i in range(int(3 / 4 * args.num_envs), args.num_envs)]

    all_envs = [i for i in range(args.num_envs)]

    NUM_STAGES = args.num_stages
    NUM_TRAIN_EPS = args.num_train_eps
    NUM_TRAIN_SAMPLES = NUM_TRAIN_EPS * len(train_policies) * len(train_envs)
    NUM_EVAL_EPS = args.num_eval_eps
    NUM_EVAL_SAMPLES = NUM_EVAL_EPS * len(train_policies) * len(train_envs)

    env_enc_input_size = 2 * envs.observation_space.shape[
        0] + args.policy_embedding_dim
    sizes = pdvf_utils.DotDict({'state_dim': envs.observation_space.shape[0], \
        'action_dim': envs.action_space.shape[0], 'env_enc_input_size': env_enc_input_size, \
        'env_max_seq_length': args.max_num_steps * env_enc_input_size})

    env_sampler = env_utils.EnvSamplerPDVF(envs, all_policies, args)
    decoder_env_sampler = env_utils.EnvSamplerEmb(envs, all_policies, args)

    ####################    TRAIN PHASE 1      ########################
    # Collect Eval Data for First Training Stage
    eval_memory = ReplayMemoryPDVF(NUM_EVAL_SAMPLES)
    decoder_eval_memory = ReplayMemoryPolicyDecoder(NUM_EVAL_SAMPLES)
    for i in range(NUM_EVAL_EPS):
        for ei in train_envs:
            for pi in train_policies:
                init_obs = torch.FloatTensor(env_sampler.env.reset(env_id=ei))
                if 'ant' in args.env_name or 'swimmer' in args.env_name:
                    init_state = env_sampler.env.sim.get_state()
                    res = env_sampler.zeroshot_sample_src_from_pol_state_mujoco(
                        args, init_obs, sizes, policy_idx=pi, env_idx=ei)
                else:
                    init_state = env_sampler.env.state
                    res = env_sampler.zeroshot_sample_src_from_pol_state(
                        args, init_obs, sizes, policy_idx=pi, env_idx=ei)

                source_env = res['source_env']
                mask_env = res['mask_env']
                mask_policy = res['mask_policy']
                source_policy = res['source_policy']
                episode_reward = res['episode_reward']
                episode_reward_tensor = torch.tensor([episode_reward],
                                                     device=device,
                                                     dtype=torch.float)

                if source_policy.shape[1] == 1:
                    source_policy = source_policy.repeat(1, 2, 1)
                    mask_policy = mask_policy.repeat(1, 1, 2)
                emb_policy = policy_encoder(
                    source_policy.detach().to(device),
                    mask_policy.detach().to(device)).detach()
                if source_env.shape[1] == 1:
                    source_env = source_env.repeat(1, 2, 1)
                    mask_env = mask_env.repeat(1, 1, 2)
                emb_env = env_encoder(source_env.detach().to(device),
                                      mask_env.detach().to(device)).detach()

                emb_policy = F.normalize(emb_policy, p=2, dim=1).detach()
                emb_env = F.normalize(emb_env, p=2, dim=1).detach()

                pred_value = value_net(
                    init_obs.unsqueeze(0).to(device), emb_env.to(device),
                    emb_policy.to(device)).item()
                if 'ant' in args.env_name or 'swimmer' in args.env_name:
                    decoded_reward = env_sampler.get_reward_pol_embedding_state_mujoco(
                        args,
                        init_state,
                        init_obs,
                        emb_policy,
                        policy_decoder,
                        env_idx=ei)[0]
                else:
                    decoded_reward = env_sampler.get_reward_pol_embedding_state(
                        args,
                        init_state,
                        init_obs,
                        emb_policy,
                        policy_decoder,
                        env_idx=ei)[0]
                decoded_reward_tensor = torch.tensor([decoded_reward],
                                                     device=device,
                                                     dtype=torch.float)

                eval_memory.push(init_obs.unsqueeze(0),
                                 emb_policy.unsqueeze(0), emb_env.unsqueeze(0),
                                 episode_reward_tensor)

                # Collect data for the decoder
                state_batch, tgt_batch, src_batch, mask_batch, _ = \
                    decoder_env_sampler.sample_policy_data(policy_idx=pi, env_idx=ei)

                for state, tgt, src, mask in zip(state_batch, tgt_batch,
                                                 src_batch, mask_batch):
                    state = state.to(device).float()
                    mask = mask.to(device)

                    state *= mask.to(device).detach()
                    emb_policy *= mask.to(device).detach()
                    recurrent_state = torch.zeros(
                        state.shape[0],
                        policy_decoder.recurrent_hidden_state_size,
                        device=args.device).float()
                    mask_dec = torch.zeros(state.shape[0],
                                           1,
                                           device=args.device).float()
                    emb_state = torch.cat((emb_policy, state.to(device)),
                                          dim=1)
                    action = policy_decoder(emb_state,
                                            recurrent_state,
                                            mask_dec,
                                            deterministic=True)
                    action *= mask.to(device)
                    decoder_eval_memory.push(emb_state, recurrent_state,
                                             mask_dec, action)

    # Collect Train Data for Frist Training Stage
    memory = ReplayMemoryPDVF(NUM_TRAIN_SAMPLES)
    decoder_memory = ReplayMemoryPolicyDecoder(NUM_TRAIN_SAMPLES)
    for i in range(NUM_TRAIN_EPS):
        for ei in train_envs:
            for pi in train_policies:
                init_obs = torch.FloatTensor(env_sampler.env.reset(env_id=ei))
                if 'ant' in args.env_name or 'swimmer' in args.env_name:
                    init_state = env_sampler.env.sim.get_state()
                    res = env_sampler.zeroshot_sample_src_from_pol_state_mujoco(
                        args, init_obs, sizes, policy_idx=pi, env_idx=ei)
                else:
                    init_state = env_sampler.env.state
                    res = env_sampler.zeroshot_sample_src_from_pol_state(
                        args, init_obs, sizes, policy_idx=pi, env_idx=ei)

                source_env = res['source_env']
                mask_env = res['mask_env']
                source_policy = res['source_policy']
                mask_policy = res['mask_policy']
                episode_reward = res['episode_reward']

                episode_reward_tensor = torch.tensor([episode_reward],
                                                     device=device,
                                                     dtype=torch.float)

                if source_policy.shape[1] == 1:
                    source_policy = source_policy.repeat(1, 2, 1)
                    mask_policy = mask_policy.repeat(1, 1, 2)
                emb_policy = policy_encoder(
                    source_policy.detach().to(device),
                    mask_policy.detach().to(device)).detach()
                if source_env.shape[1] == 1:
                    source_env = source_env.repeat(1, 2, 1)
                    mask_env = mask_env.repeat(1, 1, 2)
                emb_env = env_encoder(source_env.detach().to(device),
                                      mask_env.detach().to(device)).detach()

                emb_policy = F.normalize(emb_policy, p=2, dim=1).detach()
                emb_env = F.normalize(emb_env, p=2, dim=1).detach()

                pred_value = value_net(
                    init_obs.unsqueeze(0).to(device), emb_env.to(device),
                    emb_policy.to(device)).item()
                if 'ant' in args.env_name or 'swimmer' in args.env_name:
                    decoded_reward = env_sampler.get_reward_pol_embedding_state_mujoco(
                        args,
                        init_state,
                        init_obs,
                        emb_policy,
                        policy_decoder,
                        env_idx=ei)[0]
                else:
                    decoded_reward = env_sampler.get_reward_pol_embedding_state(
                        args,
                        init_state,
                        init_obs,
                        emb_policy,
                        policy_decoder,
                        env_idx=ei)[0]
                decoded_reward_tensor = torch.tensor([decoded_reward],
                                                     device=device,
                                                     dtype=torch.float)

                memory.push(init_obs.unsqueeze(0), emb_policy.unsqueeze(0),
                            emb_env.unsqueeze(0), episode_reward_tensor)

                # Collect data for the decoder
                state_batch, tgt_batch, src_batch, mask_batch, _ = \
                    decoder_env_sampler.sample_policy_data(policy_idx=pi, env_idx=ei)

                for state, tgt, src, mask in zip(state_batch, tgt_batch,
                                                 src_batch, mask_batch):
                    state = state.to(device).float()
                    mask = mask.to(device)

                    state *= mask.to(device).detach()
                    emb_policy *= mask.to(device).detach()
                    recurrent_state = torch.zeros(
                        state.shape[0],
                        policy_decoder.recurrent_hidden_state_size,
                        device=args.device).float()
                    mask_dec = torch.zeros(state.shape[0],
                                           1,
                                           device=args.device).float()
                    emb_state = torch.cat((emb_policy, state.to(device)),
                                          dim=1)
                    action = policy_decoder(emb_state,
                                            recurrent_state,
                                            mask_dec,
                                            deterministic=True)
                    action *= mask.to(device)
                    decoder_memory.push(emb_state, recurrent_state, mask_dec,
                                        action)

    ### Train - Stage 1 ###
    total_train_loss = 0
    total_eval_loss = 0
    BEST_EVAL_LOSS = sys.maxsize

    decoder_total_train_loss = 0
    decoder_total_eval_loss = 0
    DECODER_BEST_EVAL_LOSS = sys.maxsize
    print("\nFirst Training Stage")
    for i in range(args.num_epochs_pdvf_phase1):
        train_loss = train_utils.optimize_model_pdvf(
            args, network, memory, num_opt_steps=args.num_opt_steps)
        if train_loss:
            total_train_loss += train_loss

        eval_loss = train_utils.optimize_model_pdvf(
            args,
            network,
            eval_memory,
            num_opt_steps=args.num_opt_steps,
            eval=True)

        if eval_loss:
            total_eval_loss += eval_loss
            if eval_loss < BEST_EVAL_LOSS:
                BEST_EVAL_LOSS = eval_loss
                pdvf_utils.save_model("pdvf-stage0.", value_net, optimizer, \
                                i, args, args.env_name, save_dir=args.save_dir_pdvf)

        if i % args.log_interval == 0:
            print("\n### PD-VF: Episode {}: Train Loss {:.6f} Eval Loss {:.6f}".format( \
                i, total_train_loss / args.log_interval, total_eval_loss / args.log_interval))
            total_train_loss = 0
            total_eval_loss = 0

        # Train the Policy Decoder on mixed data
        # from trajectories collected using the pretrained policies
        # and decoded trajectories by the current decoder
        decoder_train_loss = train_utils.optimize_decoder(
            args,
            decoder_network,
            decoder_memory,
            num_opt_steps=args.num_opt_steps)
        if decoder_train_loss:
            decoder_total_train_loss += decoder_train_loss

        decoder_eval_loss = train_utils.optimize_decoder(
            args,
            decoder_network,
            decoder_eval_memory,
            num_opt_steps=args.num_opt_steps,
            eval=True)

        if decoder_eval_loss:
            decoder_total_eval_loss += decoder_eval_loss

            if decoder_eval_loss < DECODER_BEST_EVAL_LOSS:
                DECODER_BEST_EVAL_LOSS = decoder_eval_loss
                pdvf_utils.save_model("policy-decoder-stage0.", policy_decoder, decoder_optimizer, \
                                i, args, args.env_name, save_dir=args.save_dir_pdvf)

        if i % args.log_interval == 0:
            print("### PolicyDecoder: Episode {}: Train Loss {:.6f} Eval Loss {:.6f}".format( \
                i, decoder_total_train_loss / args.log_interval, decoder_total_eval_loss / args.log_interval))
            decoder_total_train_loss = 0
            decoder_total_eval_loss = 0

    ####################    TRAIN PHASE 2    ########################
    for k in range(NUM_STAGES):
        print("Stage in Second Training Phase: ", k)
        # Collect  Eval Data for Second Training Stage
        eval_memory2 = ReplayMemoryPDVF(NUM_EVAL_SAMPLES)
        decoder_eval_memory2 = ReplayMemoryPolicyDecoder(NUM_EVAL_SAMPLES)
        for i in range(NUM_EVAL_EPS):
            for ei in train_envs:
                for pi in train_policies:
                    init_obs = torch.FloatTensor(
                        env_sampler.env.reset(env_id=ei))
                    if 'ant' in args.env_name or 'swimmer' in args.env_name:
                        init_state = env_sampler.env.sim.get_state()
                        res = env_sampler.zeroshot_sample_src_from_pol_state_mujoco(
                            args, init_obs, sizes, policy_idx=pi, env_idx=ei)
                    else:
                        init_state = env_sampler.env.state
                        res = env_sampler.zeroshot_sample_src_from_pol_state(
                            args, init_obs, sizes, policy_idx=pi, env_idx=ei)

                    source_env = res['source_env']
                    mask_env = res['mask_env']
                    source_policy = res['source_policy']
                    mask_policy = res['mask_policy']
                    init_episode_reward = res['episode_reward']

                    if source_policy.shape[1] == 1:
                        source_policy = source_policy.repeat(1, 2, 1)
                        mask_policy = mask_policy.repeat(1, 1, 2)
                    emb_policy = policy_encoder(
                        source_policy.detach().to(device),
                        mask_policy.detach().to(device)).detach()
                    if source_env.shape[1] == 1:
                        source_env = source_env.repeat(1, 2, 1)
                        mask_env = mask_env.repeat(1, 1, 2)
                    emb_env = env_encoder(
                        source_env.detach().to(device),
                        mask_env.detach().to(device)).detach()

                    emb_policy = F.normalize(emb_policy, p=2, dim=1).detach()
                    emb_env = F.normalize(emb_env, p=2, dim=1).detach()

                    qf = value_net.get_qf(
                        init_obs.unsqueeze(0).to(device), emb_env)
                    u, s, v = torch.svd(qf.squeeze())

                    opt_policy_pos = u[:, 0].unsqueeze(0)
                    opt_policy_neg = -u[:, 0].unsqueeze(0)

                    if 'ant' in args.env_name or 'swimmer' in args.env_name:
                        episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state_mujoco(
                            args,
                            init_state,
                            init_obs,
                            opt_policy_pos,
                            policy_decoder,
                            env_idx=ei)
                        episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state_mujoco(
                            args,
                            init_state,
                            init_obs,
                            opt_policy_neg,
                            policy_decoder,
                            env_idx=ei)
                    else:
                        episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state(
                            args,
                            init_state,
                            init_obs,
                            opt_policy_pos,
                            policy_decoder,
                            env_idx=ei)
                        episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state(
                            args,
                            init_state,
                            init_obs,
                            opt_policy_neg,
                            policy_decoder,
                            env_idx=ei)

                    if episode_reward_pos >= episode_reward_neg:
                        episode_reward = episode_reward_pos
                        opt_policy = opt_policy_pos
                    else:
                        episode_reward = episode_reward_neg
                        opt_policy = opt_policy_neg

                    episode_reward_tensor = torch.tensor([episode_reward],
                                                         device=device,
                                                         dtype=torch.float)

                    eval_memory2.push(init_obs.unsqueeze(0),
                                      opt_policy.unsqueeze(0),
                                      emb_env.unsqueeze(0),
                                      episode_reward_tensor)

                    if 'ant' in args.env_name or 'swimmer' in args.env_name:
                        all_emb_state, all_recurrent_state, all_mask, all_action = \
                            decoder_env_sampler.get_decoded_traj_mujoco(args, init_state, \
                                init_obs, opt_policy_pos, policy_decoder, env_idx=ei)
                    else:
                        all_emb_state, all_recurrent_state, all_mask, all_action = \
                            decoder_env_sampler.get_decoded_traj(args, init_state, \
                                init_obs, opt_policy_pos, policy_decoder, env_idx=ei)

                    for e, r, m, a in zip(all_emb_state, all_recurrent_state,
                                          all_mask, all_action):
                        decoder_eval_memory2.push(e, r, m, a)

        # Collect  Train Data for Second Training Stage
        memory2 = ReplayMemoryPDVF(NUM_TRAIN_SAMPLES)
        decoder_memory2 = ReplayMemoryPolicyDecoder(NUM_TRAIN_SAMPLES)
        for i in range(NUM_TRAIN_EPS):
            for ei in train_envs:
                for pi in train_policies:
                    init_obs = torch.FloatTensor(
                        env_sampler.env.reset(env_id=ei))
                    if 'ant' in args.env_name or 'swimmer' in args.env_name:
                        init_state = env_sampler.env.sim.get_state()
                        res = env_sampler.zeroshot_sample_src_from_pol_state_mujoco(
                            args, init_obs, sizes, policy_idx=pi, env_idx=ei)
                    else:
                        init_state = env_sampler.env.state
                        res = env_sampler.zeroshot_sample_src_from_pol_state(
                            args, init_obs, sizes, policy_idx=pi, env_idx=ei)

                    source_env = res['source_env']
                    mask_env = res['mask_env']
                    source_policy = res['source_policy']
                    mask_policy = res['mask_policy']
                    init_episode_reward = res['episode_reward']

                    if source_policy.shape[1] == 1:
                        source_policy = source_policy.repeat(1, 2, 1)
                        mask_policy = mask_policy.repeat(1, 1, 2)
                    emb_policy = policy_encoder(
                        source_policy.detach().to(device),
                        mask_policy.detach().to(device)).detach()
                    if source_env.shape[1] == 1:
                        source_env = source_env.repeat(1, 2, 1)
                        mask_env = mask_env.repeat(1, 1, 2)
                    emb_env = env_encoder(
                        source_env.detach().to(device),
                        mask_env.detach().to(device)).detach()

                    emb_policy = F.normalize(emb_policy, p=2, dim=1).detach()
                    emb_env = F.normalize(emb_env, p=2, dim=1).detach()

                    qf = value_net.get_qf(
                        init_obs.unsqueeze(0).to(device), emb_env)
                    u, s, v = torch.svd(qf.squeeze())

                    opt_policy_pos = u[:, 0].unsqueeze(0)
                    opt_policy_neg = -u[:, 0].unsqueeze(0)

                    # include both solutions (positive and negative policy emb) in the train data
                    # to enforce the correct shape and make it aware of the two
                    if 'ant' in args.env_name or 'swimmer' in args.env_name:
                        episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state_mujoco(
                            args,
                            init_state,
                            init_obs,
                            opt_policy_pos,
                            policy_decoder,
                            env_idx=ei)
                        episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state_mujoco(
                            args,
                            init_state,
                            init_obs,
                            opt_policy_neg,
                            policy_decoder,
                            env_idx=ei)
                    else:
                        episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state(
                            args,
                            init_state,
                            init_obs,
                            opt_policy_pos,
                            policy_decoder,
                            env_idx=ei)
                        episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state(
                            args,
                            init_state,
                            init_obs,
                            opt_policy_neg,
                            policy_decoder,
                            env_idx=ei)

                    episode_reward_tensor_pos = torch.tensor(
                        [episode_reward_pos], device=device, dtype=torch.float)
                    episode_reward_tensor_neg = torch.tensor(
                        [episode_reward_neg], device=device, dtype=torch.float)

                    memory2.push(init_obs.unsqueeze(0),
                                 opt_policy_pos.unsqueeze(0),
                                 emb_env.unsqueeze(0),
                                 episode_reward_tensor_pos)
                    memory2.push(init_obs.unsqueeze(0),
                                 opt_policy_neg.unsqueeze(0),
                                 emb_env.unsqueeze(0),
                                 episode_reward_tensor_neg)

                    # collect PolicyDecoder train data for second training stage
                    if 'ant' in args.env_name or 'swimmer' in args.env_name:
                        all_emb_state, all_recurrent_state, all_mask, all_action = \
                            decoder_env_sampler.get_decoded_traj_mujoco(args, init_state, \
                                init_obs, opt_policy_pos, policy_decoder, env_idx=ei)
                    else:
                        all_emb_state, all_recurrent_state, all_mask, all_action = \
                            decoder_env_sampler.get_decoded_traj(args, init_state, \
                                init_obs, opt_policy_pos, policy_decoder, env_idx=ei)

                    for e, r, m, a in zip(all_emb_state, all_recurrent_state,
                                          all_mask, all_action):
                        decoder_memory2.push(e, r, m, a)

        ### Train - Stage 2 ###
        total_train_loss = 0
        total_eval_loss = 0
        BEST_EVAL_LOSS = sys.maxsize

        decoder_total_train_loss = 0
        decoder_total_eval_loss = 0
        DECODER_BEST_EVAL_LOSS = sys.maxsize
        for i in range(args.num_epochs_pdvf_phase2):
            train_loss = train_utils.optimize_model_pdvf_phase2(
                args,
                network,
                memory,
                memory2,
                num_opt_steps=args.num_opt_steps)
            if train_loss:
                total_train_loss += train_loss

            eval_loss = train_utils.optimize_model_pdvf_phase2(
                args,
                network,
                eval_memory,
                eval_memory2,
                num_opt_steps=args.num_opt_steps,
                eval=True)

            if eval_loss:
                total_eval_loss += eval_loss

                if eval_loss < BEST_EVAL_LOSS:
                    BEST_EVAL_LOSS = eval_loss
                    pdvf_utils.save_model("pdvf-stage{}.".format(k+1), value_net, optimizer, \
                                    i, args, args.env_name, save_dir=args.save_dir_pdvf)

            if i % args.log_interval == 0:
                print("\n### PDVF: Stage {} -- Episode {}: Train Loss {:.6f} Eval Loss {:.6f}".format( \
                    k, i, total_train_loss / args.log_interval, total_eval_loss / args.log_interval))
                total_train_loss = 0
                total_eval_loss = 0

            # Train the Policy Decoder on mixed data
            # from trajectories collected using the pretrained policies
            # and decoded trajectories by the current decoder
            decoder_train_loss = train_utils.optimize_decoder_phase2(
                args,
                decoder_network,
                decoder_memory,
                decoder_memory2,
                num_opt_steps=args.num_opt_steps)
            if decoder_train_loss:
                decoder_total_train_loss += decoder_train_loss

            decoder_eval_loss = train_utils.optimize_decoder_phase2(
                args,
                decoder_network,
                decoder_eval_memory,
                decoder_eval_memory2,
                num_opt_steps=args.num_opt_steps,
                eval=True)
            if decoder_eval_loss:
                decoder_total_eval_loss += decoder_eval_loss

                if decoder_eval_loss < DECODER_BEST_EVAL_LOSS:
                    DECODER_BEST_EVAL_LOSS = decoder_eval_loss
                    pdvf_utils.save_model("policy-decoder-stage{}.".format(k+1), policy_decoder, decoder_optimizer, \
                                    i, args, args.env_name, save_dir=args.save_dir_pdvf)

            if i % args.log_interval == 0:
                print("### PolicyDecoder: Stage {} -- Episode {}: Train Loss {:.6f} Eval Loss {:.6f}".format( \
                    k, i, decoder_total_train_loss / args.log_interval, decoder_total_eval_loss / args.log_interval))
                decoder_total_train_loss = 0
                decoder_total_eval_loss = 0

    ####################         EVAL        ########################

    # Eval on Train Envs
    value_net.eval()
    policy_decoder.eval()
    train_rewards = {}
    unnorm_train_rewards = {}
    for ei in range(len(all_envs)):
        train_rewards[ei] = []
        unnorm_train_rewards[ei] = []
    for i in range(NUM_EVAL_EPS):
        for ei in train_envs:
            for pi in train_policies:
                init_obs = torch.FloatTensor(env_sampler.env.reset(env_id=ei))
                if 'ant' in args.env_name or 'swimmer' in args.env_name:
                    init_state = env_sampler.env.sim.get_state()
                    res = env_sampler.zeroshot_sample_src_from_pol_state_mujoco(
                        args, init_obs, sizes, policy_idx=pi, env_idx=ei)
                else:
                    init_state = env_sampler.env.state
                    res = env_sampler.zeroshot_sample_src_from_pol_state(
                        args, init_obs, sizes, policy_idx=pi, env_idx=ei)

                source_env = res['source_env']
                mask_env = res['mask_env']
                source_policy = res['source_policy']
                init_episode_reward = res['episode_reward']
                mask_policy = res['mask_policy']

                if source_policy.shape[1] == 1:
                    source_policy = source_policy.repeat(1, 2, 1)
                    mask_policy = mask_policy.repeat(1, 1, 2)
                emb_policy = policy_encoder(
                    source_policy.detach().to(device),
                    mask_policy.detach().to(device)).detach()
                if source_env.shape[1] == 1:
                    source_env = source_env.repeat(1, 2, 1)
                    mask_env = mask_env.repeat(1, 1, 2)
                emb_env = env_encoder(source_env.detach().to(device),
                                      mask_env.detach().to(device)).detach()

                emb_policy = F.normalize(emb_policy, p=2, dim=1).detach()
                emb_env = F.normalize(emb_env, p=2, dim=1).detach()

                pred_value = value_net(
                    init_obs.unsqueeze(0).to(device), emb_env.to(device),
                    emb_policy.to(device)).item()
                if 'ant' in args.env_name or 'swimmer' in args.env_name:
                    decoded_reward = env_sampler.get_reward_pol_embedding_state_mujoco(
                        args,
                        init_state,
                        init_obs,
                        emb_policy,
                        policy_decoder,
                        env_idx=ei)[0]
                else:
                    decoded_reward = env_sampler.get_reward_pol_embedding_state(
                        args,
                        init_state,
                        init_obs,
                        emb_policy,
                        policy_decoder,
                        env_idx=ei)[0]

                qf = value_net.get_qf(
                    init_obs.unsqueeze(0).to(device), emb_env)
                u, s, v = torch.svd(qf.squeeze())

                opt_policy_pos = u[:, 0].unsqueeze(0)
                opt_policy_neg = -u[:, 0].unsqueeze(0)

                if 'ant' in args.env_name or 'swimmer' in args.env_name:
                    episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state_mujoco(
                        args,
                        init_state,
                        init_obs,
                        opt_policy_pos,
                        policy_decoder,
                        env_idx=ei)
                    episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state_mujoco(
                        args,
                        init_state,
                        init_obs,
                        opt_policy_neg,
                        policy_decoder,
                        env_idx=ei)
                else:
                    episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state(
                        args,
                        init_state,
                        init_obs,
                        opt_policy_pos,
                        policy_decoder,
                        env_idx=ei)
                    episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state(
                        args,
                        init_state,
                        init_obs,
                        opt_policy_neg,
                        policy_decoder,
                        env_idx=ei)

                if episode_reward_pos >= episode_reward_neg:
                    episode_reward = episode_reward_pos
                    opt_policy = opt_policy_pos
                else:
                    episode_reward = episode_reward_neg
                    opt_policy = opt_policy_neg

                unnorm_episode_reward = episode_reward * (
                    args.max_reward - args.min_reward) + args.min_reward
                unnorm_init_episode_reward = init_episode_reward * (
                    args.max_reward - args.min_reward) + args.min_reward
                unnorm_decoded_reward = decoded_reward * (
                    args.max_reward - args.min_reward) + args.min_reward

                unnorm_train_rewards[ei].append(unnorm_episode_reward)
                train_rewards[ei].append(episode_reward)
                if i % args.log_interval == 0:
                    if 'ant' in args.env_name or 'swimmer' in args.env_name:
                        print(
                            f"\nTrain Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- reward after update: {unnorm_episode_reward: .3f}"
                        )
                        print(
                            f"Initial Policy: {pi} --- init true reward: {unnorm_init_episode_reward: .3f} --- decoded: {unnorm_decoded_reward: .3f} --- predicted: {pred_value: .3f}"
                        )
                    print(
                        f"Train Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- norm reward after update: {episode_reward: .3f}"
                    )
                    print(
                        f"Initial Policy: {pi} --- norm init true reward: {init_episode_reward: .3f} --- norm decoded: {decoded_reward: .3f} --- predicted: {pred_value: .3f}"
                    )

    for ei in train_envs:
        if 'ant' in args.env_name or 'swimmer' in args.env_name:
            print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\
                .format(ei, np.mean(unnorm_train_rewards[ei]), np.std(unnorm_train_rewards[ei])))
        else:
            print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\
                .format(ei, np.mean(train_rewards[ei]), np.std(train_rewards[ei])))

    # Eval on Eval Envs
    value_net.eval()
    policy_decoder.eval()
    eval_rewards = {}
    unnorm_eval_rewards = {}
    for ei in range(len(all_envs)):
        eval_rewards[ei] = []
        unnorm_eval_rewards[ei] = []

    for i in range(NUM_EVAL_EPS):
        for ei in eval_envs:
            for pi in train_policies:
                init_obs = torch.FloatTensor(env_sampler.env.reset(env_id=ei))
                if 'ant' in args.env_name or 'swimmer' in args.env_name:
                    init_state = env_sampler.env.sim.get_state()
                    res = env_sampler.zeroshot_sample_src_from_pol_state_mujoco(
                        args, init_obs, sizes, policy_idx=pi, env_idx=ei)
                else:
                    init_state = env_sampler.env.state
                    res = env_sampler.zeroshot_sample_src_from_pol_state(
                        args, init_obs, sizes, policy_idx=pi, env_idx=ei)

                source_env = res['source_env']
                mask_env = res['mask_env']
                source_policy = res['source_policy']
                init_episode_reward = res['episode_reward']
                mask_policy = res['mask_policy']

                if source_policy.shape[1] == 1:
                    source_policy = source_policy.repeat(1, 2, 1)
                    mask_policy = mask_policy.repeat(1, 1, 2)
                emb_policy = policy_encoder(
                    source_policy.detach().to(device),
                    mask_policy.detach().to(device)).detach()
                if source_env.shape[1] == 1:
                    source_env = source_env.repeat(1, 2, 1)
                    mask_env = mask_env.repeat(1, 1, 2)
                emb_env = env_encoder(source_env.detach().to(device),
                                      mask_env.detach().to(device)).detach()

                emb_policy = F.normalize(emb_policy, p=2, dim=1).detach()
                emb_env = F.normalize(emb_env, p=2, dim=1).detach()

                pred_value = value_net(
                    init_obs.unsqueeze(0).to(device), emb_env.to(device),
                    emb_policy.to(device)).item()
                if 'ant' in args.env_name or 'swimmer' in args.env_name:
                    decoded_reward = env_sampler.get_reward_pol_embedding_state_mujoco(
                        args,
                        init_state,
                        init_obs,
                        emb_policy,
                        policy_decoder,
                        env_idx=ei)[0]
                else:
                    decoded_reward = env_sampler.get_reward_pol_embedding_state(
                        args,
                        init_state,
                        init_obs,
                        emb_policy,
                        policy_decoder,
                        env_idx=ei)[0]

                qf = value_net.get_qf(
                    init_obs.unsqueeze(0).to(device), emb_env)
                u, s, v = torch.svd(qf.squeeze())

                opt_policy_pos = u[:, 0].unsqueeze(0)
                opt_policy_neg = -u[:, 0].unsqueeze(0)

                if 'ant' in args.env_name or 'swimmer' in args.env_name:
                    episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state_mujoco(
                        args,
                        init_state,
                        init_obs,
                        opt_policy_pos,
                        policy_decoder,
                        env_idx=ei)
                    episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state_mujoco(
                        args,
                        init_state,
                        init_obs,
                        opt_policy_neg,
                        policy_decoder,
                        env_idx=ei)
                else:
                    episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state(
                        args,
                        init_state,
                        init_obs,
                        opt_policy_pos,
                        policy_decoder,
                        env_idx=ei)
                    episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state(
                        args,
                        init_state,
                        init_obs,
                        opt_policy_neg,
                        policy_decoder,
                        env_idx=ei)

                if episode_reward_pos >= episode_reward_neg:
                    episode_reward = episode_reward_pos
                    opt_policy = opt_policy_pos
                else:
                    episode_reward = episode_reward_neg
                    opt_policy = opt_policy_neg

                unnorm_episode_reward = episode_reward * (
                    args.max_reward - args.min_reward) + args.min_reward
                unnorm_init_episode_reward = init_episode_reward * (
                    args.max_reward - args.min_reward) + args.min_reward
                unnorm_decoded_reward = decoded_reward * (
                    args.max_reward - args.min_reward) + args.min_reward
                unnorm_eval_rewards[ei].append(unnorm_episode_reward)
                eval_rewards[ei].append(episode_reward)
                if i % args.log_interval == 0:
                    if 'ant' in args.env_name or 'swimmer' in args.env_name:
                        print(
                            f"\nEval Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- reward after update: {unnorm_episode_reward: .3f}"
                        )
                        print(
                            f"Initial Policy: {pi} --- init true reward: {unnorm_init_episode_reward: .3f} --- decoded: {unnorm_decoded_reward: .3f} --- predicted: {pred_value: .3f}"
                        )
                    print(
                        f"Eval Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- norm reward after update: {episode_reward: .3f}"
                    )
                    print(
                        f"Initial Policy: {pi} --- norm init true reward: {init_episode_reward: .3f} --- norm decoded: {decoded_reward: .3f} --- predicted: {pred_value: .3f}"
                    )

    for ei in train_envs:
        if 'ant' in args.env_name or 'swimmer' in args.env_name:
            print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\
                .format(ei, np.mean(unnorm_train_rewards[ei]), np.std(unnorm_train_rewards[ei])))
        else:
            print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\
                .format(ei, np.mean(train_rewards[ei]), np.std(train_rewards[ei])))

    for ei in eval_envs:
        if 'ant' in args.env_name or 'swimmer' in args.env_name:
            print("Eval Env {} has reward with mean {:.3f} and std {:.3f}"\
                .format(ei, np.mean(unnorm_eval_rewards[ei]), np.std(unnorm_eval_rewards[ei])))
        else:
            print("Eval Env {} has reward with mean {:.3f} and std {:.3f}"\
               .format(ei, np.mean(eval_rewards[ei]), np.std(eval_rewards[ei])))

    envs.close()
コード例 #6
0
def eval_pdvf():
    '''
    Evaluate the Policy-Dynamics Value Function. 
    '''
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    torch.set_num_threads(1)
    device = args.device
    if device != 'cpu':
        torch.cuda.empty_cache()

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    env = make_vec_envs(args, device)
    env.reset()

    names = []
    for e in range(args.num_envs):
        for s in range(args.num_seeds):
            names.append('ppo.{}.env{}.seed{}.pt'.format(args.env_name, e, s))

    source_policy = []
    for name in names:
        actor_critic = Policy(env.observation_space.shape,
                              env.action_space,
                              base_kwargs={'recurrent': False})
        actor_critic.to(device)
        model = os.path.join(args.save_dir, name)
        actor_critic.load_state_dict(torch.load(model))
        source_policy.append(actor_critic)

    # Load the collected interaction episodes for each agent
    policy_encoder, policy_decoder = pdvf_utils.load_policy_model(args, env)
    env_encoder = pdvf_utils.load_dynamics_model(args, env)

    value_net = PDVF(env.observation_space.shape[0],
                     args.dynamics_embedding_dim,
                     args.hidden_dim_pdvf,
                     args.policy_embedding_dim,
                     device=device).to(device)
    value_net.to(device)
    path_to_pdvf = os.path.join(args.save_dir_pdvf, \
        "pdvf-stage{}.{}.pt".format(args.stage, args.env_name))
    value_net.load_state_dict(torch.load(path_to_pdvf)['state_dict'])
    value_net.eval()

    all_envs = [i for i in range(args.num_envs)]
    train_policies = [i for i in range(int(3 / 4 * args.num_envs))]
    train_envs = [i for i in range(int(3 / 4 * args.num_envs))]
    eval_envs = [i for i in range(int(3 / 4 * args.num_envs), args.num_envs)]

    env_enc_input_size = env.observation_space.shape[
        0] + env.action_space.shape[0]
    sizes = pdvf_utils.DotDict({'state_dim': env.observation_space.shape[0], \
        'action_dim': env.action_space.shape[0], 'env_enc_input_size': \
        env_enc_input_size, 'env_max_seq_length': args.max_num_steps * env_enc_input_size})

    env_sampler = env_utils.EnvSamplerPDVF(env, source_policy, args)

    all_mean_rewards = [[] for _ in range(args.num_envs)]
    all_mean_unnorm_rewards = [[] for _ in range(args.num_envs)]

    # Eval on Train Envs
    train_rewards = {}
    unnorm_train_rewards = {}
    for ei in range(len(all_envs)):
        train_rewards[ei] = []
        unnorm_train_rewards[ei] = []
    for ei in train_envs:
        for i in range(args.num_eval_eps):
            args.seed = i
            np.random.seed(seed=i)
            torch.manual_seed(args.seed)
            torch.cuda.manual_seed_all(args.seed)
            for pi in train_policies:
                init_obs = torch.FloatTensor(env_sampler.env.reset(env_id=ei))
                if 'ant' in args.env_name or 'swimmer' in args.env_name:
                    init_state = env_sampler.env.sim.get_state()
                    res = env_sampler.zeroshot_sample_src_from_pol_state_mujoco(
                        args, init_obs, sizes, policy_idx=pi, env_idx=ei)
                else:
                    init_state = env_sampler.env.state
                    res = env_sampler.zeroshot_sample_src_from_pol_state(
                        args, init_obs, sizes, policy_idx=pi, env_idx=ei)

                source_env = res['source_env']
                mask_env = res['mask_env']
                source_policy = res['source_policy']
                init_episode_reward = res['episode_reward']
                mask_policy = res['mask_policy']

                if source_policy.shape[1] == 1:
                    source_policy = source_policy.repeat(1, 2, 1)
                    mask_policy = mask_policy.repeat(1, 1, 2)
                emb_policy = policy_encoder(
                    source_policy.detach().to(device),
                    mask_policy.detach().to(device)).detach()
                if source_env.shape[1] == 1:
                    source_env = source_env.repeat(1, 2, 1)
                    mask_env = mask_env.repeat(1, 1, 2)
                emb_env = env_encoder(source_env.detach().to(device),
                                      mask_env.detach().to(device)).detach()

                emb_policy = F.normalize(emb_policy, p=2, dim=1).detach()
                emb_env = F.normalize(emb_env, p=2, dim=1).detach()

                pred_value = value_net(
                    init_obs.unsqueeze(0).to(device), emb_env.to(device),
                    emb_policy.to(device)).item()
                if 'ant' in args.env_name or 'swimmer' in args.env_name:
                    decoded_reward = env_sampler.get_reward_pol_embedding_state_mujoco(
                        args,
                        init_state,
                        init_obs,
                        emb_policy,
                        policy_decoder,
                        env_idx=ei)[0]
                else:
                    decoded_reward = env_sampler.get_reward_pol_embedding_state(
                        args,
                        init_state,
                        init_obs,
                        emb_policy,
                        policy_decoder,
                        env_idx=ei)[0]

                qf = value_net.get_qf(
                    init_obs.unsqueeze(0).to(device), emb_env)
                u, s, v = torch.svd(qf.squeeze())

                opt_policy_pos = u[:, 0].unsqueeze(0)
                opt_policy_neg = -u[:, 0].unsqueeze(0)

                if 'ant' in args.env_name or 'swimmer' in args.env_name:
                    episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state_mujoco(
                        args,
                        init_state,
                        init_obs,
                        opt_policy_pos,
                        policy_decoder,
                        env_idx=ei)
                    episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state_mujoco(
                        args,
                        init_state,
                        init_obs,
                        opt_policy_neg,
                        policy_decoder,
                        env_idx=ei)
                else:
                    episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state(
                        args,
                        init_state,
                        init_obs,
                        opt_policy_pos,
                        policy_decoder,
                        env_idx=ei)
                    episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state(
                        args,
                        init_state,
                        init_obs,
                        opt_policy_neg,
                        policy_decoder,
                        env_idx=ei)

                if episode_reward_pos >= episode_reward_neg:
                    episode_reward = episode_reward_pos
                    opt_policy = opt_policy_pos
                else:
                    episode_reward = episode_reward_neg
                    opt_policy = opt_policy_neg

                unnorm_episode_reward = episode_reward * (
                    args.max_reward - args.min_reward) + args.min_reward
                unnorm_init_episode_reward = init_episode_reward * (
                    args.max_reward - args.min_reward) + args.min_reward
                unnorm_decoded_reward = decoded_reward * (
                    args.max_reward - args.min_reward) + args.min_reward

                unnorm_train_rewards[ei].append(unnorm_episode_reward)
                train_rewards[ei].append(episode_reward)
                if i % args.log_interval == 0:
                    if 'ant' in args.env_name or 'swimmer' in args.env_name:
                        print(
                            f"\nTrain Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- reward after update: {unnorm_episode_reward: .3f}"
                        )
                        print(
                            f"Initial Policy: {pi} --- init true reward: {unnorm_init_episode_reward: .3f} --- decoded: {unnorm_decoded_reward: .3f} --- predicted: {pred_value: .3f}"
                        )
                    print(
                        f"Train Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- norm reward after update: {episode_reward: .3f}"
                    )
                    print(
                        f"Initial Policy: {pi} --- norm init true reward: {init_episode_reward: .3f} --- norm decoded: {decoded_reward: .3f} --- predicted: {pred_value: .3f}"
                    )

        all_mean_rewards[ei].append(np.mean(train_rewards[ei]))
        all_mean_unnorm_rewards[ei].append(np.mean(unnorm_train_rewards[ei]))

    for ei in train_envs:
        if 'ant' in args.env_name or 'swimmer' in args.env_name:
            print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\
                .format(ei, np.mean(all_mean_unnorm_rewards[ei]), np.std(all_mean_unnorm_rewards[ei])))
        else:
            print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\
                .format(ei, np.mean(all_mean_rewards[ei]), np.std(all_mean_rewards[ei])))

    # Eval on Eval Envs
    eval_rewards = {}
    unnorm_eval_rewards = {}
    for ei in range(len(all_envs)):
        eval_rewards[ei] = []
        unnorm_eval_rewards[ei] = []
    for ei in eval_envs:
        for i in range(args.num_eval_eps):
            args.seed = i
            np.random.seed(seed=i)
            torch.manual_seed(args.seed)
            torch.cuda.manual_seed_all(args.seed)

            for pi in train_policies:
                init_obs = torch.FloatTensor(env_sampler.env.reset(env_id=ei))
                if 'ant' in args.env_name or 'swimmer' in args.env_name:
                    init_state = env_sampler.env.sim.get_state()
                    res = env_sampler.zeroshot_sample_src_from_pol_state_mujoco(
                        args, init_obs, sizes, policy_idx=pi, env_idx=ei)
                else:
                    init_state = env_sampler.env.state
                    res = env_sampler.zeroshot_sample_src_from_pol_state(
                        args, init_obs, sizes, policy_idx=pi, env_idx=ei)

                source_env = res['source_env']
                mask_env = res['mask_env']
                source_policy = res['source_policy']
                init_episode_reward = res['episode_reward']
                mask_policy = res['mask_policy']

                if source_policy.shape[1] == 1:
                    source_policy = source_policy.repeat(1, 2, 1)
                    mask_policy = mask_policy.repeat(1, 1, 2)
                emb_policy = policy_encoder(
                    source_policy.detach().to(device),
                    mask_policy.detach().to(device)).detach()
                if source_env.shape[1] == 1:
                    source_env = source_env.repeat(1, 2, 1)
                    mask_env = mask_env.repeat(1, 1, 2)
                emb_env = env_encoder(source_env.detach().to(device),
                                      mask_env.detach().to(device)).detach()

                emb_policy = F.normalize(emb_policy, p=2, dim=1).detach()
                emb_env = F.normalize(emb_env, p=2, dim=1).detach()

                pred_value = value_net(
                    init_obs.unsqueeze(0).to(device), emb_env.to(device),
                    emb_policy.to(device)).item()
                if 'ant' in args.env_name or 'swimmer' in args.env_name:
                    decoded_reward = env_sampler.get_reward_pol_embedding_state_mujoco(
                        args,
                        init_state,
                        init_obs,
                        emb_policy,
                        policy_decoder,
                        env_idx=ei)[0]
                else:
                    decoded_reward = env_sampler.get_reward_pol_embedding_state(
                        args,
                        init_state,
                        init_obs,
                        emb_policy,
                        policy_decoder,
                        env_idx=ei)[0]

                qf = value_net.get_qf(
                    init_obs.unsqueeze(0).to(device), emb_env)
                u, s, v = torch.svd(qf.squeeze())

                opt_policy_pos = u[:, 0].unsqueeze(0)
                opt_policy_neg = -u[:, 0].unsqueeze(0)

                if 'ant' in args.env_name or 'swimmer' in args.env_name:
                    episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state_mujoco(
                        args,
                        init_state,
                        init_obs,
                        opt_policy_pos,
                        policy_decoder,
                        env_idx=ei)
                    episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state_mujoco(
                        args,
                        init_state,
                        init_obs,
                        opt_policy_neg,
                        policy_decoder,
                        env_idx=ei)
                else:
                    episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state(
                        args,
                        init_state,
                        init_obs,
                        opt_policy_pos,
                        policy_decoder,
                        env_idx=ei)
                    episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state(
                        args,
                        init_state,
                        init_obs,
                        opt_policy_neg,
                        policy_decoder,
                        env_idx=ei)

                if episode_reward_pos >= episode_reward_neg:
                    episode_reward = episode_reward_pos
                    opt_policy = opt_policy_pos
                else:
                    episode_reward = episode_reward_neg
                    opt_policy = opt_policy_neg

                unnorm_episode_reward = episode_reward * (
                    args.max_reward - args.min_reward) + args.min_reward
                unnorm_init_episode_reward = init_episode_reward * (
                    args.max_reward - args.min_reward) + args.min_reward
                unnorm_decoded_reward = decoded_reward * (
                    args.max_reward - args.min_reward) + args.min_reward

                unnorm_eval_rewards[ei].append(unnorm_episode_reward)
                eval_rewards[ei].append(episode_reward)
                if i % args.log_interval == 0:
                    if 'ant' in args.env_name or 'swimmer' in args.env_name:
                        print(
                            f"\nEval Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- reward after update: {unnorm_episode_reward: .3f}"
                        )
                        print(
                            f"Initial Policy: {pi} --- init true reward: {unnorm_init_episode_reward: .3f} --- decoded: {unnorm_decoded_reward: .3f} --- predicted: {pred_value: .3f}"
                        )
                    print(
                        f"Eval Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- norm reward after update: {episode_reward: .3f}"
                    )
                    print(
                        f"Initial Policy: {pi} --- norm init true reward: {init_episode_reward: .3f} --- norm decoded: {decoded_reward: .3f} --- predicted: {pred_value: .3f}"
                    )

        all_mean_rewards[ei].append(np.mean(eval_rewards[ei]))
        all_mean_unnorm_rewards[ei].append(np.mean(unnorm_eval_rewards[ei]))

    for ei in train_envs:
        if 'ant' in args.env_name or 'swimmer' in args.env_name:
            print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\
                .format(ei, np.mean(all_mean_unnorm_rewards[ei]), np.std(all_mean_unnorm_rewards[ei])))
        else:
            print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\
                .format(ei, np.mean(all_mean_rewards[ei]), np.std(all_mean_rewards[ei])))

    for ei in eval_envs:
        if 'ant' in args.env_name or 'swimmer' in args.env_name:
            print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\
                .format(ei, np.mean(all_mean_unnorm_rewards[ei]), np.std(all_mean_unnorm_rewards[ei])))
        else:
            print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\
                .format(ei, np.mean(all_mean_rewards[ei]), np.std(all_mean_rewards[ei])))

    env.close()