def main(params): for k, v in zip(params.keys(), params.values()): assert v is not None, f"Value for {k} is None" # # # Setup logging # # metadata_schema = schema_from_dict(params) base_directory = params['out_dir'] store = Store(base_directory) # redirect stderr, stdout to file """ def make_err_redirector(stream_name): tee = Tee(os.path.join(store.path, stream_name + '.txt'), stream_name) return tee stderr_tee = make_err_redirector('stderr') stdout_tee = make_err_redirector('stdout') """ # Store the experiment path and the git commit for this experiment metadata_schema.update({ 'store_path':str, 'git_commit':str }) repo = git.Repo(path=os.path.dirname(os.path.realpath(__file__)), search_parent_directories=True) metadata_table = store.add_table('metadata', metadata_schema) metadata_table.update_row(params) metadata_table.update_row({ 'store_path':store.path, 'git_commit':repo.head.object.hexsha }) metadata_table.flush_row() # Table for checkpointing models and envs if params['save_iters'] > 0: store.add_table('checkpoints', { 'val_model':store.PYTORCH_STATE, 'policy_model':store.PYTORCH_STATE, 'envs':store.PICKLE, 'policy_opt': store.PYTORCH_STATE, 'val_opt': store.PYTORCH_STATE, 'iteration':int }) # The trainer object is in charge of sampling trajectories and # taking PPO/TRPO optimization steps p = Trainer.agent_from_params(params, store=store) if 'load_model' in params and params['load_model']: print('Loading pretrained model', params['load_model']) pretrained_models = torch.load(params['load_model']) p.policy_model.load_state_dict(pretrained_models['policy_model']) p.val_model.load_state_dict(pretrained_models['val_model']) # Load optimizer states. Note that # p.POLICY_ADAM.load_state_dict(pretrained_models['policy_opt']) # p.val_opt.load_state_dict(pretrained_models['val_opt']) # Restore environment parameters, like mean and std. p.envs = pretrained_models['envs'] rewards = [] # Table for final results final_table = store.add_table('final_results', { 'iteration':int, '5_rewards':float, 'terminated_early':bool, 'val_model':store.PYTORCH_STATE, 'policy_model':store.PYTORCH_STATE, 'envs':store.PICKLE, 'policy_opt': store.PYTORCH_STATE, 'val_opt': store.PYTORCH_STATE, 'iteration':int }) def finalize_table(iteration, terminated_early, rewards): final_5_rewards = np.array(rewards)[-5:].mean() final_table.append_row({ 'iteration':iteration, '5_rewards':final_5_rewards, 'terminated_early':terminated_early, 'iteration':iteration, 'val_model': p.val_model.state_dict(), 'policy_model': p.policy_model.state_dict(), 'policy_opt': p.POLICY_ADAM.state_dict(), 'val_opt': p.val_opt.state_dict(), 'envs':p.envs }) # Try-except so that we save if the user interrupts the process try: for i in range(params['train_steps']): print('Step %d' % (i,)) if params['save_iters'] > 0 and i % params['save_iters'] == 0: store['checkpoints'].append_row({ 'iteration':i, 'val_model': p.val_model.state_dict(), 'policy_model': p.policy_model.state_dict(), 'policy_opt': p.POLICY_ADAM.state_dict(), 'val_opt': p.val_opt.state_dict(), 'envs':p.envs }) mean_reward = p.train_step() rewards.append(mean_reward) finalize_table(i, False, rewards) except KeyboardInterrupt: torch.save(p.val_model, 'saved_experts/%s-expert-vf' % (params['game'],)) torch.save(p.policy_model, 'saved_experts/%s-expert-pol' % (params['game'],)) finalize_table(i, True, rewards) store.close()
def main(params): override_params = copy.deepcopy(params) excluded_params = [ 'config_path', 'out_dir_prefix', 'num_episodes', 'row_id', 'exp_id', 'load_model', 'seed', 'deterministic', 'noise_factor', 'compute_kl_cert', 'use_full_backward', 'sqlite_path', 'early_terminate' ] sarsa_params = [ 'sarsa_enable', 'sarsa_steps', 'sarsa_eps', 'sarsa_reg', 'sarsa_model_path' ] imit_params = ['imit_enable', 'imit_epochs', 'imit_model_path', 'imit_lr'] # original_params contains all flags in config files that are overridden via command. for k in list(override_params.keys()): if k in excluded_params: del override_params[k] if params['sqlite_path']: print( f"Will save results in sqlite database in {params['sqlite_path']}") connection = sqlite3.connect(params['sqlite_path']) cur = connection.cursor() cur.execute('''create table if not exists attack_results (method varchar(20), mean_reward real, std_reward real, min_reward real, max_reward real, sarsa_eps real, sarsa_reg real, sarsa_steps integer, deterministic bool, early_terminate bool)''') connection.commit() # We will set this flag to True we break early. early_terminate = False # Append a prefix for output path. if params['out_dir_prefix']: params['out_dir'] = os.path.join(params['out_dir_prefix'], params['out_dir']) print(f"setting output dir to {params['out_dir']}") if params['config_path']: # Load from a pretrained model using existing config. # First we need to create the model using the given config file. json_params = json.load(open(params['config_path'])) params = override_json_params( params, json_params, excluded_params + sarsa_params + imit_params) if params['sarsa_enable']: assert params['attack_method'] == "none" or params['attack_method'] is None, \ "--train-sarsa is only available when --attack-method=none, but got {}".format(params['attack_method']) if 'load_model' in params and params['load_model']: for k, v in zip(params.keys(), params.values()): assert v is not None, f"Value for {k} is None" # Create the agent from config file. p = Trainer.agent_from_params(params, store=None) print('Loading pretrained model', params['load_model']) pretrained_model = torch.load(params['load_model']) if 'policy_model' in pretrained_model: p.policy_model.load_state_dict(pretrained_model['policy_model']) if 'val_model' in pretrained_model: p.val_model.load_state_dict(pretrained_model['val_model']) if 'policy_opt' in pretrained_model: p.POLICY_ADAM.load_state_dict(pretrained_model['policy_opt']) if 'val_opt' in pretrained_model: p.val_opt.load_state_dict(pretrained_model['val_opt']) # Restore environment parameters, like mean and std. if 'envs' in pretrained_model: p.envs = pretrained_model['envs'] for e in p.envs: e.normalizer_read_only = True e.setup_visualization(params['show_env'], params['save_frames'], params['save_frames_path']) else: # Load from experiment directory. No need to use a config. base_directory = params['out_dir'] store = Store(base_directory, params['exp_id'], mode='r') if params['row_id'] < 0: row = store['final_results'].df else: checkpoints = store['checkpoints'].df row_id = params['row_id'] row = checkpoints.iloc[row_id:row_id + 1] print("row to test: ", row) if params['cpu'] == None: cpu = False else: cpu = params['cpu'] p, _ = Trainer.agent_from_data(store, row, cpu, extra_params=params, override_params=override_params, excluded_params=excluded_params) store.close() rewards = [] print('Gaussian noise in policy:') print(torch.exp(p.policy_model.log_stdev)) original_stdev = p.policy_model.log_stdev.clone().detach() if params['noise_factor'] != 1.0: p.policy_model.log_stdev.data[:] += np.log(params['noise_factor']) if params['deterministic']: print('Policy runs in deterministic mode. Ignoring Gaussian noise.') p.policy_model.log_stdev.data[:] = -100 print('Gaussian noise in policy (after adjustment):') print(torch.exp(p.policy_model.log_stdev)) if params['sarsa_enable']: num_steps = params['sarsa_steps'] # learning rate scheduler: linearly annealing learning rate after lr_decrease_point = num_steps * 2 / 3 decreasing_steps = num_steps - lr_decrease_point lr_sch = lambda epoch: 1.0 if epoch < lr_decrease_point else ( decreasing_steps - epoch + lr_decrease_point) / decreasing_steps # robust training scheduler. Currently using 1/3 epochs for warmup, 1/3 for schedule and 1/3 for final training. eps_start_point = int(num_steps * 1 / 3) robust_eps_scheduler = LinearScheduler( params['sarsa_eps'], f"start={eps_start_point},length={eps_start_point}") robust_beta_scheduler = LinearScheduler( 1.0, f"start={eps_start_point},length={eps_start_point}") # reinitialize value model, and run value function learning steps. p.setup_sarsa(lr_schedule=lr_sch, eps_scheduler=robust_eps_scheduler, beta_scheduler=robust_beta_scheduler) # Run Sarsa training. for i in range(num_steps): print( f'Step {i+1} / {num_steps}, lr={p.sarsa_scheduler.get_last_lr()}' ) mean_reward = p.sarsa_step() rewards.append(mean_reward) # for w in p.val_model.parameters(): # print(f'{w.size()}, {torch.norm(w.view(-1), 2)}') # Save Sarsa model. saved_model = { 'state_dict': p.sarsa_model.state_dict(), 'metadata': params, } torch.save(saved_model, params['sarsa_model_path']) elif params['imit_enable']: num_epochs = params['imit_epochs'] num_episodes = params['num_episodes'] print('\n\n' + 'Start collecting data\n' + '-' * 80) for i in range(num_episodes): print('Collecting %d / %d episodes' % (i + 1, num_episodes)) ep_length, ep_reward, actions, action_means, states, kl_certificates = p.run_test( compute_bounds=params['compute_kl_cert'], use_full_backward=params['use_full_backward'], original_stdev=original_stdev) not_dones = np.ones(len(actions)) not_dones[-1] = 0 if i == 0: all_actions = actions.copy() all_states = states.copy() all_not_dones = not_dones.copy() else: all_actions = np.concatenate((all_actions, actions), axis=0) all_states = np.concatenate((all_states, states), axis=0) all_not_dones = np.concatenate((all_not_dones, not_dones)) print('Collected actions shape:', all_actions.shape) print('Collected states shape:', all_states.shape) p.setup_imit(lr=params['imit_lr']) p.imit_steps(torch.from_numpy(all_actions), torch.from_numpy(all_states), torch.from_numpy(all_not_dones), num_epochs) saved_model = { 'state_dict': p.imit_network.state_dict(), 'metadata': params, } torch.save(saved_model, params['imit_model_path']) else: num_episodes = params['num_episodes'] all_rewards = [] all_lens = [] all_kl_certificates = [] for i in range(num_episodes): print('Episode %d / %d' % (i + 1, num_episodes)) ep_length, ep_reward, actions, action_means, states, kl_certificates = p.run_test( compute_bounds=params['compute_kl_cert'], use_full_backward=params['use_full_backward'], original_stdev=original_stdev) if i == 0: all_actions = actions.copy() all_states = states.copy() else: all_actions = np.concatenate((all_actions, actions), axis=0) all_states = np.concatenate((all_states, states), axis=0) if params['compute_kl_cert']: print('Epoch KL certificates:', kl_certificates) all_kl_certificates.append(kl_certificates) all_rewards.append(ep_reward) all_lens.append(ep_length) # Current step mean, std, min and max mean_reward, std_reward, min_reward, max_reward = np.mean( all_rewards), np.std(all_rewards), np.min(all_rewards), np.max( all_rewards) if i > num_episodes // 5 and params['early_terminate'] and params[ 'sqlite_path'] and params['attack_method'] != 'none': # Attempt to early terminiate if some other attacks have done with low reward. cur.execute( "SELECT MIN(mean_reward) FROM attack_results WHERE deterministic=?;", (params['deterministic'], )) current_best_reward = cur.fetchone()[0] print( f'current best: {current_best_reward}, ours: {mean_reward} +/- {std_reward}, min: {min_reward}' ) # Terminiate if mean - 2*std is worse than best, or our min is worse than best. if current_best_reward is not None and ( (current_best_reward < mean_reward - 2 * std_reward) or (min_reward > current_best_reward)): print('terminating early!') early_terminate = True break attack_dir = 'attack-{}-eps-{}'.format(params['attack_method'], params['attack_eps']) if 'sarsa' in params['attack_method']: attack_dir += '-sarsa_steps-{}-sarsa_eps-{}-sarsa_reg-{}'.format( params['sarsa_steps'], params['sarsa_eps'], params['sarsa_reg']) if 'action' in params['attack_method']: attack_dir += '-attack_sarsa_action_ratio-{}'.format( params['attack_sarsa_action_ratio']) save_path = os.path.join(params['out_dir'], params['exp_id'], attack_dir) if not os.path.exists(save_path): os.makedirs(save_path) for name, value in [('actions', all_actions), ('states', all_states), ('rewards', all_rewards), ('length', all_lens)]: with open(os.path.join(save_path, '{}.pkl'.format(name)), 'wb') as f: pickle.dump(value, f) print(params) with open(os.path.join(save_path, 'params.json'), 'w') as f: json.dump(params, f, indent=4) mean_reward, std_reward, min_reward, max_reward = np.mean( all_rewards), np.std(all_rewards), np.min(all_rewards), np.max( all_rewards) if params['compute_kl_cert']: print('KL certificates stats: mean: {}, std: {}, min: {}, max: {}'. format(np.mean(all_kl_certificates), np.std(all_kl_certificates), np.min(all_kl_certificates), np.max(all_kl_certificates))) # write results to sqlite. if params['sqlite_path']: method = params['attack_method'] if params['attack_method'] == "sarsa": # Load sarsa parameters from checkpoint sarsa_ckpt = torch.load(params['attack_sarsa_network']) sarsa_meta = sarsa_ckpt['metadata'] sarsa_eps = sarsa_meta[ 'sarsa_eps'] if 'sarsa_eps' in sarsa_meta else -1.0 sarsa_reg = sarsa_meta[ 'sarsa_reg'] if 'sarsa_reg' in sarsa_meta else -1.0 sarsa_steps = sarsa_meta[ 'sarsa_steps'] if 'sarsa_steps' in sarsa_meta else -1 elif params['attack_method'] == "sarsa+action": sarsa_eps = -1.0 sarsa_reg = params['attack_sarsa_action_ratio'] sarsa_steps = -1 else: sarsa_eps = -1.0 sarsa_reg = -1.0 sarsa_steps = -1 try: cur.execute( "INSERT INTO attack_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);", (method, mean_reward, std_reward, min_reward, max_reward, sarsa_eps, sarsa_reg, sarsa_steps, params['deterministic'], early_terminate)) connection.commit() except sqlite3.OperationalError as e: import traceback traceback.print_exc() print('Cannot insert into the SQLite table. Give up.') else: print(f'results saved to database {params["sqlite_path"]}') connection.close() print('\n') print('all rewards:', all_rewards) print('rewards stats:\nmean: {}, std:{}, min:{}, max:{}'.format( mean_reward, std_reward, min_reward, max_reward))
def main(params): override_params = copy.deepcopy(params) excluded_params = [ 'config_path', 'out_dir_prefix', 'num_episodes', 'row_id', 'exp_id', 'load_model', 'seed', 'deterministic', 'scan_config', 'compute_kl_cert', 'use_full_backward' ] sarsa_params = [ 'sarsa_enable', 'sarsa_steps', 'sarsa_eps', 'sarsa_reg', 'sarsa_model_path' ] # original_params contains all flags in config files that are overridden via command. for k in list(override_params.keys()): if k in excluded_params: del override_params[k] # Append a prefix for output path. if params['out_dir_prefix']: params['out_dir'] = os.path.join(params['out_dir_prefix'], params['out_dir']) print(f"setting output dir to {params['out_dir']}") if params['config_path']: # Load from a pretrained model using existing config. # First we need to create the model using the given config file. json_params = json.load(open(params['config_path'])) params = override_json_params(params, json_params, excluded_params + sarsa_params) if params['sarsa_enable']: assert params['attack_method'] == "none" or params['attack_method'] is None, \ "--train-sarsa is only available when --attack-method=none, but got {}".format(params['attack_method']) if 'load_model' in params and params['load_model']: for k, v in zip(params.keys(), params.values()): assert v is not None, f"Value for {k} is None" # Create the agent from config file. p = Trainer.agent_from_params(params, store=None) print('Loading pretrained model', params['load_model']) pretrained_model = torch.load(params['load_model']) if 'policy_model' in pretrained_model: p.policy_model.load_state_dict(pretrained_model['policy_model']) if 'val_model' in pretrained_model: p.val_model.load_state_dict(pretrained_model['val_model']) if 'policy_opt' in pretrained_model: p.POLICY_ADAM.load_state_dict(pretrained_model['policy_opt']) if 'val_opt' in pretrained_model: p.val_opt.load_state_dict(pretrained_model['val_opt']) # Restore environment parameters, like mean and std. if 'envs' in pretrained_model: p.envs = pretrained_model['envs'] for e in p.envs: e.normalizer_read_only = True e.setup_visualization(params['show_env'], params['save_frames'], params['save_frames_path']) else: # Load from experiment directory. No need to use a config. base_directory = params['out_dir'] store = Store(base_directory, params['exp_id'], mode='r') if params['row_id'] < 0: row = store['final_results'].df else: checkpoints = store['checkpoints'].df row_id = params['row_id'] row = checkpoints.iloc[row_id:row_id + 1] print("row to test: ", row) if params['cpu'] == None: cpu = False else: cpu = params['cpu'] p, _ = Trainer.agent_from_data(store, row, cpu, extra_params=params, override_params=override_params, excluded_params=excluded_params) store.close() rewards = [] if params['sarsa_enable']: num_steps = params['sarsa_steps'] # learning rate scheduler: linearly annealing learning rate after lr_decrease_point = num_steps * 2 / 3 decreasing_steps = num_steps - lr_decrease_point lr_sch = lambda epoch: 1.0 if epoch < lr_decrease_point else ( decreasing_steps - epoch + lr_decrease_point) / decreasing_steps # robust training scheduler. Currently using 1/3 epochs for warmup, 1/3 for schedule and 1/3 for final training. eps_start_point = int(num_steps * 1 / 3) robust_eps_scheduler = LinearScheduler( params['sarsa_eps'], f"start={eps_start_point},length={eps_start_point}") robust_beta_scheduler = LinearScheduler( 1.0, f"start={eps_start_point},length={eps_start_point}") # reinitialize value model, and run value function learning steps. p.setup_sarsa(lr_schedule=lr_sch, eps_scheduler=robust_eps_scheduler, beta_scheduler=robust_beta_scheduler) # Run Sarsa training. for i in range(num_steps): print( f'Step {i+1} / {num_steps}, lr={p.sarsa_scheduler.get_last_lr()}' ) mean_reward = p.sarsa_step() rewards.append(mean_reward) # for w in p.val_model.parameters(): # print(f'{w.size()}, {torch.norm(w.view(-1), 2)}') # Save Sarsa model. saved_model = { 'state_dict': p.sarsa_model.state_dict(), 'metadata': params, } torch.save(saved_model, params['sarsa_model_path']) else: print('Gaussian noise in policy:') print(torch.exp(p.policy_model.log_stdev)) if params['deterministic']: print( 'Policy runs in deterministic mode. Ignoring Gaussian noise.') p.policy_model.log_stdev.data[:] = -100 num_episodes = params['num_episodes'] all_rewards = [] all_lens = [] all_kl_certificates = [] for i in range(num_episodes): print('Episode %d / %d' % (i + 1, num_episodes)) ep_length, ep_reward, actions, action_means, states, kl_certificates = p.run_test( compute_bounds=params['compute_kl_cert'], use_full_backward=params['use_full_backward']) if i == 0: all_actions = actions.copy() all_states = states.copy() else: all_actions = np.concatenate((all_actions, actions), axis=0) all_states = np.concatenate((all_states, states), axis=0) if params['compute_kl_cert']: print('Epoch KL certificates:', kl_certificates) all_kl_certificates.append(kl_certificates) all_rewards.append(ep_reward) all_lens.append(ep_length) attack_dir = 'attack-{}-eps-{}'.format(params['attack_method'], params['attack_eps']) if 'sarsa' in params['attack_method']: attack_dir += '-sarsa_steps-{}-sarsa_eps-{}-sarsa_reg-{}'.format( params['sarsa_steps'], params['sarsa_eps'], params['sarsa_reg']) if 'action' in params['attack_method']: attack_dir += '-attack_sarsa_action_ratio-{}'.format( params['attack_sarsa_action_ratio']) save_path = os.path.join(params['out_dir'], params['exp_id'], attack_dir) if not os.path.exists(save_path): os.makedirs(save_path) for name, value in [('actions', all_actions), ('states', all_states), ('rewards', all_rewards), ('length', all_lens)]: with open(os.path.join(save_path, '{}.pkl'.format(name)), 'wb') as f: pickle.dump(value, f) print(params) with open(os.path.join(save_path, 'params.json'), 'w') as f: json.dump(params, f, indent=4) print('\n') print('all rewards:', all_rewards) print('rewards stats:\nmean: {}, std:{}, min:{}, max:{}'.format( np.mean(all_rewards), np.std(all_rewards), np.min(all_rewards), np.max(all_rewards))) if params['compute_kl_cert']: print('KL certificates stats: mean: {}, std: {}, min: {}, max: {}'. format(np.mean(all_kl_certificates), np.std(all_kl_certificates), np.min(all_kl_certificates), np.max(all_kl_certificates)))
def main(params): for k, v in zip(params.keys(), params.values()): assert v is not None, f"Value for {k} is None" # # # Setup logging # # metadata_schema = schema_from_dict(params) base_directory = params['out_dir'] store = Store(base_directory) # redirect stderr, stdout to file """ def make_err_redirector(stream_name): tee = Tee(os.path.join(store.path, stream_name + '.txt'), stream_name) return tee stderr_tee = make_err_redirector('stderr') stdout_tee = make_err_redirector('stdout') """ # Store the experiment path and the git commit for this experiment metadata_schema.update({ 'store_path': str, 'git_commit': str }) repo = git.Repo(path=os.path.dirname(os.path.realpath(__file__)), search_parent_directories=True) metadata_table = store.add_table('metadata', metadata_schema) metadata_table.update_row(params) metadata_table.update_row({ 'store_path': store.path, 'git_commit': repo.head.object.hexsha }) metadata_table.flush_row() # Extra items in table when minimax training is enabled. if params['mode'] == "adv_ppo" or params['mode'] == 'adv_trpo' or params['mode'] == 'adv_sa_ppo': adversary_table_dict = { 'adversary_policy_model': store.PYTORCH_STATE, 'adversary_policy_opt': store.PYTORCH_STATE, 'adversary_val_model': store.PYTORCH_STATE, 'adversary_val_opt': store.PYTORCH_STATE, } else: adversary_table_dict = {} # Table for checkpointing models and envs if params['save_iters'] > 0: checkpoint_dict = { 'val_model': store.PYTORCH_STATE, 'policy_model': store.PYTORCH_STATE, 'envs': store.PICKLE, 'policy_opt': store.PYTORCH_STATE, 'val_opt': store.PYTORCH_STATE, 'iteration': int, '5_rewards': float, } checkpoint_dict.update(adversary_table_dict) store.add_table('checkpoints', checkpoint_dict) # The trainer object is in charge of sampling trajectories and # taking PPO/TRPO optimization steps p = Trainer.agent_from_params(params, store=store) if params['initial_std'] != 1.0: p.policy_model.log_stdev.data[:] = np.log(params['initial_std']) if 'load_model' in params and params['load_model']: print('Loading pretrained model', params['load_model']) pretrained_model = torch.load(params['load_model']) if 'policy_model' in pretrained_model: p.policy_model.load_state_dict(pretrained_model['policy_model']) if params['deterministic']: print('Policy runs in deterministic mode. Ignoring Gaussian noise.') p.policy_model.log_stdev.data[:] = -100 else: print('Policy runs in non deterministic mode with Gaussian noise.') if 'val_model' in pretrained_model: p.val_model.load_state_dict(pretrained_model['val_model']) if 'policy_opt' in pretrained_model: p.POLICY_ADAM.load_state_dict(pretrained_model['policy_opt']) if 'val_opt' in pretrained_model: p.val_opt.load_state_dict(pretrained_model['val_opt']) # Load adversary models. if 'no_load_adv_policy' in params and params['no_load_adv_policy']: print('Skipping loading adversary models.') else: if 'adversary_policy_model' in pretrained_model and hasattr(p, 'adversary_policy_model'): p.adversary_policy_model.load_state_dict(pretrained_model['adversary_policy_model']) if 'adversary_val_model' in pretrained_model and hasattr(p, 'adversary_val_model'): p.adversary_val_model.load_state_dict(pretrained_model['adversary_val_model']) if 'adversary_policy_opt' in pretrained_model and hasattr(p, 'adversary_policy_opt'): p.adversary_policy_opt.load_state_dict(pretrained_model['adversary_policy_opt']) if 'adversary_val_opt' in pretrained_model and hasattr(p, 'adversary_val_opt'): p.adversary_val_opt.load_state_dict(pretrained_model['adversary_val_opt']) # Load optimizer states. # p.POLICY_ADAM.load_state_dict(pretrained_models['policy_opt']) # p.val_opt.load_state_dict(pretrained_models['val_opt']) # Restore environment parameters, like mean and std. if 'envs' in pretrained_model: p.envs = pretrained_model['envs'] for e in p.envs: e.setup_visualization(params['show_env'], params['save_frames'], params['save_frames_path']) rewards = [] # Table for final results final_dict = { 'iteration': int, '5_rewards': float, 'terminated_early': bool, 'val_model': store.PYTORCH_STATE, 'policy_model': store.PYTORCH_STATE, 'envs': store.PICKLE, 'policy_opt': store.PYTORCH_STATE, 'val_opt': store.PYTORCH_STATE, } final_dict.update(adversary_table_dict) final_table = store.add_table('final_results', final_dict) def add_adversary_to_table(p, table_dict): if params['mode'] == "adv_ppo" or params['mode'] == 'adv_trpo' or params['mode'] == 'adv_sa_ppo': table_dict["adversary_policy_model"] = p.adversary_policy_model.state_dict() table_dict["adversary_policy_opt"] = p.ADV_POLICY_ADAM.state_dict() table_dict["adversary_val_model"] = p.adversary_val_model.state_dict() table_dict["adversary_val_opt"] = p.adversary_val_opt.state_dict() return table_dict def finalize_table(iteration, terminated_early, rewards): final_5_rewards = np.array(rewards)[-5:].mean() final_dict = { 'iteration': iteration, '5_rewards': final_5_rewards, 'terminated_early': terminated_early, 'val_model': p.val_model.state_dict(), 'policy_model': p.policy_model.state_dict(), 'policy_opt': p.POLICY_ADAM.state_dict(), 'val_opt': p.val_opt.state_dict(), 'envs': p.envs } final_dict = add_adversary_to_table(p, final_dict) final_table.append_row(final_dict) ret = 0 # Try-except so that we save if the user interrupts the process try: for i in range(params['train_steps']): print('Step %d' % (i,)) if params['save_iters'] > 0 and i % params['save_iters'] == 0 and i != 0: final_5_rewards = np.array(rewards)[-5:].mean() print(f'Saving checkpoints to {store.path} with reward {final_5_rewards:.5g}') checkpoint_dict = { 'iteration': i, 'val_model': p.val_model.state_dict(), 'policy_model': p.policy_model.state_dict(), 'policy_opt': p.POLICY_ADAM.state_dict(), 'val_opt': p.val_opt.state_dict(), 'envs': p.envs, '5_rewards': final_5_rewards, } checkpoint_dict = add_adversary_to_table(p, checkpoint_dict) store['checkpoints'].append_row(checkpoint_dict) mean_reward = p.train_step() rewards.append(mean_reward) # For debugging and tuning, we can break in the middle. if i == params['force_stop_step']: print('Terminating early because --force-stop-step is set.') raise KeyboardInterrupt finalize_table(i, False, rewards) except KeyboardInterrupt: finalize_table(i, True, rewards) ret = 1 except: print("An error occurred during training:") traceback.print_exc() # Other errors, make sure to finalize the cox store before exiting. finalize_table(i, True, rewards) ret = -1 print(f'Models saved to {store.path}') store.close() return ret