Ejemplo n.º 1
0
def main():
    np.set_printoptions(suppress=True, precision=5, linewidth=1000)

    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', choices=MODES, required=True)
    # Expert dataset
    parser.add_argument('--data', type=str, required=True)
    parser.add_argument('--resume_training', action='store_true', help="Resume training from a checkpoint: --policy_checkpoint. Currently only supports GAIL with nn policy, reward and vf") 
    parser.add_argument('--checkpoint', type=str, help="Load from checkpoint if provided and if --resume_training") 
    parser.add_argument('--limit_trajs', type=int, required=True, help="How many expert trajectories to be used for training. If None : full dataset is used.") 
    parser.add_argument('--data_subsamp_freq', type=int, required=True, help="A number between 0 and max_traj_len. Rate of subsampling of expert trajectories while creating the dataset of expert transitions (state-action)")
    # MDP options
    parser.add_argument('--env_name', type=str, required=True)
    parser.add_argument('--max_traj_len', type=int, default=None)
    # Policy architecture
    parser.add_argument('--policy_hidden_spec', type=str, default=SIMPLE_ARCHITECTURE)
    parser.add_argument('--tiny_policy', action='store_true')
    parser.add_argument('--obsnorm_mode', choices=OBSNORM_MODES, default='expertdata')
    # Behavioral cloning optimizer
    parser.add_argument('--bclone_lr', type=float, default=1e-3)
    parser.add_argument('--bclone_batch_size', type=int, default=128)
    # parser.add_argument('--bclone_eval_nsa', type=int, default=128*100)
    parser.add_argument('--bclone_eval_ntrajs', type=int, default=20)
    parser.add_argument('--bclone_eval_freq', type=int, default=1000)
    parser.add_argument('--bclone_train_frac', type=float, default=.7)
    # Imitation optimizer
    parser.add_argument('--discount', type=float, default=.995)
    parser.add_argument('--lam', type=float, default=.97)
    parser.add_argument('--max_iter', type=int, default=1000000)
    parser.add_argument('--policy_max_kl', type=float, default=.01)
    parser.add_argument('--policy_cg_damping', type=float, default=.1)
    parser.add_argument('--no_vf', type=int, default=0)
    parser.add_argument('--vf_max_kl', type=float, default=.01)
    parser.add_argument('--vf_cg_damping', type=float, default=.1)
    parser.add_argument('--policy_ent_reg', type=float, default=0.)
    parser.add_argument('--reward_type', type=str, default='nn')
    # parser.add_argument('--linear_reward_bin_features', type=int, default=0)
    parser.add_argument('--reward_max_kl', type=float, default=.01)
    parser.add_argument('--reward_lr', type=float, default=.01)
    parser.add_argument('--reward_steps', type=int, default=1)
    parser.add_argument('--reward_ent_reg_weight', type=float, default=.001)
    parser.add_argument('--reward_include_time', type=int, default=0)
    parser.add_argument('--sim_batch_size', type=int, default=None)
    parser.add_argument('--min_total_sa', type=int, default=50000)
    parser.add_argument('--favor_zero_expert_reward', type=int, default=0)
    # Saving stuff
    parser.add_argument('--print_freq', type=int, default=1)
    parser.add_argument('--save_freq', type=int, default=20)
    parser.add_argument('--plot_freq', type=int, default=0)
    parser.add_argument('--log', type=str, required=False)

    args = parser.parse_args()

    # Initialize the MDP
    if args.tiny_policy:
        assert args.policy_hidden_spec == SIMPLE_ARCHITECTURE, 'policy_hidden_spec must remain unspecified if --tiny_policy is set'
        args.policy_hidden_spec = TINY_ARCHITECTURE
    argstr = json.dumps(vars(args), separators=(',', ':'), indent=2)
    print(argstr)
    print "\n\n========== Policy network specifications loaded ===========\n\n"

    mdp = rlgymenv.RLGymMDP(args.env_name)
    util.header('MDP observation space, action space sizes: %d, %d\n' % (mdp.obs_space.dim, mdp.action_space.storage_size))

    print "\n\n========== MDP initialized ===========\n\n"

    # Initialize the policy
    enable_obsnorm = args.obsnorm_mode != 'none'
    if isinstance(mdp.action_space, policyopt.ContinuousSpace):
        policy_cfg = rl.GaussianPolicyConfig(
            hidden_spec=args.policy_hidden_spec,
            min_stdev=0.,
            init_logstdev=0.,
            enable_obsnorm=enable_obsnorm)
        policy = rl.GaussianPolicy(policy_cfg, mdp.obs_space, mdp.action_space, 'GaussianPolicy')
    else:
        policy_cfg = rl.GibbsPolicyConfig(
            hidden_spec=args.policy_hidden_spec,
            enable_obsnorm=enable_obsnorm)
        policy = rl.GibbsPolicy(policy_cfg, mdp.obs_space, mdp.action_space, 'GibbsPolicy')

    #Load from checkpoint if provided <<<<<<<<<<<<<=============================>>>>>>>>>>>>>>>>.
    if args.resume_training:
        if args.checkpoint is not None:
            file, policy_key = util.split_h5_name(args.checkpoint)
            policy_file = file[:-3]+'_policy.h5'
            policy.load_h5(policy_file, policy_key)

    util.header('Policy architecture')
    for v in policy.get_trainable_variables():
        util.header('- %s (%d parameters)' % (v.name, v.get_value().size))
    util.header('Total: %d parameters' % (policy.get_num_params(),))

    print "\n\n========== Policy initialized ===========\n\n"

    # Load expert data
    exobs_Bstacked_Do, exa_Bstacked_Da, ext_Bstacked = load_dataset(
        args.data, args.limit_trajs, args.data_subsamp_freq)
    assert exobs_Bstacked_Do.shape[1] == mdp.obs_space.storage_size
    assert exa_Bstacked_Da.shape[1] == mdp.action_space.storage_size
    assert ext_Bstacked.ndim == 1

    print "\n\n========== Expert data loaded ===========\n\n"

    # Start optimization
    max_traj_len = args.max_traj_len if args.max_traj_len is not None else mdp.env_spec.timestep_limit
    print 'Max traj len:', max_traj_len

    if args.mode == 'bclone':
        # For behavioral cloning, only print output when evaluating
        args.print_freq = args.bclone_eval_freq
        args.save_freq = args.bclone_eval_freq

        reward, vf = None, None #There is no role of the reward function or value function in behavioral cloning
        opt = imitation.BehavioralCloningOptimizer(
            mdp, policy,
            lr=args.bclone_lr,
            batch_size=args.bclone_batch_size,
            obsfeat_fn=lambda o:o,
            ex_obs=exobs_Bstacked_Do, ex_a=exa_Bstacked_Da,
            eval_sim_cfg=policyopt.SimConfig(
                min_num_trajs=args.bclone_eval_ntrajs, min_total_sa=-1,
                batch_size=args.sim_batch_size, max_traj_len=max_traj_len),
            eval_freq=args.bclone_eval_freq,
            train_frac=args.bclone_train_frac)

        print "======= Behavioral Cloning optimizer initialized ======="

    elif args.mode == 'ga':
        if args.reward_type == 'nn':
            reward = imitation.TransitionClassifier( #Add resume training functionality
                hidden_spec=args.policy_hidden_spec,
                obsfeat_space=mdp.obs_space,
                action_space=mdp.action_space,
                max_kl=args.reward_max_kl,
                adam_lr=args.reward_lr,
                adam_steps=args.reward_steps,
                ent_reg_weight=args.reward_ent_reg_weight,
                enable_inputnorm=True,
                include_time=bool(args.reward_include_time),
                time_scale=1./mdp.env_spec.timestep_limit,
                favor_zero_expert_reward=bool(args.favor_zero_expert_reward),
                varscope_name='TransitionClassifier')
            #Load from checkpoint if provided <<<<<<<<<<<<<=============================>>>>>>>>>>>>>>>>.
            if args.resume_training:
                if args.checkpoint is not None:
                    file, reward_key = util.split_h5_name(args.checkpoint)
                    reward_file = file[:-3]+'_reward.h5'
                    print reward_file
                    reward.load_h5(reward_file, reward_key)

        elif args.reward_type in ['l2ball', 'simplex']:
            reward = imitation.LinearReward(
                obsfeat_space=mdp.obs_space,
                action_space=mdp.action_space,
                mode=args.reward_type,
                enable_inputnorm=True,
                favor_zero_expert_reward=bool(args.favor_zero_expert_reward),
                include_time=bool(args.reward_include_time),
                time_scale=1./mdp.env_spec.timestep_limit,
                exobs_Bex_Do=exobs_Bstacked_Do,
                exa_Bex_Da=exa_Bstacked_Da,
                ext_Bex=ext_Bstacked)
        else:
            raise NotImplementedError(args.reward_type)

        vf = None if bool(args.no_vf) else rl.ValueFunc( #Add resume training functionality
            hidden_spec=args.policy_hidden_spec,
            obsfeat_space=mdp.obs_space,
            enable_obsnorm=args.obsnorm_mode != 'none',
            enable_vnorm=True,
            max_kl=args.vf_max_kl,
            damping=args.vf_cg_damping,
            time_scale=1./mdp.env_spec.timestep_limit,
            varscope_name='ValueFunc')
        if args.resume_training:
            if args.checkpoint is not None:
                file, vf_key = util.split_h5_name(args.checkpoint)
                vf_file = file[:-3]+'_vf.h5'
                vf.load_h5(vf_file, vf_key)

        opt = imitation.ImitationOptimizer(
            mdp=mdp,
            discount=args.discount,
            lam=args.lam,
            policy=policy,
            sim_cfg=policyopt.SimConfig(
                min_num_trajs=-1, min_total_sa=args.min_total_sa,
                batch_size=args.sim_batch_size, max_traj_len=max_traj_len),
            step_func=rl.TRPO(max_kl=args.policy_max_kl, damping=args.policy_cg_damping),
            reward_func=reward,
            value_func=vf,
            policy_obsfeat_fn=lambda obs: obs,
            reward_obsfeat_fn=lambda obs: obs,
            policy_ent_reg=args.policy_ent_reg,
            ex_obs=exobs_Bstacked_Do,
            ex_a=exa_Bstacked_Da,
            ex_t=ext_Bstacked)

    # Set observation normalization
    if args.obsnorm_mode == 'expertdata':
        policy.update_obsnorm(exobs_Bstacked_Do)
        if reward is not None: reward.update_inputnorm(opt.reward_obsfeat_fn(exobs_Bstacked_Do), exa_Bstacked_Da)
        if vf is not None: vf.update_obsnorm(opt.policy_obsfeat_fn(exobs_Bstacked_Do))

        print "======== Observation normalization done ========"

    # Run optimizer
    print "======== Optimization begins ========"

    # Trial: make checkpoints for policy, reward and vf
    policy_log = nn.TrainingLog(args.log[:-3]+'_policy.h5', [('args', argstr)])
    reward_log = nn.TrainingLog(args.log[:-3]+'_reward.h5', [('args', argstr)])
    vf_log = nn.TrainingLog(args.log[:-3]+'_vf.h5', [('args', argstr)])
    

    for i in xrange(args.max_iter):
        
        #Optimization step
        iter_info = opt.step() 

        #Log and plot
        #pdb.set_trace()
    	policy_log.write(iter_info, 
                print_header=i % (20*args.print_freq) == 0, 
                display=i % args.print_freq == 0 ## FIXME: AS remove comment
                )
        reward_log.write(iter_info, 
                print_header=i % (20*args.print_freq) == 0, 
                display=i % args.print_freq == 0 ## FIXME: AS remove comment
                )
        vf_log.write(iter_info, 
                print_header=i % (20*args.print_freq) == 0, 
                display=i % args.print_freq == 0 ## FIXME: AS remove comment
                )
        

        if args.save_freq != 0 and i % args.save_freq == 0 and args.log is not None:
            policy_log.write_snapshot(policy, i)
            reward_log.write_snapshot(reward, i)
            vf_log.write_snapshot(vf, i)

        if args.plot_freq != 0 and i % args.plot_freq == 0:
            exdata_N_Doa = np.concatenate([exobs_Bstacked_Do, exa_Bstacked_Da], axis=1)
            pdata_M_Doa = np.concatenate([opt.last_sampbatch.obs.stacked, opt.last_sampbatch.a.stacked], axis=1)

            # Plot reward
            import matplotlib.pyplot as plt
            _, ax = plt.subplots()
            idx1, idx2 = 0,1
            range1 = (min(exdata_N_Doa[:,idx1].min(), pdata_M_Doa[:,idx1].min()), max(exdata_N_Doa[:,idx1].max(), pdata_M_Doa[:,idx1].max()))
            range2 = (min(exdata_N_Doa[:,idx2].min(), pdata_M_Doa[:,idx2].min()), max(exdata_N_Doa[:,idx2].max(), pdata_M_Doa[:,idx2].max()))
            reward.plot(ax, idx1, idx2, range1, range2, n=100)

            # Plot expert data
            ax.scatter(exdata_N_Doa[:,idx1], exdata_N_Doa[:,idx2], color='blue', s=1, label='expert')

            # Plot policy samples
            ax.scatter(pdata_M_Doa[:,idx1], pdata_M_Doa[:,idx2], color='red', s=1, label='apprentice')

            ax.legend()
            plt.show()
