Example #1
0
  def train(self,
            data,
            iterations=100000,
            ftol=1e-4,
            batch_size=32,
            learning_rate=1e-3,
            val_update_freq=100,
            verbose=False):

    if self.loss is None:
      return

    opt_scope = utils.opt_scope_of_obj(self)
    with tf.variable_scope(opt_scope, reuse=tf.AUTO_REUSE):
      self.update_op = tf.train.AdamOptimizer(learning_rate).minimize(self.loss)

    utils.init_tf_vars(self.sess, [self.scope, opt_scope])

    val_losses = []
    val_batch = utils.sample_batch(
        size=len(data['val_idxes']),
        data=data,
        data_keys=self.data_keys,
        idxes_key='val_idxes')

    if verbose:
      print('iters total_iters train_loss val_loss')

    for t in range(iterations):
      batch = utils.sample_batch(
          size=batch_size,
          data=data,
          data_keys=self.data_keys,
          idxes_key='train_idxes',
          class_idxes_key='train_idxes_of_act')

      train_loss = self.compute_batch_loss(self.format_batch(batch), update=True)

      if t % val_update_freq == 0:
        val_loss = self.compute_batch_loss(self.format_batch(val_batch), update=False)

        if verbose:
          print('%d %d %f %f' % (t, iterations, train_loss, val_loss))

        val_losses.append(val_loss)

        if utils.converged(val_losses, ftol):
          break

    if verbose:
      plt.plot(val_losses)
      plt.show()
Example #2
0
    def train(self,
              raw_rollout_data,
              iterations=1000,
              learning_rate=1e-3,
              ftol=1e-4,
              batch_size=32,
              val_update_freq=1,
              verbose=False):
        """
    Args:
     raw_rollout_data: a dict containing the output of a call to
      rqst.utils.vectorize_rollouts(, preserve_trajs=True)
      contains processed (but not yet encoded) frames
      raw_rollout_data['obses'] maps to a np.array with dimensions (n_trajs, traj_len, 64, 64, 3)
    """

        rollout_data = self.preproc_rollouts(raw_rollout_data)

        opt_scope = utils.opt_scope_of_obj(self)
        with tf.variable_scope(opt_scope, reuse=tf.AUTO_REUSE):
            optimizer = tf.train.AdamOptimizer(learning_rate)
            gvs = optimizer.compute_gradients(self.loss)
            capped_gvs = [(utils.tf_clip(grad, self.grad_clip), var)
                          for grad, var in gvs]
            global_step = tf.Variable(0, name='global_step', trainable=False)
            self.update_op = optimizer.apply_gradients(capped_gvs,
                                                       global_step=global_step,
                                                       name='train_step')

        utils.init_tf_vars(self.sess, [self.scope, opt_scope])

        val_losses = []
        val_batch = self.sample_seq_batch(len(rollout_data['val_idxes']),
                                          rollout_data,
                                          idxes_key='val_idxes')

        if verbose:
            print('iters total_iters train_loss val_loss')
        for t in range(iterations):
            batch = self.sample_seq_batch(batch_size,
                                          rollout_data,
                                          idxes_key='train_idxes')
            train_loss = self.compute_batch_loss(self.format_batch(batch),
                                                 update=True)
            if t % val_update_freq == 0:
                val_loss = self.compute_batch_loss(
                    self.format_batch(val_batch), update=False)
                if verbose:
                    print('%d %d %f %f' %
                          (t, iterations, train_loss, val_loss))
                val_losses.append(val_loss)
                if utils.converged(val_losses, ftol):
                    break
        if verbose:
            plt.plot(val_losses)
            plt.show()

        if self.abs_model is not None:
            self.abs_model.train(utils.flatten_traj_data(
                self.rnn_encode_rollouts(rollout_data)),
                                 iterations=iterations,
                                 learning_rate=learning_rate,
                                 ftol=ftol,
                                 batch_size=batch_size,
                                 val_update_freq=val_update_freq,
                                 verbose=verbose)
