Exemplo n.º 1
0
 def proc_rollouts(rollouts, traj_len=None):
     if rollouts is None:
         return None
     else:
         # for padding
         max_len = max(len(rollout) for rollout in rollouts)
         max_len = min(
             max_len if traj_len is None else max(
                 traj_len - 1, max_len), self.env.max_ep_len)
         return utils.split_rollouts(
             utils.vectorize_rollouts(rollouts, max_len))
Exemplo n.º 2
0
def main():
  sess = utils.make_tf_session(gpu_mode=False)

  env = envs.make_carracing_env(sess)
  trans_env = envs.make_carracing_trans_env(sess)
  random_policy = utils.make_random_policy(env)

  utils.run_ep(random_policy, env, max_ep_len=3, render=False)
  trans_rollout = utils.run_ep(
      random_policy, trans_env, max_ep_len=3, render=False)

  logging.info('envs and policies OK')

  raw_demo_rollouts = [
      utils.run_ep(random_policy, env, max_ep_len=3, render=False)
      for _ in range(n_demo_rollouts)
  ]
  raw_aug_rollouts = [
      utils.run_ep(random_policy, env, max_ep_len=3, render=False)
      for _ in range(n_aug_rollouts)
  ]
  raw_aug_rollouts += raw_demo_rollouts

  raw_aug_obses = []
  for rollout in raw_aug_rollouts:
    for x in rollout:
      raw_aug_obses.append(x[0])
  raw_aug_obses = np.array(raw_aug_obses)
  raw_aug_obs_data = utils.split_rollouts({'obses': raw_aug_obses})

  logging.info('data collection OK')

  encoder = VAEModel(
      sess,
      env,
      learning_rate=0.0001,
      kl_tolerance=0.5,
      scope=str(uuid.uuid4()),
      scope_file=os.path.join(test_data_dir, 'enc_scope.pkl'),
      tf_file=os.path.join(test_data_dir, 'enc.tf'))

  encoder.train(
      raw_aug_obs_data,
      iterations=1,
      ftol=1e-4,
      learning_rate=1e-3,
      val_update_freq=1,
      verbose=False)

  encoder = load_wm_pretrained_vae(sess, env)

  encoder.save()

  encoder.load()

  obs = raw_aug_rollouts[0][0][0]
  latent = encoder.encode_frame(obs)
  unused_recon = encoder.decode_latent(latent)

  logging.info('encoder OK')

  raw_aug_traj_data = utils.split_rollouts(
      utils.vectorize_rollouts(
          raw_aug_rollouts, env.max_ep_len, preserve_trajs=True))

  abs_model = AbsorptionModel(
      sess,
      env,
      n_layers=1,
      layer_size=32,
      scope=str(uuid.uuid4()),
      scope_file=os.path.join(test_data_dir, 'abs_scope.pkl'),
      tf_file=os.path.join(test_data_dir, 'abs.tf'))

  dynamics_model = MDNRNNDynamicsModel(
      encoder,
      sess,
      env,
      scope=str(uuid.uuid4()),
      tf_file=os.path.join(test_data_dir, 'dyn.tf'),
      scope_file=os.path.join(test_data_dir, 'dyn_scope.pkl'),
      abs_model=abs_model)

  dynamics_model.train(
      raw_aug_traj_data,
      iterations=1,
      learning_rate=1e-3,
      ftol=1e-4,
      batch_size=2,
      val_update_freq=1,
      verbose=False)

  dynamics_model = load_wm_pretrained_rnn(encoder, sess, env)

  dynamics_model.save()

  dynamics_model.load()

  demo_traj_data = utils.rnn_encode_rollouts(raw_demo_rollouts, env, encoder,
                                             dynamics_model)
  aug_traj_data = utils.rnn_encode_rollouts(raw_aug_rollouts, env, encoder,
                                            dynamics_model)
  demo_rollouts = utils.rollouts_of_traj_data(demo_traj_data)
  aug_rollouts = utils.rollouts_of_traj_data(aug_traj_data)
  demo_data = utils.split_rollouts(utils.flatten_traj_data(demo_traj_data))
  aug_data = utils.split_rollouts(utils.flatten_traj_data(aug_traj_data))

  env.default_init_obs = aug_rollouts[0][0][0]

  trans_rollouts = utils.rollouts_of_traj_data(
      utils.rnn_encode_rollouts([trans_rollout], trans_env, encoder,
                                dynamics_model))
  trans_env.default_init_obs = trans_rollouts[0][0][0]

  logging.info('mdnrnn dynamics OK')

  demo_data_for_reward_model = demo_data
  demo_rollouts_for_reward_model = demo_rollouts

  sketch_data_for_reward_model = aug_data
  sketch_rollouts_for_reward_model = aug_rollouts

  reward_init_kwargs = {
      'n_rew_nets_in_ensemble': 2,
      'n_layers': 1,
      'layer_size': 32,
      'scope': str(uuid.uuid4()),
      'scope_file': os.path.join(test_data_dir, 'true_rew_scope.pkl'),
      'tf_file': os.path.join(test_data_dir, 'true_rew.tf'),
      'rew_func_input': "s'",
      'use_discrete_rewards': True
  }

  reward_train_kwargs = {
      'demo_coeff': 1.,
      'sketch_coeff': 1.,
      'iterations': 1,
      'ftol': 1e-4,
      'batch_size': 2,
      'learning_rate': 1e-3,
      'val_update_freq': 1,
      'verbose': False
  }

  data = envs.make_carracing_rew(
      sess,
      env,
      sketch_data=sketch_data_for_reward_model,
      reward_init_kwargs=reward_init_kwargs,
      reward_train_kwargs=reward_train_kwargs)
  env.__dict__.update(data)
  trans_env.__dict__.update(data)

  autolabels = reward_models.autolabel_prefs(
      aug_rollouts, env, segment_len=env.max_ep_len + 1)

  pref_logs_for_reward_model = autolabels
  pref_data_for_reward_model = utils.split_prefs(autolabels)

  logging.info('autolabels OK')

  for rew_func_input in ['s', 'sa', "s'"]:
    reward_model = reward_models.RewardModel(
        sess,
        env,
        n_rew_nets_in_ensemble=2,
        n_layers=1,
        layer_size=32,
        scope=str(uuid.uuid4()),
        scope_file=os.path.join(test_data_dir, 'rew_scope.pkl'),
        tf_file=os.path.join(test_data_dir, 'rew.tf'),
        rew_func_input=rew_func_input,
        use_discrete_rewards=True)

  for demo_data in [None, demo_data_for_reward_model]:
    for sketch_data in [None, sketch_data_for_reward_model]:
      for pref_data in [None, pref_data_for_reward_model]:
        if pref_data is None and sketch_data is None:
          continue
        reward_model.train(
            demo_data=demo_data,
            sketch_data=sketch_data,
            pref_data=pref_data,
            demo_coeff=1.,
            sketch_coeff=1.,
            iterations=1,
            ftol=1e-4,
            batch_size=2,
            learning_rate=1e-3,
            val_update_freq=1,
            verbose=False)

  reward_model.save()

  reward_model.load()

  logging.info('reward models OK')

  for query_loss_opt in [
      'pref_uncertainty', 'rew_uncertainty', 'max_rew', 'min_rew', 'max_nov'
  ]:
    for init_obs in [None, env.default_init_obs]:
      for join_trajs_at_init_state in [True, False]:
        for shoot_steps in [1, 2]:
          if (shoot_steps > 1 and
              np.array(init_obs == env.default_init_obs).all()):
            continue
          traj_opt = GDTrajOptimizer(
              sess,
              env,
              reward_model,
              dynamics_model,
              traj_len=2,
              n_trajs=2,
              prior_coeff=1.,
              diversity_coeff=0.,
              query_loss_opt=query_loss_opt,
              opt_init_obs=(init_obs is None),
              join_trajs_at_init_state=join_trajs_at_init_state,
              shoot_steps=shoot_steps,
              learning_rate=1e-2)

          traj_opt.run(
              init_obs=init_obs,
              iterations=1,
              ftol=1e-4,
              verbose=False,
          )

  logging.info('grad descent traj opt OK')

  imitation_kwargs = {'plan_horizon': 10, 'n_blind_steps': 2, 'test_mode': True}

  for n_eval_rollouts in [0, 1]:
    reward_models.evaluate_reward_model(
        sess,
        env,
        trans_env,
        reward_model,
        dynamics_model,
        offpol_eval_rollouts=sketch_rollouts_for_reward_model,
        n_eval_rollouts=n_eval_rollouts,
        imitation_kwargs=imitation_kwargs)

  logging.info('reward eval OK')

  for query_loss_opt in [
      'pref_uncertainty', 'rew_uncertainty', 'max_rew', 'min_rew', 'max_nov',
      'unif'
  ]:
    for use_rand_policy in [False, True]:
      traj_opt = StochTrajOptimizer(
          sess,
          env,
          reward_model,
          dynamics_model,
          traj_len=2,
          rollout_len=2,
          query_loss_opt=query_loss_opt,
          use_rand_policy=use_rand_policy)

      for init_obs in [None, env.default_init_obs]:
        traj_opt.run(n_trajs=2, n_samples=2, init_obs=init_obs, verbose=False)

  logging.info('stoch traj opt OK')

  reward_model = reward_models.RewardModel(
      sess,
      env,
      n_rew_nets_in_ensemble=2,
      n_layers=1,
      layer_size=32,
      scope=str(uuid.uuid4()),
      scope_file=os.path.join(test_data_dir, 'rew_scope.pkl'),
      tf_file=os.path.join(test_data_dir, 'rew.tf'),
      rew_func_input="s'",
      use_discrete_rewards=True)

  rew_optimizer = InteractiveRewardOptimizer(sess, env, trans_env, reward_model,
                                             dynamics_model)

  reward_train_kwargs = {
      'demo_coeff': 1.,
      'sketch_coeff': 1.,
      'iterations': 1,
      'ftol': 1e-4,
      'batch_size': 2,
      'learning_rate': 1e-3,
      'val_update_freq': 1,
      'verbose': False
  }

  dynamics_train_kwargs = {
      'iterations': 1,
      'batch_size': 2,
      'learning_rate': 1e-3,
      'ftol': 1e-4,
      'val_update_freq': 1,
      'verbose': False
  }

  gd_traj_opt_init_kwargs = {
      'traj_len': env.max_ep_len,
      'n_trajs': 2,
      'prior_coeff': 1.,
      'diversity_coeff': 1.,
      'query_loss_opt': 'pref_uncertainty',
      'opt_init_obs': False,
      'learning_rate': 1e-2,
      'join_trajs_at_init_state': False
  }

  gd_traj_opt_run_kwargs = {
      'init_obs': env.default_init_obs,
      'iterations': 1,
      'ftol': 1e-4,
      'verbose': False,
  }

  unused_stoch_traj_opt_init_kwargs = {
      'traj_len': 2,
      'rollout_len': 2,
      'query_loss_opt': 'pref_uncertainty'
  }

  unused_stoch_traj_opt_run_kwargs = {
      'n_samples': 2,
      'init_obs': None,
      'verbose': False
  }

  eval_kwargs = {'n_eval_rollouts': 1}

  for init_train in [True, False]:
    for query_type in ['pref', 'sketch']:
      rew_optimizer.run(
          demo_rollouts=demo_rollouts_for_reward_model,
          sketch_rollouts=sketch_rollouts_for_reward_model,
          pref_logs=pref_logs_for_reward_model,
          rollouts_for_dyn=raw_aug_rollouts,
          reward_train_kwargs=reward_train_kwargs,
          dynamics_train_kwargs=dynamics_train_kwargs,
          traj_opt_cls=GDTrajOptimizer,
          traj_opt_run_kwargs=gd_traj_opt_run_kwargs,
          traj_opt_init_kwargs=gd_traj_opt_init_kwargs,
          imitation_kwargs=imitation_kwargs,
          eval_kwargs=eval_kwargs,
          init_train_dyn=init_train,
          init_train_rew=init_train,
          n_imitation_rollouts_per_dyn_update=1,
          n_queries=1,
          reward_update_freq=1,
          reward_eval_freq=1,
          dyn_update_freq=1,
          verbose=False,
          query_type=query_type)

  rew_optimizer.save()

  rew_optimizer.load()

  logging.info('rqst OK')