Ejemplo n.º 2
0
def main():
    np.set_printoptions(suppress=True, precision=5, linewidth=1000)

    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', choices=MODES, required=True)
    # Expert dataset
    parser.add_argument('--data', type=str, required=True)
    parser.add_argument(
        '--resume_training',
        action='store_true',
        help=
        "Resume training from a checkpoint: --policy_checkpoint. Currently only supports GAIL with nn policy, reward and vf"
    )
    parser.add_argument(
        '--checkpoint',
        type=str,
        help="Load from checkpoint if provided and if --resume_training")
    parser.add_argument(
        '--limit_trajs',
        type=int,
        required=True,
        help=
        "How many expert trajectories to be used for training. If None : full dataset is used."
    )
    parser.add_argument(
        '--data_subsamp_freq',
        type=int,
        required=True,
        help=
        "A number between 0 and max_traj_len. Rate of subsampling of expert trajectories while creating the dataset of expert transitions (state-action)"
    )
    # MDP options
    parser.add_argument('--env_name', type=str, required=True)
    parser.add_argument('--max_traj_len', type=int, default=None)
    # Policy architecture
    parser.add_argument('--policy_hidden_spec',
                        type=str,
                        default=SIMPLE_ARCHITECTURE)
    parser.add_argument('--tiny_policy', action='store_true')
    parser.add_argument('--obsnorm_mode',
                        choices=OBSNORM_MODES,
                        default='expertdata')
    # Behavioral cloning optimizer
    parser.add_argument('--bclone_lr', type=float, default=1e-3)
    parser.add_argument('--bclone_batch_size', type=int, default=128)
    # parser.add_argument('--bclone_eval_nsa', type=int, default=128*100)
    parser.add_argument('--bclone_eval_ntrajs', type=int, default=20)
    parser.add_argument('--bclone_eval_freq', type=int, default=1000)
    parser.add_argument('--bclone_train_frac', type=float, default=.7)
    # Imitation optimizer
    parser.add_argument('--discount', type=float, default=.995)

    parser.add_argument('--lam', type=float, default=.97)
    parser.add_argument('--max_iter', type=int, default=1000000)
    parser.add_argument('--policy_max_kl', type=float, default=.01)
    parser.add_argument('--policy_cg_damping',
                        type=float,
                        default=.1,
                        help="TRPO parameter")
    parser.add_argument('--no_vf', type=int, default=0)
    parser.add_argument('--vf_max_kl', type=float, default=.01)
    parser.add_argument('--vf_cg_damping', type=float, default=.1)
    parser.add_argument('--policy_ent_reg', type=float, default=0.)
    parser.add_argument('--reward_type', type=str, default='nn')
    # parser.add_argument('--linear_reward_bin_features', type=int, default=0)
    parser.add_argument('--reward_max_kl',
                        type=float,
                        default=.01,
                        help="TRPO parameter")
    parser.add_argument('--reward_lr', type=float, default=.01)
    parser.add_argument('--reward_steps', type=int, default=1)
    parser.add_argument('--reward_ent_reg_weight', type=float, default=.001)
    parser.add_argument('--reward_include_time', type=int, default=0)
    parser.add_argument('--sim_batch_size', type=int, default=None)
    parser.add_argument('--min_total_sa', type=int, default=50000)
    parser.add_argument('--favor_zero_expert_reward', type=int, default=0)
    # Saving stuff
    parser.add_argument('--print_freq', type=int, default=1)
    parser.add_argument('--save_freq', type=int, default=20)
    parser.add_argument('--plot_freq', type=int, default=0)
    parser.add_argument('--log', type=str, required=False)
    # CVaR parameters
    parser.add_argument('--useCVaR', action='store_true')
    parser.add_argument('--CVaR_alpha', type=float, default=0.9)
    parser.add_argument('--CVaR_beta', type=float, default=0.)
    parser.add_argument('--CVaR_lr', type=float, default=0.01)
    # !!! The following argument --disc_CVaR_weight is not of use and should be removed
    parser.add_argument(
        '--disc_CVaR_weight',
        type=float,
        default=1.,
        help=
        "Weight given to CVaR loss for the discriminator. Added by Anirban for smooth convergence."
    )
    parser.add_argument('--CVaR_Lambda_not_trainable', action='store_false')
    parser.add_argument('--CVaR_Lambda_val_if_not_trainable',
                        type=float,
                        default=0.5)
    #Filtering expert trajectories
    parser.add_argument('--use_expert_traj_filtering', action='store_true')
    parser.add_argument('--expert_traj_filt_percentile_threshold',
                        type=float,
                        default=20)
    # Additive state prior formulation
    parser.add_argument('--use_additiveStatePrior', action='store_true')
    parser.add_argument('--additiveStatePrior_weight', type=float, default=1.)
    parser.add_argument('--n_gmm_components', type=int, default=5)
    parser.add_argument('--cov_type_gmm', type=str, default='diag')
    parser.add_argument('--familiarity_alpha', type=float, default=10000000)
    parser.add_argument('--familiarity_beta', type=float, default=100)

    parser.add_argument('--kickThreshold_percentile',
                        type=float,
                        default=100.0)
    parser.add_argument('--appendFlag', action='store_true')

    args = parser.parse_args()

    if args.useCVaR:
        print ">>>>>>>>>>>>>>>>>>> TRAINING RAIL <<<<<<<<<<<<<<<<<<<"
    elif args.use_additiveStatePrior:
        print ">>>>>>>>>>>>>>>>>>> USING ADDITIVE STATE PRIOR <<<<<<<<<<<<<<<<<<<"
    else:
        print ">>>>>>>>> TRAINING GAIL <<<<<<<<<<"

    # Initialize the MDP
    if args.tiny_policy:
        assert args.policy_hidden_spec == SIMPLE_ARCHITECTURE, 'policy_hidden_spec must remain unspecified if --tiny_policy is set'
        args.policy_hidden_spec = TINY_ARCHITECTURE
    argstr = json.dumps(vars(args), separators=(',', ':'), indent=2)
    print(argstr)
    print "\n\n========== Policy network specifications loaded ===========\n\n"

    mdp = rlgymenv.RLGymMDP(args.env_name)
    util.header('MDP observation space, action space sizes: %d, %d\n' %
                (mdp.obs_space.dim, mdp.action_space.storage_size))

    print "\n\n========== MDP initialized ===========\n\n"

    # Initialize the policy
    enable_obsnorm = args.obsnorm_mode != 'none'
    if isinstance(mdp.action_space, policyopt.ContinuousSpace):
        policy_cfg = rl.GaussianPolicyConfig(
            hidden_spec=args.policy_hidden_spec,
            min_stdev=0.,
            init_logstdev=0.,
            enable_obsnorm=enable_obsnorm)
        policy = rl.GaussianPolicy(policy_cfg, mdp.obs_space, mdp.action_space,
                                   'GaussianPolicy', args.useCVaR)
    else:
        policy_cfg = rl.GibbsPolicyConfig(hidden_spec=args.policy_hidden_spec,
                                          enable_obsnorm=enable_obsnorm)
        policy = rl.GibbsPolicy(policy_cfg, mdp.obs_space, mdp.action_space,
                                'GibbsPolicy', args.useCVaR)

    offset = 0
    #Load from checkpoint if provided <<<<<<<<<<<<<=============================>>>>>>>>>>>>>>>>.
    if args.resume_training:
        if args.checkpoint is not None:
            file, policy_key = util.split_h5_name(args.checkpoint)
            offset = int(policy_key.split('/')[-1][4:])
            print '\n**************************************************'
            print 'Resuming from checkpoint : %d of %s' % (offset, file)
            print '**************************************************\n'

            if args.appendFlag and file != args.log:
                raise RuntimeError(
                    'Log file and checkpoint should have the same name if appendFlag is on. %s vs %s'
                    % file, args.log)

            policy_file = file[:-3] + '_policy.h5'  # Because we're naming the file as *_policy.h5 itself
            policy.load_h5(policy_file, policy_key)

    util.header('Policy architecture')
    for v in policy.get_trainable_variables():
        util.header('- %s (%d parameters)' % (v.name, v.get_value().size))
    util.header('Total: %d parameters' % (policy.get_num_params(), ))

    print "\n\n========== Policy initialized ===========\n\n"

    # Load expert data

    exobs_Bstacked_Do, exa_Bstacked_Da, ext_Bstacked = load_dataset(
        args.data,
        args.limit_trajs,
        args.data_subsamp_freq,
        len_filtering=args.use_expert_traj_filtering,
        len_filter_threshold=args.expert_traj_filt_percentile_threshold)

    assert exobs_Bstacked_Do.shape[1] == mdp.obs_space.storage_size
    assert exa_Bstacked_Da.shape[1] == mdp.action_space.storage_size
    assert ext_Bstacked.ndim == 1

    print "\n\n========== Expert data loaded ===========\n\n"

    print '\n==================== Hyperparams ===================='
    print '\texpert_traj_filt_percentile_threshold = %f' % args.expert_traj_filt_percentile_threshold
    print '\tfamiliarity_alpha = %f' % args.familiarity_alpha
    print '\tfamiliarity_beta = %f' % args.familiarity_beta
    print '\tkickThreshold_percentile = %f' % args.kickThreshold_percentile
    print '==============================================\n'

    # Start optimization
    max_traj_len = args.max_traj_len if args.max_traj_len is not None else mdp.env_spec.timestep_limit
    print 'Max traj len:', max_traj_len

    if args.mode == 'bclone':
        # For behavioral cloning, only print output when evaluating
        args.print_freq = args.bclone_eval_freq
        args.save_freq = args.bclone_eval_freq

        reward, vf = None, None  #There is no role of the reward function or value function in behavioral cloning
        opt = imitation.BehavioralCloningOptimizer(
            mdp,
            policy,
            lr=args.bclone_lr,
            batch_size=args.bclone_batch_size,
            obsfeat_fn=lambda o: o,
            ex_obs=exobs_Bstacked_Do,
            ex_a=exa_Bstacked_Da,
            eval_sim_cfg=policyopt.SimConfig(
                min_num_trajs=args.bclone_eval_ntrajs,
                min_total_sa=-1,
                batch_size=args.sim_batch_size,
                max_traj_len=max_traj_len),
            eval_freq=args.bclone_eval_freq,
            train_frac=args.bclone_train_frac)

        print "======= Behavioral Cloning optimizer initialized ======="

    elif args.mode == 'ga':
        if args.reward_type == 'nn':
            reward = imitation.TransitionClassifier(  #Add resume training functionality
                hidden_spec=args.policy_hidden_spec,
                obsfeat_space=mdp.obs_space,
                action_space=mdp.action_space,
                max_kl=args.reward_max_kl,
                adam_lr=args.reward_lr,
                adam_steps=args.reward_steps,
                ent_reg_weight=args.reward_ent_reg_weight,
                enable_inputnorm=True,
                include_time=bool(args.reward_include_time),
                time_scale=1. / mdp.env_spec.timestep_limit,
                favor_zero_expert_reward=bool(args.favor_zero_expert_reward),
                varscope_name='TransitionClassifier',
                useCVaR=args.useCVaR,
                CVaR_loss_weightage=args.disc_CVaR_weight)
            #Load from checkpoint if provided <<<<<<<<<<<<<=============================>>>>>>>>>>>>>>>>.
            if args.resume_training:
                if args.checkpoint is not None:
                    file, reward_key = util.split_h5_name(args.checkpoint)
                    reward_file = file[:-3] + '_reward.h5'
                    print reward_file
                    reward.load_h5(reward_file, reward_key)

        elif args.reward_type in ['l2ball', 'simplex']:
            reward = imitation.LinearReward(
                obsfeat_space=mdp.obs_space,
                action_space=mdp.action_space,
                mode=args.reward_type,
                enable_inputnorm=True,
                favor_zero_expert_reward=bool(args.favor_zero_expert_reward),
                include_time=bool(args.reward_include_time),
                time_scale=1. / mdp.env_spec.timestep_limit,
                exobs_Bex_Do=exobs_Bstacked_Do,
                exa_Bex_Da=exa_Bstacked_Da,
                ext_Bex=ext_Bstacked)
        else:
            raise NotImplementedError(args.reward_type)

        vf = None if bool(
            args.no_vf) else rl.ValueFunc(  #Add resume training functionality
                hidden_spec=args.policy_hidden_spec,
                obsfeat_space=mdp.obs_space,
                enable_obsnorm=args.obsnorm_mode != 'none',
                enable_vnorm=True,
                max_kl=args.vf_max_kl,
                damping=args.vf_cg_damping,
                time_scale=1. / mdp.env_spec.timestep_limit,
                varscope_name='ValueFunc')
        if args.resume_training:
            if args.checkpoint is not None:
                file, vf_key = util.split_h5_name(args.checkpoint)
                vf_file = file[:-3] + '_vf.h5'
                vf.load_h5(vf_file, vf_key)
        if args.useCVaR:
            opt = imitation.ImitationOptimizer_CVaR(
                mdp=mdp,
                discount=args.discount,
                lam=args.lam,
                policy=policy,
                sim_cfg=policyopt.SimConfig(min_num_trajs=-1,
                                            min_total_sa=args.min_total_sa,
                                            batch_size=args.sim_batch_size,
                                            max_traj_len=max_traj_len),
                step_func=rl.TRPO(max_kl=args.policy_max_kl,
                                  damping=args.policy_cg_damping,
                                  useCVaR=True),
                reward_func=reward,
                value_func=vf,
                policy_obsfeat_fn=lambda obs: obs,
                reward_obsfeat_fn=lambda obs: obs,
                policy_ent_reg=args.policy_ent_reg,
                ex_obs=exobs_Bstacked_Do,
                ex_a=exa_Bstacked_Da,
                ex_t=ext_Bstacked,
                #For CVaR
                CVaR_alpha=args.CVaR_alpha,
                CVaR_beta=args.CVaR_beta,
                CVaR_lr=args.CVaR_lr,
                CVaR_Lambda_trainable=args.CVaR_Lambda_not_trainable,
                CVaR_Lambda_val_if_not_trainable=args.
                CVaR_Lambda_val_if_not_trainable,
                offset=offset + 1)
        elif args.use_additiveStatePrior:
            opt = imitation.ImitationOptimizer_additiveStatePrior(
                mdp=mdp,
                discount=args.discount,
                lam=args.lam,
                policy=policy,
                sim_cfg=policyopt.SimConfig(min_num_trajs=-1,
                                            min_total_sa=args.min_total_sa,
                                            batch_size=args.sim_batch_size,
                                            max_traj_len=max_traj_len),
                step_func=rl.TRPO(max_kl=args.policy_max_kl,
                                  damping=args.policy_cg_damping,
                                  useCVaR=False),
                reward_func=reward,
                value_func=vf,
                policy_obsfeat_fn=lambda obs: obs,
                reward_obsfeat_fn=lambda obs: obs,
                policy_ent_reg=args.policy_ent_reg,
                ex_obs=exobs_Bstacked_Do,
                ex_a=exa_Bstacked_Da,
                ex_t=ext_Bstacked,
                n_gmm_components=args.n_gmm_components,
                cov_type_gmm=args.cov_type_gmm,
                additiveStatePrior_weight=args.additiveStatePrior_weight,
                alpha=args.familiarity_alpha,
                beta=args.familiarity_beta,
                kickThreshold_percentile=args.kickThreshold_percentile,
                offset=offset + 1)
        else:
            opt = imitation.ImitationOptimizer(
                mdp=mdp,
                discount=args.discount,
                lam=args.lam,
                policy=policy,
                sim_cfg=policyopt.SimConfig(min_num_trajs=-1,
                                            min_total_sa=args.min_total_sa,
                                            batch_size=args.sim_batch_size,
                                            max_traj_len=max_traj_len),
                step_func=rl.TRPO(max_kl=args.policy_max_kl,
                                  damping=args.policy_cg_damping,
                                  useCVaR=False),
                reward_func=reward,
                value_func=vf,
                policy_obsfeat_fn=lambda obs: obs,
                reward_obsfeat_fn=lambda obs: obs,
                policy_ent_reg=args.policy_ent_reg,
                ex_obs=exobs_Bstacked_Do,
                ex_a=exa_Bstacked_Da,
                ex_t=ext_Bstacked)

    # Set observation normalization
    if args.obsnorm_mode == 'expertdata':
        policy.update_obsnorm(exobs_Bstacked_Do)
        if reward is not None:
            reward.update_inputnorm(opt.reward_obsfeat_fn(exobs_Bstacked_Do),
                                    exa_Bstacked_Da)
        if vf is not None:
            vf.update_obsnorm(opt.policy_obsfeat_fn(exobs_Bstacked_Do))

        print "======== Observation normalization done ========"

    # Run optimizer
    print "======== Optimization begins ========"

    # Trial: make checkpoints for policy, reward and vf
    policy_log = nn.TrainingLog(args.log[:-3] + '_policy.h5',
                                [('args', argstr)], args.appendFlag)
    reward_log = nn.TrainingLog(args.log[:-3] + '_reward.h5',
                                [('args', argstr)], args.appendFlag)
    vf_log = nn.TrainingLog(args.log[:-3] + '_vf.h5', [('args', argstr)],
                            args.appendFlag)

    kickStatesData = []

    print '\n**************************************'
    print 'Running iterations from %d to %d' % (offset + 1, args.max_iter)

    for i in xrange(offset + 1, args.max_iter):
        # for i in range(1): #FIXME: this is just for studying the insides of the training algo

        # All training a.k.a. optimization happens in the next line!!! -_-
        # pdb.set_trace()
        iter_info = opt.step(
            i, kickStatesData) if args.use_additiveStatePrior else opt.step(i)

        #========= The rest is fluff =============

        #Log and plot
        #pdb.set_trace()
        policy_log.write(
            iter_info,
            print_header=i % (20 * args.print_freq) == 0,
            # display=False
            display=i % args.print_freq == 0  ## FIXME: AS remove comment
        )
        # reward_log.write(iter_info,
        #         print_header=i % (20*args.print_freq) == 0,
        #         display=False
        #         # display=i % args.print_freq == 0 ## FIXME: AS remove comment
        #         )
        # vf_log.write(iter_info,
        #         print_header=i % (20*args.print_freq) == 0,
        #         display=False
        #         # display=i % args.print_freq == 0 ## FIXME: AS remove comment
        #         )

        #FIXME: problem running this on 211 and 138. No problem on 151
        if args.save_freq != 0 and i % args.save_freq == 0 and args.log is not None:
            policy_log.write_snapshot(policy, i)
            reward_log.write_snapshot(reward, i)
            vf_log.write_snapshot(vf, i)

            # analysisFile=open(args.log[:-3]+'_kickedStates' + str(i) + '.pkl', 'wb')
            analysisFile = open(args.log[:-3] + '_kickedStates.pkl', 'wb')
            pkl.dump({'kickStatesData': kickStatesData},
                     analysisFile,
                     protocol=2)
            analysisFile.close()

        if args.plot_freq != 0 and i % args.plot_freq == 0:
            exdata_N_Doa = np.concatenate([exobs_Bstacked_Do, exa_Bstacked_Da],
                                          axis=1)
            pdata_M_Doa = np.concatenate(
                [opt.last_sampbatch.obs.stacked, opt.last_sampbatch.a.stacked],
                axis=1)

            # Plot reward
            import matplotlib.pyplot as plt
            _, ax = plt.subplots()
            idx1, idx2 = 0, 1
            range1 = (min(exdata_N_Doa[:, idx1].min(),
                          pdata_M_Doa[:, idx1].min()),
                      max(exdata_N_Doa[:, idx1].max(),
                          pdata_M_Doa[:, idx1].max()))
            range2 = (min(exdata_N_Doa[:, idx2].min(),
                          pdata_M_Doa[:, idx2].min()),
                      max(exdata_N_Doa[:, idx2].max(),
                          pdata_M_Doa[:, idx2].max()))
            reward.plot(ax, idx1, idx2, range1, range2, n=100)

            # Plot expert data
            ax.scatter(exdata_N_Doa[:, idx1],
                       exdata_N_Doa[:, idx2],
                       color='blue',
                       s=1,
                       label='expert')

            # Plot policy samples
            ax.scatter(pdata_M_Doa[:, idx1],
                       pdata_M_Doa[:, idx2],
                       color='red',
                       s=1,
                       label='apprentice')

            ax.legend()
            plt.show()
