def setup_general_agent(params: dict, save_git: bool = True): """ General agent setup for logging results Args: params: dict with parameters save_git: Save git hash to restore experiment setting Returns: cox store for logging """ for k, v in zip(params.keys(), params.values()): assert v is not None, f"Value for {k} is None" # ensure when not using entropy constraint, the cov is not projected to -inf by accident if not params['entropy_schedule']: params['entropy_eq'] = False store = None if params['log_interval'] <= params['train_steps']: # Setup logging metadata_schema = schema_from_dict(params) base_directory = params['out_dir'] exp_name = params.get('exp_name') store = CustomStore(storage_folder=base_directory, exp_id=exp_name, new=True) # Store the experiment path metadata_schema.update({'store_path': str}) metadata_table = store.add_table('metadata', metadata_schema) metadata_table.update_row(params) metadata_table.update_row({ 'store_path': store.path, }) if save_git: # the git commit for this experiment metadata_schema.update({'git_commit': str}) repo = git.Repo(path=os.path.dirname(os.path.realpath(__file__)), search_parent_directories=True) metadata_table.update_row({'git_commit': repo.head.object.hexsha}) metadata_table.flush_row() # use 0 for saving last model only, # use -1 for no saving at all if params['save_interval'] == 0: params['save_interval'] = params['train_steps'] return store
def main(params): for k, v in zip(params.keys(), params.values()): assert v is not None, f"Value for {k} is None" # # # Setup logging # # metadata_schema = schema_from_dict(params) base_directory = params['out_dir'] store = Store(base_directory) # redirect stderr, stdout to file """ def make_err_redirector(stream_name): tee = Tee(os.path.join(store.path, stream_name + '.txt'), stream_name) return tee stderr_tee = make_err_redirector('stderr') stdout_tee = make_err_redirector('stdout') """ # Store the experiment path and the git commit for this experiment metadata_schema.update({ 'store_path':str, 'git_commit':str }) repo = git.Repo(path=os.path.dirname(os.path.realpath(__file__)), search_parent_directories=True) metadata_table = store.add_table('metadata', metadata_schema) metadata_table.update_row(params) metadata_table.update_row({ 'store_path':store.path, 'git_commit':repo.head.object.hexsha }) metadata_table.flush_row() # Table for checkpointing models and envs if params['save_iters'] > 0: store.add_table('checkpoints', { 'val_model':store.PYTORCH_STATE, 'policy_model':store.PYTORCH_STATE, 'envs':store.PICKLE, 'policy_opt': store.PYTORCH_STATE, 'val_opt': store.PYTORCH_STATE, 'iteration':int }) # The trainer object is in charge of sampling trajectories and # taking PPO/TRPO optimization steps p = Trainer.agent_from_params(params, store=store) if 'load_model' in params and params['load_model']: print('Loading pretrained model', params['load_model']) pretrained_models = torch.load(params['load_model']) p.policy_model.load_state_dict(pretrained_models['policy_model']) p.val_model.load_state_dict(pretrained_models['val_model']) # Load optimizer states. Note that # p.POLICY_ADAM.load_state_dict(pretrained_models['policy_opt']) # p.val_opt.load_state_dict(pretrained_models['val_opt']) # Restore environment parameters, like mean and std. p.envs = pretrained_models['envs'] rewards = [] # Table for final results final_table = store.add_table('final_results', { 'iteration':int, '5_rewards':float, 'terminated_early':bool, 'val_model':store.PYTORCH_STATE, 'policy_model':store.PYTORCH_STATE, 'envs':store.PICKLE, 'policy_opt': store.PYTORCH_STATE, 'val_opt': store.PYTORCH_STATE, 'iteration':int }) def finalize_table(iteration, terminated_early, rewards): final_5_rewards = np.array(rewards)[-5:].mean() final_table.append_row({ 'iteration':iteration, '5_rewards':final_5_rewards, 'terminated_early':terminated_early, 'iteration':iteration, 'val_model': p.val_model.state_dict(), 'policy_model': p.policy_model.state_dict(), 'policy_opt': p.POLICY_ADAM.state_dict(), 'val_opt': p.val_opt.state_dict(), 'envs':p.envs }) # Try-except so that we save if the user interrupts the process try: for i in range(params['train_steps']): print('Step %d' % (i,)) if params['save_iters'] > 0 and i % params['save_iters'] == 0: store['checkpoints'].append_row({ 'iteration':i, 'val_model': p.val_model.state_dict(), 'policy_model': p.policy_model.state_dict(), 'policy_opt': p.POLICY_ADAM.state_dict(), 'val_opt': p.val_opt.state_dict(), 'envs':p.envs }) mean_reward = p.train_step() rewards.append(mean_reward) finalize_table(i, False, rewards) except KeyboardInterrupt: torch.save(p.val_model, 'saved_experts/%s-expert-vf' % (params['game'],)) torch.save(p.policy_model, 'saved_experts/%s-expert-pol' % (params['game'],)) finalize_table(i, True, rewards) store.close()
def main(params): for k, v in zip(params.keys(), params.values()): assert v is not None, f"Value for {k} is None" # # # Setup logging # # metadata_schema = schema_from_dict(params) base_directory = params['out_dir'] store = Store(base_directory) # redirect stderr, stdout to file """ def make_err_redirector(stream_name): tee = Tee(os.path.join(store.path, stream_name + '.txt'), stream_name) return tee stderr_tee = make_err_redirector('stderr') stdout_tee = make_err_redirector('stdout') """ # Store the experiment path and the git commit for this experiment metadata_schema.update({ 'store_path': str, 'git_commit': str }) repo = git.Repo(path=os.path.dirname(os.path.realpath(__file__)), search_parent_directories=True) metadata_table = store.add_table('metadata', metadata_schema) metadata_table.update_row(params) metadata_table.update_row({ 'store_path': store.path, 'git_commit': repo.head.object.hexsha }) metadata_table.flush_row() # Extra items in table when minimax training is enabled. if params['mode'] == "adv_ppo" or params['mode'] == 'adv_trpo' or params['mode'] == 'adv_sa_ppo': adversary_table_dict = { 'adversary_policy_model': store.PYTORCH_STATE, 'adversary_policy_opt': store.PYTORCH_STATE, 'adversary_val_model': store.PYTORCH_STATE, 'adversary_val_opt': store.PYTORCH_STATE, } else: adversary_table_dict = {} # Table for checkpointing models and envs if params['save_iters'] > 0: checkpoint_dict = { 'val_model': store.PYTORCH_STATE, 'policy_model': store.PYTORCH_STATE, 'envs': store.PICKLE, 'policy_opt': store.PYTORCH_STATE, 'val_opt': store.PYTORCH_STATE, 'iteration': int, '5_rewards': float, } checkpoint_dict.update(adversary_table_dict) store.add_table('checkpoints', checkpoint_dict) # The trainer object is in charge of sampling trajectories and # taking PPO/TRPO optimization steps p = Trainer.agent_from_params(params, store=store) if params['initial_std'] != 1.0: p.policy_model.log_stdev.data[:] = np.log(params['initial_std']) if 'load_model' in params and params['load_model']: print('Loading pretrained model', params['load_model']) pretrained_model = torch.load(params['load_model']) if 'policy_model' in pretrained_model: p.policy_model.load_state_dict(pretrained_model['policy_model']) if params['deterministic']: print('Policy runs in deterministic mode. Ignoring Gaussian noise.') p.policy_model.log_stdev.data[:] = -100 else: print('Policy runs in non deterministic mode with Gaussian noise.') if 'val_model' in pretrained_model: p.val_model.load_state_dict(pretrained_model['val_model']) if 'policy_opt' in pretrained_model: p.POLICY_ADAM.load_state_dict(pretrained_model['policy_opt']) if 'val_opt' in pretrained_model: p.val_opt.load_state_dict(pretrained_model['val_opt']) # Load adversary models. if 'no_load_adv_policy' in params and params['no_load_adv_policy']: print('Skipping loading adversary models.') else: if 'adversary_policy_model' in pretrained_model and hasattr(p, 'adversary_policy_model'): p.adversary_policy_model.load_state_dict(pretrained_model['adversary_policy_model']) if 'adversary_val_model' in pretrained_model and hasattr(p, 'adversary_val_model'): p.adversary_val_model.load_state_dict(pretrained_model['adversary_val_model']) if 'adversary_policy_opt' in pretrained_model and hasattr(p, 'adversary_policy_opt'): p.adversary_policy_opt.load_state_dict(pretrained_model['adversary_policy_opt']) if 'adversary_val_opt' in pretrained_model and hasattr(p, 'adversary_val_opt'): p.adversary_val_opt.load_state_dict(pretrained_model['adversary_val_opt']) # Load optimizer states. # p.POLICY_ADAM.load_state_dict(pretrained_models['policy_opt']) # p.val_opt.load_state_dict(pretrained_models['val_opt']) # Restore environment parameters, like mean and std. if 'envs' in pretrained_model: p.envs = pretrained_model['envs'] for e in p.envs: e.setup_visualization(params['show_env'], params['save_frames'], params['save_frames_path']) rewards = [] # Table for final results final_dict = { 'iteration': int, '5_rewards': float, 'terminated_early': bool, 'val_model': store.PYTORCH_STATE, 'policy_model': store.PYTORCH_STATE, 'envs': store.PICKLE, 'policy_opt': store.PYTORCH_STATE, 'val_opt': store.PYTORCH_STATE, } final_dict.update(adversary_table_dict) final_table = store.add_table('final_results', final_dict) def add_adversary_to_table(p, table_dict): if params['mode'] == "adv_ppo" or params['mode'] == 'adv_trpo' or params['mode'] == 'adv_sa_ppo': table_dict["adversary_policy_model"] = p.adversary_policy_model.state_dict() table_dict["adversary_policy_opt"] = p.ADV_POLICY_ADAM.state_dict() table_dict["adversary_val_model"] = p.adversary_val_model.state_dict() table_dict["adversary_val_opt"] = p.adversary_val_opt.state_dict() return table_dict def finalize_table(iteration, terminated_early, rewards): final_5_rewards = np.array(rewards)[-5:].mean() final_dict = { 'iteration': iteration, '5_rewards': final_5_rewards, 'terminated_early': terminated_early, 'val_model': p.val_model.state_dict(), 'policy_model': p.policy_model.state_dict(), 'policy_opt': p.POLICY_ADAM.state_dict(), 'val_opt': p.val_opt.state_dict(), 'envs': p.envs } final_dict = add_adversary_to_table(p, final_dict) final_table.append_row(final_dict) ret = 0 # Try-except so that we save if the user interrupts the process try: for i in range(params['train_steps']): print('Step %d' % (i,)) if params['save_iters'] > 0 and i % params['save_iters'] == 0 and i != 0: final_5_rewards = np.array(rewards)[-5:].mean() print(f'Saving checkpoints to {store.path} with reward {final_5_rewards:.5g}') checkpoint_dict = { 'iteration': i, 'val_model': p.val_model.state_dict(), 'policy_model': p.policy_model.state_dict(), 'policy_opt': p.POLICY_ADAM.state_dict(), 'val_opt': p.val_opt.state_dict(), 'envs': p.envs, '5_rewards': final_5_rewards, } checkpoint_dict = add_adversary_to_table(p, checkpoint_dict) store['checkpoints'].append_row(checkpoint_dict) mean_reward = p.train_step() rewards.append(mean_reward) # For debugging and tuning, we can break in the middle. if i == params['force_stop_step']: print('Terminating early because --force-stop-step is set.') raise KeyboardInterrupt finalize_table(i, False, rewards) except KeyboardInterrupt: finalize_table(i, True, rewards) ret = 1 except: print("An error occurred during training:") traceback.print_exc() # Other errors, make sure to finalize the cox store before exiting. finalize_table(i, True, rewards) ret = -1 print(f'Models saved to {store.path}') store.close() return ret
indices = ch.randperm(X_tr.size(0)) X_val, X_tr = X_tr[indices[:val_sz]], X_tr[indices[val_sz:]] y_val, y_tr = y_tr[indices[:val_sz]], y_tr[indices[val_sz:]] ds_tr = IndexedTensorDataset(X_tr, y_tr) ds_val = TensorDataset(X_val, y_val) ds_te = TensorDataset(X_te, y_te) ld_tr = DataLoader(ds_tr, batch_size=args.batch_size, shuffle=True) ld_val = DataLoader(ds_val, batch_size=args.batch_size, shuffle=False) ld_te = DataLoader(ds_te, batch_size=args.batch_size, shuffle=False) Nfeatures = X_tr.size(1) args_dict['Nfeatures'] = Nfeatures args_dict['feature_indices'] = idx_sub schema = store.schema_from_dict(args_dict) out_store.add_table("metadata", schema) out_store["metadata"].append_row(args_dict) print("Initializing linear model...") linear = nn.Linear(Nfeatures, NUM_CLASSES).cuda() weight = linear.weight bias = linear.bias for p in [weight,bias]: p.data.zero_() print("Calculating the regularization path") params = glm_saga(linear, ld_tr, args.lr, args.max_epochs, args.alpha, n_classes=NUM_CLASSES, checkpoint=f"{res_dir}/params", verbose=args.verbose, tol=args.tol, group=args.group, epsilon=args.lam_factor,