Exemplo n.º 3
0
    def run(self,
            demo_rollouts=None,
            sketch_rollouts=None,
            pref_logs=None,
            rollouts_for_dyn=[],
            reward_train_kwargs=None,
            dynamics_train_kwargs=None,
            traj_opt_cls=None,
            traj_opt_run_kwargs=None,
            traj_opt_init_kwargs=None,
            imitation_kwargs=None,
            eval_kwargs=None,
            init_train_dyn=False,
            init_train_rew=False,
            n_imitation_rollouts_per_dyn_update=1,
            n_queries=1000,
            reward_update_freq=1,
            reward_eval_freq=1,
            dyn_update_freq=1,
            verbose=True,
            warm_start_rew=False,
            query_type='pref',
            callback=None):

        if query_type not in ['pref', 'sketch', 'demo']:
            raise ValueError

        if query_type == 'demo' and not any(
                utils.isinstance(self.reward_model, reward_model_cls_name) for
                reward_model_cls_name in ['REDRewardModel', 'BCRewardModel']):
            raise ValueError

        if reward_train_kwargs is None:
            reward_train_kwargs = default_reward_train_kwargs

        if dynamics_train_kwargs is None:
            dynamics_train_kwargs = default_dynamics_train_kwargs

        if traj_opt_cls is None:
            traj_opt_cls = default_traj_opt_cls

        if traj_opt_run_kwargs is None:
            traj_opt_run_kwargs = default_traj_opt_run_kwargs

        if traj_opt_init_kwargs is None:
            traj_opt_init_kwargs = default_traj_opt_init_kwargs

        if type(traj_opt_init_kwargs) != type(traj_opt_run_kwargs):
            raise ValueError

        if type(traj_opt_init_kwargs
                ) == dict and 'opt_init_obs' in traj_opt_init_kwargs and (
                    not traj_opt_init_kwargs['opt_init_obs']
                ) and 'init_obs' not in traj_opt_run_kwargs:
            raise ValueError

        if imitation_kwargs is None:
            imitation_kwargs = default_imitation_kwargs

        if eval_kwargs is None:
            eval_kwargs = default_eval_kwargs
        eval_kwargs['imitation_kwargs'] = imitation_kwargs

        if type(traj_opt_init_kwargs) == list:
            for i in range(len(traj_opt_init_kwargs)):
                traj_opt_init_kwargs[i]['query_type'] = query_type
        else:
            traj_opt_init_kwargs['query_type'] = query_type

        if verbose:
            print('initializing reward and dynamics models...')

        demo_rollouts = deepcopy(demo_rollouts)
        sketch_rollouts = deepcopy(sketch_rollouts)
        pref_logs = deepcopy(pref_logs)

        def proc_rollouts(rollouts, traj_len=None):
            if rollouts is None:
                return None
            else:
                # for padding
                max_len = max(len(rollout) for rollout in rollouts)
                max_len = min(
                    max_len if traj_len is None else max(
                        traj_len - 1, max_len), self.env.max_ep_len)
                return utils.split_rollouts(
                    utils.vectorize_rollouts(rollouts, max_len))

        proc_pref_logs = lambda pref_logs: None if pref_logs is None else utils.split_prefs(
            pref_logs)

        demo_data = proc_rollouts(demo_rollouts)
        sketch_data = proc_rollouts(sketch_rollouts)
        pref_data = proc_pref_logs(pref_logs)

        if init_train_rew:
            self.reward_model.train(demo_data=demo_data,
                                    sketch_data=sketch_data,
                                    pref_data=pref_data,
                                    **reward_train_kwargs)
        else:
            self.reward_model.init_tf_vars()

        using_rnn_dyn = utils.isinstance(self.dynamics_model,
                                         'MDNRNNDynamicsModel')
        proc_dyn_rollouts = lambda rollouts: utils.split_rollouts(
            utils.vectorize_rollouts(
                rollouts, self.env.max_ep_len, preserve_trajs=using_rnn_dyn))

        if init_train_dyn:
            self.dynamics_model.train(proc_dyn_rollouts(rollouts_for_dyn),
                                      **dynamics_train_kwargs)

        if pref_logs is None and query_type == 'pref':
            pref_logs = {
                'ref_trajs': [],
                'trajs': [],
                'ref_act_seqs': [],
                'act_seqs': [],
                'prefs': []
            }
        if sketch_rollouts is None:
            sketch_rollouts = []
        if demo_rollouts is None:
            demo_rollouts = []

        make_traj_optimizer = lambda kwargs: traj_opt_cls(
            self.sess, self.env, self.reward_model, self.dynamics_model, **
            kwargs)

        if type(traj_opt_init_kwargs) != list:
            traj_opt_run_kwargs = [traj_opt_run_kwargs]
            traj_opt_init_kwargs = [traj_opt_init_kwargs]

        imitator = traj_opt.make_imitation_policy(self.sess, self.env,
                                                  self.reward_model,
                                                  self.dynamics_model,
                                                  **imitation_kwargs)

        def update_rew_perf(rew_perf_evals, n_queries_made):
            rew_eval = reward_models.evaluate_reward_model(self.sess,
                                                           self.env,
                                                           self.trans_env,
                                                           self.reward_model,
                                                           self.dynamics_model,
                                                           imitator=imitator,
                                                           **eval_kwargs)

            rew_perf = rew_eval['perf']
            rew_perf['n_queries'] = n_queries_made
            rew_perf['n_real_rollouts'] = len(
                rollouts_for_dyn) + n_real_rollouts_from_traj_opt

            if verbose:
                print('\n'.join(
                    ['%s: %s' % (k, str(v)) for k, v in rew_perf.items()]),
                      flush=True)
                # uncomment to plot learned rewards
                #utils.viz_rew_eval(rew_eval, self.env, encoder=self.dynamics_model.encoder)

            if rew_perf_evals == {}:
                rew_perf_evals = {k: [] for k in rew_perf}
            for k, v in rew_perf.items():
                rew_perf_evals[k].append(v)

            return rew_perf_evals

        if verbose:
            print('initializing traj optimizers...')

        traj_optimizers = [
            make_traj_optimizer(kwargs) for kwargs in traj_opt_init_kwargs
        ]

        if query_type == 'demo' and self.dynamics_model.encoder is not None:
            proc_obses = self.dynamics_model.encoder.decode_batch_latents
        else:
            proc_obses = None

        if verbose:
            print('evaluating reward model...')

        n_queries_made = 0
        n_real_rollouts_from_traj_opt = 0
        rew_perf_evals = update_rew_perf({}, n_queries_made)

        if verbose:
            print('')

        iter_idx = 0

        if type(traj_opt_run_kwargs[0]['init_obs']) == list:
            init_obses = traj_opt_run_kwargs[0]['init_obs']
        else:
            init_obses = None
            traj_opt_run_kwargs_update = {}

        while n_queries_made < n_queries:
            start_time = time.time()
            if verbose:
                print('iter %d' % iter_idx)
                print('synthesizing queries...')

            query_trajs = []
            query_act_seqs = []
            if init_obses is not None:
                traj_opt_run_kwargs_update = {
                    'init_obs': init_obses[iter_idx % len(init_obses)]
                }
            for traj_optimizer, kwargs in zip(traj_optimizers,
                                              traj_opt_run_kwargs):
                kwargs.update(traj_opt_run_kwargs_update)
                data = traj_optimizer.run(**kwargs)
                query_trajs.extend(data['traj'])
                query_act_seqs.extend(data['act_seq'])

            if utils.isinstance(
                    traj_optimizers[0], 'StochTrajOptimizer'
            ) and traj_opt_run_kwargs[0]['init_obs'] is None:
                assert len(traj_opt_run_kwargs) == 1
                n_real_rollouts_from_traj_opt += traj_opt_run_kwargs[0][
                    'n_samples']

            if verbose:
                print('eliciting feedback...')

            if query_type == 'pref':
                if len(query_trajs) != 2:
                    raise ValueError
                pref = reward_models.synth_pref(query_trajs[0],
                                                query_act_seqs[0],
                                                query_trajs[1],
                                                query_act_seqs[1],
                                                self.env.reward_func)
                pref_logs['ref_trajs'].append(query_trajs[0])
                pref_logs['ref_act_seqs'].append(query_act_seqs[0])
                pref_logs['trajs'].append(query_trajs[1])
                pref_logs['act_seqs'].append(query_act_seqs[1])
                pref_logs['prefs'].append(pref)
                n_queries_made += 1
            elif query_type == 'sketch':
                sketches = [
                    reward_models.synth_sketch(traj, act_seq,
                                               self.env.reward_func)
                    for traj, act_seq in zip(query_trajs, query_act_seqs)
                ]
                sketch_rollouts.extend(sketches)
                n_queries_made += sum(len(sketch) for sketch in sketches)
            elif query_type == 'demo':
                demos = [
                    reward_models.synth_demo(traj,
                                             self.env.expert_policy,
                                             proc_obses=proc_obses)
                    for traj in query_trajs
                ]
                demo_rollouts.extend(demos)
                n_queries_made += sum(len(demo) for demo in demos)

            if verbose:
                query_data = {
                    'demo_rollouts': demo_rollouts,
                    'sketch_rollouts': sketch_rollouts,
                    'pref_logs': pref_logs
                }
                utils.viz_query_data(query_data,
                                     self.env,
                                     encoder=self.dynamics_model.encoder)

            update_reward = iter_idx % reward_update_freq == 0
            if update_reward:
                if verbose:
                    print('updating reward model...')

                if query_type == 'pref':
                    pref_data = proc_pref_logs(pref_logs)
                elif query_type == 'demo':
                    demo_data = proc_rollouts(
                        demo_rollouts, traj_len=traj_optimizers[0].traj_len)
                elif query_type == 'sketch':
                    sketch_data = proc_rollouts(
                        sketch_rollouts, traj_len=traj_optimizers[0].traj_len)

                self.reward_model.train(demo_data=demo_data,
                                        sketch_data=sketch_data,
                                        pref_data=pref_data,
                                        warm_start=warm_start_rew,
                                        **reward_train_kwargs)

                if verbose:
                    self.reward_model.viz_learned_rew()

            update_dynamics = dyn_update_freq is not None and iter_idx % dyn_update_freq == 0
            if update_dynamics:
                if verbose:
                    print('updating dynamics model...')

                rollouts_for_dyn += [
                    self.dynamics_model.run_ep(imitator,
                                               self.env,
                                               store_raw_obs=using_rnn_dyn)
                    for _ in range(n_imitation_rollouts_per_dyn_update)
                ]
                self.dynamics_model.train(proc_dyn_rollouts(rollouts_for_dyn),
                                          **dynamics_train_kwargs)

            if (iter_idx + 1) % reward_eval_freq == 0:
                if verbose:
                    print('evaluating reward model...')

                rew_perf_evals = update_rew_perf(rew_perf_evals,
                                                 n_queries_made)

                if callback is not None:
                    result = (rew_perf_evals, None)
                    callback(result)

            if verbose:
                print('time elapsed: %f' % (time.time() - start_time),
                      flush=True)
                print('')

            iter_idx += 1

        query_data = {
            'demo_rollouts': demo_rollouts,
            'sketch_rollouts': sketch_rollouts,
            'pref_logs': pref_logs
        }

        return rew_perf_evals, query_data