Ejemplo n.º 3
0
def main():
    """ 
    NOTE! Don't forget that these are effectively called directly from the yaml
    files. They call imitate_mj.py with their own arguments, so check there if
    some of the values differ from the default ones.
    """
    np.set_printoptions(suppress=True, precision=5, linewidth=1000)

    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', choices=MODES, required=True)
    # Expert dataset
    parser.add_argument('--data', type=str, required=True)
    parser.add_argument('--limit_trajs', type=int, required=True)
    parser.add_argument('--data_subsamp_freq', type=int, required=True)
    # MDP options
    parser.add_argument('--env_name', type=str, required=True)
    parser.add_argument('--max_traj_len', type=int, default=None)
    # Policy architecture
    parser.add_argument('--policy_hidden_spec',
                        type=str,
                        default=SIMPLE_ARCHITECTURE)
    parser.add_argument('--tiny_policy', action='store_true')
    parser.add_argument('--obsnorm_mode',
                        choices=OBSNORM_MODES,
                        default='expertdata')
    # Behavioral cloning optimizer (ok ... 128 and 0.7 settings are in the paper).
    parser.add_argument('--bclone_lr', type=float, default=1e-3)
    parser.add_argument('--bclone_batch_size', type=int, default=128)
    # parser.add_argument('--bclone_eval_nsa', type=int, default=128*100)
    parser.add_argument('--bclone_eval_ntrajs', type=int, default=20)
    parser.add_argument('--bclone_eval_freq', type=int, default=1000)
    parser.add_argument('--bclone_train_frac', type=float, default=.7)
    # Imitation optimizer
    parser.add_argument('--discount', type=float, default=.995)
    parser.add_argument('--lam', type=float, default=.97)
    parser.add_argument('--max_iter', type=int, default=1000000)
    parser.add_argument('--policy_max_kl', type=float, default=.01)
    parser.add_argument('--policy_cg_damping', type=float, default=.1)
    parser.add_argument('--no_vf', type=int, default=0)
    parser.add_argument('--vf_max_kl', type=float, default=.01)
    parser.add_argument('--vf_cg_damping', type=float, default=.1)
    parser.add_argument('--policy_ent_reg', type=float, default=0.)
    parser.add_argument('--reward_type', type=str, default='nn')
    # parser.add_argument('--linear_reward_bin_features', type=int, default=0)
    parser.add_argument('--reward_max_kl', type=float, default=.01)
    parser.add_argument('--reward_lr', type=float, default=.01)
    parser.add_argument('--reward_steps', type=int, default=1)
    parser.add_argument('--reward_ent_reg_weight', type=float, default=.001)
    parser.add_argument('--reward_include_time', type=int, default=0)
    parser.add_argument('--sim_batch_size', type=int, default=None)
    parser.add_argument('--min_total_sa', type=int, default=50000)
    parser.add_argument('--favor_zero_expert_reward', type=int, default=0)
    # Saving stuff
    parser.add_argument('--print_freq', type=int, default=1)
    parser.add_argument('--save_freq', type=int, default=20)
    parser.add_argument('--plot_freq', type=int, default=0)
    parser.add_argument('--log', type=str, required=False)

    args = parser.parse_args()

    # Initialize the MDP
    if args.tiny_policy:
        assert args.policy_hidden_spec == SIMPLE_ARCHITECTURE, 'policy_hidden_spec must remain unspecified if --tiny_policy is set'
        args.policy_hidden_spec = TINY_ARCHITECTURE
    argstr = json.dumps(vars(args), separators=(',', ':'), indent=2)
    print(argstr)

    mdp = rlgymenv.RLGymMDP(args.env_name)
    util.header('MDP observation space, action space sizes: %d, %d\n' %
                (mdp.obs_space.dim, mdp.action_space.storage_size))

    # Initialize the policy
    print("\n\tNow initializing the policy:")
    enable_obsnorm = args.obsnorm_mode != 'none'
    if isinstance(mdp.action_space, policyopt.ContinuousSpace):
        policy_cfg = rl.GaussianPolicyConfig(
            hidden_spec=args.policy_hidden_spec,
            min_stdev=0.,
            init_logstdev=0.,
            enable_obsnorm=enable_obsnorm)
        policy = rl.GaussianPolicy(policy_cfg, mdp.obs_space, mdp.action_space,
                                   'GaussianPolicy')
    else:
        policy_cfg = rl.GibbsPolicyConfig(hidden_spec=args.policy_hidden_spec,
                                          enable_obsnorm=enable_obsnorm)
        policy = rl.GibbsPolicy(policy_cfg, mdp.obs_space, mdp.action_space,
                                'GibbsPolicy')

    util.header('Policy architecture')
    for v in policy.get_trainable_variables():
        util.header('- %s (%d parameters)' % (v.name, v.get_value().size))
    util.header('Total: %d parameters' % (policy.get_num_params(), ))
    print("\tFinished initializing the policy.\n")

    # Load expert data
    exobs_Bstacked_Do, exa_Bstacked_Da, ext_Bstacked = load_dataset(
        args.data, args.limit_trajs, args.data_subsamp_freq)
    assert exobs_Bstacked_Do.shape[1] == mdp.obs_space.storage_size
    assert exa_Bstacked_Da.shape[1] == mdp.action_space.storage_size
    assert ext_Bstacked.ndim == 1

    # Start optimization
    max_traj_len = args.max_traj_len if args.max_traj_len is not None else mdp.env_spec.timestep_limit
    print 'Max traj len:', max_traj_len

    if args.mode == 'bclone':
        # For behavioral cloning, only print output when evaluating
        args.print_freq = args.bclone_eval_freq
        args.save_freq = args.bclone_eval_freq

        reward, vf = None, None
        opt = imitation.BehavioralCloningOptimizer(
            mdp,
            policy,
            lr=args.bclone_lr,
            batch_size=args.bclone_batch_size,
            obsfeat_fn=lambda o: o,
            ex_obs=exobs_Bstacked_Do,
            ex_a=exa_Bstacked_Da,
            eval_sim_cfg=policyopt.SimConfig(
                min_num_trajs=args.bclone_eval_ntrajs,
                min_total_sa=-1,
                batch_size=args.sim_batch_size,
                max_traj_len=max_traj_len),
            eval_freq=args.bclone_eval_freq,
            train_frac=args.bclone_train_frac)

    elif args.mode == 'ga':
        if args.reward_type == 'nn':
            # FYI: this is the GAIL case. Note that it doesn't take in any of
            # the raw expert data, unlike the other reward types. And we call
            # them `reward types` since the optimize can use their output in
            # some way to impove itself.
            reward = imitation.TransitionClassifier(
                hidden_spec=args.policy_hidden_spec,
                obsfeat_space=mdp.obs_space,
                action_space=mdp.action_space,
                max_kl=args.reward_max_kl,
                adam_lr=args.reward_lr,
                adam_steps=args.reward_steps,
                ent_reg_weight=args.reward_ent_reg_weight,
                enable_inputnorm=True,
                include_time=bool(args.reward_include_time),
                time_scale=1. / mdp.env_spec.timestep_limit,
                favor_zero_expert_reward=bool(args.favor_zero_expert_reward),
                varscope_name='TransitionClassifier')
        elif args.reward_type in ['l2ball', 'simplex']:
            # FEM or game-theoretic apprenticeship learning, respectively.
            reward = imitation.LinearReward(
                obsfeat_space=mdp.obs_space,
                action_space=mdp.action_space,
                mode=args.reward_type,
                enable_inputnorm=True,
                favor_zero_expert_reward=bool(args.favor_zero_expert_reward),
                include_time=bool(args.reward_include_time),
                time_scale=1. / mdp.env_spec.timestep_limit,
                exobs_Bex_Do=exobs_Bstacked_Do,
                exa_Bex_Da=exa_Bstacked_Da,
                ext_Bex=ext_Bstacked)
        else:
            raise NotImplementedError(args.reward_type)

        # All three of these 'advanced' IL algorithms use neural network value
        # functions to reduce variance for policy gradient estimates.
        print("\n\tThe **VALUE** function (may have action concatenated):")
        vf = None if bool(args.no_vf) else rl.ValueFunc(
            hidden_spec=args.policy_hidden_spec,
            obsfeat_space=mdp.obs_space,
            enable_obsnorm=args.obsnorm_mode != 'none',
            enable_vnorm=True,
            max_kl=args.vf_max_kl,
            damping=args.vf_cg_damping,
            time_scale=1. / mdp.env_spec.timestep_limit,
            varscope_name='ValueFunc')

        opt = imitation.ImitationOptimizer(
            mdp=mdp,
            discount=args.discount,
            lam=args.lam,
            policy=policy,
            sim_cfg=policyopt.SimConfig(min_num_trajs=-1,
                                        min_total_sa=args.min_total_sa,
                                        batch_size=args.sim_batch_size,
                                        max_traj_len=max_traj_len),
            step_func=rl.TRPO(max_kl=args.policy_max_kl,
                              damping=args.policy_cg_damping),
            reward_func=reward,
            value_func=vf,
            policy_obsfeat_fn=lambda obs: obs,
            reward_obsfeat_fn=lambda obs: obs,
            policy_ent_reg=args.policy_ent_reg,
            ex_obs=exobs_Bstacked_Do,
            ex_a=exa_Bstacked_Da,
            ex_t=ext_Bstacked)

    # Set observation normalization
    if args.obsnorm_mode == 'expertdata':
        policy.update_obsnorm(exobs_Bstacked_Do)
        if reward is not None:
            reward.update_inputnorm(opt.reward_obsfeat_fn(exobs_Bstacked_Do),
                                    exa_Bstacked_Da)
        if vf is not None:
            vf.update_obsnorm(opt.policy_obsfeat_fn(exobs_Bstacked_Do))

    # Run optimizer, i.e. {BehavioralCloning,Imitation}Optimizer.
    log = nn.TrainingLog(args.log, [('args', argstr)])
    for i in xrange(args.max_iter):
        iter_info = opt.step()
        log.write(iter_info,
                  print_header=i % (20 * args.print_freq) == 0,
                  display=i % args.print_freq == 0)
        if args.save_freq != 0 and i % args.save_freq == 0 and args.log is not None:
            log.write_snapshot(policy, i)

        if args.plot_freq != 0 and i % args.plot_freq == 0:
            exdata_N_Doa = np.concatenate([exobs_Bstacked_Do, exa_Bstacked_Da],
                                          axis=1)
            pdata_M_Doa = np.concatenate(
                [opt.last_sampbatch.obs.stacked, opt.last_sampbatch.a.stacked],
                axis=1)

            # Plot reward
            import matplotlib.pyplot as plt
            _, ax = plt.subplots()
            idx1, idx2 = 0, 1
            range1 = (min(exdata_N_Doa[:, idx1].min(),
                          pdata_M_Doa[:, idx1].min()),
                      max(exdata_N_Doa[:, idx1].max(),
                          pdata_M_Doa[:, idx1].max()))
            range2 = (min(exdata_N_Doa[:, idx2].min(),
                          pdata_M_Doa[:, idx2].min()),
                      max(exdata_N_Doa[:, idx2].max(),
                          pdata_M_Doa[:, idx2].max()))
            reward.plot(ax, idx1, idx2, range1, range2, n=100)

            # Plot expert data
            ax.scatter(exdata_N_Doa[:, idx1],
                       exdata_N_Doa[:, idx2],
                       color='blue',
                       s=1,
                       label='expert')

            # Plot policy samples
            ax.scatter(pdata_M_Doa[:, idx1],
                       pdata_M_Doa[:, idx2],
                       color='red',
                       s=1,
                       label='apprentice')

            ax.legend()
            plt.show()
