def main(): args = get_args() np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) torch.set_num_threads(1) device = torch.device(args.device) utils.cleanup_log_dir(args.log_dir) env_make = make_pybullet_env(args.task, log_dir=args.log_dir, frame_skip=args.frame_skip) envs = make_vec_envs(env_make, args.num_processes, args.log_dir, device, args.frame_stack) actor_critic = MetaPolicy(envs.observation_space, envs.action_space) loss_writer = LossWriter(args.log_dir, fieldnames= ('V_loss','action_loss','meta_action_loss','meta_value_loss','meta_loss', 'loss')) if args.restart_model: actor_critic.load_state_dict( torch.load(args.restart_model, map_location=device)) actor_critic.to(device) agent = MetaPPO( actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm) obs = envs.reset() rollouts = RolloutStorage(args.num_steps, args.num_processes, obs, envs.action_space, actor_critic.recurrent_hidden_state_size) rollouts.to(device) # they live in GPU, converted to torch from the env wrapper start = time.time() num_updates = int(args.num_env_steps) // args.num_steps // args.num_processes for j in range(num_updates): ppo_rollout(args.num_steps, envs, actor_critic, rollouts) value_loss, meta_value_loss, action_loss, meta_action_loss, loss, meta_loss = ppo_update( agent, actor_critic, rollouts, args.use_gae, args.gamma, args.gae_lambda) loss_writer.write_row({'V_loss': value_loss.item(), 'action_loss': action_loss.item(), 'meta_action_loss':meta_action_loss.item(),'meta_value_loss':meta_value_loss.item(),'meta_loss': meta_loss.item(), 'loss': loss.item()} ) if (j % args.save_interval == 0 or j == num_updates - 1) and args.log_dir != "": ppo_save_model(actor_critic, os.path.join(args.log_dir, "model.state_dict"), j) if j % args.log_interval == 0: total_num_steps = (j + 1) * args.num_processes * args.num_steps s = "Update {}, num timesteps {}, FPS {} \n".format( j, total_num_steps, int(total_num_steps / (time.time() - start))) s += "Loss {:.5f}, meta loss {:.5f}, value_loss {:.5f}, meta_value_loss {:.5f}, action_loss {:.5f}, meta action loss {:.5f}".format( loss.item(), meta_loss.item(), value_loss.item(), meta_value_loss.item(), action_loss.item(), meta_action_loss.item()) print(s, flush=True)
def evaluate(actor_critic, ob_rms, env_name, seed, num_processes, eval_log_dir, device): eval_envs = make_vec_envs(env_name, seed + num_processes, num_processes, None, eval_log_dir, device, True) vec_norm = utils.get_vec_normalize(eval_envs) if vec_norm is not None: vec_norm.eval() vec_norm.ob_rms = ob_rms eval_episode_rewards = [] obs = eval_envs.reset() eval_recurrent_hidden_states = torch.zeros( num_processes, actor_critic.recurrent_hidden_state_size, device=device) eval_masks = torch.zeros(num_processes, 1, device=device) while len(eval_episode_rewards) < 10: with torch.no_grad(): _, action, _, eval_recurrent_hidden_states = actor_critic.act( obs, eval_recurrent_hidden_states, eval_masks, deterministic=True) # Obser reward and next obs obs, _, done, infos = eval_envs.step(action) eval_masks = torch.tensor([[0.0] if done_ else [1.0] for done_ in done], dtype=torch.float32, device=device) for info in infos: if 'episode' in info.keys(): eval_episode_rewards.append(info['episode']['r']) eval_envs.close() print(" Evaluation using {} episodes: mean reward {:.5f}\n".format( len(eval_episode_rewards), np.mean(eval_episode_rewards)))
parser.add_argument('--cnn', default='Fixup', help='Type of cnn. Options are CNN,Impala,Fixup,State') parser.add_argument('--state-stack', type=int, default=4, help='Number of steps to stack in states') parser.add_argument('--task', default='HalfCheetahPyBulletEnv-v0', help='which of the pybullet task') args = parser.parse_args() args.det = not args.non_det device = torch.device(args.device) env_make = make_pybullet_env(args.task, frame_skip=args.frame_skip) env = make_vec_envs(env_make, 1, None, device, args.frame_stack) env.render(mode="human") base_kwargs = {'recurrent': args.recurrent_policy} actor_critic = MetaPolicy(env.observation_space, env.action_space) if args.load_model: actor_critic.load_state_dict( torch.load(args.load_model, map_location=device)) actor_critic.to(device) recurrent_hidden_states = torch.zeros( 1, actor_critic.recurrent_hidden_state_size).to(device) masks = torch.zeros(1, 1).to(device) obs = env.reset() fig = plt.figure()
def train_policy_embedding(): """ Script for training the dynamics (or environment) embeddings using a transformer. References: https://github.com/jadore801120/attention-is-all-you-need-pytorch/ """ args = get_args() os.environ['OMP_NUM_THREADS'] = '1' # Useful Variables best_eval_loss = sys.maxsize device = args.device if device != 'cpu': torch.cuda.empty_cache() # Create the Environment env = make_vec_envs(args, device) names = [] for e in range(args.num_envs): for s in range(args.num_seeds): names.append('ppo.{}.env{}.seed{}.pt'.format(args.env_name, e, s)) all_policies = [] for name in names: actor_critic = Policy(env.observation_space.shape, env.action_space, base_kwargs={'recurrent': False}) actor_critic.to(device) model = os.path.join(args.save_dir, name) actor_critic.load_state_dict(torch.load(model)) all_policies.append(actor_critic) encoder_dim = args.num_attn_heads * args.policy_attn_head_dim enc_input_size = env.observation_space.shape[0] + env.action_space.shape[0] # Initialize the Transformer encoder and decoders encoder = embedding_networks.make_encoder_oh(enc_input_size, N=args.num_layers, \ d_model=encoder_dim, h=args.num_attn_heads, \ dropout=args.dropout, d_emb=args.policy_embedding_dim) decoder = Policy(tuple( [env.observation_space.shape[0] + args.policy_embedding_dim]), env.action_space, base_kwargs={'recurrent': False}) embedding_networks.init_weights(encoder) embedding_networks.init_weights(decoder) encoder.train() decoder.train() encoder.to(device) decoder.to(device) # Loss and Optimizer criterion = nn.MSELoss(reduction='sum') encoder_optimizer = optim.Adam(encoder.parameters(), lr=args.lr_policy) decoder_optimizer = optim.Adam(decoder.parameters(), lr=args.lr_policy) # Create the Environment env_sampler = env_utils.EnvSamplerEmb(env, all_policies, args) # Collect Train Data src_batch = [] tgt_batch = [] state_batch = [] mask_batch = [] mask_batch_all = [] train_policies = [i for i in range(int(3 / 4 * args.num_envs))] train_envs = [i for i in range(int(3 / 4 * args.num_envs))] # For each policy in our dataset for pi in train_policies: # For each environment in our dataset for env in train_envs: # Sample a number of trajectories for this (policy, env) pair for _ in range(args.num_eps_policy): state_batch_t, tgt_batch_t, src_batch_t, mask_batch_t,\ mask_batch_all_t = env_sampler.sample_policy_data(\ policy_idx=pi, env_idx=env) state_batch.extend(state_batch_t) tgt_batch.extend(tgt_batch_t) src_batch.extend(src_batch_t) mask_batch.extend(mask_batch_t) mask_batch_all.extend(mask_batch_all_t) src_batch = torch.stack(src_batch) tgt_batch = torch.stack(tgt_batch).squeeze(1) state_batch = torch.stack(state_batch).squeeze(1) mask_batch = torch.stack(mask_batch) mask_batch_all = torch.stack(mask_batch_all) num_samples_train = src_batch.shape[0] # Collect Eval Data src_batch_eval = [] tgt_batch_eval = [] state_batch_eval = [] mask_batch_eval = [] mask_batch_all_eval = [] eval_policies = [i for i in range(int(3 / 4 * args.num_envs))] eval_envs = [i for i in range(int(3 / 4 * args.num_envs))] # For each policy in our dataset for pi in eval_policies: # For each environment in our dataset for env in eval_envs: # Sample a number of trajectories for this (policy, env) pair for _ in range(args.num_eps_policy): state_batch_t, tgt_batch_t, src_batch_t, mask_batch_t, \ mask_batch_all_t = env_sampler.sample_policy_data(\ policy_idx=pi, env_idx=env) state_batch_eval.extend(state_batch_t) tgt_batch_eval.extend(tgt_batch_t) src_batch_eval.extend(src_batch_t) mask_batch_eval.extend(mask_batch_t) mask_batch_all_eval.extend(mask_batch_all_t) src_batch_eval = torch.stack(src_batch_eval).detach() tgt_batch_eval = torch.stack(tgt_batch_eval).squeeze(1).detach() state_batch_eval = torch.stack(state_batch_eval).squeeze(1).detach() mask_batch_eval = torch.stack(mask_batch_eval).detach() mask_batch_all_eval = torch.stack(mask_batch_all_eval).detach() num_samples_eval = src_batch_eval.shape[0] # Training Loop for epoch in range(args.num_epochs_emb + 1): encoder.train() decoder.train() indices = [i for i in range(num_samples_train)] random.shuffle(indices) total_counts = 0 total_loss = 0 num_correct_actions = 0 for nmb in range(0, len(indices), args.policy_batch_size): indices_mb = indices[nmb:nmb + args.policy_batch_size] source = src_batch[indices_mb].to(device) target = tgt_batch[indices_mb].to(device) state = state_batch[indices_mb].to(device).float() mask = mask_batch[indices_mb].to(device) mask_all = mask_batch_all[indices_mb].squeeze(2).unsqueeze(1).to( device) embedding = encoder(source.detach().to(device), mask_all.detach().to(device)) embedding = F.normalize(embedding, p=2, dim=1) state *= mask.to(device) embedding *= mask.to(device) recurrent_hidden_state = torch.zeros( args.policy_batch_size, decoder.recurrent_hidden_state_size, device=device, requires_grad=True).float() mask_dec = torch.zeros(args.policy_batch_size, 1, device=device, requires_grad=True).float() emb_state_input = torch.cat((embedding, state.to(device)), dim=1).to(device) action = decoder(emb_state_input, recurrent_hidden_state, mask_dec) action *= mask.to(device) target *= mask loss = criterion(action, target.to(device)) total_loss += loss.item() total_counts += len(indices_mb) encoder_optimizer.zero_grad() decoder_optimizer.zero_grad() loss.backward() encoder_optimizer.step() decoder_optimizer.step() if epoch % args.log_interval == 0: avg_loss = total_loss / total_counts print("\n# Epoch %d: Train Loss = %.6f " % (epoch + 1, avg_loss)) # Evaluation encoder.eval() decoder.eval() indices_eval = [i for i in range(num_samples_eval)] total_counts_eval = 0 total_loss_eval = 0 num_correct_actions_eval = 0 for nmb in range(0, len(indices_eval), args.policy_batch_size): indices_mb_eval = indices_eval[nmb:nmb + args.policy_batch_size] source_eval = src_batch_eval[indices_mb_eval].to(device).detach() target_eval = tgt_batch_eval[indices_mb_eval].to(device).detach() state_eval = state_batch_eval[indices_mb_eval].float().to( device).detach() mask_eval = mask_batch_eval[indices_mb_eval].to(device).detach() mask_all_eval = mask_batch_all_eval[indices_mb_eval].squeeze( 2).unsqueeze(1).to(device).detach() embedding_eval = encoder( source_eval.detach().to(device), mask_all_eval.detach().to(device)).detach() embedding_eval = F.normalize(embedding_eval, p=2, dim=1).detach() state_eval *= mask_eval.to(device).detach() embedding_eval *= mask_eval.to(device).detach() recurrent_hidden_state_eval = torch.zeros( args.policy_batch_size, decoder.recurrent_hidden_state_size, device='cpu').float() mask_dec_eval = torch.zeros(args.policy_batch_size, 1, device='cpu').float() emb_state_input_eval = torch.cat( (embedding_eval, state_eval.to(device)), dim=1) action_eval = decoder(emb_state_input_eval, recurrent_hidden_state_eval, mask_dec_eval, deterministic=True) action_eval *= mask_eval.to(device) target_eval *= mask_eval loss_eval = criterion(action_eval, target_eval.to(device)) total_loss_eval += loss_eval.item() total_counts_eval += len(indices_mb_eval) avg_loss_eval = total_loss_eval / total_counts_eval # Save the models if avg_loss_eval <= best_eval_loss: best_eval_loss = avg_loss_eval pdvf_utils.save_model("policy-encoder.", encoder, encoder_optimizer, \ epoch, args, args.env_name, save_dir=args.save_dir_policy_embedding) pdvf_utils.save_model("policy-decoder.", decoder, decoder_optimizer, \ epoch, args, args.env_name, save_dir=args.save_dir_policy_embedding) if epoch % args.log_interval == 0: print("# Epoch %d: Eval Loss = %.6f " % (epoch + 1, avg_loss_eval))
def train_pdvf(): ''' Train the Policy-Dynamics Value Function of PD-VF which estimates the return for a family of policies in a family of environments with varying dynamics. To do this, it trains a network conditioned on an initial state, a (learned) policy embedding, and a (learned) dynamics embedding and outputs an estimate of the cumulative reward of the corresponding policy in the given environment. ''' args = get_args() os.environ['OMP_NUM_THREADS'] = '1' device = args.device if device != 'cpu': torch.cuda.empty_cache() # Create the environment envs = make_vec_envs(args, device) if args.seed: torch.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) names = [] for e in range(args.num_envs): for s in range(args.num_seeds): names.append('ppo.{}.env{}.seed{}.pt'.format(args.env_name, e, s)) all_policies = [] for name in names: actor_critic = Policy(envs.observation_space.shape, envs.action_space, base_kwargs={'recurrent': False}) actor_critic.to(device) model = os.path.join(args.save_dir, name) actor_critic.load_state_dict(torch.load(model)) all_policies.append(actor_critic) # Load the collected interaction episodes for each agent policy_encoder, policy_decoder = pdvf_utils.load_policy_model(args, envs) env_encoder = pdvf_utils.load_dynamics_model(args, envs) policy_decoder.train() decoder_optimizer = optim.Adam(policy_decoder.parameters(), lr=args.lr_policy) decoder_optimizer2 = optim.Adam(policy_decoder.parameters(), lr=args.lr_policy) decoder_network = {'policy_decoder': policy_decoder, \ 'decoder_optimizer': decoder_optimizer, \ 'decoder_optimizer2': decoder_optimizer2} # Instantiate the PD-VF, Optimizer and Loss args.use_l2_loss = True value_net = PDVF(envs.observation_space.shape[0], args.dynamics_embedding_dim, args.hidden_dim_pdvf, args.policy_embedding_dim, device=device).to(device) optimizer = optim.Adam(value_net.parameters(), lr=args.lr_pdvf, eps=args.eps) optimizer2 = optim.Adam(value_net.parameters(), lr=args.lr_pdvf, eps=args.eps) network = { 'net': value_net, 'optimizer': optimizer, 'optimizer2': optimizer2 } value_net.train() train_policies = [i for i in range(int(3 / 4 * args.num_envs))] train_envs = [i for i in range(int(3 / 4 * args.num_envs))] eval_envs = [i for i in range(int(3 / 4 * args.num_envs), args.num_envs)] all_envs = [i for i in range(args.num_envs)] NUM_STAGES = args.num_stages NUM_TRAIN_EPS = args.num_train_eps NUM_TRAIN_SAMPLES = NUM_TRAIN_EPS * len(train_policies) * len(train_envs) NUM_EVAL_EPS = args.num_eval_eps NUM_EVAL_SAMPLES = NUM_EVAL_EPS * len(train_policies) * len(train_envs) env_enc_input_size = 2 * envs.observation_space.shape[ 0] + args.policy_embedding_dim sizes = pdvf_utils.DotDict({'state_dim': envs.observation_space.shape[0], \ 'action_dim': envs.action_space.shape[0], 'env_enc_input_size': env_enc_input_size, \ 'env_max_seq_length': args.max_num_steps * env_enc_input_size}) env_sampler = env_utils.EnvSamplerPDVF(envs, all_policies, args) decoder_env_sampler = env_utils.EnvSamplerEmb(envs, all_policies, args) #################### TRAIN PHASE 1 ######################## # Collect Eval Data for First Training Stage eval_memory = ReplayMemoryPDVF(NUM_EVAL_SAMPLES) decoder_eval_memory = ReplayMemoryPolicyDecoder(NUM_EVAL_SAMPLES) for i in range(NUM_EVAL_EPS): for ei in train_envs: for pi in train_policies: init_obs = torch.FloatTensor(env_sampler.env.reset(env_id=ei)) if 'ant' in args.env_name or 'swimmer' in args.env_name: init_state = env_sampler.env.sim.get_state() res = env_sampler.zeroshot_sample_src_from_pol_state_mujoco( args, init_obs, sizes, policy_idx=pi, env_idx=ei) else: init_state = env_sampler.env.state res = env_sampler.zeroshot_sample_src_from_pol_state( args, init_obs, sizes, policy_idx=pi, env_idx=ei) source_env = res['source_env'] mask_env = res['mask_env'] mask_policy = res['mask_policy'] source_policy = res['source_policy'] episode_reward = res['episode_reward'] episode_reward_tensor = torch.tensor([episode_reward], device=device, dtype=torch.float) if source_policy.shape[1] == 1: source_policy = source_policy.repeat(1, 2, 1) mask_policy = mask_policy.repeat(1, 1, 2) emb_policy = policy_encoder( source_policy.detach().to(device), mask_policy.detach().to(device)).detach() if source_env.shape[1] == 1: source_env = source_env.repeat(1, 2, 1) mask_env = mask_env.repeat(1, 1, 2) emb_env = env_encoder(source_env.detach().to(device), mask_env.detach().to(device)).detach() emb_policy = F.normalize(emb_policy, p=2, dim=1).detach() emb_env = F.normalize(emb_env, p=2, dim=1).detach() pred_value = value_net( init_obs.unsqueeze(0).to(device), emb_env.to(device), emb_policy.to(device)).item() if 'ant' in args.env_name or 'swimmer' in args.env_name: decoded_reward = env_sampler.get_reward_pol_embedding_state_mujoco( args, init_state, init_obs, emb_policy, policy_decoder, env_idx=ei)[0] else: decoded_reward = env_sampler.get_reward_pol_embedding_state( args, init_state, init_obs, emb_policy, policy_decoder, env_idx=ei)[0] decoded_reward_tensor = torch.tensor([decoded_reward], device=device, dtype=torch.float) eval_memory.push(init_obs.unsqueeze(0), emb_policy.unsqueeze(0), emb_env.unsqueeze(0), episode_reward_tensor) # Collect data for the decoder state_batch, tgt_batch, src_batch, mask_batch, _ = \ decoder_env_sampler.sample_policy_data(policy_idx=pi, env_idx=ei) for state, tgt, src, mask in zip(state_batch, tgt_batch, src_batch, mask_batch): state = state.to(device).float() mask = mask.to(device) state *= mask.to(device).detach() emb_policy *= mask.to(device).detach() recurrent_state = torch.zeros( state.shape[0], policy_decoder.recurrent_hidden_state_size, device=args.device).float() mask_dec = torch.zeros(state.shape[0], 1, device=args.device).float() emb_state = torch.cat((emb_policy, state.to(device)), dim=1) action = policy_decoder(emb_state, recurrent_state, mask_dec, deterministic=True) action *= mask.to(device) decoder_eval_memory.push(emb_state, recurrent_state, mask_dec, action) # Collect Train Data for Frist Training Stage memory = ReplayMemoryPDVF(NUM_TRAIN_SAMPLES) decoder_memory = ReplayMemoryPolicyDecoder(NUM_TRAIN_SAMPLES) for i in range(NUM_TRAIN_EPS): for ei in train_envs: for pi in train_policies: init_obs = torch.FloatTensor(env_sampler.env.reset(env_id=ei)) if 'ant' in args.env_name or 'swimmer' in args.env_name: init_state = env_sampler.env.sim.get_state() res = env_sampler.zeroshot_sample_src_from_pol_state_mujoco( args, init_obs, sizes, policy_idx=pi, env_idx=ei) else: init_state = env_sampler.env.state res = env_sampler.zeroshot_sample_src_from_pol_state( args, init_obs, sizes, policy_idx=pi, env_idx=ei) source_env = res['source_env'] mask_env = res['mask_env'] source_policy = res['source_policy'] mask_policy = res['mask_policy'] episode_reward = res['episode_reward'] episode_reward_tensor = torch.tensor([episode_reward], device=device, dtype=torch.float) if source_policy.shape[1] == 1: source_policy = source_policy.repeat(1, 2, 1) mask_policy = mask_policy.repeat(1, 1, 2) emb_policy = policy_encoder( source_policy.detach().to(device), mask_policy.detach().to(device)).detach() if source_env.shape[1] == 1: source_env = source_env.repeat(1, 2, 1) mask_env = mask_env.repeat(1, 1, 2) emb_env = env_encoder(source_env.detach().to(device), mask_env.detach().to(device)).detach() emb_policy = F.normalize(emb_policy, p=2, dim=1).detach() emb_env = F.normalize(emb_env, p=2, dim=1).detach() pred_value = value_net( init_obs.unsqueeze(0).to(device), emb_env.to(device), emb_policy.to(device)).item() if 'ant' in args.env_name or 'swimmer' in args.env_name: decoded_reward = env_sampler.get_reward_pol_embedding_state_mujoco( args, init_state, init_obs, emb_policy, policy_decoder, env_idx=ei)[0] else: decoded_reward = env_sampler.get_reward_pol_embedding_state( args, init_state, init_obs, emb_policy, policy_decoder, env_idx=ei)[0] decoded_reward_tensor = torch.tensor([decoded_reward], device=device, dtype=torch.float) memory.push(init_obs.unsqueeze(0), emb_policy.unsqueeze(0), emb_env.unsqueeze(0), episode_reward_tensor) # Collect data for the decoder state_batch, tgt_batch, src_batch, mask_batch, _ = \ decoder_env_sampler.sample_policy_data(policy_idx=pi, env_idx=ei) for state, tgt, src, mask in zip(state_batch, tgt_batch, src_batch, mask_batch): state = state.to(device).float() mask = mask.to(device) state *= mask.to(device).detach() emb_policy *= mask.to(device).detach() recurrent_state = torch.zeros( state.shape[0], policy_decoder.recurrent_hidden_state_size, device=args.device).float() mask_dec = torch.zeros(state.shape[0], 1, device=args.device).float() emb_state = torch.cat((emb_policy, state.to(device)), dim=1) action = policy_decoder(emb_state, recurrent_state, mask_dec, deterministic=True) action *= mask.to(device) decoder_memory.push(emb_state, recurrent_state, mask_dec, action) ### Train - Stage 1 ### total_train_loss = 0 total_eval_loss = 0 BEST_EVAL_LOSS = sys.maxsize decoder_total_train_loss = 0 decoder_total_eval_loss = 0 DECODER_BEST_EVAL_LOSS = sys.maxsize print("\nFirst Training Stage") for i in range(args.num_epochs_pdvf_phase1): train_loss = train_utils.optimize_model_pdvf( args, network, memory, num_opt_steps=args.num_opt_steps) if train_loss: total_train_loss += train_loss eval_loss = train_utils.optimize_model_pdvf( args, network, eval_memory, num_opt_steps=args.num_opt_steps, eval=True) if eval_loss: total_eval_loss += eval_loss if eval_loss < BEST_EVAL_LOSS: BEST_EVAL_LOSS = eval_loss pdvf_utils.save_model("pdvf-stage0.", value_net, optimizer, \ i, args, args.env_name, save_dir=args.save_dir_pdvf) if i % args.log_interval == 0: print("\n### PD-VF: Episode {}: Train Loss {:.6f} Eval Loss {:.6f}".format( \ i, total_train_loss / args.log_interval, total_eval_loss / args.log_interval)) total_train_loss = 0 total_eval_loss = 0 # Train the Policy Decoder on mixed data # from trajectories collected using the pretrained policies # and decoded trajectories by the current decoder decoder_train_loss = train_utils.optimize_decoder( args, decoder_network, decoder_memory, num_opt_steps=args.num_opt_steps) if decoder_train_loss: decoder_total_train_loss += decoder_train_loss decoder_eval_loss = train_utils.optimize_decoder( args, decoder_network, decoder_eval_memory, num_opt_steps=args.num_opt_steps, eval=True) if decoder_eval_loss: decoder_total_eval_loss += decoder_eval_loss if decoder_eval_loss < DECODER_BEST_EVAL_LOSS: DECODER_BEST_EVAL_LOSS = decoder_eval_loss pdvf_utils.save_model("policy-decoder-stage0.", policy_decoder, decoder_optimizer, \ i, args, args.env_name, save_dir=args.save_dir_pdvf) if i % args.log_interval == 0: print("### PolicyDecoder: Episode {}: Train Loss {:.6f} Eval Loss {:.6f}".format( \ i, decoder_total_train_loss / args.log_interval, decoder_total_eval_loss / args.log_interval)) decoder_total_train_loss = 0 decoder_total_eval_loss = 0 #################### TRAIN PHASE 2 ######################## for k in range(NUM_STAGES): print("Stage in Second Training Phase: ", k) # Collect Eval Data for Second Training Stage eval_memory2 = ReplayMemoryPDVF(NUM_EVAL_SAMPLES) decoder_eval_memory2 = ReplayMemoryPolicyDecoder(NUM_EVAL_SAMPLES) for i in range(NUM_EVAL_EPS): for ei in train_envs: for pi in train_policies: init_obs = torch.FloatTensor( env_sampler.env.reset(env_id=ei)) if 'ant' in args.env_name or 'swimmer' in args.env_name: init_state = env_sampler.env.sim.get_state() res = env_sampler.zeroshot_sample_src_from_pol_state_mujoco( args, init_obs, sizes, policy_idx=pi, env_idx=ei) else: init_state = env_sampler.env.state res = env_sampler.zeroshot_sample_src_from_pol_state( args, init_obs, sizes, policy_idx=pi, env_idx=ei) source_env = res['source_env'] mask_env = res['mask_env'] source_policy = res['source_policy'] mask_policy = res['mask_policy'] init_episode_reward = res['episode_reward'] if source_policy.shape[1] == 1: source_policy = source_policy.repeat(1, 2, 1) mask_policy = mask_policy.repeat(1, 1, 2) emb_policy = policy_encoder( source_policy.detach().to(device), mask_policy.detach().to(device)).detach() if source_env.shape[1] == 1: source_env = source_env.repeat(1, 2, 1) mask_env = mask_env.repeat(1, 1, 2) emb_env = env_encoder( source_env.detach().to(device), mask_env.detach().to(device)).detach() emb_policy = F.normalize(emb_policy, p=2, dim=1).detach() emb_env = F.normalize(emb_env, p=2, dim=1).detach() qf = value_net.get_qf( init_obs.unsqueeze(0).to(device), emb_env) u, s, v = torch.svd(qf.squeeze()) opt_policy_pos = u[:, 0].unsqueeze(0) opt_policy_neg = -u[:, 0].unsqueeze(0) if 'ant' in args.env_name or 'swimmer' in args.env_name: episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state_mujoco( args, init_state, init_obs, opt_policy_pos, policy_decoder, env_idx=ei) episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state_mujoco( args, init_state, init_obs, opt_policy_neg, policy_decoder, env_idx=ei) else: episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state( args, init_state, init_obs, opt_policy_pos, policy_decoder, env_idx=ei) episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state( args, init_state, init_obs, opt_policy_neg, policy_decoder, env_idx=ei) if episode_reward_pos >= episode_reward_neg: episode_reward = episode_reward_pos opt_policy = opt_policy_pos else: episode_reward = episode_reward_neg opt_policy = opt_policy_neg episode_reward_tensor = torch.tensor([episode_reward], device=device, dtype=torch.float) eval_memory2.push(init_obs.unsqueeze(0), opt_policy.unsqueeze(0), emb_env.unsqueeze(0), episode_reward_tensor) if 'ant' in args.env_name or 'swimmer' in args.env_name: all_emb_state, all_recurrent_state, all_mask, all_action = \ decoder_env_sampler.get_decoded_traj_mujoco(args, init_state, \ init_obs, opt_policy_pos, policy_decoder, env_idx=ei) else: all_emb_state, all_recurrent_state, all_mask, all_action = \ decoder_env_sampler.get_decoded_traj(args, init_state, \ init_obs, opt_policy_pos, policy_decoder, env_idx=ei) for e, r, m, a in zip(all_emb_state, all_recurrent_state, all_mask, all_action): decoder_eval_memory2.push(e, r, m, a) # Collect Train Data for Second Training Stage memory2 = ReplayMemoryPDVF(NUM_TRAIN_SAMPLES) decoder_memory2 = ReplayMemoryPolicyDecoder(NUM_TRAIN_SAMPLES) for i in range(NUM_TRAIN_EPS): for ei in train_envs: for pi in train_policies: init_obs = torch.FloatTensor( env_sampler.env.reset(env_id=ei)) if 'ant' in args.env_name or 'swimmer' in args.env_name: init_state = env_sampler.env.sim.get_state() res = env_sampler.zeroshot_sample_src_from_pol_state_mujoco( args, init_obs, sizes, policy_idx=pi, env_idx=ei) else: init_state = env_sampler.env.state res = env_sampler.zeroshot_sample_src_from_pol_state( args, init_obs, sizes, policy_idx=pi, env_idx=ei) source_env = res['source_env'] mask_env = res['mask_env'] source_policy = res['source_policy'] mask_policy = res['mask_policy'] init_episode_reward = res['episode_reward'] if source_policy.shape[1] == 1: source_policy = source_policy.repeat(1, 2, 1) mask_policy = mask_policy.repeat(1, 1, 2) emb_policy = policy_encoder( source_policy.detach().to(device), mask_policy.detach().to(device)).detach() if source_env.shape[1] == 1: source_env = source_env.repeat(1, 2, 1) mask_env = mask_env.repeat(1, 1, 2) emb_env = env_encoder( source_env.detach().to(device), mask_env.detach().to(device)).detach() emb_policy = F.normalize(emb_policy, p=2, dim=1).detach() emb_env = F.normalize(emb_env, p=2, dim=1).detach() qf = value_net.get_qf( init_obs.unsqueeze(0).to(device), emb_env) u, s, v = torch.svd(qf.squeeze()) opt_policy_pos = u[:, 0].unsqueeze(0) opt_policy_neg = -u[:, 0].unsqueeze(0) # include both solutions (positive and negative policy emb) in the train data # to enforce the correct shape and make it aware of the two if 'ant' in args.env_name or 'swimmer' in args.env_name: episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state_mujoco( args, init_state, init_obs, opt_policy_pos, policy_decoder, env_idx=ei) episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state_mujoco( args, init_state, init_obs, opt_policy_neg, policy_decoder, env_idx=ei) else: episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state( args, init_state, init_obs, opt_policy_pos, policy_decoder, env_idx=ei) episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state( args, init_state, init_obs, opt_policy_neg, policy_decoder, env_idx=ei) episode_reward_tensor_pos = torch.tensor( [episode_reward_pos], device=device, dtype=torch.float) episode_reward_tensor_neg = torch.tensor( [episode_reward_neg], device=device, dtype=torch.float) memory2.push(init_obs.unsqueeze(0), opt_policy_pos.unsqueeze(0), emb_env.unsqueeze(0), episode_reward_tensor_pos) memory2.push(init_obs.unsqueeze(0), opt_policy_neg.unsqueeze(0), emb_env.unsqueeze(0), episode_reward_tensor_neg) # collect PolicyDecoder train data for second training stage if 'ant' in args.env_name or 'swimmer' in args.env_name: all_emb_state, all_recurrent_state, all_mask, all_action = \ decoder_env_sampler.get_decoded_traj_mujoco(args, init_state, \ init_obs, opt_policy_pos, policy_decoder, env_idx=ei) else: all_emb_state, all_recurrent_state, all_mask, all_action = \ decoder_env_sampler.get_decoded_traj(args, init_state, \ init_obs, opt_policy_pos, policy_decoder, env_idx=ei) for e, r, m, a in zip(all_emb_state, all_recurrent_state, all_mask, all_action): decoder_memory2.push(e, r, m, a) ### Train - Stage 2 ### total_train_loss = 0 total_eval_loss = 0 BEST_EVAL_LOSS = sys.maxsize decoder_total_train_loss = 0 decoder_total_eval_loss = 0 DECODER_BEST_EVAL_LOSS = sys.maxsize for i in range(args.num_epochs_pdvf_phase2): train_loss = train_utils.optimize_model_pdvf_phase2( args, network, memory, memory2, num_opt_steps=args.num_opt_steps) if train_loss: total_train_loss += train_loss eval_loss = train_utils.optimize_model_pdvf_phase2( args, network, eval_memory, eval_memory2, num_opt_steps=args.num_opt_steps, eval=True) if eval_loss: total_eval_loss += eval_loss if eval_loss < BEST_EVAL_LOSS: BEST_EVAL_LOSS = eval_loss pdvf_utils.save_model("pdvf-stage{}.".format(k+1), value_net, optimizer, \ i, args, args.env_name, save_dir=args.save_dir_pdvf) if i % args.log_interval == 0: print("\n### PDVF: Stage {} -- Episode {}: Train Loss {:.6f} Eval Loss {:.6f}".format( \ k, i, total_train_loss / args.log_interval, total_eval_loss / args.log_interval)) total_train_loss = 0 total_eval_loss = 0 # Train the Policy Decoder on mixed data # from trajectories collected using the pretrained policies # and decoded trajectories by the current decoder decoder_train_loss = train_utils.optimize_decoder_phase2( args, decoder_network, decoder_memory, decoder_memory2, num_opt_steps=args.num_opt_steps) if decoder_train_loss: decoder_total_train_loss += decoder_train_loss decoder_eval_loss = train_utils.optimize_decoder_phase2( args, decoder_network, decoder_eval_memory, decoder_eval_memory2, num_opt_steps=args.num_opt_steps, eval=True) if decoder_eval_loss: decoder_total_eval_loss += decoder_eval_loss if decoder_eval_loss < DECODER_BEST_EVAL_LOSS: DECODER_BEST_EVAL_LOSS = decoder_eval_loss pdvf_utils.save_model("policy-decoder-stage{}.".format(k+1), policy_decoder, decoder_optimizer, \ i, args, args.env_name, save_dir=args.save_dir_pdvf) if i % args.log_interval == 0: print("### PolicyDecoder: Stage {} -- Episode {}: Train Loss {:.6f} Eval Loss {:.6f}".format( \ k, i, decoder_total_train_loss / args.log_interval, decoder_total_eval_loss / args.log_interval)) decoder_total_train_loss = 0 decoder_total_eval_loss = 0 #################### EVAL ######################## # Eval on Train Envs value_net.eval() policy_decoder.eval() train_rewards = {} unnorm_train_rewards = {} for ei in range(len(all_envs)): train_rewards[ei] = [] unnorm_train_rewards[ei] = [] for i in range(NUM_EVAL_EPS): for ei in train_envs: for pi in train_policies: init_obs = torch.FloatTensor(env_sampler.env.reset(env_id=ei)) if 'ant' in args.env_name or 'swimmer' in args.env_name: init_state = env_sampler.env.sim.get_state() res = env_sampler.zeroshot_sample_src_from_pol_state_mujoco( args, init_obs, sizes, policy_idx=pi, env_idx=ei) else: init_state = env_sampler.env.state res = env_sampler.zeroshot_sample_src_from_pol_state( args, init_obs, sizes, policy_idx=pi, env_idx=ei) source_env = res['source_env'] mask_env = res['mask_env'] source_policy = res['source_policy'] init_episode_reward = res['episode_reward'] mask_policy = res['mask_policy'] if source_policy.shape[1] == 1: source_policy = source_policy.repeat(1, 2, 1) mask_policy = mask_policy.repeat(1, 1, 2) emb_policy = policy_encoder( source_policy.detach().to(device), mask_policy.detach().to(device)).detach() if source_env.shape[1] == 1: source_env = source_env.repeat(1, 2, 1) mask_env = mask_env.repeat(1, 1, 2) emb_env = env_encoder(source_env.detach().to(device), mask_env.detach().to(device)).detach() emb_policy = F.normalize(emb_policy, p=2, dim=1).detach() emb_env = F.normalize(emb_env, p=2, dim=1).detach() pred_value = value_net( init_obs.unsqueeze(0).to(device), emb_env.to(device), emb_policy.to(device)).item() if 'ant' in args.env_name or 'swimmer' in args.env_name: decoded_reward = env_sampler.get_reward_pol_embedding_state_mujoco( args, init_state, init_obs, emb_policy, policy_decoder, env_idx=ei)[0] else: decoded_reward = env_sampler.get_reward_pol_embedding_state( args, init_state, init_obs, emb_policy, policy_decoder, env_idx=ei)[0] qf = value_net.get_qf( init_obs.unsqueeze(0).to(device), emb_env) u, s, v = torch.svd(qf.squeeze()) opt_policy_pos = u[:, 0].unsqueeze(0) opt_policy_neg = -u[:, 0].unsqueeze(0) if 'ant' in args.env_name or 'swimmer' in args.env_name: episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state_mujoco( args, init_state, init_obs, opt_policy_pos, policy_decoder, env_idx=ei) episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state_mujoco( args, init_state, init_obs, opt_policy_neg, policy_decoder, env_idx=ei) else: episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state( args, init_state, init_obs, opt_policy_pos, policy_decoder, env_idx=ei) episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state( args, init_state, init_obs, opt_policy_neg, policy_decoder, env_idx=ei) if episode_reward_pos >= episode_reward_neg: episode_reward = episode_reward_pos opt_policy = opt_policy_pos else: episode_reward = episode_reward_neg opt_policy = opt_policy_neg unnorm_episode_reward = episode_reward * ( args.max_reward - args.min_reward) + args.min_reward unnorm_init_episode_reward = init_episode_reward * ( args.max_reward - args.min_reward) + args.min_reward unnorm_decoded_reward = decoded_reward * ( args.max_reward - args.min_reward) + args.min_reward unnorm_train_rewards[ei].append(unnorm_episode_reward) train_rewards[ei].append(episode_reward) if i % args.log_interval == 0: if 'ant' in args.env_name or 'swimmer' in args.env_name: print( f"\nTrain Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- reward after update: {unnorm_episode_reward: .3f}" ) print( f"Initial Policy: {pi} --- init true reward: {unnorm_init_episode_reward: .3f} --- decoded: {unnorm_decoded_reward: .3f} --- predicted: {pred_value: .3f}" ) print( f"Train Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- norm reward after update: {episode_reward: .3f}" ) print( f"Initial Policy: {pi} --- norm init true reward: {init_episode_reward: .3f} --- norm decoded: {decoded_reward: .3f} --- predicted: {pred_value: .3f}" ) for ei in train_envs: if 'ant' in args.env_name or 'swimmer' in args.env_name: print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\ .format(ei, np.mean(unnorm_train_rewards[ei]), np.std(unnorm_train_rewards[ei]))) else: print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\ .format(ei, np.mean(train_rewards[ei]), np.std(train_rewards[ei]))) # Eval on Eval Envs value_net.eval() policy_decoder.eval() eval_rewards = {} unnorm_eval_rewards = {} for ei in range(len(all_envs)): eval_rewards[ei] = [] unnorm_eval_rewards[ei] = [] for i in range(NUM_EVAL_EPS): for ei in eval_envs: for pi in train_policies: init_obs = torch.FloatTensor(env_sampler.env.reset(env_id=ei)) if 'ant' in args.env_name or 'swimmer' in args.env_name: init_state = env_sampler.env.sim.get_state() res = env_sampler.zeroshot_sample_src_from_pol_state_mujoco( args, init_obs, sizes, policy_idx=pi, env_idx=ei) else: init_state = env_sampler.env.state res = env_sampler.zeroshot_sample_src_from_pol_state( args, init_obs, sizes, policy_idx=pi, env_idx=ei) source_env = res['source_env'] mask_env = res['mask_env'] source_policy = res['source_policy'] init_episode_reward = res['episode_reward'] mask_policy = res['mask_policy'] if source_policy.shape[1] == 1: source_policy = source_policy.repeat(1, 2, 1) mask_policy = mask_policy.repeat(1, 1, 2) emb_policy = policy_encoder( source_policy.detach().to(device), mask_policy.detach().to(device)).detach() if source_env.shape[1] == 1: source_env = source_env.repeat(1, 2, 1) mask_env = mask_env.repeat(1, 1, 2) emb_env = env_encoder(source_env.detach().to(device), mask_env.detach().to(device)).detach() emb_policy = F.normalize(emb_policy, p=2, dim=1).detach() emb_env = F.normalize(emb_env, p=2, dim=1).detach() pred_value = value_net( init_obs.unsqueeze(0).to(device), emb_env.to(device), emb_policy.to(device)).item() if 'ant' in args.env_name or 'swimmer' in args.env_name: decoded_reward = env_sampler.get_reward_pol_embedding_state_mujoco( args, init_state, init_obs, emb_policy, policy_decoder, env_idx=ei)[0] else: decoded_reward = env_sampler.get_reward_pol_embedding_state( args, init_state, init_obs, emb_policy, policy_decoder, env_idx=ei)[0] qf = value_net.get_qf( init_obs.unsqueeze(0).to(device), emb_env) u, s, v = torch.svd(qf.squeeze()) opt_policy_pos = u[:, 0].unsqueeze(0) opt_policy_neg = -u[:, 0].unsqueeze(0) if 'ant' in args.env_name or 'swimmer' in args.env_name: episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state_mujoco( args, init_state, init_obs, opt_policy_pos, policy_decoder, env_idx=ei) episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state_mujoco( args, init_state, init_obs, opt_policy_neg, policy_decoder, env_idx=ei) else: episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state( args, init_state, init_obs, opt_policy_pos, policy_decoder, env_idx=ei) episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state( args, init_state, init_obs, opt_policy_neg, policy_decoder, env_idx=ei) if episode_reward_pos >= episode_reward_neg: episode_reward = episode_reward_pos opt_policy = opt_policy_pos else: episode_reward = episode_reward_neg opt_policy = opt_policy_neg unnorm_episode_reward = episode_reward * ( args.max_reward - args.min_reward) + args.min_reward unnorm_init_episode_reward = init_episode_reward * ( args.max_reward - args.min_reward) + args.min_reward unnorm_decoded_reward = decoded_reward * ( args.max_reward - args.min_reward) + args.min_reward unnorm_eval_rewards[ei].append(unnorm_episode_reward) eval_rewards[ei].append(episode_reward) if i % args.log_interval == 0: if 'ant' in args.env_name or 'swimmer' in args.env_name: print( f"\nEval Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- reward after update: {unnorm_episode_reward: .3f}" ) print( f"Initial Policy: {pi} --- init true reward: {unnorm_init_episode_reward: .3f} --- decoded: {unnorm_decoded_reward: .3f} --- predicted: {pred_value: .3f}" ) print( f"Eval Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- norm reward after update: {episode_reward: .3f}" ) print( f"Initial Policy: {pi} --- norm init true reward: {init_episode_reward: .3f} --- norm decoded: {decoded_reward: .3f} --- predicted: {pred_value: .3f}" ) for ei in train_envs: if 'ant' in args.env_name or 'swimmer' in args.env_name: print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\ .format(ei, np.mean(unnorm_train_rewards[ei]), np.std(unnorm_train_rewards[ei]))) else: print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\ .format(ei, np.mean(train_rewards[ei]), np.std(train_rewards[ei]))) for ei in eval_envs: if 'ant' in args.env_name or 'swimmer' in args.env_name: print("Eval Env {} has reward with mean {:.3f} and std {:.3f}"\ .format(ei, np.mean(unnorm_eval_rewards[ei]), np.std(unnorm_eval_rewards[ei]))) else: print("Eval Env {} has reward with mean {:.3f} and std {:.3f}"\ .format(ei, np.mean(eval_rewards[ei]), np.std(eval_rewards[ei]))) envs.close()
def eval_pdvf(): ''' Evaluate the Policy-Dynamics Value Function. ''' args = get_args() torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) torch.set_num_threads(1) device = args.device if device != 'cpu': torch.cuda.empty_cache() if args.cuda and torch.cuda.is_available() and args.cuda_deterministic: torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True env = make_vec_envs(args, device) env.reset() names = [] for e in range(args.num_envs): for s in range(args.num_seeds): names.append('ppo.{}.env{}.seed{}.pt'.format(args.env_name, e, s)) source_policy = [] for name in names: actor_critic = Policy(env.observation_space.shape, env.action_space, base_kwargs={'recurrent': False}) actor_critic.to(device) model = os.path.join(args.save_dir, name) actor_critic.load_state_dict(torch.load(model)) source_policy.append(actor_critic) # Load the collected interaction episodes for each agent policy_encoder, policy_decoder = pdvf_utils.load_policy_model(args, env) env_encoder = pdvf_utils.load_dynamics_model(args, env) value_net = PDVF(env.observation_space.shape[0], args.dynamics_embedding_dim, args.hidden_dim_pdvf, args.policy_embedding_dim, device=device).to(device) value_net.to(device) path_to_pdvf = os.path.join(args.save_dir_pdvf, \ "pdvf-stage{}.{}.pt".format(args.stage, args.env_name)) value_net.load_state_dict(torch.load(path_to_pdvf)['state_dict']) value_net.eval() all_envs = [i for i in range(args.num_envs)] train_policies = [i for i in range(int(3 / 4 * args.num_envs))] train_envs = [i for i in range(int(3 / 4 * args.num_envs))] eval_envs = [i for i in range(int(3 / 4 * args.num_envs), args.num_envs)] env_enc_input_size = env.observation_space.shape[ 0] + env.action_space.shape[0] sizes = pdvf_utils.DotDict({'state_dim': env.observation_space.shape[0], \ 'action_dim': env.action_space.shape[0], 'env_enc_input_size': \ env_enc_input_size, 'env_max_seq_length': args.max_num_steps * env_enc_input_size}) env_sampler = env_utils.EnvSamplerPDVF(env, source_policy, args) all_mean_rewards = [[] for _ in range(args.num_envs)] all_mean_unnorm_rewards = [[] for _ in range(args.num_envs)] # Eval on Train Envs train_rewards = {} unnorm_train_rewards = {} for ei in range(len(all_envs)): train_rewards[ei] = [] unnorm_train_rewards[ei] = [] for ei in train_envs: for i in range(args.num_eval_eps): args.seed = i np.random.seed(seed=i) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) for pi in train_policies: init_obs = torch.FloatTensor(env_sampler.env.reset(env_id=ei)) if 'ant' in args.env_name or 'swimmer' in args.env_name: init_state = env_sampler.env.sim.get_state() res = env_sampler.zeroshot_sample_src_from_pol_state_mujoco( args, init_obs, sizes, policy_idx=pi, env_idx=ei) else: init_state = env_sampler.env.state res = env_sampler.zeroshot_sample_src_from_pol_state( args, init_obs, sizes, policy_idx=pi, env_idx=ei) source_env = res['source_env'] mask_env = res['mask_env'] source_policy = res['source_policy'] init_episode_reward = res['episode_reward'] mask_policy = res['mask_policy'] if source_policy.shape[1] == 1: source_policy = source_policy.repeat(1, 2, 1) mask_policy = mask_policy.repeat(1, 1, 2) emb_policy = policy_encoder( source_policy.detach().to(device), mask_policy.detach().to(device)).detach() if source_env.shape[1] == 1: source_env = source_env.repeat(1, 2, 1) mask_env = mask_env.repeat(1, 1, 2) emb_env = env_encoder(source_env.detach().to(device), mask_env.detach().to(device)).detach() emb_policy = F.normalize(emb_policy, p=2, dim=1).detach() emb_env = F.normalize(emb_env, p=2, dim=1).detach() pred_value = value_net( init_obs.unsqueeze(0).to(device), emb_env.to(device), emb_policy.to(device)).item() if 'ant' in args.env_name or 'swimmer' in args.env_name: decoded_reward = env_sampler.get_reward_pol_embedding_state_mujoco( args, init_state, init_obs, emb_policy, policy_decoder, env_idx=ei)[0] else: decoded_reward = env_sampler.get_reward_pol_embedding_state( args, init_state, init_obs, emb_policy, policy_decoder, env_idx=ei)[0] qf = value_net.get_qf( init_obs.unsqueeze(0).to(device), emb_env) u, s, v = torch.svd(qf.squeeze()) opt_policy_pos = u[:, 0].unsqueeze(0) opt_policy_neg = -u[:, 0].unsqueeze(0) if 'ant' in args.env_name or 'swimmer' in args.env_name: episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state_mujoco( args, init_state, init_obs, opt_policy_pos, policy_decoder, env_idx=ei) episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state_mujoco( args, init_state, init_obs, opt_policy_neg, policy_decoder, env_idx=ei) else: episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state( args, init_state, init_obs, opt_policy_pos, policy_decoder, env_idx=ei) episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state( args, init_state, init_obs, opt_policy_neg, policy_decoder, env_idx=ei) if episode_reward_pos >= episode_reward_neg: episode_reward = episode_reward_pos opt_policy = opt_policy_pos else: episode_reward = episode_reward_neg opt_policy = opt_policy_neg unnorm_episode_reward = episode_reward * ( args.max_reward - args.min_reward) + args.min_reward unnorm_init_episode_reward = init_episode_reward * ( args.max_reward - args.min_reward) + args.min_reward unnorm_decoded_reward = decoded_reward * ( args.max_reward - args.min_reward) + args.min_reward unnorm_train_rewards[ei].append(unnorm_episode_reward) train_rewards[ei].append(episode_reward) if i % args.log_interval == 0: if 'ant' in args.env_name or 'swimmer' in args.env_name: print( f"\nTrain Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- reward after update: {unnorm_episode_reward: .3f}" ) print( f"Initial Policy: {pi} --- init true reward: {unnorm_init_episode_reward: .3f} --- decoded: {unnorm_decoded_reward: .3f} --- predicted: {pred_value: .3f}" ) print( f"Train Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- norm reward after update: {episode_reward: .3f}" ) print( f"Initial Policy: {pi} --- norm init true reward: {init_episode_reward: .3f} --- norm decoded: {decoded_reward: .3f} --- predicted: {pred_value: .3f}" ) all_mean_rewards[ei].append(np.mean(train_rewards[ei])) all_mean_unnorm_rewards[ei].append(np.mean(unnorm_train_rewards[ei])) for ei in train_envs: if 'ant' in args.env_name or 'swimmer' in args.env_name: print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\ .format(ei, np.mean(all_mean_unnorm_rewards[ei]), np.std(all_mean_unnorm_rewards[ei]))) else: print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\ .format(ei, np.mean(all_mean_rewards[ei]), np.std(all_mean_rewards[ei]))) # Eval on Eval Envs eval_rewards = {} unnorm_eval_rewards = {} for ei in range(len(all_envs)): eval_rewards[ei] = [] unnorm_eval_rewards[ei] = [] for ei in eval_envs: for i in range(args.num_eval_eps): args.seed = i np.random.seed(seed=i) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) for pi in train_policies: init_obs = torch.FloatTensor(env_sampler.env.reset(env_id=ei)) if 'ant' in args.env_name or 'swimmer' in args.env_name: init_state = env_sampler.env.sim.get_state() res = env_sampler.zeroshot_sample_src_from_pol_state_mujoco( args, init_obs, sizes, policy_idx=pi, env_idx=ei) else: init_state = env_sampler.env.state res = env_sampler.zeroshot_sample_src_from_pol_state( args, init_obs, sizes, policy_idx=pi, env_idx=ei) source_env = res['source_env'] mask_env = res['mask_env'] source_policy = res['source_policy'] init_episode_reward = res['episode_reward'] mask_policy = res['mask_policy'] if source_policy.shape[1] == 1: source_policy = source_policy.repeat(1, 2, 1) mask_policy = mask_policy.repeat(1, 1, 2) emb_policy = policy_encoder( source_policy.detach().to(device), mask_policy.detach().to(device)).detach() if source_env.shape[1] == 1: source_env = source_env.repeat(1, 2, 1) mask_env = mask_env.repeat(1, 1, 2) emb_env = env_encoder(source_env.detach().to(device), mask_env.detach().to(device)).detach() emb_policy = F.normalize(emb_policy, p=2, dim=1).detach() emb_env = F.normalize(emb_env, p=2, dim=1).detach() pred_value = value_net( init_obs.unsqueeze(0).to(device), emb_env.to(device), emb_policy.to(device)).item() if 'ant' in args.env_name or 'swimmer' in args.env_name: decoded_reward = env_sampler.get_reward_pol_embedding_state_mujoco( args, init_state, init_obs, emb_policy, policy_decoder, env_idx=ei)[0] else: decoded_reward = env_sampler.get_reward_pol_embedding_state( args, init_state, init_obs, emb_policy, policy_decoder, env_idx=ei)[0] qf = value_net.get_qf( init_obs.unsqueeze(0).to(device), emb_env) u, s, v = torch.svd(qf.squeeze()) opt_policy_pos = u[:, 0].unsqueeze(0) opt_policy_neg = -u[:, 0].unsqueeze(0) if 'ant' in args.env_name or 'swimmer' in args.env_name: episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state_mujoco( args, init_state, init_obs, opt_policy_pos, policy_decoder, env_idx=ei) episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state_mujoco( args, init_state, init_obs, opt_policy_neg, policy_decoder, env_idx=ei) else: episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state( args, init_state, init_obs, opt_policy_pos, policy_decoder, env_idx=ei) episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state( args, init_state, init_obs, opt_policy_neg, policy_decoder, env_idx=ei) if episode_reward_pos >= episode_reward_neg: episode_reward = episode_reward_pos opt_policy = opt_policy_pos else: episode_reward = episode_reward_neg opt_policy = opt_policy_neg unnorm_episode_reward = episode_reward * ( args.max_reward - args.min_reward) + args.min_reward unnorm_init_episode_reward = init_episode_reward * ( args.max_reward - args.min_reward) + args.min_reward unnorm_decoded_reward = decoded_reward * ( args.max_reward - args.min_reward) + args.min_reward unnorm_eval_rewards[ei].append(unnorm_episode_reward) eval_rewards[ei].append(episode_reward) if i % args.log_interval == 0: if 'ant' in args.env_name or 'swimmer' in args.env_name: print( f"\nEval Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- reward after update: {unnorm_episode_reward: .3f}" ) print( f"Initial Policy: {pi} --- init true reward: {unnorm_init_episode_reward: .3f} --- decoded: {unnorm_decoded_reward: .3f} --- predicted: {pred_value: .3f}" ) print( f"Eval Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- norm reward after update: {episode_reward: .3f}" ) print( f"Initial Policy: {pi} --- norm init true reward: {init_episode_reward: .3f} --- norm decoded: {decoded_reward: .3f} --- predicted: {pred_value: .3f}" ) all_mean_rewards[ei].append(np.mean(eval_rewards[ei])) all_mean_unnorm_rewards[ei].append(np.mean(unnorm_eval_rewards[ei])) for ei in train_envs: if 'ant' in args.env_name or 'swimmer' in args.env_name: print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\ .format(ei, np.mean(all_mean_unnorm_rewards[ei]), np.std(all_mean_unnorm_rewards[ei]))) else: print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\ .format(ei, np.mean(all_mean_rewards[ei]), np.std(all_mean_rewards[ei]))) for ei in eval_envs: if 'ant' in args.env_name or 'swimmer' in args.env_name: print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\ .format(ei, np.mean(all_mean_unnorm_rewards[ei]), np.std(all_mean_unnorm_rewards[ei]))) else: print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\ .format(ei, np.mean(all_mean_rewards[ei]), np.std(all_mean_rewards[ei]))) env.close()