Example #3
0
    def _run(
        self,
        init_obs=None,
        act_seq=None,
        iterations=10000,
        ftol=1e-6,
        min_iters=2,
        verbose=False,
        warm_start=False,
        init_with_lbfgs=False,
        init_act_seq=None,
        init_traj=None,
    ):

        if (init_obs is not None) == self.opt_init_obs:
            raise ValueError

        if (act_seq is not None) == self.opt_act_seq:
            raise ValueError

        if act_seq is not None and init_act_seq is not None:
            raise ValueError

        if init_act_seq is not None and warm_start:
            raise ValueError

        if self.query_loss_opt == 'unif':
            if self.env.name == 'clfbandit':
                std = np.exp(-self.prior_coeff)

                def rand_traj():
                    obs = np.random.normal(0, std,
                                           self.env.n_z_dim)[np.newaxis, :]
                    next_obs = self.env.absorbing_state[np.newaxis, :]
                    return np.concatenate((obs, next_obs), axis=0)

                trajs_eval = [rand_traj() for _ in range(self.n_trajs)]
                act_seqs_eval = [[self.env.action_space.sample()]
                                 for _ in range(self.n_trajs)]
            elif self.env.name == 'pointmass':
                unif_env = envs.make_pointmass_env()
                random_policy = utils.make_random_policy(unif_env)
                unif_rollouts = [
                    utils.run_ep(random_policy, unif_env, max_ep_len=1)
                    for _ in range(self.n_trajs)
                ]
                trajs_eval = [
                    utils.traj_of_rollout(rollout) for rollout in unif_rollouts
                ]
                act_seqs_eval = [
                    utils.act_seq_of_rollout(rollout)
                    for rollout in unif_rollouts
                ]
            else:
                raise ValueError
            loss_eval = 0.
            return {
                'traj': trajs_eval,
                'act_seq': act_seqs_eval,
                'loss': loss_eval
            }

        scopes = [self.opt_scope]
        if not warm_start:
            scopes.append(self.traj_scope)
        utils.init_tf_vars(self.sess, scopes, use_cache=True)

        feed_dict = {}
        assign_ops = []
        if init_act_seq is not None:
            feed_dict[self.init_act_seq_ph] = init_act_seq
            assign_ops.append(self.assign_init_act_seq)
        if init_traj is not None:
            self.obs_dim = (self.env.n_z_dim if self.env.name == 'carracing'
                            else self.env.n_obs_dim)
            feed_dict[self.init_traj_ph] = init_traj[1:, :self.obs_dim]
            assign_ops.append(self.assign_init_traj)
        if assign_ops != []:
            self.sess.run(assign_ops, feed_dict=feed_dict)

        feed_dict = {}
        if init_obs is not None:
            feed_dict[self.init_obs_ph] = init_obs() if callable(
                init_obs) else init_obs
        if act_seq is not None:
            feed_dict[self.act_seq_ph] = act_seq

        if verbose:
            print('iters loss')

        if init_with_lbfgs:
            self.lbfgs_optimizer.minimize(self.sess, feed_dict=feed_dict)

        loss_evals = []
        loss_eval, trajs_eval, act_seqs_eval = self.sess.run(
            [self.loss, self.trajs, self.act_seqs], feed_dict=feed_dict)
        best_eval = {
            'traj': trajs_eval,
            'act_seq': act_seqs_eval,
            'loss': loss_eval
        }
        #start_time = time.time() # uncomment for profiling
        for t in range(iterations):
            loss_eval, trajs_eval, act_seqs_eval, _ = self.sess.run(
                [self.loss, self.trajs, self.act_seqs, self.update_op],
                feed_dict=feed_dict)

            if verbose:
                print('%d %f' % (t, loss_eval))

            loss_evals.append(loss_eval)

            if loss_eval < best_eval['loss']:
                best_eval = {
                    'traj': trajs_eval,
                    'act_seq': act_seqs_eval,
                    'loss': loss_eval
                }

            if ftol is not None and utils.converged(
                    loss_evals, ftol, min_iters=min_iters):
                break
        # uncomment for profiling
        #print('call to update_op: %0.3f' % ((time.time() - start_time) / t))
        #print('iterations: %d' % t)

        if verbose:
            plt.plot(loss_evals)
            plt.show()

        return best_eval