Ejemplo n.º 4
0
def main():
    np.set_printoptions(suppress=True, precision=5, linewidth=1000)

    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', choices=MODES, required=True)
    # Expert dataset
    parser.add_argument('--data', type=str, required=True)
    parser.add_argument('--limit_trajs', type=int, required=True)
    parser.add_argument('--data_subsamp_freq', type=int, required=True)
    # MDP options
    parser.add_argument('--env_name', type=str, required=True)
    parser.add_argument('--max_traj_len', type=int, default=None)
    # Policy architecture
    parser.add_argument('--policy_hidden_spec',
                        type=str,
                        default=SIMPLE_ARCHITECTURE)
    parser.add_argument('--tiny_policy', action='store_true')
    parser.add_argument('--obsnorm_mode',
                        choices=OBSNORM_MODES,
                        default='expertdata')

    # add a spec for transition classifier
    parser.add_argument('--clf_hidden_spec',
                        type=str,
                        default=SIMPLE_ARCHITECTURE)

    # Behavioral cloning optimizer
    parser.add_argument('--bclone_lr', type=float, default=1e-3)
    parser.add_argument('--bclone_batch_size', type=int, default=128)
    # parser.add_argument('--bclone_eval_nsa', type=int, default=128*100)
    parser.add_argument('--bclone_eval_ntrajs', type=int, default=20)
    parser.add_argument('--bclone_eval_freq', type=int, default=1000)
    parser.add_argument('--bclone_train_frac', type=float, default=.7)
    # Imitation optimizer
    parser.add_argument('--discount', type=float, default=.995)
    parser.add_argument('--lam', type=float, default=.97)
    parser.add_argument('--max_iter', type=int, default=1000000)
    parser.add_argument('--policy_max_kl', type=float, default=.01)
    parser.add_argument('--policy_cg_damping', type=float, default=.1)
    parser.add_argument('--no_vf', type=int, default=0)
    parser.add_argument('--vf_max_kl', type=float, default=.01)
    parser.add_argument('--vf_cg_damping', type=float, default=.1)
    parser.add_argument('--policy_ent_reg', type=float, default=0.)
    parser.add_argument('--reward_type', type=str, default='nn')
    # parser.add_argument('--linear_reward_bin_features', type=int, default=0)
    parser.add_argument('--reward_max_kl', type=float, default=.01)
    parser.add_argument('--reward_lr', type=float, default=.01)
    parser.add_argument('--reward_steps', type=int, default=1)
    parser.add_argument('--reward_ent_reg_weight', type=float, default=.001)
    parser.add_argument('--reward_include_time', type=int, default=0)
    parser.add_argument('--sim_batch_size', type=int, default=None)
    parser.add_argument('--min_total_sa', type=int, default=50000)
    parser.add_argument('--favor_zero_expert_reward', type=int, default=0)
    # Saving stuff
    parser.add_argument('--print_freq', type=int, default=1)
    parser.add_argument('--save_freq', type=int, default=20)
    parser.add_argument('--plot_freq', type=int, default=100)
    parser.add_argument('--log', type=str, required=False)

    # Sequential model
    parser.add_argument('--seq_model', type=int, default=0)
    parser.add_argument('--time_step', type=int, default=10)

    args = parser.parse_args()

    # Initialize the MDP
    if not args.seq_model:
        if args.tiny_policy:
            assert args.policy_hidden_spec == SIMPLE_ARCHITECTURE, 'policy_hidden_spec must remain unspecified if --tiny_policy is set'
            args.policy_hidden_spec = TINY_ARCHITECTURE
        argstr = json.dumps(vars(args), separators=(',', ':'), indent=2)
        print(argstr)
    # Add sequential model
    else:
        if args.tiny_policy:
            assert args.policy_hidden_spec == SEQ_SIMPLE_ARCHITECTURE, 'policy_hidden_spec must remain unspecified if --tiny_policy is set'
            args.policy_hidden_spec = SEQ_TINY_ARCHITECTURE
