def main(): args = get_args() torch.manual_seed(config.seed) torch.cuda.manual_seed_all(config.seed) if config.cuda and torch.cuda.is_available() and config.cuda_deterministic: torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True logger, final_output_dir, tb_log_dir = create_logger(config, args.cfg, 'train', seed=config.seed) eval_log_dir = final_output_dir + "_eval" utils.cleanup_log_dir(final_output_dir) utils.cleanup_log_dir(eval_log_dir) logger.info(pprint.pformat(args)) logger.info(pprint.pformat(config)) writer = SummaryWriter(tb_log_dir) torch.set_num_threads(1) device = torch.device("cuda:" + config.GPUS if config.cuda else "cpu") width = height = 84 envs = make_vec_envs(config.env_name, config.seed, config.num_processes, config.gamma, final_output_dir, device, False, width=width, height=height, ram_wrapper=False) # create agent actor_critic = Policy(envs.observation_space.shape, envs.action_space, base_kwargs={ 'recurrent': config.recurrent_policy, 'hidden_size': config.hidden_size, 'feat_from_selfsup_attention': config.feat_from_selfsup_attention, 'feat_add_selfsup_attention': config.feat_add_selfsup_attention, 'feat_mul_selfsup_attention_mask': config.feat_mul_selfsup_attention_mask, 'selfsup_attention_num_keypoints': config.SELFSUP_ATTENTION.NUM_KEYPOINTS, 'selfsup_attention_gauss_std': config.SELFSUP_ATTENTION.GAUSS_STD, 'selfsup_attention_fix': config.selfsup_attention_fix, 'selfsup_attention_fix_keypointer': config.selfsup_attention_fix_keypointer, 'selfsup_attention_pretrain': config.selfsup_attention_pretrain, 'selfsup_attention_keyp_maps_pool': config.selfsup_attention_keyp_maps_pool, 'selfsup_attention_image_feat_only': config.selfsup_attention_image_feat_only, 'selfsup_attention_feat_masked': config.selfsup_attention_feat_masked, 'selfsup_attention_feat_masked_residual': config.selfsup_attention_feat_masked_residual, 'selfsup_attention_feat_load_pretrained': config.selfsup_attention_feat_load_pretrained, 'use_layer_norm': config.use_layer_norm, 'selfsup_attention_keyp_cls_agnostic': config.SELFSUP_ATTENTION.KEYPOINTER_CLS_AGNOSTIC, 'selfsup_attention_feat_use_ln': config.SELFSUP_ATTENTION.USE_LAYER_NORM, 'selfsup_attention_use_instance_norm': config.SELFSUP_ATTENTION.USE_INSTANCE_NORM, 'feat_mul_selfsup_attention_mask_residual': config.feat_mul_selfsup_attention_mask_residual, 'bottom_up_form_objects': config.bottom_up_form_objects, 'bottom_up_form_num_of_objects': config.bottom_up_form_num_of_objects, 'gaussian_std': config.gaussian_std, 'train_selfsup_attention': config.train_selfsup_attention, 'block_selfsup_attention_grad': config.block_selfsup_attention_grad, 'sep_bg_fg_feat': config.sep_bg_fg_feat, 'mask_threshold': config.mask_threshold, 'fix_feature': config.fix_feature }) # init / load parameter if config.MODEL_FILE: logger.info('=> loading model from {}'.format(config.MODEL_FILE)) state_dict = torch.load(config.MODEL_FILE) state_dict = OrderedDict( (_k, _v) for _k, _v in state_dict.items() if 'dist' not in _k) actor_critic.load_state_dict(state_dict, strict=False) elif config.RESUME: checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth') if os.path.exists(checkpoint_file): logger.info("=> loading checkpoint '{}'".format(checkpoint_file)) checkpoint = torch.load(checkpoint_file) actor_critic.load_state_dict(checkpoint['state_dict']) logger.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_file, checkpoint['epoch'])) actor_critic.to(device) if config.algo == 'a2c': agent = algo.A2C_ACKTR( actor_critic, config.value_loss_coef, config.entropy_coef, lr=config.lr, eps=config.eps, alpha=config.alpha, max_grad_norm=config.max_grad_norm, train_selfsup_attention=config.train_selfsup_attention) elif config.algo == 'ppo': agent = algo.PPO(actor_critic, config.clip_param, config.ppo_epoch, config.num_mini_batch, config.value_loss_coef, config.entropy_coef, lr=config.lr, eps=config.eps, max_grad_norm=config.max_grad_norm) elif config.algo == 'acktr': agent = algo.A2C_ACKTR( actor_critic, config.value_loss_coef, config.entropy_coef, acktr=True, train_selfsup_attention=config.train_selfsup_attention, max_grad_norm=config.max_grad_norm) # rollouts: environment rollouts = RolloutStorage( config.num_steps, config.num_processes, envs.observation_space.shape, envs.action_space, actor_critic.recurrent_hidden_state_size, keep_buffer=config.train_selfsup_attention, buffer_size=config.train_selfsup_attention_buffer_size) if config.RESUME: if os.path.exists(checkpoint_file): agent.optimizer.load_state_dict(checkpoint['optimizer']) obs = envs.reset() rollouts.obs[0].copy_(obs) rollouts.to(device) episode_rewards = deque(maxlen=10) start = time.time() num_updates = int( config.num_env_steps) // config.num_steps // config.num_processes best_perf = 0.0 best_model = False print('num updates', num_updates, 'num steps', config.num_steps) for j in range(num_updates): if config.use_linear_lr_decay: # decrease learning rate linearly utils.update_linear_schedule( agent.optimizer, j, num_updates, agent.optimizer.lr if config.algo == "acktr" else config.lr) for step in range(config.num_steps): # Sample actions with torch.no_grad(): value, action, action_log_prob, recurrent_hidden_states = actor_critic.act( rollouts.obs[step], rollouts.recurrent_hidden_states[step], rollouts.masks[step]) recurrent_hidden_states, meta = recurrent_hidden_states # Obser reward and next obs obs, reward, done, infos = envs.step(action) objects_locs = [] for info in infos: if 'episode' in info.keys(): episode_rewards.append(info['episode']['r']) # If done then clean the history of observations. masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]) if objects_locs: objects_locs = torch.FloatTensor(objects_locs) objects_locs = objects_locs * 2 - 1 # -1, 1 else: objects_locs = None rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks, bad_masks, objects_loc=objects_locs) with torch.no_grad(): next_value = actor_critic.get_value( rollouts.obs[-1], rollouts.recurrent_hidden_states[-1], rollouts.masks[-1], ).detach() rollouts.compute_returns(next_value, config.use_gae, config.gamma, config.gae_lambda, config.use_proper_time_limits) value_loss, action_loss, dist_entropy = agent.update(rollouts) rollouts.after_update() if config.train_selfsup_attention and j > 15: for _iter in range(config.num_steps // 5): frame_x, frame_y = rollouts.generate_pair_image() selfsup_attention_loss, selfsup_attention_output, image_b_keypoints_maps = \ agent.update_selfsup_attention(frame_x, frame_y, config.SELFSUP_ATTENTION) if j % config.log_interval == 0 and len(episode_rewards) > 1: total_num_steps = (j + 1) * config.num_processes * config.num_steps end = time.time() msg = 'Updates {}, num timesteps {}, FPS {} \n' \ 'Last {} training episodes: mean/median reward {:.1f}/{:.1f} ' \ 'min/max reward {:.1f}/{:.1f} ' \ 'dist entropy {:.1f}, value loss {:.1f}, action loss {:.1f}\n'. \ format(j, total_num_steps, int(total_num_steps / (end - start)), len(episode_rewards), np.mean(episode_rewards), np.median(episode_rewards), np.min(episode_rewards), np.max(episode_rewards), dist_entropy, value_loss, action_loss) if config.train_selfsup_attention and j > 15: msg = msg + 'selfsup attention loss {:.5f}\n'.format( selfsup_attention_loss) logger.info(msg) if (config.eval_interval is not None and len(episode_rewards) > 1 and j % config.eval_interval == 0): total_num_steps = (j + 1) * config.num_processes * config.num_steps ob_rms = getattr(utils.get_vec_normalize(envs), 'ob_rms', None) eval_mean_score, eval_max_score, eval_scores = evaluate( actor_critic, ob_rms, config.env_name, config.seed, config.num_processes, eval_log_dir, device, width=width, height=height) perf_indicator = eval_mean_score if perf_indicator > best_perf: best_perf = perf_indicator best_model = True else: best_model = False # record test scores with open(os.path.join(final_output_dir, 'test_scores'), 'a+') as f: out_s = "TEST: {}, {}, {}, {}\n".format( str(total_num_steps), str(eval_mean_score), str(eval_max_score), [str(_eval_scores) for _eval_scores in eval_scores]) print(out_s, end="", file=f) logger.info(out_s) writer.add_scalar('data/mean_score', eval_mean_score, total_num_steps) writer.add_scalar('data/max_score', eval_max_score, total_num_steps) writer.add_scalars('test', {'mean_score': eval_mean_score}, total_num_steps) # save for every interval-th episode or for the last epoch if (j % config.save_interval == 0 or j == num_updates - 1) and config.save_dir != "": logger.info( "=> saving checkpoint to {}".format(final_output_dir)) epoch = j / config.save_interval save_checkpoint( { 'epoch': epoch + 1, 'model': get_model_name(config), 'state_dict': actor_critic.state_dict(), 'perf': perf_indicator, 'optimizer': agent.optimizer.state_dict(), 'ob_rms': getattr(utils.get_vec_normalize(envs), 'ob_rms', None) }, best_model, final_output_dir) final_model_state_file = os.path.join(final_output_dir, 'final_state.pth') logger.info( '=> saving final model state to {}'.format(final_model_state_file)) torch.save(actor_critic.state_dict(), final_model_state_file) # export_scalars_to_json needs results from add scalars writer.export_scalars_to_json(os.path.join(tb_log_dir, 'all_scalars.json')) writer.close()
def main(): args = get_args() torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) if args.cuda and torch.cuda.is_available() and args.cuda_deterministic: torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True log_dir = os.path.expanduser(args.log_dir) eval_log_dir = log_dir + "_eval" utils.cleanup_log_dir(log_dir) utils.cleanup_log_dir(eval_log_dir) torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") envs = make_vec_envs(args.env_name, args.seed, args.num_processes, args.gamma, args.log_dir, device, False) actor_critic = Policy(envs.observation_space.shape, envs.action_space, base_kwargs={'recurrent': args.recurrent_policy}) actor_critic.to(device) save_path = os.path.join(args.save_dir, args.algo) if args.algo == 'a2c': agent = algo.A2C_ACKTR(actor_critic, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, alpha=args.alpha, max_grad_norm=args.max_grad_norm) elif args.algo == 'ppo': agent = algo.PPO(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm) elif args.algo == 'acktr': agent = algo.A2C_ACKTR(actor_critic, args.value_loss_coef, args.entropy_coef, acktr=True) if args.gail: assert len(envs.observation_space.shape) == 1 discr = gail.Discriminator( envs.observation_space.shape[0] + envs.action_space.shape[0], 100, device) file_name = os.path.join( args.gail_experts_dir, "trajs_{}.pt".format(args.env_name.split('-')[0].lower())) expert_dataset = gail.ExpertDataset(file_name, num_trajectories=4, subsample_frequency=20) drop_last = len(expert_dataset) > args.gail_batch_size gail_train_loader = torch.utils.data.DataLoader( dataset=expert_dataset, batch_size=args.gail_batch_size, shuffle=True, drop_last=drop_last) if args.load: # actor_critic,ob_rms2=torch.load(os.path.join(save_path, args.env_name + ".pt")) # evaluate(actor_critic, ob_rms2, args.env_name, args.seed, # args.num_processes, eval_log_dir, device) #actor_critic.eval() #exit() #.state_dict() actor_critic, agent.optimizer, start_epoch = load_checkpoint( actor_critic, agent.optimizer, os.path.join(save_path, args.env_name + ".pt")) actor_critic = actor_critic.to(device) for state in agent.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) else: start_epoch = 0 rollouts = RolloutStorage(args.num_steps, args.num_processes, envs.observation_space.shape, envs.action_space, actor_critic.recurrent_hidden_state_size) obs = envs.reset() rollouts.obs[0].copy_(obs) rollouts.to(device) episode_rewards = deque(maxlen=10) # ob_rms=ob_rms2 start = time.time() num_updates = int( args.num_env_steps) // args.num_steps // args.num_processes for j in range(start_epoch, num_updates): if args.use_linear_lr_decay: # decrease learning rate linearly utils.update_linear_schedule( agent.optimizer, j, num_updates, agent.optimizer.lr if args.algo == "acktr" else args.lr) for step in range(args.num_steps): # Sample actions with torch.no_grad(): value, action, action_log_prob, recurrent_hidden_states = actor_critic.act( rollouts.obs[step], rollouts.recurrent_hidden_states[step], rollouts.masks[step]) # Obser reward and next obs obs, reward, done, infos = envs.step(action) for info in infos: if 'episode' in info.keys(): episode_rewards.append(info['episode']['r']) # If done then clean the history of observations. masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]) rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks, bad_masks) with torch.no_grad(): next_value = actor_critic.get_value( rollouts.obs[-1], rollouts.recurrent_hidden_states[-1], rollouts.masks[-1]).detach() if args.gail: if j >= 10: envs.venv.eval() gail_epoch = args.gail_epoch if j < 10: gail_epoch = 100 # Warm up for _ in range(gail_epoch): discr.update(gail_train_loader, rollouts, utils.get_vec_normalize(envs)._obfilt) for step in range(args.num_steps): rollouts.rewards[step] = discr.predict_reward( rollouts.obs[step], rollouts.actions[step], args.gamma, rollouts.masks[step]) rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.gae_lambda, args.use_proper_time_limits) value_loss, action_loss, dist_entropy = agent.update(rollouts) rollouts.after_update() # save for every interval-th episode or for the last epoch if (j % args.save_interval == 0 or j == num_updates - 1) and args.save_dir != "": try: os.makedirs(save_path) except OSError: pass state = { 'epoch': j + 1, 'state_dict': actor_critic.state_dict(), 'optimizer': agent.optimizer.state_dict() } torch.save(state, os.path.join(save_path, args.env_name + ".pt")) # torch.save([ # actor_critic, # getattr(utils.get_vec_normalize(envs), 'ob_rms', None) # ], os.path.join(save_path, args.env_name + ".pt")) if j % args.log_interval == 0 and len(episode_rewards) > 1: total_num_steps = (j + 1) * args.num_processes * args.num_steps end = time.time() print( "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n" .format(j, total_num_steps, int(total_num_steps / (end - start)), len(episode_rewards), np.mean(episode_rewards), np.median(episode_rewards), np.min(episode_rewards), np.max(episode_rewards), dist_entropy, value_loss, action_loss)) if (args.eval_interval is not None and len(episode_rewards) > 1 and j % args.eval_interval == 0): ob_rms = utils.get_vec_normalize(envs).ob_rms evaluate(actor_critic, ob_rms, args.env_name, args.seed, args.num_processes, eval_log_dir, device)
def main(): args = get_args() # Record trajectories if args.record_trajectories: record_trajectories() return print(args) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) if args.cuda and torch.cuda.is_available() and args.cuda_deterministic: torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True # Append the model name log_dir = os.path.expanduser(args.log_dir) log_dir = os.path.join(log_dir, args.model_name, str(args.seed)) eval_log_dir = log_dir + "_eval" utils.cleanup_log_dir(log_dir) utils.cleanup_log_dir(eval_log_dir) torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") envs = make_vec_envs(args.env_name, args.seed, args.num_processes, args.gamma, log_dir, device, False) # Take activation for carracing print("Loaded env...") activation = None if args.env_name == 'CarRacing-v0' and args.use_activation: activation = torch.tanh print(activation) actor_critic = Policy(envs.observation_space.shape, envs.action_space, base_kwargs={ 'recurrent': args.recurrent_policy, 'env': args.env_name }, activation=activation) actor_critic.to(device) # Load from previous model if args.load_model_name: state = torch.load( os.path.join(args.save_dir, args.load_model_name, args.load_model_name + '_{}.pt'.format(args.seed)))[0] try: actor_critic.load_state_dict(state) except: actor_critic = state if args.algo == 'a2c': agent = algo.A2C_ACKTR(actor_critic, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, alpha=args.alpha, max_grad_norm=args.max_grad_norm) elif args.algo == 'ppo': agent = algo.PPO(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm) elif args.algo == 'acktr': agent = algo.A2C_ACKTR(actor_critic, args.value_loss_coef, args.entropy_coef, acktr=True) if args.gail: if len(envs.observation_space.shape) == 1: discr = gail.Discriminator( envs.observation_space.shape[0] + envs.action_space.shape[0], 100, device) file_name = os.path.join( args.gail_experts_dir, "trajs_{}.pt".format(args.env_name.split('-')[0].lower())) expert_dataset = gail.ExpertDataset(file_name, num_trajectories=3, subsample_frequency=1) expert_dataset_test = gail.ExpertDataset(file_name, num_trajectories=1, start=3, subsample_frequency=1) drop_last = len(expert_dataset) > args.gail_batch_size gail_train_loader = torch.utils.data.DataLoader( dataset=expert_dataset, batch_size=args.gail_batch_size, shuffle=True, drop_last=drop_last) gail_test_loader = torch.utils.data.DataLoader( dataset=expert_dataset_test, batch_size=args.gail_batch_size, shuffle=False, drop_last=False) print(len(expert_dataset), len(expert_dataset_test)) else: # env observation shape is 3 => its an image assert len(envs.observation_space.shape) == 3 discr = gail.CNNDiscriminator(envs.observation_space.shape, envs.action_space, 100, device) file_name = os.path.join(args.gail_experts_dir, 'expert_data.pkl') expert_dataset = gail.ExpertImageDataset(file_name, train=True) test_dataset = gail.ExpertImageDataset(file_name, train=False) gail_train_loader = torch.utils.data.DataLoader( dataset=expert_dataset, batch_size=args.gail_batch_size, shuffle=True, drop_last=len(expert_dataset) > args.gail_batch_size, ) gail_test_loader = torch.utils.data.DataLoader( dataset=test_dataset, batch_size=args.gail_batch_size, shuffle=False, drop_last=len(test_dataset) > args.gail_batch_size, ) print('Dataloader size', len(gail_train_loader)) rollouts = RolloutStorage(args.num_steps, args.num_processes, envs.observation_space.shape, envs.action_space, actor_critic.recurrent_hidden_state_size) obs = envs.reset() rollouts.obs[0].copy_(obs) rollouts.to(device) episode_rewards = deque(maxlen=10) start = time.time() #num_updates = int( #args.num_env_steps) // args.num_steps // args.num_processes num_updates = args.num_steps print(num_updates) # count the number of times validation loss increases val_loss_increase = 0 prev_val_action = np.inf best_val_loss = np.inf for j in range(num_updates): if args.use_linear_lr_decay: # decrease learning rate linearly utils.update_linear_schedule( agent.optimizer, j, num_updates, agent.optimizer.lr if args.algo == "acktr" else args.lr) for step in range(args.num_steps): # Sample actions with torch.no_grad(): value, action, action_log_prob, recurrent_hidden_states = actor_critic.act( rollouts.obs[step], rollouts.recurrent_hidden_states[step], rollouts.masks[step]) # Observe reward and next obs obs, reward, done, infos = envs.step(action) for info in infos: if 'episode' in info.keys(): episode_rewards.append(info['episode']['r']) # If done then clean the history of observations. masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]) rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks, bad_masks) with torch.no_grad(): next_value = actor_critic.get_value( rollouts.obs[-1], rollouts.recurrent_hidden_states[-1], rollouts.masks[-1]).detach() if args.gail: if j >= 10: try: envs.venv.eval() except: pass gail_epoch = args.gail_epoch #if j < 10: #gail_epoch = 100 # Warm up for _ in range(gail_epoch): #discr.update(gail_train_loader, rollouts, #None) pass for step in range(args.num_steps): rollouts.rewards[step] = discr.predict_reward( rollouts.obs[step], rollouts.actions[step], args.gamma, rollouts.masks[step]) rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.gae_lambda, args.use_proper_time_limits) #value_loss, action_loss, dist_entropy = agent.update(rollouts) value_loss = 0 dist_entropy = 0 for data in gail_train_loader: expert_states, expert_actions = data expert_states = Variable(expert_states).to(device) expert_actions = Variable(expert_actions).to(device) loss = agent.update_bc(expert_states, expert_actions) action_loss = loss.data.cpu().numpy() print("Epoch: {}, Loss: {}".format(j, action_loss)) with torch.no_grad(): cnt = 0 val_action_loss = 0 for data in gail_test_loader: expert_states, expert_actions = data expert_states = Variable(expert_states).to(device) expert_actions = Variable(expert_actions).to(device) loss = agent.get_action_loss(expert_states, expert_actions) val_action_loss += loss.data.cpu().numpy() cnt += 1 val_action_loss /= cnt print("Val Loss: {}".format(val_action_loss)) #rollouts.after_update() # save for every interval-th episode or for the last epoch if (j % args.save_interval == 0 or j == num_updates - 1) and args.save_dir != "": if val_action_loss < best_val_loss: val_loss_increase = 0 best_val_loss = val_action_loss save_path = os.path.join(args.save_dir, args.model_name) try: os.makedirs(save_path) except OSError: pass torch.save([ actor_critic.state_dict(), getattr(utils.get_vec_normalize(envs), 'ob_rms', None), getattr(utils.get_vec_normalize(envs), 'ret_rms', None) ], os.path.join( save_path, args.model_name + "_{}.pt".format(args.seed))) elif val_action_loss > prev_val_action: val_loss_increase += 1 if val_loss_increase == 10: print("Val loss increasing too much, breaking here...") break elif val_action_loss < prev_val_action: val_loss_increase = 0 # Update prev val action prev_val_action = val_action_loss # log interval if j % args.log_interval == 0 and len(episode_rewards) > 1: total_num_steps = (j + 1) * args.num_processes * args.num_steps end = time.time() print( "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n" .format(j, total_num_steps, int(total_num_steps / (end - start)), len(episode_rewards), np.mean(episode_rewards), np.median(episode_rewards), np.min(episode_rewards), np.max(episode_rewards), dist_entropy, value_loss, action_loss)) if (args.eval_interval is not None and len(episode_rewards) > 1 and j % args.eval_interval == 0): ob_rms = utils.get_vec_normalize(envs).ob_rms evaluate(actor_critic, ob_rms, args.env_name, args.seed, args.num_processes, eval_log_dir, device)
def main(): args = get_args() torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) if args.cuda and torch.cuda.is_available() and args.cuda_deterministic: torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True log_dir = os.path.expanduser(args.log_dir) eval_log_dir = log_dir + "_eval" utils.cleanup_log_dir(log_dir) utils.cleanup_log_dir(eval_log_dir) torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") # coinrun environments need to be treated differently. coinrun_envs = { 'CoinRun': 'standard', 'CoinRun-Platforms': 'platform', 'Random-Mazes': 'maze' } envs = make_vec_envs(args.env_name, args.seed, args.num_processes, args.gamma, args.log_dir, device, False, coin_run_level=args.num_levels, difficulty=args.high_difficulty, coin_run_seed=args.seed) if args.env_name in coinrun_envs.keys(): observation_space_shape = (3, 64, 64) args.save_dir = args.save_dir + "/NUM_LEVELS_{}".format( args.num_levels) # Save the level info in the else: observation_space_shape = envs.observation_space.shape # trained model name if args.continue_ppo_training: actor_critic, _ = torch.load(os.path.join(args.check_point, args.env_name + ".pt"), map_location=torch.device(device)) elif args.cor_gail: embed_size = args.embed_size actor_critic = Policy(observation_space_shape, envs.action_space, hidden_size=args.hidden_size, embed_size=embed_size, base_kwargs={'recurrent': args.recurrent_policy}) actor_critic.to(device) correlator = Correlator(observation_space_shape, envs.action_space, hidden_dim=args.hidden_size, embed_dim=embed_size, lr=args.lr, device=device) correlator.to(device) embeds = torch.zeros(1, embed_size) else: embed_size = 0 actor_critic = Policy(observation_space_shape, envs.action_space, hidden_size=args.hidden_size, base_kwargs={'recurrent': args.recurrent_policy}) actor_critic.to(device) embeds = None if args.algo == 'a2c': agent = algo.A2C_ACKTR(actor_critic, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, alpha=args.alpha, max_grad_norm=args.max_grad_norm) elif args.algo == 'ppo': agent = algo.PPO(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, use_clipped_value_loss=True, ftrl_mode=args.cor_gail or args.no_regret_gail, correlated_mode=args.cor_gail) elif args.algo == 'acktr': agent = algo.A2C_ACKTR(actor_critic, args.value_loss_coef, args.entropy_coef, acktr=True) if args.gail or args.no_regret_gail or args.cor_gail: file_name = os.path.join( args.gail_experts_dir, "trajs_{}.pt".format(args.env_name.split('-')[0].lower())) expert_dataset = gail.ExpertDataset( file_name, num_trajectories=50, subsample_frequency=1) #if subsample set to a different number, # grad_pen might need adjustment drop_last = len(expert_dataset) > args.gail_batch_size gail_train_loader = torch.utils.data.DataLoader( dataset=expert_dataset, batch_size=args.gail_batch_size, shuffle=True, drop_last=drop_last) if args.gail: discr = gail.Discriminator(observation_space_shape, envs.action_space, device=device) if args.no_regret_gail or args.cor_gail: queue = deque( maxlen=args.queue_size ) # Strategy Queues: Each element of a queue is a dicr strategy agent_queue = deque( maxlen=args.queue_size ) # Strategy Queues: Each element of a queue is an agent strategy pruning_frequency = 1 if args.no_regret_gail: discr = regret_gail.NoRegretDiscriminator(observation_space_shape, envs.action_space, device=device) if args.cor_gail: discr = cor_gail.CorDiscriminator(observation_space_shape, envs.action_space, hidden_size=args.hidden_size, embed_size=embed_size, device=device) discr.to(device) rollouts = RolloutStorage(args.num_steps, args.num_processes, observation_space_shape, envs.action_space, actor_critic.recurrent_hidden_state_size, embed_size) obs = envs.reset() rollouts.obs[0].copy_(obs) if args.cor_gail: rollouts.embeds[0].copy_(embeds) rollouts.to(device) episode_rewards = deque(maxlen=10) start = time.time() num_updates = int( args.num_env_steps) // args.num_steps // args.num_processes for j in range(num_updates): if args.use_linear_lr_decay: # decrease learning rate linearly utils.update_linear_schedule( agent.optimizer, j, num_updates, agent.optimizer.lr if args.algo == "acktr" else args.lr) for step in range(args.num_steps): # Sample actions # Roll-out with torch.no_grad(): value, action, action_log_prob, recurrent_hidden_states = actor_critic.act( rollouts.obs[step], rollouts.recurrent_hidden_states[step], rollouts.masks[step], rollouts.embeds[step]) obs, reward, done, infos = envs.step(action.to('cpu')) for info in infos: if 'episode' in info.keys(): episode_rewards.append(info['episode']['r']) # If done then clean the history of observations. masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]) rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks, bad_masks) # Sample mediating/correlating actions # Correlated Roll-out if args.cor_gail: embeds, embeds_log_prob, mean = correlator.act( rollouts.obs[step], rollouts.actions[step]) rollouts.insert_embedding(embeds, embeds_log_prob) with torch.no_grad(): next_value = actor_critic.get_value( rollouts.obs[-1], rollouts.recurrent_hidden_states[-1], rollouts.masks[-1], rollouts.embeds[-1]).detach() if args.gail or args.no_regret_gail or args.cor_gail: if args.env_name not in {'CoinRun', 'Random-Mazes'}: if j >= 10: envs.venv.eval() gail_epoch = args.gail_epoch if args.gail: if j < 10: gail_epoch = 100 # Warm up # no need for gail epoch or warm up in the no-regret case and cor_gail. for _ in range(gail_epoch): if utils.get_vec_normalize(envs): obfilt = utils.get_vec_normalize(envs)._obfilt else: obfilt = None if args.gail: discr.update(gail_train_loader, rollouts, obfilt) if args.no_regret_gail or args.cor_gail: last_strategy = discr.update(gail_train_loader, rollouts, queue, args.max_grad_norm, obfilt, j) for step in range(args.num_steps): if args.gail: rollouts.rewards[step] = discr.predict_reward( rollouts.obs[step], rollouts.actions[step], args.gamma, rollouts.masks[step]) if args.no_regret_gail: rollouts.rewards[step] = discr.predict_reward( rollouts.obs[step], rollouts.actions[step], args.gamma, rollouts.masks[step], queue) if args.cor_gail: rollouts.rewards[ step], correlator_reward = discr.predict_reward( rollouts.obs[step], rollouts.actions[step], rollouts.embeds[step], args.gamma, rollouts.masks[step], queue) rollouts.correlated_reward[step] = correlator_reward rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.gae_lambda, args.use_proper_time_limits) if args.gail: value_loss, action_loss, dist_entropy = agent.update(rollouts, j) elif args.no_regret_gail or args.cor_gail: value_loss, action_loss, dist_entropy, agent_gains, agent_strategy = \ agent.mixed_update(rollouts, agent_queue, j) if args.cor_gail: correlator.update(rollouts, agent_gains, args.max_grad_norm) if args.no_regret_gail or args.cor_gail: queue, _ = utils.queue_update(queue, pruning_frequency, args.queue_size, j, last_strategy) agent_queue, pruning_frequency = utils.queue_update( agent_queue, pruning_frequency, args.queue_size, j, agent_strategy) rollouts.after_update() # save for every interval-th episode or for the last epoch if (j % args.save_interval == 0 or j == num_updates - 1) and args.save_dir != "": save_path = os.path.join(args.save_dir, args.algo) try: os.makedirs(save_path) except OSError: pass if not args.cor_gail: torch.save([ actor_critic, getattr(utils.get_vec_normalize(envs), 'ob_rms', None) ], os.path.join(save_path, args.env_name + ".pt")) else: print("saving models in {}".format( os.path.join(save_path, args.env_name))) torch.save( correlator.state_dict(), os.path.join(save_path, args.env_name + "correlator.pt")) torch.save([ actor_critic.state_dict(), getattr(utils.get_vec_normalize(envs), 'ob_rms', None) ], os.path.join(save_path, args.env_name + "actor.pt")) if j % args.log_interval == 0 and len(episode_rewards) > 1: total_num_steps = (j + 1) * args.num_processes * args.num_steps end = time.time() print( "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}," " value loss/action loss {:.1f}/{}".format( j, total_num_steps, int(total_num_steps / (end - start)), len(episode_rewards), np.mean(episode_rewards), np.median(episode_rewards), value_loss, action_loss)) if (args.eval_interval is not None and len(episode_rewards) > 1 and j % args.eval_interval == 0): ob_rms = utils.get_vec_normalize(envs).ob_rms evaluate(actor_critic, ob_rms, args.env_name, args.seed, args.num_processes, eval_log_dir, device)
def main(): args = get_args() import random random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) np.random.seed(args.seed) if args.cuda and torch.cuda.is_available() and args.cuda_deterministic: torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True logdir = args.env_name + '_' + args.algo + '_num_arms_' + str( args.num_processes) + '_' + time.strftime("%d-%m-%Y_%H-%M-%S") if args.use_privacy: logdir = logdir + '_privacy' elif args.use_noisygrad: logdir = logdir + '_noisygrad' elif args.use_pcgrad: logdir = logdir + '_pcgrad' elif args.use_testgrad: logdir = logdir + '_testgrad' elif args.use_median_grad: logdir = logdir + '_mediangrad' logdir = os.path.join('runs', logdir) logdir = os.path.join(os.path.expanduser(args.log_dir), logdir) utils.cleanup_log_dir(logdir) # Ugly but simple logging log_dict = { 'task_steps': args.task_steps, 'grad_noise_ratio': args.grad_noise_ratio, 'max_task_grad_norm': args.max_task_grad_norm, 'use_noisygrad': args.use_noisygrad, 'use_pcgrad': args.use_pcgrad, 'use_testgrad': args.use_testgrad, 'use_testgrad_median': args.use_testgrad_median, 'testgrad_quantile': args.testgrad_quantile, 'median_grad': args.use_median_grad, 'use_meanvargrad': args.use_meanvargrad, 'meanvar_beta': args.meanvar_beta, 'no_special_grad_for_critic': args.no_special_grad_for_critic, 'use_privacy': args.use_privacy, 'seed': args.seed, 'recurrent': args.recurrent_policy, 'obs_recurrent': args.obs_recurrent, 'cmd': ' '.join(sys.argv[1:]) } for eval_disp_name, eval_env_name in EVAL_ENVS.items(): log_dict[eval_disp_name] = [] summary_writer = SummaryWriter() summary_writer.add_hparams( { 'task_steps': args.task_steps, 'grad_noise_ratio': args.grad_noise_ratio, 'max_task_grad_norm': args.max_task_grad_norm, 'use_noisygrad': args.use_noisygrad, 'use_pcgrad': args.use_pcgrad, 'use_testgrad': args.use_testgrad, 'use_testgrad_median': args.use_testgrad_median, 'testgrad_quantile': args.testgrad_quantile, 'median_grad': args.use_median_grad, 'use_meanvargrad': args.use_meanvargrad, 'meanvar_beta': args.meanvar_beta, 'no_special_grad_for_critic': args.no_special_grad_for_critic, 'use_privacy': args.use_privacy, 'seed': args.seed, 'recurrent': args.recurrent_policy, 'obs_recurrent': args.obs_recurrent, 'cmd': ' '.join(sys.argv[1:]) }, {}) torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") print('making envs...') envs = make_vec_envs(args.env_name, args.seed, args.num_processes, args.gamma, args.log_dir, device, False, steps=args.task_steps, free_exploration=args.free_exploration, recurrent=args.recurrent_policy, obs_recurrent=args.obs_recurrent, multi_task=True) val_envs = make_vec_envs(args.val_env_name, args.seed, args.num_processes, args.gamma, args.log_dir, device, False, steps=args.task_steps, free_exploration=args.free_exploration, recurrent=args.recurrent_policy, obs_recurrent=args.obs_recurrent, multi_task=True) eval_envs_dic = {} for eval_disp_name, eval_env_name in EVAL_ENVS.items(): eval_envs_dic[eval_disp_name] = make_vec_envs( eval_env_name[0], args.seed, args.num_processes, None, logdir, device, True, steps=args.task_steps, recurrent=args.recurrent_policy, obs_recurrent=args.obs_recurrent, multi_task=True, free_exploration=args.free_exploration) prev_eval_r = {} print('done') if args.hard_attn: actor_critic = Policy(envs.observation_space.shape, envs.action_space, base=MLPHardAttnBase, base_kwargs={ 'recurrent': args.recurrent_policy or args.obs_recurrent }) else: actor_critic = Policy(envs.observation_space.shape, envs.action_space, base=MLPAttnBase, base_kwargs={ 'recurrent': args.recurrent_policy or args.obs_recurrent }) actor_critic.to(device) if (args.continue_from_epoch > 0) and args.save_dir != "": save_path = os.path.join(args.save_dir, args.algo) actor_critic_, loaded_obs_rms_ = torch.load( os.path.join( save_path, args.env_name + "-epoch-{}.pt".format(args.continue_from_epoch))) actor_critic.load_state_dict(actor_critic_.state_dict()) if args.algo != 'ppo': raise "only PPO is supported" agent = algo.PPO(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, num_tasks=args.num_processes, attention_policy=False, max_grad_norm=args.max_grad_norm, weight_decay=args.weight_decay) val_agent = algo.PPO(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.val_lr, eps=args.eps, num_tasks=args.num_processes, attention_policy=True, max_grad_norm=args.max_grad_norm, weight_decay=args.weight_decay) rollouts = RolloutStorage(args.num_steps, args.num_processes, envs.observation_space.shape, envs.action_space, actor_critic.recurrent_hidden_state_size) val_rollouts = RolloutStorage(args.num_steps, args.num_processes, val_envs.observation_space.shape, val_envs.action_space, actor_critic.recurrent_hidden_state_size) obs = envs.reset() rollouts.obs[0].copy_(obs) rollouts.to(device) val_obs = val_envs.reset() val_rollouts.obs[0].copy_(val_obs) val_rollouts.to(device) episode_rewards = deque(maxlen=10) start = time.time() num_updates = int( args.num_env_steps) // args.num_steps // args.num_processes save_copy = True for j in range(args.continue_from_epoch, args.continue_from_epoch + num_updates): # policy rollouts for step in range(args.num_steps): # Sample actions actor_critic.eval() with torch.no_grad(): value, action, action_log_prob, recurrent_hidden_states = actor_critic.act( rollouts.obs[step], rollouts.recurrent_hidden_states[step], rollouts.masks[step]) actor_critic.train() # Obser reward and next obs obs, reward, done, infos = envs.step(action) for info in infos: if 'episode' in info.keys(): episode_rewards.append(info['episode']['r']) for k, v in info['episode'].items(): summary_writer.add_scalar( f'training/{k}', v, j * args.num_processes * args.num_steps + args.num_processes * step) # If done then clean the history of observations. masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]) rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks, bad_masks) actor_critic.eval() with torch.no_grad(): next_value = actor_critic.get_value( rollouts.obs[-1], rollouts.recurrent_hidden_states[-1], rollouts.masks[-1]).detach() actor_critic.train() rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.gae_lambda, args.use_proper_time_limits) if save_copy: prev_weights = copy.deepcopy(actor_critic.state_dict()) prev_opt_state = copy.deepcopy(agent.optimizer.state_dict()) prev_val_opt_state = copy.deepcopy( val_agent.optimizer.state_dict()) save_copy = False value_loss, action_loss, dist_entropy = agent.update(rollouts) rollouts.after_update() # validation rollouts for val_iter in range(args.val_agent_steps): for step in range(args.num_steps): # Sample actions actor_critic.eval() with torch.no_grad(): value, action, action_log_prob, recurrent_hidden_states = actor_critic.act( val_rollouts.obs[step], val_rollouts.recurrent_hidden_states[step], val_rollouts.masks[step]) actor_critic.train() # Obser reward and next obs obs, reward, done, infos = val_envs.step(action) # If done then clean the history of observations. masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]) val_rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks, bad_masks) actor_critic.eval() with torch.no_grad(): next_value = actor_critic.get_value( val_rollouts.obs[-1], val_rollouts.recurrent_hidden_states[-1], val_rollouts.masks[-1]).detach() actor_critic.train() val_rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.gae_lambda, args.use_proper_time_limits) val_value_loss, val_action_loss, val_dist_entropy = val_agent.update( val_rollouts) val_rollouts.after_update() # save for every interval-th episode or for the last epoch if (j % args.save_interval == 0 or j == num_updates - 1) and args.save_dir != "": save_path = os.path.join(args.save_dir, args.algo) try: os.makedirs(save_path) except OSError: pass torch.save([ actor_critic, getattr(utils.get_vec_normalize(envs), 'obs_rms', None) ], os.path.join(save_path, args.env_name + "-epoch-{}.pt".format(j))) if j % args.log_interval == 0 and len(episode_rewards) > 1: total_num_steps = (j + 1) * args.num_processes * args.num_steps end = time.time() print( "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n" .format(j, total_num_steps, int(total_num_steps / (end - start)), len(episode_rewards), np.mean(episode_rewards), np.median(episode_rewards), np.min(episode_rewards), np.max(episode_rewards), dist_entropy, value_loss, action_loss)) revert = False if (args.eval_interval is not None and len(episode_rewards) > 1 and j % args.eval_interval == 0): actor_critic.eval() obs_rms = utils.get_vec_normalize(envs).obs_rms eval_r = {} printout = f'Seed {args.seed} Iter {j} ' for eval_disp_name, eval_env_name in EVAL_ENVS.items(): eval_r[eval_disp_name] = evaluate( actor_critic, obs_rms, eval_envs_dic, eval_disp_name, args.seed, args.num_processes, eval_env_name[1], logdir, device, steps=args.task_steps, recurrent=args.recurrent_policy, obs_recurrent=args.obs_recurrent, multi_task=True, free_exploration=args.free_exploration) if eval_disp_name in prev_eval_r: diff = np.array(eval_r[eval_disp_name]) - np.array( prev_eval_r[eval_disp_name]) if eval_disp_name == 'many_arms': if np.sum(diff > 0) - np.sum( diff < 0) < args.val_improvement_threshold: print('no update') revert = True summary_writer.add_scalar(f'eval/{eval_disp_name}', np.mean(eval_r[eval_disp_name]), (j + 1) * args.num_processes * args.num_steps) log_dict[eval_disp_name].append([ (j + 1) * args.num_processes * args.num_steps, eval_r[eval_disp_name] ]) printout += eval_disp_name + ' ' + str( np.mean(eval_r[eval_disp_name])) + ' ' # summary_writer.add_scalars('eval_combined', eval_r, (j+1) * args.num_processes * args.num_steps) if revert: actor_critic.load_state_dict(prev_weights) agent.optimizer.load_state_dict(prev_opt_state) val_agent.optimizer.load_state_dict(prev_val_opt_state) else: print(printout) prev_eval_r = eval_r.copy() save_copy = True actor_critic.train() save_obj(log_dict, os.path.join(logdir, 'log_dict.pkl')) envs.close() val_envs.close() for eval_disp_name, eval_env_name in EVAL_ENVS.items(): eval_envs_dic[eval_disp_name].close()
def main(): torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") ## Make environments envs = make_vec_envs(args, device) ## Setup Policy / network architecture if args.load_path != '': if os.path.isfile(os.path.join(args.load_path, "best_model.pt")): import_name = "best_model.pt" else: import_name = "model.pt" online_actor_critic = torch.load( os.path.join(args.load_path, import_name)) target_actor_critic = torch.load( os.path.join(args.load_path, import_name)) if args.cuda: target_actor_critic = target_actor_critic.cuda() online_actor_critic = online_actor_critic.cuda() else: online_actor_critic = Policy(occ_obs_shape, sign_obs_shape, args.state_rep, envs.action_space, args.recurrent_policy) online_actor_critic.to(device) target_actor_critic = Policy(occ_obs_shape, sign_obs_shape, args.state_rep, envs.action_space, args.recurrent_policy) target_actor_critic.to(device) target_actor_critic.load_state_dict(online_actor_critic.state_dict()) if args.penetration_type == "constant": target_actor_critic = online_actor_critic ## Choose algorithm to use if args.algo == 'a2c': agent = algo.A2C_ACKTR(online_actor_critic, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, alpha=args.alpha, max_grad_norm=args.max_grad_norm) elif args.algo == 'ppo': agent = algo.PPO(online_actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm) elif args.algo == 'acktr': agent = algo.A2C_ACKTR(online_actor_critic, args.value_loss_coef, args.entropy_coef, acktr=True) ## Initiate memory buffer rollouts = RolloutStorage(args.num_steps, args.num_processes, occ_obs_shape, sign_obs_shape, envs.action_space, target_actor_critic.recurrent_hidden_state_size) ## Start env with first observation occ_obs, sign_obs = envs.reset() if args.state_rep == 'full': rollouts.occ_obs[0].copy_(occ_obs) rollouts.sign_obs[0].copy_(sign_obs) rollouts.to(device) # Last 20 rewards - can set different queue length for different averaging episode_rewards = deque(maxlen=args.num_steps) reward_track = [] best_eval_rewards = 0 start = time.time() ## Loop over every policy updatetarget network for j in range(num_updates): ## Setup parameter decays if args.use_linear_lr_decay: # decrease learning rate linearly if args.algo == "acktr": # use optimizer's learning rate since it's hard-coded in kfac.py update_linear_schedule(agent.optimizer, j, num_updates, agent.optimizer.lr) else: update_linear_schedule(agent.optimizer, j, num_updates, args.lr) if args.algo == 'ppo' and args.use_linear_clip_decay: agent.clip_param = args.clip_param * (1 - j / float(num_updates)) ## Loop over num_steps environment updates to form trajectory for step in range(args.num_steps): # Sample actionspython3 main.py --algo ppo --num-steps 700000 --penetration-rate $i --env-name TrafficLight-simple-dense-v0 --lr 2.5e-4 --num-processes 8 --num-steps 128 --num-mini-batch 4 --use-linear-lr-decay --use-linear-clip-decay with torch.no_grad(): # Pass observation through network and get outputs value, action, action_log_prob, recurrent_hidden_states = target_actor_critic.act( rollouts.occ_obs[step], rollouts.sign_obs[step], rollouts.recurrent_hidden_states[step], rollouts.masks[step]) # Do action in environment and save reward occ_obs, sign_obs, reward, done, _ = envs.step(action) episode_rewards.append(reward.numpy()) # Masks the processes which are done masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) # Insert step information in buffer rollouts.insert(occ_obs, sign_obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks) ## Get state value of current env state with torch.no_grad(): next_value = target_actor_critic.get_value( rollouts.occ_obs[-1], rollouts.sign_obs[-1], rollouts.recurrent_hidden_states[-1], rollouts.masks[-1]).detach() ## Computes the num_step return (next_value approximates reward after num_step) see Supp Material of https://arxiv.org/pdf/1804.02717.pdf ## Can use Generalized Advantage Estimation rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau) # Update the policy with the rollouts value_loss, action_loss, dist_entropy = agent.update(rollouts) # Clean the rollout by cylcing last elements to first ones rollouts.after_update() if (args.penetration_type == "linear") and (j % update_period == 0): target_actor_critic.load_state_dict( online_actor_critic.state_dict()) ## Save model}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.3f}/{:.3f}, min/max reward {:.3f}/{:.3f}\n". if (j % args.save_interval == 0 or j == num_updates - 1) and args.save_dir != "": # A really ugly way to save a model to CPU save_model = target_actor_critic if args.cuda: save_model = copy.deepcopy(target_actor_critic).cpu() torch.save(save_model, os.path.join(save_path, "model.pt")) total_num_steps = (j + 1) * args.num_processes * args.num_steps if args.vis: # Add the average reward of update to reward tracker reward_track.append(np.mean(episode_rewards)) ## Log progress if j % args.log_interval == 0 and len(episode_rewards) > 1: end = time.time() print( "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.3f}/{:.3f}, min/max reward {:.3f}/{:.3f}\n" .format(j, total_num_steps, int(total_num_steps / (end - start)), len(episode_rewards), np.mean(episode_rewards), np.median(episode_rewards), np.min(episode_rewards), np.max(episode_rewards), dist_entropy)) ## Evaluate model on new environments for 10 rewards percentage = 100 * total_num_steps // args.num_env_steps if (args.eval_interval is not None and percentage > 1 and (j % args.eval_interval == 0 or j == num_updates - 1)): print("###### EVALUATING #######") args_eval = copy.deepcopy(args) args_eval.num_processes = 1 eval_envs = make_vec_envs(args_eval, device, no_logging=True) eval_episode_rewards = [] occ_obs, sign_obs = eval_envs.reset() eval_recurrent_hidden_states = torch.zeros( args_eval.num_processes, target_actor_critic.recurrent_hidden_state_size, device=device) eval_masks = torch.zeros(args_eval.num_processes, 1, device=device) while len(eval_episode_rewards) < 3000: with torch.no_grad(): _, action, _, eval_recurrent_hidden_states = target_actor_critic.act( occ_obs, sign_obs, eval_recurrent_hidden_states, eval_masks, deterministic=True) # Obser reward and next obs occ_obs, sign_obs, reward, done, infos = eval_envs.step(action) eval_masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) eval_episode_rewards.append(reward) eval_envs.close() if np.mean(eval_episode_rewards) > best_eval_rewards: best_eval_rewards = np.mean(eval_episode_rewards) save_model = target_actor_critic if args.cuda: save_model = copy.deepcopy(target_actor_critic).cpu() torch.save(save_model, os.path.join(save_path, 'best_model.pt')) ## Visualize tracked rewards(over num_steps) over time if args.vis: visualize(reward_track, args.algo, save_path)
def main(): args = get_args() # Record trajectories if args.record_trajectories: record_trajectories() return print(args) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) if args.cuda and torch.cuda.is_available() and args.cuda_deterministic: torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True # Append the model name log_dir = os.path.expanduser(args.log_dir) log_dir = os.path.join(log_dir, args.model_name, str(args.seed)) eval_log_dir = log_dir + "_eval" utils.cleanup_log_dir(log_dir) utils.cleanup_log_dir(eval_log_dir) torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") envs = make_vec_envs(args.env_name, args.seed, args.num_processes, args.gamma, log_dir, device, False) obs_shape = len(envs.observation_space.shape) # Take activation for carracing print("Loaded env...") activation = None if args.env_name == 'CarRacing-v0' and args.use_activation: activation = torch.tanh print(activation) actor_critic = Policy(envs.observation_space.shape, envs.action_space, base_kwargs={ 'recurrent': args.recurrent_policy, 'env': args.env_name }, activation=activation) actor_critic.to(device) # Load from previous model if args.load_model_name: state = torch.load( os.path.join(args.save_dir, args.load_model_name, args.load_model_name + '_{}.pt'.format(args.seed)))[0] try: actor_critic.load_state_dict(state) except: actor_critic = state # If BCGAIL, then decay factor and gamma should be float if args.bcgail: assert type(args.decay) == float assert type(args.gailgamma) == float if args.decay < 0: args.decay = 1 elif args.decay > 1: args.decay = 0.5**(1. / args.decay) print('Gamma: {}, decay: {}'.format(args.gailgamma, args.decay)) print('BCGAIL used') if args.algo == 'a2c': agent = algo.A2C_ACKTR(actor_critic, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, alpha=args.alpha, max_grad_norm=args.max_grad_norm) elif args.algo == 'ppo': agent = algo.PPO(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, gamma=args.gailgamma, decay=args.decay, act_space=envs.action_space, max_grad_norm=args.max_grad_norm) elif args.algo == 'acktr': agent = algo.A2C_ACKTR(actor_critic, args.value_loss_coef, args.entropy_coef, acktr=True) if args.gail: if len(envs.observation_space.shape) == 1: # Load RED here red = None if args.red: red = gail.RED( envs.observation_space.shape[0] + envs.action_space.shape[0], 100, device, args.redsigma, args.rediters) discr = gail.Discriminator(envs.observation_space.shape[0] + envs.action_space.shape[0], 100, device, red=red, sail=args.sail, learn=args.learn) file_name = os.path.join( args.gail_experts_dir, "trajs_{}.pt".format(args.env_name.split('-')[0].lower())) expert_dataset = gail.ExpertDataset(file_name, num_trajectories=args.num_traj, subsample_frequency=1) args.gail_batch_size = min(args.gail_batch_size, len(expert_dataset)) drop_last = len(expert_dataset) > args.gail_batch_size gail_train_loader = torch.utils.data.DataLoader( dataset=expert_dataset, batch_size=args.gail_batch_size, shuffle=True, drop_last=drop_last) print("Data loader size", len(expert_dataset)) else: # env observation shape is 3 => its an image assert len(envs.observation_space.shape) == 3 discr = gail.CNNDiscriminator(envs.observation_space.shape, envs.action_space, 100, device) file_name = os.path.join(args.gail_experts_dir, 'expert_data.pkl') expert_dataset = gail.ExpertImageDataset(file_name, act=envs.action_space) gail_train_loader = torch.utils.data.DataLoader( dataset=expert_dataset, batch_size=args.gail_batch_size, shuffle=True, drop_last=len(expert_dataset) > args.gail_batch_size, ) print('Dataloader size', len(gail_train_loader)) rollouts = RolloutStorage(args.num_steps, args.num_processes, envs.observation_space.shape, envs.action_space, actor_critic.recurrent_hidden_state_size) obs = envs.reset() rollouts.obs[0].copy_(obs) rollouts.to(device) episode_rewards = deque(maxlen=10) start = time.time() num_updates = int( args.num_env_steps) // args.num_steps // args.num_processes print(num_updates) for j in range(num_updates): if args.use_linear_lr_decay: # decrease learning rate linearly utils.update_linear_schedule( agent.optimizer, j, num_updates, agent.optimizer.lr if args.algo == "acktr" else args.lr) for step in range(args.num_steps): # Sample actions with torch.no_grad(): value, action, action_log_prob, recurrent_hidden_states = actor_critic.act( rollouts.obs[step], rollouts.recurrent_hidden_states[step], rollouts.masks[step]) # Observe reward and next obs obs, reward, done, infos = envs.step(action) for info in infos: if 'episode' in info.keys(): episode_rewards.append(info['episode']['r']) # If done then clean the history of observations. masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]) rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks, bad_masks) with torch.no_grad(): next_value = actor_critic.get_value( rollouts.obs[-1], rollouts.recurrent_hidden_states[-1], rollouts.masks[-1]).detach() if args.gail: if j >= 10: try: envs.venv.eval() except: pass gail_epoch = args.gail_epoch if j < 10 and obs_shape == 1: gail_epoch = 100 # Warm up for _ in range(gail_epoch): if obs_shape == 1: discr.update(gail_train_loader, rollouts, utils.get_vec_normalize(envs)._obfilt) else: discr.update(gail_train_loader, rollouts, None) if obs_shape == 3: obfilt = None else: obfilt = utils.get_vec_normalize(envs)._rev_obfilt for step in range(args.num_steps): rollouts.rewards[step] = discr.predict_reward( rollouts.obs[step], rollouts.actions[step], args.gamma, rollouts.masks[step], obfilt ) # The reverse function is passed down for RED to receive unnormalized obs which it is trained on rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.gae_lambda, args.use_proper_time_limits) if args.bcgail: if obs_shape == 3: obfilt = None else: obfilt = utils.get_vec_normalize(envs)._obfilt value_loss, action_loss, dist_entropy = agent.update( rollouts, gail_train_loader, obfilt) else: value_loss, action_loss, dist_entropy = agent.update(rollouts) rollouts.after_update() # save for every interval-th episode or for the last epoch if (j % args.save_interval == 0 or j == num_updates - 1) and args.save_dir != "": save_path = os.path.join(args.save_dir, args.model_name) try: os.makedirs(save_path) except OSError: pass torch.save([ actor_critic.state_dict(), getattr(utils.get_vec_normalize(envs), 'ob_rms', None), getattr(utils.get_vec_normalize(envs), 'ret_rms', None) ], os.path.join( save_path, args.model_name + "_{}.pt".format(args.seed))) if j % args.log_interval == 0 and len(episode_rewards) > 1: total_num_steps = (j + 1) * args.num_processes * args.num_steps end = time.time() print( "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n" .format(j, total_num_steps, int(total_num_steps / (end - start)), len(episode_rewards), np.mean(episode_rewards), np.median(episode_rewards), np.min(episode_rewards), np.max(episode_rewards), dist_entropy, value_loss, action_loss)) if (args.eval_interval is not None and len(episode_rewards) > 1 and j % args.eval_interval == 0): ob_rms = utils.get_vec_normalize(envs).ob_rms evaluate(actor_critic, ob_rms, args.env_name, args.seed, args.num_processes, eval_log_dir, device)
def main(): args = get_args() torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) if args.cuda and torch.cuda.is_available() and args.cuda_deterministic: torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True log_dir = os.path.expanduser(args.log_dir) eval_log_dir = log_dir + "_eval" utils.cleanup_log_dir(log_dir) utils.cleanup_log_dir(eval_log_dir) # import pdb; pdb.set_trace() save_path = os.path.join(args.save_dir, args.algo) torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") envs = make_vec_envs(args.env_name, args.seed, args.num_processes, args.gamma, args.log_dir, device, False) # import pdb; pdb.set_trace() actor_critic = Policy(envs.observation_space.shape, envs.action_space, base_kwargs={'recurrent': args.recurrent_policy}) # transforms = [ hl.transforms.Prune('Constant') ] # Removes Constant nodes from graph. # graph = hl.build_graph(actor_critic, torch.zeros([1, 1, 64, 64]), transforms=transforms) # graph.theme = hl.graph.THEMES['blue'].copy() # graph.save('rnn_hiddenlayer2', format='png') # print(args.re) # import pdb; pdb.set_trace() my_model_state_dict = actor_critic.state_dict() count = 0 pretrained_weights = torch.load('net_main_4rh_v2_64.pth') # pretrained_weights = torch.load(os.path.join(save_path, args.env_name + "_ft.pt")) # pretrained_weights[''] old_names = list(pretrained_weights.items()) pretrained_weights_items = list(pretrained_weights.items()) for key, value in my_model_state_dict.items(): layer_name, weights = pretrained_weights_items[count] my_model_state_dict[key] = weights print(count) print(layer_name) count += 1 if layer_name == 'enc_dense.bias': break # pretrained_weights = torch.load(os.path.join(save_path, args.env_name + "_random.pt"))[1] actor_critic.load_state_dict(my_model_state_dict) start_epoch = 0 ka = 0 # for param in actor_critic.parameters(): # ka += 1 # # import pdb; pdb.set_trace() # param.requires_grad = False # if ka == 14: # break count = 0 # import pdb; pdb.set_trace()n actor_critic.to(device) if args.algo == 'a2c': agent = algo.A2C_ACKTR(actor_critic, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, alpha=args.alpha, max_grad_norm=args.max_grad_norm) elif args.algo == 'ppo': agent = algo.PPO(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm) elif args.algo == 'acktr': agent = algo.A2C_ACKTR(actor_critic, args.value_loss_coef, args.entropy_coef, acktr=True) if args.gail: assert len(envs.observation_space.shape) == 1 discr = gail.Discriminator( envs.observation_space.shape[0] + envs.action_space.shape[0], 100, device) file_name = os.path.join( args.gail_experts_dir, "trajs_{}.pt".format(args.env_name.split('-')[0].lower())) expert_dataset = gail.ExpertDataset(file_name, num_trajectories=4, subsample_frequency=20) drop_last = len(expert_dataset) > args.gail_batch_size gail_train_loader = torch.utils.data.DataLoader( dataset=expert_dataset, batch_size=args.gail_batch_size, shuffle=True, drop_last=drop_last) rollouts = RolloutStorage(args.num_steps, args.num_processes, envs.observation_space.shape, envs.action_space, actor_critic.recurrent_hidden_state_size) obs = envs.reset() rollouts.obs[0].copy_(obs) rollouts.to(device) episode_rewards = deque(maxlen=10) start = time.time() rewards_mean = [] rewards_median = [] val_loss = [] act_loss = [] num_updates = int( args.num_env_steps) // args.num_steps // args.num_processes for j in range(start_epoch, num_updates): if args.use_linear_lr_decay: # decrease learning rate linearly utils.update_linear_schedule( agent.optimizer, j, num_updates, agent.optimizer.lr if args.algo == "acktr" else args.lr) for step in range(args.num_steps): # Sample actions with torch.no_grad(): value, action, action_log_prob, recurrent_hidden_states = actor_critic.act( rollouts.obs[step], rollouts.recurrent_hidden_states[step], rollouts.masks[step]) # Obser reward and next obs obs, reward, done, infos = envs.step(action) for info in infos: if 'episode' in info.keys(): episode_rewards.append(info['episode']['r']) # If done then clean the history of observations. masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]) rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks, bad_masks) with torch.no_grad(): next_value = actor_critic.get_value( rollouts.obs[-1], rollouts.recurrent_hidden_states[-1], rollouts.masks[-1]).detach() if args.gail: if j >= 10: envs.venv.eval() gail_epoch = args.gail_epoch if j < 10: gail_epoch = 100 # Warm up for _ in range(gail_epoch): discr.update(gail_train_loader, rollouts, utils.get_vec_normalize(envs)._obfilt) for step in range(args.num_steps): rollouts.rewards[step] = discr.predict_reward( rollouts.obs[step], rollouts.actions[step], args.gamma, rollouts.masks[step]) rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.gae_lambda, args.use_proper_time_limits) value_loss, action_loss, dist_entropy = agent.update(rollouts) rollouts.after_update() # save for every interval-th episode or for the last epoch if (j % args.save_interval == 0 or j == num_updates - 1) and args.save_dir != "": save_path = os.path.join(args.save_dir, args.algo) try: os.makedirs(save_path) except OSError: pass torch.save([ actor_critic, actor_critic.state_dict(), getattr(utils.get_vec_normalize(envs), 'ob_rms', None) ], os.path.join(save_path, args.env_name + "_finetune.pt")) if j % args.log_interval == 0 and len(episode_rewards) > 1: total_num_steps = (j + 1) * args.num_processes * args.num_steps end = time.time() print( "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n" .format(j, total_num_steps, int(total_num_steps / (end - start)), len(episode_rewards), np.mean(episode_rewards), np.median(episode_rewards), np.min(episode_rewards), np.max(episode_rewards), dist_entropy, value_loss, action_loss)) rewards_mean.append(np.mean(episode_rewards)) rewards_median.append(np.median(episode_rewards)) val_loss.append(value_loss) act_loss.append(action_loss) torch.save( rewards_mean, "./plot_data/" + args.env_name + "_avg_rewards_finetune.pt") torch.save( rewards_median, "./plot_data/" + args.env_name + "_median_rewards_finetune.pt") # torch.save(val_loss, "./plot_data/"+args.env_name+"_val_loss_enc_weights.pt") # torch.save(act_loss, "./plot_data/"+args.env_name+"_act_loss_enc_weights.pt") plt.plot(rewards_mean) # print(plt_points2) plt.savefig("./imgs/" + args.env_name + "avg_reward_finetune.png") # plt.show(block = False) if (args.eval_interval is not None and len(episode_rewards) > 1 and j % args.eval_interval == 0): ob_rms = utils.get_vec_normalize(envs).ob_rms evaluate(actor_critic, ob_rms, args.env_name, args.seed, args.num_processes, eval_log_dir, device)