Example #4
0
 def init_tf_vars(self):
   utils.init_tf_vars(self.sess, [self.scope])
Example #5
0
    def train(self,
              demo_data=None,
              sketch_data=None,
              pref_data=None,
              demo_coeff=1.,
              sketch_coeff=1.,
              iterations=100000,
              ftol=1e-4,
              batch_size=512,
              learning_rate=1e-3,
              val_update_freq=100,
              verbose=False,
              warm_start=False):
        """
    Args:
     demo_data: output of a call to rqst.utils.split_rollouts
     sketch_data: output of a call to rqst.utils.split_rollouts
     pref_data: output of a call to rqst.utils.split_prefs
    """
        if demo_data is None and pref_data is None and sketch_data is None:
            raise ValueError

        self.demo_data = demo_data
        self.sketch_data = sketch_data

        self.build_ensemble_outputs()

        pref_loss = self.pref_loss if pref_data is not None and self.pref_loss is not None else 0
        demo_loss = self.demo_loss if demo_data is not None and self.demo_loss is not None else 0
        sketch_loss = self.sketch_loss if sketch_data is not None and self.sketch_loss is not None else 0

        if pref_loss == 0 and demo_loss == 0 and sketch_loss == 0:
            raise ValueError

        self.loss = demo_coeff * demo_loss + sketch_coeff * sketch_loss + pref_loss

        opt_scope = utils.opt_scope_of_obj(self)
        with tf.variable_scope(opt_scope, reuse=tf.AUTO_REUSE):
            self.update_op = tf.train.AdamOptimizer(learning_rate).minimize(
                self.loss)

        scopes = [opt_scope]
        if not warm_start:
            scopes.append(self.scope)
        utils.init_tf_vars(self.sess, scopes)

        sketch_data_keys = ['obses', 'actions', 'next_obses', 'rews']
        demo_data_keys = ['obses', 'actions']
        pref_data_keys = [
            'ref_trajs', 'ref_act_seqs', 'trajs', 'act_seqs', 'prefs', 'mask',
            'ref_mask'
        ]

        unif_member_mask = np.ones(
            self.n_rew_nets_in_ensemble) / self.n_rew_nets_in_ensemble
        val_losses = []

        if pref_data is None:
            val_pref_batch = None
        else:
            val_pref_batch = utils.sample_batch(size=len(
                pref_data['val_idxes']),
                                                data=pref_data,
                                                data_keys=pref_data_keys,
                                                idxes_key='val_idxes')

        if demo_data is None:
            val_demo_batch = None
        else:
            val_demo_batch = utils.sample_batch(size=len(
                demo_data['val_idxes']),
                                                data=demo_data,
                                                data_keys=demo_data_keys,
                                                idxes_key='val_idxes')

        if sketch_data is None:
            val_sketch_batch = None
        else:
            val_sketch_batch = utils.sample_batch(size=len(
                sketch_data['val_idxes']),
                                                  data=sketch_data,
                                                  data_keys=sketch_data_keys,
                                                  idxes_key='val_idxes')

        pref_batch = None
        demo_batch = None
        sketch_batch = None

        member_masks = [
            utils.onehot_encode(member_idx, self.n_rew_nets_in_ensemble)
            for member_idx in range(self.n_rew_nets_in_ensemble)
        ]

        bootstrap_prob = 1.

        def bootstrap(train_idxes, mem_idx):
            guar_idxes = [
                x for i, x in enumerate(train_idxes)
                if i % self.n_rew_nets_in_ensemble == mem_idx
            ]
            nonguar_idxes = [
                x for i, x in enumerate(train_idxes)
                if i % self.n_rew_nets_in_ensemble != mem_idx
            ]
            n_train_per_mem = int(np.ceil(bootstrap_prob * len(nonguar_idxes)))
            return guar_idxes + random.sample(nonguar_idxes, n_train_per_mem)

        train_idxes_key_of_mem = []
        for mem_idx in range(self.n_rew_nets_in_ensemble):
            train_idxes_key = 'train_idxes_of_mem_%d' % mem_idx
            train_idxes_key_of_mem.append(train_idxes_key)
            if demo_data is not None:
                if self.use_discrete_actions:
                    train_idxes_of_act_key = 'train_idxes_of_act_of_mem_%d' % mem_idx
                    demo_data[train_idxes_of_act_key] = {}
                    for c, idxes_of_c in demo_data['train_idxes_of_act'].items(
                    ):
                        demo_data[train_idxes_of_act_key][c] = bootstrap(
                            idxes_of_c, mem_idx)
                    demo_data[train_idxes_key] = sum(
                        (v
                         for v in demo_data[train_idxes_of_act_key].values()),
                        [])
                else:
                    demo_data[train_idxes_key] = bootstrap(
                        demo_data['train_idxes'], mem_idx)
            if pref_data is not None:
                pref_data[train_idxes_key] = bootstrap(
                    pref_data['train_idxes'], mem_idx)
            if sketch_data is not None:
                if self.use_discrete_rewards:
                    train_idxes_of_rew_class_key = 'train_idxes_of_rew_class_of_mem_%d' % mem_idx
                    sketch_data[train_idxes_of_rew_class_key] = {}
                    for c, idxes_of_c in sketch_data[
                            'train_idxes_of_rew_class'].items():
                        sketch_data[train_idxes_of_rew_class_key][
                            c] = bootstrap(idxes_of_c, mem_idx)
                    sketch_data[train_idxes_key] = sum(
                        (v for v in
                         sketch_data[train_idxes_of_rew_class_key].values()),
                        [])
                else:
                    sketch_data[train_idxes_key] = bootstrap(
                        sketch_data['train_idxes'], mem_idx)

        if verbose:
            print('iters total_iters train_loss val_loss')

        #best_val_loss = None # uncomment to save model with lowest val loss
        for t in range(iterations):
            for mem_idx, member_mask in enumerate(member_masks):
                if demo_data is not None:
                    if self.use_discrete_actions:
                        class_idxes_key = 'train_idxes_of_act_of_mem_%d' % mem_idx
                    else:
                        class_idxes_key = None
                    demo_batch = utils.sample_batch(
                        size=batch_size,
                        data=demo_data,
                        data_keys=demo_data_keys,
                        idxes_key=train_idxes_key_of_mem[mem_idx],
                        class_idxes_key=class_idxes_key)

                if sketch_data is not None:
                    if self.use_discrete_rewards:
                        class_idxes_key = 'train_idxes_of_rew_class_of_mem_%d' % mem_idx
                    else:
                        class_idxes_key = None
                    sketch_batch = utils.sample_batch(
                        size=batch_size,
                        data=sketch_data,
                        data_keys=sketch_data_keys,
                        idxes_key=train_idxes_key_of_mem[mem_idx],
                        class_idxes_key=class_idxes_key)

                if pref_data is not None:
                    pref_batch = utils.sample_batch(
                        size=batch_size,
                        data=pref_data,
                        data_keys=pref_data_keys,
                        idxes_key=train_idxes_key_of_mem[mem_idx])

                formatted_batch = self.format_batch(demo_batch, sketch_batch,
                                                    pref_batch, member_mask)
                train_loss = self.compute_batch_loss(formatted_batch,
                                                     update=True)

            if t % val_update_freq == 0:
                formatted_batch = self.format_batch(val_demo_batch,
                                                    val_sketch_batch,
                                                    val_pref_batch,
                                                    unif_member_mask)
                val_loss = self.compute_batch_loss(formatted_batch,
                                                   update=False)

                if verbose:
                    print('%d %d %f %f' %
                          (t, iterations, train_loss, val_loss))

                val_losses.append(val_loss)

                # uncomment to save model checkpoint if it achieves lower val loss
                #if best_val_loss is None or val_loss < best_val_loss:
                #  best_val_loss = val_loss
                #  self.save()

                if utils.converged(val_losses, ftol):
                    break

        if verbose:
            plt.plot(val_losses)
            plt.show()