#        # change the default architecture to fit sequential model
#        if args.policy_hidden_spec == SIMPLE_ARCHITECTURE:
#            args.policy_hidden_spec = SEQ_SIMPLE_ARCHITECTURE
        if args.clf_hidden_spec == SIMPLE_ARCHITECTURE:
            args.clf_hidden_spec = SEQ_SIMPLE_ARCHITECTURE
        argstr = json.dumps(vars(args), separators=(',', ':'), indent=2)

    mdp = rlgymenv.RLGymMDP(args.env_name)
    util.header('MDP observation space, action space sizes: %d, %d\n' %
                (mdp.obs_space.dim, mdp.action_space.storage_size))

    # Initialize the policy
    enable_obsnorm = args.obsnorm_mode != 'none'

    if not args.seq_model:
        if isinstance(mdp.action_space, policyopt.ContinuousSpace):
            policy_cfg = rl.GaussianPolicyConfig(
                hidden_spec=args.policy_hidden_spec,
                min_stdev=0.,
                init_logstdev=0.,
                enable_obsnorm=enable_obsnorm)
            policy = rl.GaussianPolicy(policy_cfg, mdp.obs_space,
                                       mdp.action_space, 'GaussianPolicy')
        else:
            policy_cfg = rl.GibbsPolicyConfig(
                hidden_spec=args.policy_hidden_spec,
                enable_obsnorm=enable_obsnorm)
            policy = rl.GibbsPolicy(policy_cfg, mdp.obs_space,
                                    mdp.action_space, 'GibbsPolicy')
    # Add squential model
    else:
        if isinstance(mdp.action_space, policyopt.ContinuousSpace):
            policy_cfg = rl.SeqGaussianPolicyConfig(
                hidden_spec=args.policy_hidden_spec,
                time_step=args.time_step,  # add time step
                min_stdev=0.,
                init_logstdev=0.,
                enable_obsnorm=enable_obsnorm,
                enable_actnorm=False)  # XXX not implement actnorm yet
            policy = rl.SeqGaussianPolicy(policy_cfg, mdp.obs_space,
                                          mdp.action_space,
                                          'SeqGaussianPolicy')
        else:
            policy_cfg = rl.SeqGibbsPolicyConfig(
                hidden_spec=args.policy_hidden_spec,
                time_step=args.time_step,  # add time step
                enable_obsnorm=enable_obsnorm,
                enable_actnorm=False)  # XXX not implement actnorm yet
            policy = rl.SeqGibbsPolicy(policy_cfg, mdp.obs_space,
                                       mdp.action_space, 'SeqGibbsPolicy')

    util.header('Policy architecture')
    for v in policy.get_trainable_variables():
        util.header('- %s (%d parameters)' % (v.name, v.get_value().size))
    util.header('Total: %d parameters' % (policy.get_num_params(), ))

    # Load expert data
    exobs_Bstacked_Do, exa_Bstacked_Da, ext_Bstacked = load_dataset(
        args.data, args.limit_trajs, args.data_subsamp_freq)
    assert exobs_Bstacked_Do.shape[1] == mdp.obs_space.storage_size
    assert exa_Bstacked_Da.shape[1] == mdp.action_space.storage_size
    assert ext_Bstacked.ndim == 1

    #    print 'Debug: exobs_Bstacked_Do dtype:', exobs_Bstacked_Do.dtype
    #    print 'Debug: exa_Bstacked_Da dtype:', exa_Bstacked_Da.dtype
    #    print 'Debug: ext_Bstacked dtype:', ext_Bstacked.dtype

    #    assert 1 == 0

    # Start optimization
    max_traj_len = args.max_traj_len if args.max_traj_len is not None else mdp.env_spec.timestep_limit
    print('Max traj len:', max_traj_len)

    if args.mode == 'bclone':
        # For behavioral cloning, only print output when evaluating
        #        args.print_freq = args.bclone_eval_freq
        #        args.save_freq = args.bclone_eval_freq

        reward, vf = None, None
        opt = imitation.BehavioralCloningOptimizer(
            mdp,
            policy,
            lr=args.bclone_lr,
            batch_size=args.bclone_batch_size,
            obsfeat_fn=lambda o: o,
            ex_obs=exobs_Bstacked_Do,
            ex_a=exa_Bstacked_Da,
            eval_sim_cfg=policyopt.SimConfig(
                min_num_trajs=args.bclone_eval_ntrajs,
                min_total_sa=-1,
                batch_size=args.sim_batch_size,
                max_traj_len=max_traj_len,
                smp_traj_len=-1),
            eval_freq=args.
            bclone_eval_freq,  # XXX set a value when using bclone
            train_frac=args.bclone_train_frac)

    elif args.mode == 'ga':
        if args.reward_type == 'nn':
            reward = imitation.TransitionClassifier(
                hidden_spec=args.policy_hidden_spec,
                obsfeat_space=mdp.obs_space,
                action_space=mdp.action_space,
                max_kl=args.reward_max_kl,
                adam_lr=args.reward_lr,
                adam_steps=args.reward_steps,
                ent_reg_weight=args.reward_ent_reg_weight,
                enable_inputnorm=True,
                include_time=bool(args.reward_include_time),
                time_scale=1. / mdp.env_spec.timestep_limit,
                favor_zero_expert_reward=bool(args.favor_zero_expert_reward),
                varscope_name='TransitionClassifier')
        elif args.reward_type in ['l2ball', 'simplex']:
            reward = imitation.LinearReward(
                obsfeat_space=mdp.obs_space,
                action_space=mdp.action_space,
                mode=args.reward_type,
                enable_inputnorm=True,
                favor_zero_expert_reward=bool(args.favor_zero_expert_reward),
                include_time=bool(args.reward_include_time),
                time_scale=1. / mdp.env_spec.timestep_limit,
                exobs_Bex_Do=exobs_Bstacked_Do,
                exa_Bex_Da=exa_Bstacked_Da,
                ext_Bex=ext_Bstacked)
        else:
            raise NotImplementedError(args.reward_type)

        vf = None if bool(args.no_vf) else rl.ValueFunc(
            hidden_spec=args.policy_hidden_spec,
            obsfeat_space=mdp.obs_space,
            enable_obsnorm=args.obsnorm_mode != 'none',
            enable_vnorm=True,
            max_kl=args.vf_max_kl,
            damping=args.vf_cg_damping,
            time_scale=1. / mdp.env_spec.timestep_limit,
            varscope_name='ValueFunc')

        opt = imitation.ImitationOptimizer(
            mdp=mdp,
            discount=args.discount,
            lam=args.lam,
            policy=policy,
            sim_cfg=policyopt.SimConfig(min_num_trajs=-1,
                                        min_total_sa=args.min_total_sa,
                                        batch_size=args.sim_batch_size,
                                        max_traj_len=max_traj_len,
                                        smp_traj_len=-1),
            step_func=rl.TRPO(max_kl=args.policy_max_kl,
                              damping=args.policy_cg_damping,
                              sequential_model=False),  # add sequential model
            reward_func=reward,
            value_func=vf,
            policy_obsfeat_fn=lambda obs: obs,
            reward_obsfeat_fn=lambda obs: obs,
            policy_ent_reg=args.policy_ent_reg,
            ex_obs=exobs_Bstacked_Do,
            ex_a=exa_Bstacked_Da,
            ex_t=ext_Bstacked)

    # Add Sequential Model
    elif args.mode == 'sga':
        if args.reward_type == 'nn':
            reward = imitation.SequentialTransitionClassifier(
                hidden_spec=args.clf_hidden_spec,
                obsfeat_space=mdp.obs_space,
                action_space=mdp.action_space,
                max_kl=args.reward_max_kl,
                adam_lr=args.reward_lr,
                adam_steps=args.reward_steps,
                ent_reg_weight=args.reward_ent_reg_weight,
                time_step=args.time_step,  # add time step
                enable_inputnorm=True,
                include_time=bool(args.reward_include_time),
                time_scale=1. / mdp.env_spec.timestep_limit,
                favor_zero_expert_reward=bool(args.favor_zero_expert_reward),
                varscope_name='SequentialTransitionClassifier')
