Exemple #1
0
                                else:
                                    baseline = GaussianMLPBaseline(env_spec=env.spec)
                                algo = MAMLIL(
                                    env=env,
                                    policy=policy,
                                    baseline=baseline,
                                    batch_size=fast_batch_size,  # number of trajs for alpha grad update
                                    max_path_length=max_path_length,
                                    meta_batch_size=meta_batch_size,  # number of tasks sampled for beta grad update
                                    num_grad_updates=num_grad_updates,  # number of alpha grad updates
                                    n_itr=2000, #100
                                    use_maml=use_maml,
                                    use_pooled_goals=True,
                                    step_size=meta_step_size,
                                    plot=False,
                                    beta_steps=beta_steps,
                                    adam_steps=adam_steps,
                                    adam_curve=adam_curve,
                                    pre_std_modifier=pre_std_modifier,
                                    l2loss_std_mult=l2loss_std_mult,
                                    post_std_modifier_train=post_std_modifier_train,
                                    post_std_modifier_test=post_std_modifier_test,
                                    expert_trajs_dir=EXPERT_TRAJ_LOCATION_DICT[env_option+"."+mode+goals_suffix],
                                    test_on_training_goals=True,
                                    make_video=True,
                                    extra_input=extra_input,
                                    extra_input_dim=(0 if extra_input is None else extra_input_dim),
                                    limit_demos_num=limit_demos_num,

                                )
                                run_experiment_lite(
                                    algo.train(),
 else:
     baseline = GaussianMLPBaseline(
         env_spec=env.spec)
 algo = MAMLIL(
     env=env,
     policy=policy,
     baseline=baseline,
     batch_size=
     fast_batch_size,  # number of trajs for alpha grad update
     max_path_length=max_path_length,
     meta_batch_size=
     meta_batch_size,  # number of tasks sampled for beta grad update
     num_grad_updates=
     num_grad_updates,  # number of alpha grad updates
     n_itr=800,  #100
     use_maml=use_maml,
     use_pooled_goals=True,
     step_size=meta_step_size,
     plot=False,
     beta_steps=beta_steps,
     adam_steps=adam_steps,
     pre_std_modifier=pre_std_modifier,
     l2loss_std_mult=l2loss_std_mult,
     post_std_modifier_train=post_std_modifier_train,
     post_std_modifier_test=post_std_modifier_test,
     expert_trajs_dir=EXPERT_TRAJ_LOCATION_DICT[
         env_option + "." + mode +
         ".noise0.1.small"],
 )
 run_experiment_lite(
     algo.train(),
     n_parallel=1,
Exemple #3
0
                                                            algo = MAMLIL(
                                                                env=env,
                                                                policy=policy,
                                                                baseline=baseline,
                                                                batch_size=fast_batch_size,  # number of trajs for alpha grad update
                                                                max_path_length=max_path_length,
                                                                meta_batch_size=meta_batch_size,  # number of tasks sampled for beta grad update
                                                                num_grad_updates=num_grad_updates,  # number of alpha grad updates
                                                                n_itr=200, #100
                                                                make_video=True,
                                                                use_maml=use_maml,
                                                                use_pooled_goals=True,
                                                                use_corr_term=use_corr_term,
                                                                test_on_training_goals=test_on_training_goals,
                                                                metalearn_baseline=(bas=="MAMLGaussianMLP"),
                                                                # metalearn_baseline=False,
                                                                limit_demos_num=limit_demos_num,
                                                                test_goals_mult=test_goals_mult,
                                                                step_size=meta_step_size,
                                                                plot=False,
                                                                beta_steps=beta_steps,
                                                                adam_curve=adam_curve,
                                                                adam_steps=adam_steps,
                                                                pre_std_modifier=pre_std_modifier,
                                                                l2loss_std_mult=l2loss_std_mult,
                                                                importance_sampling_modifier=MOD_FUNC[ism],
                                                                post_std_modifier_train=post_std_modifier_train,
                                                                post_std_modifier_test=post_std_modifier_test,
                                                                expert_trajs_dir=EXPERT_TRAJ_LOCATION_DICT[env_option+"."+mode+goals_suffix+("_"+str(extra_input_dim) if type(extra_input_dim) == int else "")],
                                                                expert_trajs_suffix=("_"+str(extra_input_dim) if type(extra_input_dim) == int else ""),
                                                                seed=seed,
                                                                extra_input=extra_input,
                                                                extra_input_dim=(0 if extra_input is None else extra_input_dim),

                                                            )
Exemple #4
0
 algo = MAMLIL(
     env=env,
     policy=
     policy,
     # load_policy="/home/rosen/paper_ready_experiments/pusher/best/PU_IL_1_flr0.01_dem24_ei5_as10_basl_1805_09_14/params.pkl",
     baseline=
     baseline,
     batch_size=
     fast_batch_size,  # number of trajs for alpha grad update
     max_path_length
     =max_path_length,
     meta_batch_size
     =
     meta_batch_size,  # number of tasks sampled for beta grad update
     num_grad_updates
     =
     num_grad_updates,  # number of alpha grad updates
     n_itr=
     1600,  #100
     make_video=
     False,
     use_maml=
     use_maml,
     use_pooled_goals
     =True,
     use_corr_term
     =use_corr_term,
     test_on_training_goals
     =test_on_training_goals,
     metalearn_baseline
     =(bas ==
       "MAMLGaussianMLP"
       ),
     # metalearn_baseline=False,
     limit_demos_num
     =limit_demos_num,
     test_goals_mult
     =test_goals_mult,
     step_size=
     meta_step_size,
     plot=False,
     beta_steps=
     beta_steps,
     adam_curve=
     adam_curve,
     adam_steps=
     adam_steps,
     pre_std_modifier
     =pre_std_modifier,
     l2loss_std_mult
     =l2loss_std_mult,
     importance_sampling_modifier
     =MOD_FUNC[
         ism],
     post_std_modifier_train
     =post_std_modifier_train,
     post_std_modifier_test
     =post_std_modifier_test,
     expert_trajs_dir
     =EXPERT_TRAJ_LOCATION_DICT[
         env_option
         + "." +
         mode +
         goals_suffix
         +
         ("_" +
          str(extra_input_dim
              )
          if
          type(
              extra_input_dim
          )
          == int
          else
          "")],
     # expert_trajs_suffix=("_"+str(extra_input_dim) if type(extra_input_dim) == int else ""),
     expert_trajs_suffix
     ="",  #"("_"+str(extra_input_dim) if type(extra_input_dim) == int else ""),
     seed=seed,
     extra_input=
     extra_input,
     extra_input_dim
     =(0 if
       extra_input
       is None
       else
       extra_input_dim
       ),
     input_feed=
     INPUT_FEED)
 algo = MAMLIL(
     env=env,
     policy=policy,
     #policy=None,
     #oad_policy='/home/alvin/maml_rl/data/local/R7-IL-0918/R7_IL_200_40_1_1_dem40_ei5_as50_basl_1809_04_27/itr_24.pkl',
     baseline=baseline,
     batch_size=
     fast_batch_size,  # number of trajs for alpha grad update
     max_path_length=
     max_path_length,
     meta_batch_size=
     meta_batch_size,  # number of tasks sampled for beta grad update
     num_grad_updates=
     num_grad_updates,  # number of alpha grad updates
     n_itr=50,  #100
     make_video=False,
     use_maml=use_maml,
     use_pooled_goals=
     True,
     use_corr_term=
     use_corr_term,
     test_on_training_goals
     =test_on_training_goals,
     metalearn_baseline=
     (bas ==
      "MAMLGaussianMLP"
      ),
     # metalearn_baseline=False,
     limit_demos_num=
     limit_demos_num,
     test_goals_mult=
     test_goals_mult,
     step_size=
     meta_step_size,
     plot=False,
     beta_steps=
     beta_steps,
     adam_curve=
     adam_curve,
     adam_steps=
     adam_steps,
     pre_std_modifier=
     pre_std_modifier,
     l2loss_std_mult=
     l2loss_std_mult,
     importance_sampling_modifier
     =MOD_FUNC[ism],
     post_std_modifier_train
     =post_std_modifier_train,
     post_std_modifier_test
     =post_std_modifier_test,
     expert_trajs_dir=
     EXPERT_TRAJ_LOCATION_DICT[
         env_option +
         "." + mode +
         goals_suffix],
     expert_trajs_suffix
     ="",
     seed=seed,
     extra_input=
     extra_input,
     extra_input_dim=(
         0
         if extra_input
         is None else
         extra_input_dim
     ),
     updateMode=
     updateMode)
Exemple #6
0
                                    #expert_policy = PointEnvExpertPolicy(env_spec=env.spec)
                                    algo = MAMLIL(
                                        env=env,
                                        policy=policy,
                                        baseline=baseline,
                                        #expert_policy=expert_policy,  TODO: we will want to define the expert policy here
                                        batch_size=
                                        fast_batch_size,  ## number of trajs for alpha grad update
                                        max_path_length=max_path_length,
                                        meta_batch_size=
                                        meta_batch_size,  ## number of tasks sampled for beta grad update
                                        num_grad_updates=
                                        num_grad_updates,  ## number of alpha grad updates per beta update
                                        n_itr=100,  #100
                                        use_maml=use_maml,
                                        use_pooled_goals=True,
                                        step_size=meta_step_size,
                                        plot=False,
                                        beta_steps=beta_steps,
                                        adam_steps=adam_steps,
                                        pre_std_modifier=pre_std_modifier,
                                        l2loss_std_mult=l2loss_std_mult,
                                        importance_sampling_modifier=MOD_FUNC[
                                            ""],
                                        post_std_modifier_train=
                                        post_std_modifier_train,
                                        post_std_modifier_test=
                                        post_std_modifier_test,
                                        expert_trajs_dir=
                                        EXPERT_TRAJ_LOCATION_DICT[".ec2"],
                                    )
 algo = MAMLIL(
     env=env,
     # policy=policy,
     policy=None,
     load_policy=
     "/home/rosen/maml_rl/data/local/R7-IL-0909/R7_IL_vision_2distr_dummy_1nocorr_fbs1_mbs1_flr0.0_dem5_as1_basz_0909_07_34/itr_799.pkl",
     # load_policy="/home/rosen/maml_rl/data/local/R7-IL-0828/R7_IL_vision_2distr_dummy_1nocorr_fbs1_mbs1_flr0.0_dem300_as1_basz_2808_20_49/itr_680.pkl",
     baseline=baseline,
     batch_size=
     fast_batch_size,  # number of trajs for alpha grad update
     max_path_length=
     max_path_length,
     meta_batch_size=
     meta_batch_size,  # number of tasks sampled for beta grad update
     num_grad_updates=
     num_grad_updates,  # number of alpha grad updates
     n_itr=800,  #100
     make_video=True,
     # sampler_cls=BatchSampler,
     # sampler_args=dict(n_envs=1),
     use_maml=use_maml,
     use_vision=True,
     use_pooled_goals=
     True,
     use_corr_term=
     use_corr_term,
     test_on_training_goals
     =test_on_training_goals,
     metalearn_baseline=
     (bas ==
      "MAMLGaussianMLP"
      ),
     # metalearn_baseline=False,
     limit_demos_num=
     limit_demos_num,
     test_goals_mult=
     test_goals_mult,
     step_size=
     meta_step_size,
     plot=False,
     beta_steps=
     beta_steps,
     adam_curve=
     adam_curve,
     adam_steps=
     adam_steps,
     pre_std_modifier=
     pre_std_modifier,
     l2loss_std_mult=
     l2loss_std_mult,
     importance_sampling_modifier
     =MOD_FUNC[ism],
     post_std_modifier_train
     =post_std_modifier_train,
     post_std_modifier_test
     =post_std_modifier_test,
     expert_trajs_dir=
     EXPERT_TRAJ_LOCATION_DICT[
         env_option +
         "." + mode +
         goals_suffix +
         ("_" + str(
             extra_input_dim
         ) if type(
             extra_input_dim
         ) == int else
          "")],
     expert_trajs_suffix
     =("_" + str(
         extra_input_dim
     ) if type(
         extra_input_dim
     ) == int else ""),
     seed=seed,
     extra_input=
     extra_input,
     extra_input_dim=(
         0
         if extra_input
         is None else
         extra_input_dim
     ),
 )
                                 
                                    algo = MAMLIL(
                                       
                                        env=env,
                                        trainGoals = goals,
                                        partitions = partitions,
                                        policy=metaPolicy,
                                        baseline=metaBaseline,
                                        max_path_length=max_path_length,
                                        expert_batch_size=expert_batch_size,
                                        expert_num_itrs = expert_num_itrs,
                                        penalty = kl_penalty,
                                        fast_batch_size=fast_batch_size, ## number of trajs for alpha grad update                                   
                                        meta_batch_size=meta_batch_size, ## number of tasks sampled for beta grad update
                                        taskPoolSize = taskPoolSize,
                                        num_grad_updates=num_grad_updates, ## number of alpha grad updates per beta update
                                        metaL_num_itrs=100, #100
                                        use_maml=use_maml,
                                        use_pooled_goals=True,
                                        step_size=meta_step_size,
                                        plot=False,
                                        beta_steps=beta_steps,
                                        adam_steps=adam_steps,
                                        pre_std_modifier=pre_std_modifier,
                                        l2loss_std_mult=l2loss_std_mult,
                                        importance_sampling_modifier=MOD_FUNC[""],
                                        post_std_modifier_train=post_std_modifier_train,
                                        post_std_modifier_test=post_std_modifier_test,
                                        expert_trajs_dir=EXPERT_TRAJ_LOCATION_DICT["."+mode],
                                        updateMode = updateMode
                                    )
Exemple #9
0
def experiment(variant):


  seed = variant['seed'] ; n_parallel = variant['n_parallel'] ; log_dir = variant['log_dir']
  setup(seed, n_parallel, log_dir)
  fast_learning_rate = variant['flr'] ; fast_batch_size = variant['fbs'] ; meta_batch_size = variant['mbs']
  envClass = variant['envClass']

  beta_steps = 1
  adam_steps = variant['adam_steps']
  updateMode = 'vec'
  adam_curve = None
  
  env_option = ''
  
  extra_input = "onehot_exploration" # "onehot_exploration" "gaussian_exploration"
  # extra_input = None
  extra_input_dim = 5

  
  num_grad_updates = 1
  meta_step_size = 0.01
  pre_std_modifier = 1.0
  post_std_modifier_train = 0.00001
  post_std_modifier_test = 0.00001
  l2loss_std_mult = 1.0
  ism = ''
 
  limit_demos_num = 40  # 40
  test_goals_mult = 1
  bas_lr = 0.01 # baseline learning rate
  momentum=0.5
  bas_hnl = tf.nn.relu
  hidden_layers = (100,100)

  basas = 60 # baseline adam steps
  use_corr_term = True
  # seeds = [1,2,3,4,5,6,7]  #,2,3,4,5,6,7,8] #, 2,3,4,5,6,7,8]
  
  use_maml = True
  test_on_training_goals = False
  env = None
  
  if envClass == 'Ant':
    env = TfEnv(normalize(AntEnvRandGoalRing())) 
    max_path_length = 200  
    EXPERT_TRAJ_LOCATION_DICT = '/root/code/rllab/saved_expert_traj/Expert_trajs_dense_ant/'

  elif envClass == 'SawyerPusher':
   
    baseEnv = FlatGoalEnv(SawyerPushEnv(tasks=None), obs_keys=['state_observation'])
    env = TfEnv(NormalizedBoxEnv(FinnMamlEnv( baseEnv , reset_mode = 'task')))
    max_path_length = 150
    EXPERT_TRAJ_LOCATION_DICT = '/root/code/maml_gps/saved_expert_traj/Expert_trajs_sawyer_pusher/'

  else:
    raise AssertionError('Env must be either Ant or SawyerPusher')

 
  policy = MAMLGaussianMLPPolicy(
      name="policy",
      env_spec=env.spec,
      grad_step_size=fast_learning_rate,
      hidden_nonlinearity=tf.nn.relu,
      hidden_sizes=(100, 100),
      std_modifier=pre_std_modifier,
      # metalearn_baseline=(bas == "MAMLGaussianMLP"),
      extra_input_dim=(0 if extra_input is None else extra_input_dim),
      updateMode = updateMode,
      num_tasks = meta_batch_size
  )
 
  
  baseline = LinearFeatureBaseline(env_spec=env.spec)
 
  algo = MAMLIL(
      env=env,
      policy=policy,
      #policy=None,
      #oad_policy='/home/alvin/maml_rl/data/local/R7-IL-0918/R7_IL_200_40_1_1_dem40_ei5_as50_basl_1809_04_27/itr_24.pkl',
      baseline=baseline,
      batch_size=fast_batch_size,  # number of trajs for alpha grad update
      max_path_length=max_path_length,
      meta_batch_size=meta_batch_size,  # number of tasks sampled for beta grad update
      num_grad_updates=num_grad_updates,  # number of alpha grad updates
      n_itr=200, #100
      make_video=False,
      use_maml=use_maml,
      use_pooled_goals=True,
      use_corr_term=use_corr_term,
      test_on_training_goals=test_on_training_goals,
      metalearn_baseline=False,
      # metalearn_baseline=False,
      limit_demos_num=limit_demos_num,
      test_goals_mult=test_goals_mult,
      step_size=meta_step_size,
      plot=False,
      beta_steps=beta_steps,
      adam_curve=adam_curve,
      adam_steps=adam_steps,
      pre_std_modifier=pre_std_modifier,
      l2loss_std_mult=l2loss_std_mult,
      importance_sampling_modifier=MOD_FUNC[ism],
      post_std_modifier_train=post_std_modifier_train,
      post_std_modifier_test=post_std_modifier_test,
      expert_trajs_dir=EXPERT_TRAJ_LOCATION_DICT,
      #[env_option+"."+mode+goals_suffix],
      expert_trajs_suffix="",
      seed=seed,
      extra_input=extra_input,
      extra_input_dim=(0 if extra_input is None else extra_input_dim),
      updateMode = updateMode
  )
  algo.train()
Exemple #10
0
def run_FaReLI(input_feed=None):
    beta_adam_steps_list = [(1,50)]
    # beta_curve = [250,250,250,250,250,5,5,5,5,1,1,1,1,] # make sure to check maml_experiment_vars
    # beta_curve = [1000] # make sure to check maml_experiment_vars
    adam_curve = [250,249,248,247,245,50,50,10] # make sure to check maml_experiment_vars
    # adam_curve = None

    fast_learning_rates = [1.0]
    baselines = ['linear',]  # linear GaussianMLP MAMLGaussianMLP zero
    env_option = ''
    # mode = "ec2"
    mode = "local"
    extra_input = "onehot_exploration" # "onehot_exploration" "gaussian_exploration"
    # extra_input = None
    extra_input_dim = 5
    # extra_input_dim = None
    goals_suffixes = ["_200_40_1"] #,"_200_40_2", "_200_40_3","_200_40_4"]
    # goals_suffixes = ["_1000_40"]

    fast_batch_size_list = [20]  # 20 # 10 works for [0.1, 0.2], 20 doesn't improve much for [0,0.2]  #inner grad update size
    meta_batch_size_list = [40]  # 40 @ 10 also works, but much less stable, 20 is fairly stable, 40 is more stable
    max_path_length = 100  # 100
    num_grad_updates = 1
    meta_step_size = 0.01
    pre_std_modifier_list = [1.0]
    post_std_modifier_train_list = [0.00001]
    post_std_modifier_test_list = [0.00001]
    l2loss_std_mult_list = [1.0]
    importance_sampling_modifier_list = ['']  #'', 'clip0.5_'
    limit_demos_num_list = [1]  # 40
    test_goals_mult = 1
    bas_lr = 0.01 # baseline learning rate
    momentum=0.5
    bas_hnl = tf.nn.relu
    baslayers_list = [(32,32), ]

    basas = 60 # baseline adam steps
    use_corr_term = True
    seeds = [1] #,2,3,4,5]
    envseeds = [6]
    use_maml = True
    test_on_training_goals = False
    for goals_suffix in goals_suffixes:
        for envseed in envseeds:
            for seed in seeds:
                for baslayers in baslayers_list:
                    for fast_batch_size in fast_batch_size_list:
                        for meta_batch_size in meta_batch_size_list:
                            for ism in importance_sampling_modifier_list:
                                for limit_demos_num in limit_demos_num_list:
                                    for l2loss_std_mult in l2loss_std_mult_list:
                                        for post_std_modifier_train in post_std_modifier_train_list:
                                            for post_std_modifier_test in post_std_modifier_test_list:
                                                for pre_std_modifier in pre_std_modifier_list:
                                                    for fast_learning_rate in fast_learning_rates:
                                                        for beta_steps, adam_steps in beta_adam_steps_list:
                                                            for bas in baselines:
                                                                stub(globals())
                                                                tf.set_random_seed(seed)
                                                                np.random.seed(seed)
                                                                rd.seed(seed)
                                                                env = TfEnv(normalize(Reacher7DofMultitaskEnv(envseed=envseed)))
                                                                exp_name = str(
                                                                    'R7_IL'
                                                                    # +time.strftime("%D").replace("/", "")[0:4]
                                                                    + goals_suffix + "_"
                                                                    + str(seed)
                                                                    # + str(envseed)
                                                                    + ("" if use_corr_term else "nocorr")
                                                                    # + str(int(use_maml))
                                                                    + ('_fbs' + str(fast_batch_size) if fast_batch_size!=20 else "")
                                                                    + ('_mbs' + str(meta_batch_size) if meta_batch_size!=40 else "")
                                                                    + ('_flr' + str(fast_learning_rate) if fast_learning_rate!=1.0 else "")
                                                                    + '_dem' + str(limit_demos_num)
                                                                    + ('_ei' + str(extra_input_dim) if type(
                                                                        extra_input_dim) == int else "")
                                                                    # + '_tgm' + str(test_goals_mult)
                                                                    #     +'metalr_'+str(meta_step_size)
                                                                    #     +'_ngrad'+str(num_grad_updates)
                                                                    + ("_bs" + str(beta_steps) if beta_steps != 1 else "")
                                                                    + "_as" + str(adam_steps)
                                                                    # +"_net" + str(net_size[0])
                                                                    # +"_L2m" + str(l2loss_std_mult)
                                                                    + ("_prsm" + str(
                                                                        pre_std_modifier) if pre_std_modifier != 1 else "")
                                                                    # + "_pstr" + str(post_std_modifier_train)
                                                                    # + "_posm" + str(post_std_modifier_test)
                                                                    #  + "_l2m" + str(l2loss_std_mult)
                                                                    + ("_" + ism if len(ism) > 0 else "")
                                                                    + "_bas" + bas[0]
                                                                    # +"_tfbe" # TF backend for baseline
                                                                    # +"_qdo" # quad dist optimizer
                                                                    + (("_bi" if bas_hnl == tf.identity else (
                                                                        "_brel" if bas_hnl == tf.nn.relu else "_bth"))  # identity or relu or tanh for baseline
                                                                       # + "_" + str(baslayers)  # size
                                                                       + "_baslr" + str(bas_lr)
                                                                       + "_basas" + str(basas) if bas[0] in ["G",
                                                                                                             "M"] else "")  # baseline adam steps
                                                                    + ("r" if test_on_training_goals else "")
                                                                    + "_" + time.strftime("%d%m_%H_%M"))



                                                                policy = MAMLGaussianMLPPolicy(
                                                                    name="policy",
                                                                    env_spec=env.spec,
                                                                    grad_step_size=fast_learning_rate,
                                                                    hidden_nonlinearity=tf.nn.relu,
                                                                    hidden_sizes=(100, 100),
                                                                    std_modifier=pre_std_modifier,
                                                                    # metalearn_baseline=(bas == "MAMLGaussianMLP"),
                                                                    extra_input_dim=(0 if extra_input is None else extra_input_dim),
                                                                )
                                                                if bas == 'zero':
                                                                    baseline = ZeroBaseline(env_spec=env.spec)
                                                                elif bas == 'MAMLGaussianMLP':
                                                                    baseline = MAMLGaussianMLPBaseline(env_spec=env.spec,
                                                                                                       learning_rate=bas_lr,
                                                                                                       hidden_sizes=baslayers,
                                                                                                       hidden_nonlinearity=bas_hnl,
                                                                                                       repeat=basas,
                                                                                                       repeat_sym=basas,
                                                                                                       momentum=momentum,
                                                                                                       extra_input_dim=( 0 if extra_input is None else extra_input_dim),

                                                                                                       # learn_std=False,
                                                                                                       # use_trust_region=False,
                                                                                                       # optimizer=QuadDistExpertOptimizer(
                                                                                                       #      name="bas_optimizer",
                                                                                                       #     #  tf_optimizer_cls=tf.train.GradientDescentOptimizer,
                                                                                                       #     #  tf_optimizer_args=dict(
                                                                                                       #     #      learning_rate=bas_lr,
                                                                                                       #     #  ),
                                                                                                       #     # # tf_optimizer_cls=tf.train.AdamOptimizer,
                                                                                                       #     # max_epochs=200,
                                                                                                       #     # batch_size=None,
                                                                                                       #      adam_steps=basas
                                                                                                       #     )
                                                                                                       )

                                                                elif bas == 'linear':
                                                                    baseline = LinearFeatureBaseline(env_spec=env.spec)
                                                                elif "GaussianMLP" in bas:
                                                                    baseline = GaussianMLPBaseline(env_spec=env.spec,
                                                                                                   regressor_args=dict(
                                                                                                       hidden_sizes=baslayers,
                                                                                                       hidden_nonlinearity=bas_hnl,
                                                                                                       learn_std=False,
                                                                                                       # use_trust_region=False,
                                                                                                       # normalize_inputs=False,
                                                                                                       # normalize_outputs=False,
                                                                                                       optimizer=QuadDistExpertOptimizer(
                                                                                                           name="bas_optimizer",
                                                                                                           #  tf_optimizer_cls=tf.train.GradientDescentOptimizer,
                                                                                                           #  tf_optimizer_args=dict(
                                                                                                           #      learning_rate=bas_lr,
                                                                                                           #  ),
                                                                                                           # # tf_optimizer_cls=tf.train.AdamOptimizer,
                                                                                                           # max_epochs=200,
                                                                                                           # batch_size=None,
                                                                                                           adam_steps=basas,
                                                                                                           use_momentum_optimizer=True,
                                                                                                       )))
                                                                algo = MAMLIL(
                                                                    env=env,
                                                                    policy=policy,
                                                                    baseline=baseline,
                                                                    batch_size=fast_batch_size,  # number of trajs for alpha grad update
                                                                    max_path_length=max_path_length,
                                                                    meta_batch_size=meta_batch_size,  # number of tasks sampled for beta grad update
                                                                    num_grad_updates=num_grad_updates,  # number of alpha grad updates
                                                                    n_itr=800, #100
                                                                    make_video=True,
                                                                    use_maml=use_maml,
                                                                    use_pooled_goals=True,
                                                                    use_corr_term=use_corr_term,
                                                                    test_on_training_goals=test_on_training_goals,
                                                                    metalearn_baseline=(bas=="MAMLGaussianMLP"),
                                                                    # metalearn_baseline=False,
                                                                    limit_demos_num=limit_demos_num,
                                                                    test_goals_mult=test_goals_mult,
                                                                    step_size=meta_step_size,
                                                                    plot=False,
                                                                    beta_steps=beta_steps,
                                                                    adam_curve=adam_curve,
                                                                    adam_steps=adam_steps,
                                                                    pre_std_modifier=pre_std_modifier,
                                                                    l2loss_std_mult=l2loss_std_mult,
                                                                    importance_sampling_modifier=MOD_FUNC[ism],
                                                                    post_std_modifier_train=post_std_modifier_train,
                                                                    post_std_modifier_test=post_std_modifier_test,
                                                                    expert_trajs_dir=EXPERT_TRAJ_LOCATION_DICT[env_option+"."+mode+goals_suffix+("_"+str(extra_input_dim) if type(extra_input_dim) == int else "")],
                                                                    expert_trajs_suffix=("_"+str(extra_input_dim) if type(extra_input_dim) == int else ""),
                                                                    seed=seed,
                                                                    extra_input=extra_input,
                                                                    extra_input_dim=(0 if extra_input is None else extra_input_dim),
                                                                    input_feed=input_feed,
                                                                    run_on_pr2=False,

                                                                )
                                                                run_experiment_lite(
                                                                    algo.train(),
                                                                    n_parallel=1,
                                                                    snapshot_mode="last",
                                                                    python_command='python3',
                                                                    seed=seed,
                                                                    exp_prefix=str('R7_IL_'
                                                                                   +time.strftime("%D").replace("/", "")[0:4]),
                                                                    exp_name=exp_name,
                                                                    plot=False,
                                                                    sync_s3_pkl=True,
                                                                    mode=mode,
                                                                    terminate_machine=True,
                                                                )