Exemplo n.º 4
0
def main():
    sess = utils.make_tf_session(gpu_mode=False)

    env = envs.make_pointmass_env()
    trans_env = envs.make_pointmass_trans_env(env)
    expert_policy = env.make_expert_policy()
    random_policy = utils.make_random_policy(env)

    default_init_obs = env.default_init_obs

    utils.run_ep(expert_policy, env)
    utils.run_ep(random_policy, env)
    utils.run_ep(expert_policy, trans_env)
    utils.run_ep(random_policy, trans_env)

    logging.info('envs and policies OK')

    demo_rollouts = [
        utils.run_ep(expert_policy, env) for _ in range(n_demo_rollouts)
    ]

    aug_rollouts = demo_rollouts + [
        utils.run_ep(random_policy, env) for _ in range(n_aug_rollouts)
    ]

    demo_data = utils.split_rollouts(
        utils.vectorize_rollouts(demo_rollouts, env.max_ep_len))
    aug_data = utils.split_rollouts(
        utils.vectorize_rollouts(aug_rollouts, env.max_ep_len))

    unused_demo_traj_data = utils.split_rollouts(
        utils.vectorize_rollouts(demo_rollouts,
                                 env.max_ep_len,
                                 preserve_trajs=True))
    unused_aug_traj_data = utils.split_rollouts(
        utils.vectorize_rollouts(aug_rollouts,
                                 env.max_ep_len,
                                 preserve_trajs=True))

    logging.info('data collection OK')

    abs_model = AbsorptionModel(sess,
                                env,
                                n_layers=1,
                                layer_size=32,
                                scope=str(uuid.uuid4()),
                                scope_file=os.path.join(
                                    test_data_dir, 'abs_scope.pkl'),
                                tf_file=os.path.join(test_data_dir, 'abs.tf'))

    dynamics_model = MLPDynamicsModel(
        sess,
        env,
        n_layers=1,
        layer_size=32,
        scope=str(uuid.uuid4()),
        scope_file=os.path.join(test_data_dir, 'dyn_scope.pkl'),
        tf_file=os.path.join(test_data_dir, 'dyn.tf'),
        abs_model=abs_model)

    dynamics_model.train(aug_data,
                         iterations=1,
                         ftol=1e-4,
                         learning_rate=1e-3,
                         batch_size=4,
                         val_update_freq=1,
                         verbose=False)

    dynamics_model.save()

    dynamics_model.load()

    logging.info('dynamics model OK')

    demo_data_for_reward_model = demo_data
    demo_rollouts_for_reward_model = demo_rollouts
    sketch_data_for_reward_model = aug_data
    sketch_rollouts_for_reward_model = aug_rollouts

    autolabels = reward_models.autolabel_prefs(
        sketch_rollouts_for_reward_model, env, segment_len=env.max_ep_len + 1)

    pref_logs_for_reward_model = autolabels
    pref_data_for_reward_model = utils.split_prefs(autolabels)

    logging.info('autolabels OK')

    for rew_func_input in ['sa', 's', "s'"]:
        reward_model = reward_models.RewardModel(
            sess,
            env,
            n_rew_nets_in_ensemble=4,
            n_layers=1,
            layer_size=64,
            scope=str(uuid.uuid4()),
            scope_file=os.path.join(test_data_dir, 'rew_scope.pkl'),
            tf_file=os.path.join(test_data_dir, 'rew.tf'),
            rew_func_input=rew_func_input,
            use_discrete_rewards=True)

    for demo_data in [None, demo_data_for_reward_model]:
        for sketch_data in [None, sketch_data_for_reward_model]:
            for pref_data in [None, pref_data_for_reward_model]:
                if pref_data is None and sketch_data is None:
                    continue
                reward_model.train(demo_data=demo_data,
                                   sketch_data=sketch_data,
                                   pref_data=pref_data,
                                   demo_coeff=1.,
                                   sketch_coeff=1.,
                                   iterations=1,
                                   ftol=1e-4,
                                   batch_size=4,
                                   learning_rate=1e-3,
                                   val_update_freq=1,
                                   verbose=False)

    reward_model.save()

    reward_model.load()

    logging.info('reward models OK')

    for query_loss_opt in [
            'pref_uncertainty', 'rew_uncertainty', 'max_rew', 'min_rew',
            'max_nov'
    ]:
        for init_obs in [None, default_init_obs]:
            for join_trajs_at_init_state in [True, False]:
                for query_type in ['pref', 'demo', 'sketch']:
                    if query_type == 'pref' and query_loss_opt == 'max_nov':
                        continue

                    for shoot_steps in [1, 2]:
                        traj_optimizer = GDTrajOptimizer(
                            sess,
                            env,
                            reward_model,
                            dynamics_model,
                            traj_len=2,
                            n_trajs=2,
                            prior_coeff=1.,
                            diversity_coeff=0.,
                            query_loss_opt=query_loss_opt,
                            opt_init_obs=(init_obs is None),
                            join_trajs_at_init_state=join_trajs_at_init_state,
                            shoot_steps=shoot_steps,
                            learning_rate=1e-2,
                            query_type=query_type)

                        traj_optimizer.run(
                            init_obs=init_obs,
                            iterations=1,
                            ftol=1e-4,
                            verbose=False,
                        )

    logging.info('grad descent traj opt OK')

    imitation_kwargs = {
        'plan_horizon': 10,
        'n_blind_steps': 2,
        'test_mode': True
    }

    for n_eval_rollouts in [0, 1]:
        reward_models.evaluate_reward_model(
            sess,
            env,
            trans_env,
            reward_model,
            dynamics_model,
            offpol_eval_rollouts=sketch_rollouts_for_reward_model,
            n_eval_rollouts=n_eval_rollouts,
            imitation_kwargs=imitation_kwargs)

    logging.info('reward eval OK')

    for query_loss_opt in [
            'pref_uncertainty', 'rew_uncertainty', 'max_rew', 'min_rew',
            'max_nov', 'unif'
    ]:
        for use_rand_policy in [False, True]:
            traj_optimizer = StochTrajOptimizer(
                sess,
                env,
                reward_model,
                dynamics_model,
                traj_len=2,
                rollout_len=2,
                query_loss_opt=query_loss_opt,
                use_rand_policy=use_rand_policy)

            for init_obs in [None, default_init_obs]:
                traj_optimizer.run(n_trajs=2,
                                   n_samples=2,
                                   init_obs=init_obs,
                                   verbose=False)

    logging.info('stoch traj opt OK')

    reward_model = reward_models.RewardModel(
        sess,
        env,
        n_rew_nets_in_ensemble=4,
        n_layers=1,
        layer_size=64,
        scope=str(uuid.uuid4()),
        scope_file=os.path.join(test_data_dir, 'rew_scope.pkl'),
        tf_file=os.path.join(test_data_dir, 'rew.tf'),
        use_discrete_rewards=True)

    rew_optimizer = InteractiveRewardOptimizer(sess, env, trans_env,
                                               reward_model, dynamics_model)

    reward_train_kwargs = {
        'demo_coeff': 1.,
        'sketch_coeff': 1.,
        'iterations': 1,
        'ftol': 1e-4,
        'batch_size': 4,
        'learning_rate': 1e-3,
        'val_update_freq': 1,
        'verbose': False
    }

    dynamics_train_kwargs = {
        'iterations': 1,
        'batch_size': 4,
        'learning_rate': 1e-3,
        'ftol': 1e-4,
        'val_update_freq': 1,
        'verbose': False
    }

    gd_traj_opt_init_kwargs = {
        'traj_len': env.max_ep_len,
        'n_trajs': 2,
        'prior_coeff': 1.,
        'diversity_coeff': 1.,
        'query_loss_opt': 'pref_uncertainty',
        'opt_init_obs': False,
        'learning_rate': 1e-2,
        'join_trajs_at_init_state': False
    }

    gd_traj_opt_run_kwargs = {
        'init_obs': default_init_obs,
        'iterations': 1,
        'ftol': 1e-4,
        'verbose': False
    }

    unused_stoch_traj_opt_init_kwargs = {
        'traj_len': 2,
        'rollout_len': 2,
        'query_loss_opt': 'pref_uncertainty'
    }

    unused_stoch_traj_opt_run_kwargs = {
        'n_samples': 2,
        'init_obs': None,
        'verbose': False
    }

    eval_kwargs = {'n_eval_rollouts': 1}

    for init_train in [True, False]:
        for query_type in ['pref', 'sketch']:
            rew_optimizer.run(demo_rollouts=demo_rollouts_for_reward_model,
                              sketch_rollouts=sketch_rollouts_for_reward_model,
                              pref_logs=pref_logs_for_reward_model,
                              rollouts_for_dyn=aug_rollouts,
                              reward_train_kwargs=reward_train_kwargs,
                              dynamics_train_kwargs=dynamics_train_kwargs,
                              traj_opt_cls=GDTrajOptimizer,
                              traj_opt_run_kwargs=gd_traj_opt_run_kwargs,
                              traj_opt_init_kwargs=gd_traj_opt_init_kwargs,
                              imitation_kwargs=imitation_kwargs,
                              eval_kwargs=eval_kwargs,
                              init_train_dyn=init_train,
                              init_train_rew=init_train,
                              n_imitation_rollouts_per_dyn_update=1,
                              n_queries=1,
                              reward_update_freq=1,
                              reward_eval_freq=1,
                              dyn_update_freq=1,
                              verbose=False,
                              query_type=query_type)

    rew_optimizer.save()

    rew_optimizer.load()

    logging.info('rqst OK')