#        elif args.reward_type in ['l2ball', 'simplex']:
#            reward = imitation.LinearReward(
#                obsfeat_space=mdp.obs_space,
#                action_space=mdp.action_space,
#                mode=args.reward_type,
#                enable_inputnorm=True,
#                favor_zero_expert_reward=bool(args.favor_zero_expert_reward),
#                include_time=bool(args.reward_include_time),
#                time_scale=1./mdp.env_spec.timestep_limit,
#                exobs_Bex_Do=exobs_Bstacked_Do,
#                exa_Bex_Da=exa_Bstacked_Da,
#                ext_Bex=ext_Bstacked)
        else:
            raise NotImplementedError(args.reward_type)

        vf = None if bool(args.no_vf) else rl.SequentialValueFunc(
            hidden_spec=args.policy_hidden_spec,
            obsfeat_space=mdp.obs_space,
            time_step=args.time_step,  # add time step
            enable_obsnorm=args.obsnorm_mode != 'none',
            enable_vnorm=True,
            max_kl=args.vf_max_kl,
            damping=args.vf_cg_damping,
            time_scale=1. / mdp.env_spec.timestep_limit,
            varscope_name='SequentialValueFunc')

        opt = imitation.SequentialImitationOptimizer(
            mdp=mdp,
            discount=args.discount,
            lam=args.lam,
            policy=policy,
            sim_cfg=policyopt.SeqSimConfig(
                min_num_trajs=-1,
                min_total_sa=args.min_total_sa,
                batch_size=args.sim_batch_size,
                max_traj_len=max_traj_len,
                time_step=args.time_step),  # add time step
            step_func=rl.TRPO(
                max_kl=args.policy_max_kl,
                damping=args.policy_cg_damping,
                sequential_model=False),  # XXX not use sequential trpo
            reward_func=reward,
            value_func=vf,
            policy_obsfeat_fn=lambda obs: obs,
            reward_obsfeat_fn=lambda obs: obs,
            policy_ent_reg=args.policy_ent_reg,
            ex_obs=exobs_Bstacked_Do,
            ex_a=exa_Bstacked_Da,
            ex_t=ext_Bstacked)

    # Set observation normalization
    if args.obsnorm_mode == 'expertdata':
        if not args.seq_model:
            policy.update_obsnorm(exobs_Bstacked_Do)
            if reward is not None:
                reward.update_inputnorm(
                    opt.reward_obsfeat_fn(exobs_Bstacked_Do), exa_Bstacked_Da)
            if vf is not None:
                vf.update_obsnorm(opt.policy_obsfeat_fn(exobs_Bstacked_Do))
        # Add sequential model
        else:
            Bstacked, Do, T = exobs_Bstacked_Do.shape[
                0], exobs_Bstacked_Do.shape[1], args.time_step
            exobs_BT_Do = exobs_Bstacked_Do[:T * (Bstacked // T), :]
            exa_BT_Da = exa_Bstacked_Da[:T * (Bstacked // T), :]
            # reshape:(B*T, ...) => (B, T, ...)
            exobs_B_T_Do = np.reshape(
                exobs_BT_Do, (Bstacked // T, T, exobs_Bstacked_Do.shape[1]))
            exa_B_T_Da = np.reshape(
                exa_BT_Da, (Bstacked // T, T, exa_Bstacked_Da.shape[1]))
            print("Debug: exobs_Bstacked_Do:", exobs_Bstacked_Do.shape[0],
                  exobs_Bstacked_Do.shape[1])
            print("Debug: exobs_B_T_Do:", exobs_B_T_Do.shape[0],
                  exobs_B_T_Do.shape[1], exobs_B_T_Do.shape[2])
            # XXX use original policy (not sequential)
            policy.update_obsnorm(exobs_Bstacked_Do)
            if reward is not None:
                reward.update_inputnorm(opt.reward_obsfeat_fn(exobs_B_T_Do),
                                        exa_B_T_Da)
            if vf is not None:
                vf.update_obsnorm(opt.policy_obsfeat_fn(exobs_Bstacked_Do))

    # Run optimizer

#    log = nn.TrainingLog(args.log, [('args', argstr)])
    log = nn.BasicTrainingLog(args.log, [('args', argstr)])
    for i in xrange(args.max_iter):
        iter_info = opt.step()
        #        log.write(iter_info, print_header=i % (20*args.print_freq) == 0, display=i % args.print_freq == 0)
        log.add_log(iter_info,
                    print_header=i % (20 * args.print_freq) == 0,
                    display=i % args.print_freq == 0)
        if args.save_freq != 0 and i % args.save_freq == 0 and args.log is not None:
            print('%i/%i iters is done. Save snapshot.' % (i, args.max_iter))
            #            log.write_snapshot(policy, i)
            log.write_snapshot(policy, i)

        if args.mode == 'ga' and args.plot_freq != 0 and i % args.plot_freq == 0:
            print('%i/%i iters is done. Save plot.' % (i, args.max_iter))
            exdata_N_Doa = np.concatenate([exobs_Bstacked_Do, exa_Bstacked_Da],
                                          axis=1)
            pdata_M_Doa = np.concatenate(
                [opt.last_sampbatch.obs.stacked, opt.last_sampbatch.a.stacked],
                axis=1)
            # convert dtype to follow theano config
            exdata_N_Doa = exdata_N_Doa.astype(theano.config.floatX)
            pdata_M_Doa = pdata_M_Doa.astype(theano.config.floatX)
            #            print 'Debug: exobs_Bstacked_Do dtype:', exobs_Bstacked_Do.dtype    # float32
            #            print 'Debug: exa_Bstacked_Da dtype:', exa_Bstacked_Da.dtype    # int64
            #            print 'Debug: opt.last_sampbatch.obs.stacked dtype:', opt.last_sampbatch.obs.stacked.dtype    # float32
            #            print 'Debug: opt.last_sampbatch.a.stacked dtype:', opt.last_sampbatch.a.stacked.dtype    # int64
            #            print 'Debug: exdata_N_Doa dtype:', exdata_N_Doa.dtype    # float32
            #            print 'Debug: pdata_M_Doa dtype:', pdata_M_Doa.dtype    # float32

            # Plot reward
            #            import matplotlib
            #            matplotlib.use('Agg')
            #            import matplotlib.pyplot as plt
            _, ax = plt.subplots()
            idx1, idx2 = 0, 1
            range1 = (min(exdata_N_Doa[:, idx1].min(),
                          pdata_M_Doa[:, idx1].min()),
                      max(exdata_N_Doa[:, idx1].max(),
                          pdata_M_Doa[:, idx1].max()))
            range2 = (min(exdata_N_Doa[:, idx2].min(),
                          pdata_M_Doa[:, idx2].min()),
                      max(exdata_N_Doa[:, idx2].max(),
                          pdata_M_Doa[:, idx2].max()))

            #            print 'Debug: range1 types:', type(range1[0]), type(range1[1])    # float32, float32
            #            print 'Debug: range2 types:', type(range2[0]), type(range2[1])    # float32, float32

            x, y, z = reward.plot(ax, idx1, idx2, range1, range2, n=100)
            plot = [
                x, y, z, exdata_N_Doa[:, idx1], exdata_N_Doa[:, idx2],
                pdata_M_Doa[:, idx1], pdata_M_Doa[:, idx2]
            ]
            log.write_plot(plot, i)

            # Plot expert data


#            ax.scatter(exdata_N_Doa[:,idx1], exdata_N_Doa[:,idx2], color='blue', s=1, label='expert')

# Plot policy samples
#            ax.scatter(pdata_M_Doa[:,idx1], pdata_M_Doa[:,idx2], color='red', s=1, label='apprentice')

#            ax.legend()
#            plt.show()
#            plt.savefig()
#            plot = [x, y, z, exdata_N_Doa[:,idx1], exdata_N_Doa[:,idx2], pdata_M_Doa[:,idx1], pdata_M_Doa[:,idx2]]
#            log.write_plot(plot, i)

#        if args.mode == 'sga' and args.plot_freq != 0 and i % args.plot_freq == 0:
#            print ('%i/%i iters is done. Save plot.' %(i, args.max_iter))
#            exdata_N_Doa = np.concatenate([exobs_Bstacked_Do, exa_Bstacked_Da], axis=1)
#            # reshape: (B, T, ...) => (B*T, ...)
##            B, T, Df = opt.last_sampbatch.obs.stacked.shape
##            obs_flatten = np.reshape(opt.last_sampbatch.obs.stacked, (B*T, opt.last_sampbatch.obs.stacked.shape[2]))
##            a_flatten = np.reshape(opt.last_sampbatch.a.stacked, (B*T, opt.last_sampbatch.a.stacked.shape[2]))
###            pdata_M_Doa = np.concatenate([opt.last_sampbatch.obs.stacked, opt.last_sampbatch.a.stacked], axis=1)
#            pdata_M_Doa = np.concatenate([opt.last_sampbatch.obs.stacked, opt.last_sampbatch.a.stacked], axis=1)
#            # convert dtype to follow theano config
#            exdata_N_Doa = exdata_N_Doa.astype(theano.config.floatX)
#            pdata_M_Doa = pdata_M_Doa.astype(theano.config.floatX)
##            print 'Debug: exobs_Bstacked_Do dtype:', exobs_Bstacked_Do.dtype    # float32
##            print 'Debug: exa_Bstacked_Da dtype:', exa_Bstacked_Da.dtype    # int64
##            print 'Debug: opt.last_sampbatch.obs.stacked dtype:', opt.last_sampbatch.obs.stacked.dtype    # float32
##            print 'Debug: opt.last_sampbatch.a.stacked dtype:', opt.last_sampbatch.a.stacked.dtype    # int64
##            print 'Debug: exdata_N_Doa dtype:', exdata_N_Doa.dtype    # float32
##            print 'Debug: pdata_M_Doa dtype:', pdata_M_Doa.dtype    # float32

#            # Plot reward
##            import matplotlib
##            matplotlib.use('Agg')
##            import matplotlib.pyplot as plt
#            _, ax = plt.subplots()
#            idx1, idx2 = 0,1
#            range1 = (min(exdata_N_Doa[:,idx1].min(), pdata_M_Doa[:,idx1].min()), max(exdata_N_Doa[:,idx1].max(), pdata_M_Doa[:,idx1].max()))
#            range2 = (min(exdata_N_Doa[:,idx2].min(), pdata_M_Doa[:,idx2].min()), max(exdata_N_Doa[:,idx2].max(), pdata_M_Doa[:,idx2].max()))

##            print 'Debug: range1 types:', type(range1[0]), type(range1[1])    # float32, float32
##            print 'Debug: range2 types:', type(range2[0]), type(range2[1])    # float32, float32

#           # for sequential model, input the length of sequence
#           # XXX take care of the usage of memory !!
#           x, y, z = reward.plot(ax, idx1, idx2, range1, range2, args.time_step, n=100)
#           plot = [x, y, z, exdata_N_Doa[:,idx1], exdata_N_Doa[:,idx2], pdata_M_Doa[:,idx1], pdata_M_Doa[:,idx2]]
#           log.write_plot(plot, i)

#            # Plot expert data
##            ax.scatter(exdata_N_Doa[:,idx1], exdata_N_Doa[:,idx2], color='blue', s=1, label='expert')

#            # Plot policy samples
##            ax.scatter(pdata_M_Doa[:,idx1], pdata_M_Doa[:,idx2], color='red', s=1, label='apprentice')

##            ax.legend()
##            plt.show()
##            plt.savefig()
##            plot = [x, y, z, exdata_N_Doa[:,idx1], exdata_N_Doa[:,idx2], pdata_M_Doa[:,idx1], pdata_M_Doa[:,idx2]]
##            log.write_plot(plot, i)

# write log
    print('Training is done. Save log.')
    log.write_log()
    log.close()
Ejemplo n.º 5
0
def main():
    np.set_printoptions(suppress=True, precision=5, linewidth=1000)

    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', choices=MODES, required=True)
    parser.add_argument('--seed', type=int, default=0)
    # Expert dataset
    parser.add_argument('--data', type=str, required=True)
    parser.add_argument('--limit_trajs', type=int, required=True)
    parser.add_argument('--data_subsamp_freq', type=int, required=True)
    # MDP options
    parser.add_argument('--env_name', type=str, required=True)
    parser.add_argument('--max_traj_len', type=int, default=None)
    # Policy architecture
    parser.add_argument('--policy_hidden_spec',
                        type=str,
                        default=SIMPLE_ARCHITECTURE)
    parser.add_argument('--tiny_policy', action='store_true')
    parser.add_argument('--obsnorm_mode',
                        choices=OBSNORM_MODES,
                        default='expertdata')
    # Behavioral cloning optimizer
    parser.add_argument('--bclone_lr', type=float, default=1e-3)
    parser.add_argument('--bclone_batch_size', type=int, default=128)
    # parser.add_argument('--bclone_eval_nsa', type=int, default=128*100)
    parser.add_argument('--bclone_eval_ntrajs', type=int, default=20)
    parser.add_argument('--bclone_eval_freq', type=int, default=1000)
    parser.add_argument('--bclone_train_frac', type=float, default=.7)
    # Imitation optimizer
    parser.add_argument('--discount', type=float, default=.995)
    parser.add_argument('--lam', type=float, default=.97)
    parser.add_argument('--max_iter', type=int, default=1000000)
    parser.add_argument('--policy_max_kl', type=float, default=.01)
    parser.add_argument('--policy_cg_damping', type=float, default=.1)
    parser.add_argument('--no_vf', type=int, default=0)
    parser.add_argument('--vf_max_kl', type=float, default=.01)
    parser.add_argument('--vf_cg_damping', type=float, default=.1)
    parser.add_argument('--policy_ent_reg', type=float, default=0.)
    parser.add_argument('--reward_type', type=str, default='nn')
    # parser.add_argument('--linear_reward_bin_features', type=int, default=0)
    parser.add_argument('--reward_max_kl', type=float, default=.01)
    parser.add_argument('--reward_lr', type=float, default=.01)
    parser.add_argument('--reward_steps', type=int, default=1)
    parser.add_argument('--reward_ent_reg_weight', type=float, default=.001)
    parser.add_argument('--reward_include_time', type=int, default=0)
    parser.add_argument('--sim_batch_size', type=int, default=None)
    parser.add_argument('--min_total_sa', type=int, default=50000)
    parser.add_argument('--favor_zero_expert_reward', type=int, default=0)
    parser.add_argument('--use_shared_std_network', type=int, default=0)
    # Generative Moment matching
    parser.add_argument('--kernel_batchsize', type=int, default=1000)
    parser.add_argument('--kernel_reg_weight', type=float, default=0.)
    parser.add_argument('--use_median_heuristic', type=int, default=1)
    parser.add_argument('--use_logscale_reward', type=int)
    parser.add_argument('--reward_epsilon', type=float, default=0.0001)
    # Auto-Encoder Information
    # Saving stuff
    parser.add_argument('--print_freq', type=int, default=1)
    parser.add_argument('--save_freq', type=int, default=20)
    parser.add_argument('--plot_freq', type=int, default=0)
    parser.add_argument('--log', type=str, required=False)
    parser.add_argument('--save_reward', type=int, default=0)

    args = parser.parse_args()

    # Initialize the MDP
    if args.tiny_policy:
        assert args.policy_hidden_spec == SIMPLE_ARCHITECTURE, 'policy_hidden_spec must remain unspecified if --tiny_policy is set'
        args.policy_hidden_spec = TINY_ARCHITECTURE
    argstr = json.dumps(vars(args), separators=(',', ':'), indent=2)
    print(argstr)

    mdp = rlgymenv.RLGymMDP(args.env_name)
    util.header('MDP observation space, action space sizes: %d, %d\n' %
                (mdp.obs_space.dim, mdp.action_space.storage_size))

    # Initialize the policy
    enable_obsnorm = args.obsnorm_mode != 'none'
    if isinstance(mdp.action_space, policyopt.ContinuousSpace):
        policy_cfg = rl.GaussianPolicyConfig(
            hidden_spec=args.policy_hidden_spec,
            min_stdev=0.,
            init_logstdev=0.,
            enable_obsnorm=enable_obsnorm)
        policy = rl.GaussianPolicy(policy_cfg, mdp.obs_space, mdp.action_space,
                                   'GaussianPolicy',
                                   bool(args.use_shared_std_network))
    else:
        policy_cfg = rl.GibbsPolicyConfig(hidden_spec=args.policy_hidden_spec,
                                          enable_obsnorm=enable_obsnorm)
        policy = rl.GibbsPolicy(policy_cfg, mdp.obs_space,
                                mdp.action_space, 'GibbsPolicy',
                                bool(args.use_shared_std_network))

    util.header('Policy architecture')
    for v in policy.get_trainable_variables():
        util.header('- %s (%d parameters)' % (v.name, v.get_value().size))
    util.header('Total: %d parameters' % (policy.get_num_params(), ))

    # Load expert data
    exobs_Bstacked_Do, exa_Bstacked_Da, ext_Bstacked = load_dataset(
        args.data, args.limit_trajs, args.data_subsamp_freq, args.seed)
    assert exobs_Bstacked_Do.shape[1] == mdp.obs_space.storage_size
    assert exa_Bstacked_Da.shape[1] == mdp.action_space.storage_size
    assert ext_Bstacked.ndim == 1

    # Start optimization
    max_traj_len = args.max_traj_len if args.max_traj_len is not None else mdp.env_spec.timestep_limit
    print 'Max traj len:', max_traj_len

    if args.mode == 'bclone':
        # For behavioral cloning, only print output when evaluating
        args.print_freq = args.bclone_eval_freq
        args.save_freq = args.bclone_eval_freq

        reward, vf = None, None
        opt = imitation.BehavioralCloningOptimizer(
            mdp,
            policy,
            lr=args.bclone_lr,
            batch_size=args.bclone_batch_size,
            obsfeat_fn=lambda o: o,
            ex_obs=exobs_Bstacked_Do,
            ex_a=exa_Bstacked_Da,
            eval_sim_cfg=policyopt.SimConfig(
                min_num_trajs=args.bclone_eval_ntrajs,
                min_total_sa=-1,
                batch_size=args.sim_batch_size,
                max_traj_len=max_traj_len),
            eval_freq=args.bclone_eval_freq,
            train_frac=args.bclone_train_frac)

    elif args.mode == 'ga':
        if args.reward_type == 'nn':
            reward = imitation.TransitionClassifier(
                hidden_spec=args.policy_hidden_spec,
                obsfeat_space=mdp.obs_space,
                action_space=mdp.action_space,
                max_kl=args.reward_max_kl,
                adam_lr=args.reward_lr,
                adam_steps=args.reward_steps,
                ent_reg_weight=args.reward_ent_reg_weight,
                enable_inputnorm=True,
                include_time=bool(args.reward_include_time),
                time_scale=1. / mdp.env_spec.timestep_limit,
                favor_zero_expert_reward=bool(args.favor_zero_expert_reward),
                varscope_name='TransitionClassifier')

        elif args.reward_type in ['l2ball', 'simplex']:
            reward = imitation.LinearReward(
                obsfeat_space=mdp.obs_space,
                action_space=mdp.action_space,
                mode=args.reward_type,
                enable_inputnorm=True,
                favor_zero_expert_reward=bool(args.favor_zero_expert_reward),
                include_time=bool(args.reward_include_time),
                time_scale=1. / mdp.env_spec.timestep_limit,
                exobs_Bex_Do=exobs_Bstacked_Do,
                exa_Bex_Da=exa_Bstacked_Da,
                ext_Bex=ext_Bstacked)
        else:
            raise NotImplementedError(args.reward_type)

        vf = None if bool(args.no_vf) else rl.ValueFunc(
            hidden_spec=args.policy_hidden_spec,
            obsfeat_space=mdp.obs_space,
            enable_obsnorm=args.obsnorm_mode != 'none',
            enable_vnorm=True,
            max_kl=args.vf_max_kl,
            damping=args.vf_cg_damping,
            time_scale=1. / mdp.env_spec.timestep_limit,
            varscope_name='ValueFunc')

        opt = imitation.ImitationOptimizer(
            mdp=mdp,
            discount=args.discount,
            lam=args.lam,
            policy=policy,
            sim_cfg=policyopt.SimConfig(min_num_trajs=-1,
                                        min_total_sa=args.min_total_sa,
                                        batch_size=args.sim_batch_size,
                                        max_traj_len=max_traj_len),
            step_func=rl.TRPO(max_kl=args.policy_max_kl,
                              damping=args.policy_cg_damping),
            reward_func=reward,
            value_func=vf,
            policy_obsfeat_fn=lambda obs: obs,
            reward_obsfeat_fn=lambda obs: obs,
            policy_ent_reg=args.policy_ent_reg,
            ex_obs=exobs_Bstacked_Do,
            ex_a=exa_Bstacked_Da,
            ex_t=ext_Bstacked)

    elif args.mode == 'gmmil':
        if args.use_median_heuristic == 0:
            bandwidth_params = [
                1.0, 1.0 / 2.0, 1.0 / 5.0, 1.0 / 10.0, 1.0 / 40.0, 1.0 / 80.0
            ]
        else:
            bandwidth_params = []

        if args.reward_type == 'mmd':
            reward = gmmil.MMDReward(
                obsfeat_space=mdp.obs_space,
                action_space=mdp.action_space,
                enable_inputnorm=True,
                favor_zero_expert_reward=bool(args.favor_zero_expert_reward),
                include_time=bool(args.reward_include_time),
                time_scale=1. / mdp.env_spec.timestep_limit,
                exobs_Bex_Do=exobs_Bstacked_Do,
                exa_Bex_Da=exa_Bstacked_Da,
                ext_Bex=ext_Bstacked,
                kernel_bandwidth_params=bandwidth_params,
                kernel_reg_weight=args.kernel_reg_weight,
                kernel_batchsize=args.kernel_batchsize,
                use_median_heuristic=args.use_median_heuristic,
                use_logscale_reward=bool(args.use_logscale_reward),
                save_reward=bool(args.save_reward),
                epsilon=args.reward_epsilon)
        else:
            raise NotImplementedError(args.reward_type)

        vf = None if bool(args.no_vf) else rl.ValueFunc(
            hidden_spec=args.policy_hidden_spec,
            obsfeat_space=mdp.obs_space,
            enable_obsnorm=args.obsnorm_mode != 'none',
            enable_vnorm=True,
            max_kl=args.vf_max_kl,
            damping=args.vf_cg_damping,
            time_scale=1. / mdp.env_spec.timestep_limit,
            varscope_name='ValueFunc')

        opt = imitation.ImitationOptimizer(
            mdp=mdp,
            discount=args.discount,
            lam=args.lam,
            policy=policy,
            sim_cfg=policyopt.SimConfig(min_num_trajs=-1,
                                        min_total_sa=args.min_total_sa,
                                        batch_size=args.sim_batch_size,
                                        max_traj_len=max_traj_len),
            step_func=rl.TRPO(max_kl=args.policy_max_kl,
                              damping=args.policy_cg_damping),
            reward_func=reward,
            value_func=vf,
            policy_obsfeat_fn=lambda obs: obs,
            reward_obsfeat_fn=lambda obs: obs,
            policy_ent_reg=args.policy_ent_reg,
            ex_obs=exobs_Bstacked_Do,
            ex_a=exa_Bstacked_Da,
            ex_t=ext_Bstacked)

    # Set observation normalization
    if args.obsnorm_mode == 'expertdata':
        policy.update_obsnorm(exobs_Bstacked_Do)
        if reward is not None:
            reward.update_inputnorm(opt.reward_obsfeat_fn(exobs_Bstacked_Do),
                                    exa_Bstacked_Da)
        if vf is not None:
            vf.update_obsnorm(opt.policy_obsfeat_fn(exobs_Bstacked_Do))

    # Run optimizer
    log = nn.TrainingLog(args.log, [('args', argstr)])
    for i in xrange(args.max_iter):
        iter_info = opt.step()
        log.write(iter_info,
                  print_header=i % (20 * args.print_freq) == 0,
                  display=i % args.print_freq == 0)
        if args.save_freq != 0 and i % args.save_freq == 0 and args.log is not None:
            log.write_snapshot(policy, i)

        if args.plot_freq != 0 and i % args.plot_freq == 0:
            exdata_N_Doa = np.concatenate([exobs_Bstacked_Do, exa_Bstacked_Da],
                                          axis=1)
            pdata_M_Doa = np.concatenate(
                [opt.last_sampbatch.obs.stacked, opt.last_sampbatch.a.stacked],
                axis=1)

            # Plot reward
            import matplotlib.pyplot as plt
            _, ax = plt.subplots()
            idx1, idx2 = 0, 1
            range1 = (min(exdata_N_Doa[:, idx1].min(),
                          pdata_M_Doa[:, idx1].min()),
                      max(exdata_N_Doa[:, idx1].max(),
                          pdata_M_Doa[:, idx1].max()))
            range2 = (min(exdata_N_Doa[:, idx2].min(),
                          pdata_M_Doa[:, idx2].min()),
                      max(exdata_N_Doa[:, idx2].max(),
                          pdata_M_Doa[:, idx2].max()))
            reward.plot(ax, idx1, idx2, range1, range2, n=100)

            # Plot expert data
            ax.scatter(exdata_N_Doa[:, idx1],
                       exdata_N_Doa[:, idx2],
                       color='blue',
                       s=1,
                       label='expert')

            # Plot policy samples
            ax.scatter(pdata_M_Doa[:, idx1],
                       pdata_M_Doa[:, idx2],
                       color='red',
                       s=1,
                       label='apprentice')

            ax.legend()
            plt.show()