Ejemplo n.º 1
0
    def __init__(self,
                 cluster_job_name,
                 core_network,
                 lr=1e-4,
                 cluster_dict=None,
                 batchnorm=False,
                 dropout=0.0,
                 n_preds=1,
                 log_dir=None):
        self.n_preds = n_preds
        graph, self.sess = self.init_sess(cluster_dict, cluster_job_name)
        # Why not just use soft device placement? With soft placement,
        # if we have a bug which prevents an operation being placed on the GPU
        # (e.g. we're using uint8s for operations that the GPU can't do),
        # then TensorFlow will be silent and just place the operation on a CPU.
        # Instead, we want to say: if there's a GPU present, definitely try and
        # put things on the GPU. If it fails, tell us!
        if tf.test.gpu_device_name():
            worker_device = "/job:{}/task:0/gpu:0".format(cluster_job_name)
        else:
            worker_device = "/job:{}/task:0".format(cluster_job_name)
        device_setter = tf.train.replica_device_setter(
            cluster=cluster_dict,
            ps_device="/job:ps/task:0",
            worker_device=worker_device)
        self.rps = []
        with graph.as_default():
            for pred_n in range(n_preds):
                with tf.device(device_setter):
                    with tf.variable_scope("pred_{}".format(pred_n)):
                        rp = RewardPredictorNetwork(core_network=core_network,
                                                    dropout=dropout,
                                                    batchnorm=batchnorm,
                                                    lr=lr)
                self.rps.append(rp)
            self.init_op = tf.global_variables_initializer()
            # Why save_relative_paths=True?
            # So that the plain-text 'checkpoint' file written uses relative paths,
            # which seems to be needed in order to avoid confusing saver.restore()
            # when restoring from FloydHub runs.
            self.saver = tf.train.Saver(max_to_keep=1,
                                        save_relative_paths=True)
            self.summaries = self.add_summary_ops()

        self.checkpoint_file = osp.join(log_dir,
                                        'reward_predictor_checkpoints',
                                        'reward_predictor.ckpt')
        self.train_writer = tf.summary.FileWriter(osp.join(
            log_dir, 'reward_predictor', 'train'),
                                                  flush_secs=5)
        self.test_writer = tf.summary.FileWriter(osp.join(
            log_dir, 'reward_predictor', 'test'),
                                                 flush_secs=5)

        self.n_steps = 0
        self.r_norm = RunningStat(shape=n_preds)

        misc_logs_dir = osp.join(log_dir, 'reward_predictor', 'misc')
        easy_tf_log.set_dir(misc_logs_dir)
Ejemplo n.º 2
0
 def test_running_stat(self):
     for shp in ((), (3, ), (3, 4)):
         li = []
         rs = RunningStat(shp)
         for i in range(5):
             val = np.random.randn(*shp)
             rs.push(val)
             li.append(val)
             m = np.mean(li, axis=0)
             assert np.allclose(rs.mean, m)
             if i == 0:
                 continue
             # ddof=1 => calculate unbiased sample variance
             v = np.var(li, ddof=1, axis=0)
             assert np.allclose(rs.var, v)
Ejemplo n.º 3
0
    def __init__(self,
                 env,
                 tuning_parameters,
                 replicated_device=None,
                 task_id=0):
        """
        :param env: An environment instance
        :type env: EnvironmentWrapper
        :param tuning_parameters: A Preset class instance with all the running paramaters
        :type tuning_parameters: Preset
        :param replicated_device: A tensorflow device for distributed training (optional)
        :type replicated_device: instancemethod
        :param thread_id: The current thread id
        :param thread_id: int
        """

        screen.log_title("Creating agent {}".format(task_id))
        self.task_id = task_id
        self.sess = tuning_parameters.sess
        self.env = tuning_parameters.env_instance = env
        self.imitation = False

        # i/o dimensions
        if not tuning_parameters.env.desired_observation_width or not tuning_parameters.env.desired_observation_height:
            tuning_parameters.env.desired_observation_width = self.env.width
            tuning_parameters.env.desired_observation_height = self.env.height
        self.action_space_size = tuning_parameters.env.action_space_size = self.env.action_space_size
        self.measurements_size = tuning_parameters.env.measurements_size = self.env.measurements_size
        if tuning_parameters.agent.use_accumulated_reward_as_measurement:
            self.measurements_size = tuning_parameters.env.measurements_size = (
                self.measurements_size[0] + 1, )

        # modules
        if tuning_parameters.agent.load_memory_from_file_path:
            screen.log_title(
                "Loading replay buffer from pickle. Pickle path: {}".format(
                    tuning_parameters.agent.load_memory_from_file_path))
            self.memory = read_pickle(
                tuning_parameters.agent.load_memory_from_file_path)
        else:
            self.memory = eval(tuning_parameters.memory +
                               '(tuning_parameters)')
        # self.architecture = eval(tuning_parameters.architecture)

        self.has_global = replicated_device is not None
        self.replicated_device = replicated_device
        self.worker_device = "/job:worker/task:{}/cpu:0".format(
            task_id) if replicated_device is not None else "/gpu:0"

        self.exploration_policy = eval(tuning_parameters.exploration.policy +
                                       '(tuning_parameters)')
        self.evaluation_exploration_policy = eval(
            tuning_parameters.exploration.evaluation_policy +
            '(tuning_parameters)')
        self.evaluation_exploration_policy.change_phase(RunPhase.TEST)

        # initialize all internal variables
        self.tp = tuning_parameters
        self.in_heatup = False
        self.total_reward_in_current_episode = 0
        self.total_steps_counter = 0
        self.running_reward = None
        self.training_iteration = 0
        self.current_episode = self.tp.current_episode = 0
        self.curr_state = {}
        self.current_episode_steps_counter = 0
        self.episode_running_info = {}
        self.last_episode_evaluation_ran = 0
        self.running_observations = []
        logger.set_current_time(self.current_episode)
        self.main_network = None
        self.networks = []
        self.last_episode_images = []
        self.renderer = Renderer()

        # signals
        self.signals = []
        self.loss = Signal('Loss')
        self.signals.append(self.loss)
        self.curr_learning_rate = Signal('Learning Rate')
        self.signals.append(self.curr_learning_rate)

        if self.tp.env.normalize_observation and not self.env.is_state_type_image:
            if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers:
                self.running_observation_stats = RunningStat(
                    (self.tp.env.desired_observation_width, ))
                self.running_reward_stats = RunningStat(())
            else:
                self.running_observation_stats = SharedRunningStats(
                    self.tp,
                    replicated_device,
                    shape=(self.tp.env.desired_observation_width, ),
                    name='observation_stats')
                self.running_reward_stats = SharedRunningStats(
                    self.tp, replicated_device, shape=(), name='reward_stats')

        # env is already reset at this point. Otherwise we're getting an error where you cannot
        # reset an env which is not done
        self.reset_game(do_not_reset_env=True)

        # use seed
        if self.tp.seed is not None:
            random.seed(self.tp.seed)
            np.random.seed(self.tp.seed)
Ejemplo n.º 4
0
class RewardPredictorEnsemble:
    """
    An ensemble of reward predictors and associated helper functions.
    """
    def __init__(self,
                 cluster_job_name,
                 core_network,
                 lr=1e-4,
                 cluster_dict=None,
                 batchnorm=False,
                 dropout=0.0,
                 n_preds=1,
                 log_dir=None):
        self.n_preds = n_preds
        graph, self.sess = self.init_sess(cluster_dict, cluster_job_name)
        # Why not just use soft device placement? With soft placement,
        # if we have a bug which prevents an operation being placed on the GPU
        # (e.g. we're using uint8s for operations that the GPU can't do),
        # then TensorFlow will be silent and just place the operation on a CPU.
        # Instead, we want to say: if there's a GPU present, definitely try and
        # put things on the GPU. If it fails, tell us!
        if tf.test.gpu_device_name():
            worker_device = "/job:{}/task:0/gpu:0".format(cluster_job_name)
        else:
            worker_device = "/job:{}/task:0".format(cluster_job_name)
        device_setter = tf.train.replica_device_setter(
            cluster=cluster_dict,
            ps_device="/job:ps/task:0",
            worker_device=worker_device)
        self.rps = []
        with graph.as_default():
            for pred_n in range(n_preds):
                with tf.device(device_setter):
                    with tf.variable_scope("pred_{}".format(pred_n)):
                        rp = RewardPredictorNetwork(core_network=core_network,
                                                    dropout=dropout,
                                                    batchnorm=batchnorm,
                                                    lr=lr)
                self.rps.append(rp)
            self.init_op = tf.global_variables_initializer()
            # Why save_relative_paths=True?
            # So that the plain-text 'checkpoint' file written uses relative paths,
            # which seems to be needed in order to avoid confusing saver.restore()
            # when restoring from FloydHub runs.
            self.saver = tf.train.Saver(max_to_keep=1,
                                        save_relative_paths=True)
            self.summaries = self.add_summary_ops()

        self.checkpoint_file = osp.join(log_dir,
                                        'reward_predictor_checkpoints',
                                        'reward_predictor.ckpt')
        self.train_writer = tf.summary.FileWriter(osp.join(
            log_dir, 'reward_predictor', 'train'),
                                                  flush_secs=5)
        self.test_writer = tf.summary.FileWriter(osp.join(
            log_dir, 'reward_predictor', 'test'),
                                                 flush_secs=5)

        self.n_steps = 0
        self.r_norm = RunningStat(shape=n_preds)

        misc_logs_dir = osp.join(log_dir, 'reward_predictor', 'misc')
        easy_tf_log.set_dir(misc_logs_dir)

    @staticmethod
    def init_sess(cluster_dict, cluster_job_name):
        graph = tf.Graph()
        cluster = tf.train.ClusterSpec(cluster_dict)
        config = tf.ConfigProto(gpu_options={'allow_growth': True})
        server = tf.train.Server(cluster,
                                 job_name=cluster_job_name,
                                 config=config)
        sess = tf.Session(server.target, graph)
        return graph, sess

    def add_summary_ops(self):
        summary_ops = []

        for pred_n, rp in enumerate(self.rps):
            name = 'reward_predictor_accuracy_{}'.format(pred_n)
            op = tf.summary.scalar(name, rp.accuracy)
            summary_ops.append(op)
            name = 'reward_predictor_loss_{}'.format(pred_n)
            op = tf.summary.scalar(name, rp.loss)
            summary_ops.append(op)

        mean_accuracy = tf.reduce_mean([rp.accuracy for rp in self.rps])
        op = tf.summary.scalar('reward_predictor_accuracy_mean', mean_accuracy)
        summary_ops.append(op)

        mean_loss = tf.reduce_mean([rp.loss for rp in self.rps])
        op = tf.summary.scalar('reward_predictor_loss_mean', mean_loss)
        summary_ops.append(op)

        summaries = tf.summary.merge(summary_ops)

        return summaries

    def init_network(self, load_ckpt_dir=None):
        if load_ckpt_dir:
            ckpt_file = tf.train.latest_checkpoint(load_ckpt_dir)
            if ckpt_file is None:
                msg = "No reward predictor checkpoint found in '{}'".format(
                    load_ckpt_dir)
                raise FileNotFoundError(msg)
            self.saver.restore(self.sess, ckpt_file)
            print("Loaded reward predictor checkpoint from '{}'".format(
                ckpt_file))
        else:
            self.sess.run(self.init_op)

    def save(self):
        ckpt_name = self.saver.save(self.sess, self.checkpoint_file,
                                    self.n_steps)
        print("Saved reward predictor checkpoint to '{}'".format(ckpt_name))

    def raw_rewards(self, obs):
        """
        Return (unnormalized) reward for each frame of a single segment
        from each member of the ensemble.
        """
        assert_equal(obs.shape[1:], (84, 84, 4))
        n_steps = obs.shape[0]
        feed_dict = {}
        for rp in self.rps:
            feed_dict[rp.training] = False
            feed_dict[rp.s1] = [obs]
        # This will return nested lists of sizes n_preds x 1 x nsteps
        # (x 1 because of the batch size of 1)
        rs = self.sess.run([rp.r1 for rp in self.rps], feed_dict)
        rs = np.array(rs)
        # Get rid of the extra x 1 dimension
        rs = rs[:, 0, :]
        assert_equal(rs.shape, (self.n_preds, n_steps))
        return rs

    def reward(self, obs):
        """
        Return (normalized) reward for each frame of a single segment.

        (Normalization involves normalizing the rewards from each member of the
        ensemble separately, then averaging the resulting rewards across all
        ensemble members.)
        """
        assert_equal(obs.shape[1:], (84, 84, 4))
        n_steps = obs.shape[0]

        # Get unnormalized rewards

        ensemble_rs = self.raw_rewards(obs)
        logging.debug("Unnormalized rewards:\n%s", ensemble_rs)

        # Normalize rewards

        # Note that we implement this here instead of in the network itself
        # because:
        # * It's simpler not to do it in TensorFlow
        # * Preference prediction doesn't need normalized rewards. Only
        #   rewards sent to the the RL algorithm need to be normalized.
        #   So we can save on computation.

        # Page 4:
        # "We normalized the rewards produced by r^ to have zero mean and
        #  constant standard deviation."
        # Page 15: (Atari)
        # "Since the reward predictor is ultimately used to compare two sums
        #  over timesteps, its scale is arbitrary, and we normalize it to have
        #  a standard deviation of 0.05"
        # Page 5:
        # "The estimate r^ is defined by independently normalizing each of
        #  these predictors..."

        # We want to keep track of running mean/stddev for each member of the
        # ensemble separately, so we have to be a little careful here.
        assert_equal(ensemble_rs.shape, (self.n_preds, n_steps))
        ensemble_rs = ensemble_rs.transpose()
        assert_equal(ensemble_rs.shape, (n_steps, self.n_preds))
        for ensemble_rs_step in ensemble_rs:
            self.r_norm.push(ensemble_rs_step)
        ensemble_rs -= self.r_norm.mean
        ensemble_rs /= (self.r_norm.std + 1e-12)
        ensemble_rs *= 0.05
        ensemble_rs = ensemble_rs.transpose()
        assert_equal(ensemble_rs.shape, (self.n_preds, n_steps))
        logging.debug("Reward mean/stddev:\n%s %s", self.r_norm.mean,
                      self.r_norm.std)
        logging.debug("Normalized rewards:\n%s", ensemble_rs)

        # "...and then averaging the results."
        rs = np.mean(ensemble_rs, axis=0)
        assert_equal(rs.shape, (n_steps, ))
        logging.debug("After ensemble averaging:\n%s", rs)

        return rs

    def preferences(self, s1s, s2s):
        """
        Predict probability of human preferring one segment over another
        for each segment in the supplied batch of segment pairs.
        """
        feed_dict = {}
        for rp in self.rps:
            feed_dict[rp.s1] = s1s
            feed_dict[rp.s2] = s2s
            feed_dict[rp.training] = False
        preds = self.sess.run([rp.pred for rp in self.rps], feed_dict)
        return preds

    def train(self, prefs_train, prefs_val, val_interval):
        """
        Train all ensemble members for one epoch.
        """
        print("Training/testing with %d/%d preferences" %
              (len(prefs_train), len(prefs_val)))

        start_steps = self.n_steps
        start_time = time.time()

        for _, batch in enumerate(
                batch_iter(prefs_train.prefs, batch_size=32, shuffle=True)):
            self.train_step(batch, prefs_train)
            self.n_steps += 1

            if self.n_steps and self.n_steps % val_interval == 0:
                self.val_step(prefs_val)

        end_time = time.time()
        end_steps = self.n_steps
        rate = (end_steps - start_steps) / (end_time - start_time)
        easy_tf_log.tflog('reward_predictor_training_steps_per_second', rate)

    def train_step(self, batch, prefs_train):
        s1s = [prefs_train.segments[k1] for k1, k2, pref, in batch]
        s2s = [prefs_train.segments[k2] for k1, k2, pref, in batch]
        prefs = [pref for k1, k2, pref, in batch]
        feed_dict = {}
        for rp in self.rps:
            feed_dict[rp.s1] = s1s
            feed_dict[rp.s2] = s2s
            feed_dict[rp.pref] = prefs
            feed_dict[rp.training] = True
        ops = [self.summaries, [rp.train for rp in self.rps]]
        summaries, _ = self.sess.run(ops, feed_dict)
        self.train_writer.add_summary(summaries, self.n_steps)

    def val_step(self, prefs_val):
        val_batch_size = 32
        if len(prefs_val) <= val_batch_size:
            batch = prefs_val.prefs
        else:
            idxs = np.random.choice(len(prefs_val.prefs),
                                    val_batch_size,
                                    replace=False)
            batch = [prefs_val.prefs[i] for i in idxs]
        s1s = [prefs_val.segments[k1] for k1, k2, pref, in batch]
        s2s = [prefs_val.segments[k2] for k1, k2, pref, in batch]
        prefs = [pref for k1, k2, pref, in batch]
        feed_dict = {}
        for rp in self.rps:
            feed_dict[rp.s1] = s1s
            feed_dict[rp.s2] = s2s
            feed_dict[rp.pref] = prefs
            feed_dict[rp.training] = False
        summaries = self.sess.run(self.summaries, feed_dict)
        self.test_writer.add_summary(summaries, self.n_steps)
Ejemplo n.º 5
0
def run():
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    gym.logger.set_level(40)
    env = gym.make(args.env_name)
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.shape[0]
    state_stat = RunningStat(env.observation_space.shape, eps=1e-2)
    action_space = env.action_space
    policy = Policy(state_size, action_size, args.hidden_size,
                    action_space.low, action_space.high)
    num_params = policy.num_params
    optim = Adam(num_params, args.lr)

    ray.init(num_cpus=args.num_parallel)

    return_list = []
    for epoch in range(100000):
        #####################################
        ### Rollout and Update State Stat ###
        #####################################

        policy.set_state_stat(state_stat.mean, state_stat.std)

        # set diff params (mirror sampling)
        assert args.episodes_per_batch % 2 == 0
        diff_params = torch.empty((args.episodes_per_batch, num_params),
                                  dtype=torch.float)
        diff_params_pos = torch.randn(args.episodes_per_batch // 2,
                                      num_params) * args.noise_std
        diff_params[::2] = diff_params_pos
        diff_params[1::2] = -diff_params_pos

        rets = []
        num_episodes_popped = 0
        num_timesteps_popped = 0
        while num_episodes_popped < args.episodes_per_batch \
                and num_timesteps_popped < args.timesteps_per_batch:
            #or num_timesteps_popped < args.timesteps_per_batch:
            results = []
            for i in range(min(args.episodes_per_batch, 500)):
                # set policy
                randomized_policy = deepcopy(policy)
                randomized_policy.add_params(diff_params[num_episodes_popped +
                                                         i])
                # rollout
                results.append(
                    rollout.remote(randomized_policy,
                                   args.env_name,
                                   seed=np.random.randint(0, 10000000)))

            for result in results:
                ret, timesteps, states = ray.get(result)
                rets.append(ret)
                # update state stat
                if states is not None:
                    state_stat.increment(states.sum(axis=0),
                                         np.square(states).sum(axis=0),
                                         states.shape[0])

                num_timesteps_popped += timesteps
                num_episodes_popped += 1
        rets = np.array(rets, dtype=np.float32)
        diff_params = diff_params[:num_episodes_popped]

        best_policy_idx = np.argmax(rets)
        best_policy = deepcopy(policy)
        best_policy.add_params(diff_params[best_policy_idx])
        best_rets = [
            rollout.remote(best_policy,
                           args.env_name,
                           seed=np.random.randint(0, 10000000),
                           calc_state_stat_prob=0.0,
                           test=True) for _ in range(10)
        ]
        best_rets = np.average(ray.get(best_rets))

        print('epoch:', epoch, 'mean:', np.average(rets), 'max:', np.max(rets),
              'best:', best_rets)
        with open(args.outdir + '/return.csv', 'w') as f:
            return_list.append(
                [epoch, np.max(rets),
                 np.average(rets), best_rets])
            writer = csv.writer(f, lineterminator='\n')
            writer.writerows(return_list)

            plt.figure()
            sns.lineplot(data=np.array(return_list)[:, 1:])
            plt.savefig(args.outdir + '/return.png')
            plt.close('all')

        #############
        ### Train ###
        #############

        fitness = compute_centered_ranks(rets).reshape(-1, 1)
        if args.weight_decay > 0:
            #l2_decay = args.weight_decay * ((policy.get_params() + diff_params)**2).mean(dim=1, keepdim=True).numpy()
            l1_decay = args.weight_decay * (policy.get_params() +
                                            diff_params).mean(
                                                dim=1, keepdim=True).numpy()
            fitness += l1_decay
        grad = (fitness * diff_params.numpy()).mean(axis=0)
        policy = optim.update(policy, -grad)
Ejemplo n.º 6
0
def run():
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    gym.logger.set_level(40)
    env = gym.make(args.env_name)
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.shape[0]
    state_stat = RunningStat(
        env.observation_space.shape,
        eps=1e-2
    )
    action_space = env.action_space
    policy = Policy(state_size, action_size, args.hidden_size, action_space.low, action_space.high)
    num_params = policy.num_params
    es = cma.CMAEvolutionStrategy([0] * num_params,
                                    args.sigma_init,
                                    {'popsize': args.popsize,
                                        })
    
    ray.init(num_cpus=args.num_parallel)

    return_list = []
    for epoch in range(100000):
        #####################################
        ### Rollout and Update State Stat ###
        #####################################

        solutions = np.array(es.ask(), dtype=np.float32)
        policy.set_state_stat(state_stat.mean, state_stat.std)

        rets = []
        results = []
        for i in range(args.popsize):
            # set policy
            randomized_policy = deepcopy(policy)
            randomized_policy.set_params(solutions[i])
            # rollout
            results.append(rollout.remote(randomized_policy, args.env_name, seed=np.random.randint(0,10000000)))
        
        for result in results:
            ret, timesteps, states = ray.get(result)
            rets.append(ret)
            # update state stat
            if states is not None:
                state_stat.increment(states.sum(axis=0), np.square(states).sum(axis=0), states.shape[0])
            
        rets = np.array(rets, dtype=np.float32)
        
        best_policy_idx = np.argmax(rets)
        best_policy = deepcopy(policy)
        best_policy.set_params(solutions[best_policy_idx])
        best_rets = [rollout.remote(best_policy, args.env_name, seed=np.random.randint(0,10000000), calc_state_stat_prob=0.0, test=True) for _ in range(10)]
        best_rets = np.average(ray.get(best_rets))
        
        print('epoch:', epoch, 'mean:', np.average(rets), 'max:', np.max(rets), 'best:', best_rets)
        with open(args.outdir + '/return.csv', 'w') as f:
            return_list.append([epoch, np.max(rets), np.average(rets), best_rets])
            writer = csv.writer(f, lineterminator='\n')
            writer.writerows(return_list)

            plt.figure()
            sns.lineplot(data=np.array(return_list)[:,1:])
            plt.savefig(args.outdir + '/return.png')
            plt.close('all')
        

        #############
        ### Train ###
        #############

        ranks = compute_centered_ranks(rets)
        fitness = ranks
        if args.weight_decay > 0:
            l2_decay = compute_weight_decay(args.weight_decay, solutions)
            fitness -= l2_decay
        # convert minimize to maximize
        es.tell(solutions, fitness)
Ejemplo n.º 7
0
    def __init__(self,
                 actions_size,
                 policy,
                 sess=None,
                 gamma=0.99,
                 lr=1e-5,
                 batch_size=32,
                 num_itr=20,
                 use_vairl=False,
                 mutual_information=0.5,
                 alpha=0.0005,
                 with_action=False,
                 name='reward_model',
                 entropy_weight=0.5,
                 with_value=True,
                 fixed_reward_model=False,
                 vs=None,
                 **kwargs):

        # Initialize some model attributes
        # RunningStat to normalize reward from the model
        if not fixed_reward_model:
            self.r_norm = DynamicRunningStat()
        else:
            self.r_norm = RunningStat(1)

        # Discount factor
        self.gamma = gamma
        # Policy agent needed to compute the discriminator
        self.policy = policy
        # Demonstrations buffer
        self.expert_traj = None
        self.validation_traj = None
        # Num of actions available in the environment
        self.actions_size = actions_size
        # If is state-only or state-action discriminator
        self.with_action = with_action

        # TF parameters
        self.sess = sess
        self.lr = lr
        self.batch_size = batch_size
        self.num_itr = num_itr
        self.entropy_weight = entropy_weight

        # Use Variation Bottleneck Autoencoder
        self.use_vairl = use_vairl
        self.mutual_information = mutual_information
        self.alpha = alpha
        self.name = name

        # Buffer of policy experience with which train the reward model
        self.buffer = dict()
        self.create_buffer()

        with tf.compat.v1.variable_scope(name) as vs:
            with tf.compat.v1.variable_scope('irl'):

                # Input spec for both reward and value function

                # Current state (DeepCrawl spec)
                self.global_state = tf.compat.v1.placeholder(
                    tf.float32, [None, 10, 10, 52], name='global_state')
                self.local_state = tf.compat.v1.placeholder(tf.float32,
                                                            [None, 5, 5, 52],
                                                            name='local_state')
                self.local_two_state = tf.compat.v1.placeholder(
                    tf.float32, [None, 3, 3, 52], name='local_two_state')
                self.agent_stats = tf.compat.v1.placeholder(tf.int32,
                                                            [None, 16],
                                                            name='agent_stats')
                self.target_stats = tf.compat.v1.placeholder(
                    tf.int32, [None, 15], name='target_stats')
                if self.with_action:
                    self.acts = tf.compat.v1.placeholder(tf.int32, [None, 1],
                                                         name='acts')

                # Next state (DeepCrawl spec) - for discriminator
                self.global_state_n = tf.compat.v1.placeholder(
                    tf.float32, [None, 10, 10, 52], name='global_state_n')
                self.local_state_n = tf.compat.v1.placeholder(
                    tf.float32, [None, 5, 5, 52], name='local_state_n')
                self.local_two_state_n = tf.compat.v1.placeholder(
                    tf.float32, [None, 3, 3, 52], name='local_two_state_n')
                self.agent_stats_n = tf.compat.v1.placeholder(
                    tf.int32, [None, 16], name='agent_stats_n')
                self.target_stats_n = tf.compat.v1.placeholder(
                    tf.int32, [None, 15], name='target_stats_n')

                # Probability distribution and labels - whether or not this state belongs to expert buffer
                self.probs = tf.compat.v1.placeholder(tf.float32, [None, 1],
                                                      name='probs')
                self.labels = tf.compat.v1.placeholder(tf.float32, [None, 1],
                                                       name='labels')

                # For V-AIRL
                self.use_noise = tf.compat.v1.placeholder(shape=[1],
                                                          dtype=tf.float32,
                                                          name="noise")

                self.z_sigma_g = None
                self.z_sigma_h = None
                if self.use_vairl:
                    self.z_sigma_g = tf.compat.v1.get_variable(
                        'z_sigma_g',
                        100,
                        dtype=tf.float32,
                        initializer=tf.compat.v1.ones_initializer(),
                    )
                    self.z_sigma_g_sq = self.z_sigma_g * self.z_sigma_g
                    self.z_log_sigma_g_sq = tf.compat.v1.log(
                        self.z_sigma_g_sq + eps)

                    self.z_sigma_h = tf.compat.v1.get_variable(
                        "z_sigma_h",
                        100,
                        dtype=tf.float32,
                        initializer=tf.compat.v1.ones_initializer(),
                    )
                    self.z_sigma_h_sq = self.z_sigma_h * self.z_sigma_h
                    self.z_log_sigma_h_sq = tf.compat.v1.log(
                        self.z_sigma_h_sq + eps)

                # Reward Funvtion
                with tf.compat.v1.variable_scope('reward'):
                    self.reward, self.z_g = self.conv_net(
                        self.global_state,
                        self.local_state,
                        self.local_two_state,
                        self.agent_stats,
                        self.target_stats,
                        with_action=self.with_action,
                        z_sigma=self.z_sigma_g,
                        use_noise=self.use_noise)

                # Value Function
                if with_value:
                    with tf.compat.v1.variable_scope('value'):
                        self.value, self.z_h = self.conv_net(
                            self.global_state,
                            self.local_state,
                            self.local_two_state,
                            self.agent_stats,
                            self.target_stats,
                            z_sigma=self.z_sigma_h,
                            use_noise=self.use_noise,
                            with_action=False)
                    with tf.compat.v1.variable_scope('value', reuse=True):
                        self.value_n, self.z_1_h = self.conv_net(
                            self.global_state_n,
                            self.local_state_n,
                            self.local_two_state_n,
                            self.agent_stats_n,
                            self.target_stats_n,
                            z_sigma=self.z_sigma_h,
                            use_noise=self.use_noise,
                            with_action=False)

                    self.f = self.reward + self.gamma * self.value_n - self.value
                else:
                    self.f = self.reward

                # Discriminator
                self.discriminator = tf.math.divide(
                    tf.math.exp(self.f),
                    tf.math.add(tf.math.exp(self.f), self.probs))

                # Loss Function
                self.loss = -tf.reduce_mean(
                    (self.labels * tf.math.log(self.discriminator + eps)) +
                    ((1 - self.labels) *
                     tf.math.log(1 - self.discriminator + eps)))

                # Loss function modification for V-AIRL
                if self.use_vairl:
                    # Define beta
                    self.beta = tf.compat.v1.get_variable(
                        "airl_beta",
                        [],
                        trainable=False,
                        dtype=tf.float32,
                        initializer=tf.compat.v1.ones_initializer(),
                    )

                    # Number of batch element
                    self.batch = tf.compat.v1.shape(self.z_g)[0]
                    self.batch_index = tf.dtypes.cast(self.batch / 2, tf.int32)

                    self.kl_loss = tf.reduce_mean(-tf.reduce_sum(
                        1 + self.z_log_sigma_g_sq -
                        0.5 * tf.square(self.z_g[0:self.batch_index, :] *
                                        self.z_h[0:self.batch_index, :] *
                                        self.z_1_h[0:self.batch_index, :]) -
                        0.5 * tf.square(self.z_g[self.batch_index:, :] *
                                        self.z_h[self.batch_index:, :] *
                                        self.z_1_h[self.batch_index:, :]) -
                        tf.exp(self.z_log_sigma_g_sq),
                        1,
                    ))

                    self.loss = self.beta * (
                        self.kl_loss - self.mutual_information) + self.loss

                # Adam optimizer with gradient clipping
                optimizer = tf.compat.v1.train.AdamOptimizer(self.lr)
                gradients, variables = zip(
                    *optimizer.compute_gradients(self.loss))
                gradients, _ = tf.compat.v1.clip_by_global_norm(gradients, 1.0)
                self.step = optimizer.apply_gradients(zip(
                    gradients, variables))
                #self.step = tf.compat.v1.train.AdamOptimizer(learning_rate=self.lr).minimize(self.loss)

                if self.use_vairl:
                    self.make_beta_update()

        self.vs = vs
        self.saver = tf.compat.v1.train.Saver(max_to_keep=None)
Ejemplo n.º 8
0
class RewardModel:
    def __init__(self,
                 actions_size,
                 policy,
                 sess=None,
                 gamma=0.99,
                 lr=1e-5,
                 batch_size=32,
                 num_itr=20,
                 use_vairl=False,
                 mutual_information=0.5,
                 alpha=0.0005,
                 with_action=False,
                 name='reward_model',
                 entropy_weight=0.5,
                 with_value=True,
                 fixed_reward_model=False,
                 vs=None,
                 **kwargs):

        # Initialize some model attributes
        # RunningStat to normalize reward from the model
        if not fixed_reward_model:
            self.r_norm = DynamicRunningStat()
        else:
            self.r_norm = RunningStat(1)

        # Discount factor
        self.gamma = gamma
        # Policy agent needed to compute the discriminator
        self.policy = policy
        # Demonstrations buffer
        self.expert_traj = None
        self.validation_traj = None
        # Num of actions available in the environment
        self.actions_size = actions_size
        # If is state-only or state-action discriminator
        self.with_action = with_action

        # TF parameters
        self.sess = sess
        self.lr = lr
        self.batch_size = batch_size
        self.num_itr = num_itr
        self.entropy_weight = entropy_weight

        # Use Variation Bottleneck Autoencoder
        self.use_vairl = use_vairl
        self.mutual_information = mutual_information
        self.alpha = alpha
        self.name = name

        # Buffer of policy experience with which train the reward model
        self.buffer = dict()
        self.create_buffer()

        with tf.compat.v1.variable_scope(name) as vs:
            with tf.compat.v1.variable_scope('irl'):

                # Input spec for both reward and value function

                # Current state (DeepCrawl spec)
                self.global_state = tf.compat.v1.placeholder(
                    tf.float32, [None, 10, 10, 52], name='global_state')
                self.local_state = tf.compat.v1.placeholder(tf.float32,
                                                            [None, 5, 5, 52],
                                                            name='local_state')
                self.local_two_state = tf.compat.v1.placeholder(
                    tf.float32, [None, 3, 3, 52], name='local_two_state')
                self.agent_stats = tf.compat.v1.placeholder(tf.int32,
                                                            [None, 16],
                                                            name='agent_stats')
                self.target_stats = tf.compat.v1.placeholder(
                    tf.int32, [None, 15], name='target_stats')
                if self.with_action:
                    self.acts = tf.compat.v1.placeholder(tf.int32, [None, 1],
                                                         name='acts')

                # Next state (DeepCrawl spec) - for discriminator
                self.global_state_n = tf.compat.v1.placeholder(
                    tf.float32, [None, 10, 10, 52], name='global_state_n')
                self.local_state_n = tf.compat.v1.placeholder(
                    tf.float32, [None, 5, 5, 52], name='local_state_n')
                self.local_two_state_n = tf.compat.v1.placeholder(
                    tf.float32, [None, 3, 3, 52], name='local_two_state_n')
                self.agent_stats_n = tf.compat.v1.placeholder(
                    tf.int32, [None, 16], name='agent_stats_n')
                self.target_stats_n = tf.compat.v1.placeholder(
                    tf.int32, [None, 15], name='target_stats_n')

                # Probability distribution and labels - whether or not this state belongs to expert buffer
                self.probs = tf.compat.v1.placeholder(tf.float32, [None, 1],
                                                      name='probs')
                self.labels = tf.compat.v1.placeholder(tf.float32, [None, 1],
                                                       name='labels')

                # For V-AIRL
                self.use_noise = tf.compat.v1.placeholder(shape=[1],
                                                          dtype=tf.float32,
                                                          name="noise")

                self.z_sigma_g = None
                self.z_sigma_h = None
                if self.use_vairl:
                    self.z_sigma_g = tf.compat.v1.get_variable(
                        'z_sigma_g',
                        100,
                        dtype=tf.float32,
                        initializer=tf.compat.v1.ones_initializer(),
                    )
                    self.z_sigma_g_sq = self.z_sigma_g * self.z_sigma_g
                    self.z_log_sigma_g_sq = tf.compat.v1.log(
                        self.z_sigma_g_sq + eps)

                    self.z_sigma_h = tf.compat.v1.get_variable(
                        "z_sigma_h",
                        100,
                        dtype=tf.float32,
                        initializer=tf.compat.v1.ones_initializer(),
                    )
                    self.z_sigma_h_sq = self.z_sigma_h * self.z_sigma_h
                    self.z_log_sigma_h_sq = tf.compat.v1.log(
                        self.z_sigma_h_sq + eps)

                # Reward Funvtion
                with tf.compat.v1.variable_scope('reward'):
                    self.reward, self.z_g = self.conv_net(
                        self.global_state,
                        self.local_state,
                        self.local_two_state,
                        self.agent_stats,
                        self.target_stats,
                        with_action=self.with_action,
                        z_sigma=self.z_sigma_g,
                        use_noise=self.use_noise)

                # Value Function
                if with_value:
                    with tf.compat.v1.variable_scope('value'):
                        self.value, self.z_h = self.conv_net(
                            self.global_state,
                            self.local_state,
                            self.local_two_state,
                            self.agent_stats,
                            self.target_stats,
                            z_sigma=self.z_sigma_h,
                            use_noise=self.use_noise,
                            with_action=False)
                    with tf.compat.v1.variable_scope('value', reuse=True):
                        self.value_n, self.z_1_h = self.conv_net(
                            self.global_state_n,
                            self.local_state_n,
                            self.local_two_state_n,
                            self.agent_stats_n,
                            self.target_stats_n,
                            z_sigma=self.z_sigma_h,
                            use_noise=self.use_noise,
                            with_action=False)

                    self.f = self.reward + self.gamma * self.value_n - self.value
                else:
                    self.f = self.reward

                # Discriminator
                self.discriminator = tf.math.divide(
                    tf.math.exp(self.f),
                    tf.math.add(tf.math.exp(self.f), self.probs))

                # Loss Function
                self.loss = -tf.reduce_mean(
                    (self.labels * tf.math.log(self.discriminator + eps)) +
                    ((1 - self.labels) *
                     tf.math.log(1 - self.discriminator + eps)))

                # Loss function modification for V-AIRL
                if self.use_vairl:
                    # Define beta
                    self.beta = tf.compat.v1.get_variable(
                        "airl_beta",
                        [],
                        trainable=False,
                        dtype=tf.float32,
                        initializer=tf.compat.v1.ones_initializer(),
                    )

                    # Number of batch element
                    self.batch = tf.compat.v1.shape(self.z_g)[0]
                    self.batch_index = tf.dtypes.cast(self.batch / 2, tf.int32)

                    self.kl_loss = tf.reduce_mean(-tf.reduce_sum(
                        1 + self.z_log_sigma_g_sq -
                        0.5 * tf.square(self.z_g[0:self.batch_index, :] *
                                        self.z_h[0:self.batch_index, :] *
                                        self.z_1_h[0:self.batch_index, :]) -
                        0.5 * tf.square(self.z_g[self.batch_index:, :] *
                                        self.z_h[self.batch_index:, :] *
                                        self.z_1_h[self.batch_index:, :]) -
                        tf.exp(self.z_log_sigma_g_sq),
                        1,
                    ))

                    self.loss = self.beta * (
                        self.kl_loss - self.mutual_information) + self.loss

                # Adam optimizer with gradient clipping
                optimizer = tf.compat.v1.train.AdamOptimizer(self.lr)
                gradients, variables = zip(
                    *optimizer.compute_gradients(self.loss))
                gradients, _ = tf.compat.v1.clip_by_global_norm(gradients, 1.0)
                self.step = optimizer.apply_gradients(zip(
                    gradients, variables))
                #self.step = tf.compat.v1.train.AdamOptimizer(learning_rate=self.lr).minimize(self.loss)

                if self.use_vairl:
                    self.make_beta_update()

        self.vs = vs
        self.saver = tf.compat.v1.train.Saver(max_to_keep=None)

    ## Layers
    def linear(self,
               inp,
               inner_size,
               name='linear',
               bias=True,
               activation=None,
               init=None):
        with tf.compat.v1.variable_scope(name):

            lin = tf.compat.v1.layers.dense(inp,
                                            inner_size,
                                            name=name,
                                            activation=activation,
                                            use_bias=bias,
                                            kernel_initializer=init)
            return lin

    def conv_layer_2d(self,
                      input,
                      filters,
                      kernel_size,
                      strides=(1, 1),
                      padding="SAME",
                      name='conv',
                      activation=None,
                      bias=True):
        with tf.compat.v1.variable_scope(name):

            conv = tf.compat.v1.layers.conv2d(input,
                                              filters,
                                              kernel_size,
                                              strides,
                                              padding=padding,
                                              name=name,
                                              activation=activation,
                                              use_bias=bias)
            return conv

    def embedding(self, input, indices, size, name='embs'):
        with tf.compat.v1.variable_scope(name):
            shape = (indices, size)
            stddev = min(0.1, sqrt(2.0 / (product(xs=shape[:-1]) + shape[-1])))
            initializer = tf.random.normal(shape=shape,
                                           stddev=stddev,
                                           dtype=tf.float32)
            W = tf.Variable(initial_value=initializer,
                            trainable=True,
                            validate_shape=True,
                            name='W',
                            dtype=tf.float32,
                            shape=shape)
            return tf.nn.tanh(
                tf.compat.v1.nn.embedding_lookup(params=W,
                                                 ids=input,
                                                 max_norm=None))

    # Netowrk specification
    def conv_net(self,
                 global_state,
                 local_state,
                 local_two_state,
                 agent_stats,
                 target_stats,
                 z_sigma=None,
                 use_noise=None,
                 with_action=False):

        conv_10 = self.conv_layer_2d(global_state,
                                     32, [1, 1],
                                     name='conv_10',
                                     activation=tf.nn.tanh)
        conv_11 = self.conv_layer_2d(conv_10,
                                     32, [3, 3],
                                     name='conv_11',
                                     activation=tf.nn.leaky_relu)
        conv_12 = self.conv_layer_2d(conv_11,
                                     32, [3, 3],
                                     name='conv_12',
                                     activation=tf.nn.leaky_relu)
        fc11 = tf.reshape(conv_12, [-1, 10 * 10 * 32])

        embs_41 = tf.nn.tanh(
            self.embedding(agent_stats, 129, 32, name='embs_41'))
        embs_41 = tf.reshape(embs_41, [-1, 16 * 32])
        fc_41 = self.linear(embs_41,
                            100,
                            name='fc_41',
                            activation=tf.nn.leaky_relu)

        embs_51 = self.embedding(target_stats, 125, 32, name='embs_51')
        embs_51 = tf.reshape(embs_51, [-1, 15 * 32])
        fc_51 = self.linear(embs_51,
                            100,
                            name='fc_51',
                            activation=tf.nn.leaky_relu)

        all_flat = tf.concat([fc11, fc_41, fc_51], axis=1)

        all_flat = self.linear(all_flat,
                               32,
                               name='fc1',
                               activation=tf.nn.leaky_relu)

        if with_action:
            hot_acts = tf.one_hot(self.acts, self.actions_size)
            hot_acts = tf.reshape(hot_acts, [-1, self.actions_size])
            all_flat = tf.concat([all_flat, hot_acts], axis=1)

        z_mean = None
        fc2 = self.linear(all_flat,
                          32,
                          name='fc2',
                          activation=tf.nn.leaky_relu)

        # In case we want to use V-AIRL
        if self.use_vairl:

            z_mean = self.linear(
                fc2,
                32,
                name='z_mean',
                init=tf.compat.v1.initializers.variance_scaling(0.01))
            noise = tf.compat.v1.random_normal(tf.compat.v1.shape(z_mean),
                                               dtype=tf.float32)

            z = z_mean + z_sigma * noise * use_noise
            fc2 = z

            return self.linear(fc2, 1, name='out'), z_mean
        else:
            return self.linear(fc2, 1, name='out'), None

    # Train method of the discriminator
    def train(self):

        losses = []

        # Update discriminator
        for it in range(self.num_itr):

            expert_batch_idxs = random.sample(
                range(len(self.expert_traj['obs'])), self.batch_size)
            policy_batch_idxs = random.sample(range(len(self.buffer['obs'])),
                                              self.batch_size)

            #expert_batch_idxs = np.random.randint(0, len(expert_traj['obs']), batch_size)
            #policy_batch_idxs = np.random.randint(0, len(policy_traj['obs']), batch_size)

            expert_obs = [
                self.expert_traj['obs'][id] for id in expert_batch_idxs
            ]
            policy_obs = [self.buffer['obs'][id] for id in policy_batch_idxs]

            expert_obs_n = [
                self.expert_traj['obs_n'][id] for id in expert_batch_idxs
            ]
            policy_obs_n = [
                self.buffer['obs_n'][id] for id in policy_batch_idxs
            ]

            expert_acts = [
                self.expert_traj['acts'][id] for id in expert_batch_idxs
            ]
            policy_acts = [self.buffer['acts'][id] for id in policy_batch_idxs]

            expert_probs = []
            for (index, state) in enumerate(expert_obs):
                _, probs = self.select_action(state)
                expert_probs.append(probs[expert_acts[index]])

            policy_probs = []
            for (index, state) in enumerate(policy_obs):
                _, probs = self.select_action(state)
                policy_probs.append(probs[policy_acts[index]])

            expert_probs = np.asarray(expert_probs)
            policy_probs = np.asarray(policy_probs)

            labels = np.ones((self.batch_size, 1))
            labels = np.concatenate([labels, np.zeros((self.batch_size, 1))])

            e_states = self.obs_to_state(expert_obs)
            p_states = self.obs_to_state(policy_obs)

            all_global = np.concatenate([e_states[0], p_states[0]], axis=0)
            all_local = np.concatenate([e_states[1], p_states[1]], axis=0)
            all_local_two = np.concatenate([e_states[2], p_states[2]], axis=0)
            all_agent_stats = np.concatenate([e_states[3], p_states[3]],
                                             axis=0)
            all_target_stats = np.concatenate([e_states[4], p_states[4]],
                                              axis=0)

            e_states_n = self.obs_to_state(expert_obs_n)
            p_states_n = self.obs_to_state(policy_obs_n)

            all_global_n = np.concatenate([e_states_n[0], p_states_n[0]],
                                          axis=0)
            all_local_n = np.concatenate([e_states_n[1], p_states_n[1]],
                                         axis=0)
            all_local_two_n = np.concatenate([e_states_n[2], p_states_n[2]],
                                             axis=0)
            all_agent_stats_n = np.concatenate([e_states_n[3], p_states_n[3]],
                                               axis=0)
            all_target_stats_n = np.concatenate([e_states_n[4], p_states_n[4]],
                                                axis=0)

            all_probs = np.concatenate([expert_probs, policy_probs], axis=0)
            all_probs = np.expand_dims(all_probs, axis=1)

            feed_dict = {
                self.local_state: all_local,
                self.local_two_state: all_local_two,
                self.agent_stats: all_agent_stats,
                self.target_stats: all_target_stats,
                self.local_state_n: all_local_n,
                self.local_two_state_n: all_local_two_n,
                self.agent_stats_n: all_agent_stats_n,
                self.target_stats_n: all_target_stats_n,
                self.probs: all_probs,
                self.labels: labels,
                self.use_noise: [1]
            }

            if self.with_action:
                all_acts = np.concatenate([expert_acts, policy_acts], axis=0)
                all_acts = np.expand_dims(all_acts, axis=1)

                feed_dict[self.acts] = all_acts

            feed_dict[self.global_state] = all_global
            feed_dict[self.global_state_n] = all_global_n

            if self.use_vairl:
                loss, f, _, _, kl_loss = self.sess.run([
                    self.loss, self.f, self.step, self.update_beta,
                    self.kl_loss
                ],
                                                       feed_dict=feed_dict)
            else:
                loss, f, disc, _ = self.sess.run(
                    [self.loss, self.f, self.discriminator, self.step],
                    feed_dict=feed_dict)

            losses.append(loss)

        # Update nomralization parameters
        for it in range(self.num_itr):

            expert_batch_idxs = random.sample(
                range(len(self.expert_traj['obs'])), self.batch_size)
            policy_batch_idxs = random.sample(range(len(self.buffer['obs'])),
                                              self.batch_size)

            expert_obs = [
                self.expert_traj['obs'][id] for id in expert_batch_idxs
            ]
            policy_obs = [self.buffer['obs'][id] for id in policy_batch_idxs]

            expert_obs_n = [
                self.expert_traj['obs_n'][id] for id in expert_batch_idxs
            ]
            policy_obs_n = [
                self.buffer['obs_n'][id] for id in policy_batch_idxs
            ]

            expert_acts = [
                self.expert_traj['acts'][id] for id in expert_batch_idxs
            ]
            policy_acts = [self.buffer['acts'][id] for id in policy_batch_idxs]

            e_states = self.obs_to_state(expert_obs)
            p_states = self.obs_to_state(policy_obs)

            all_global = np.concatenate([e_states[0], p_states[0]], axis=0)
            all_local = np.concatenate([e_states[1], p_states[1]], axis=0)
            all_local_two = np.concatenate([e_states[2], p_states[2]], axis=0)
            all_agent_stats = np.concatenate([e_states[3], p_states[3]],
                                             axis=0)
            all_target_stats = np.concatenate([e_states[4], p_states[4]],
                                              axis=0)

            e_states_n = self.obs_to_state(expert_obs_n)
            p_states_n = self.obs_to_state(policy_obs_n)

            all_global_n = np.concatenate([e_states_n[0], p_states_n[0]],
                                          axis=0)
            all_local_n = np.concatenate([e_states_n[1], p_states_n[1]],
                                         axis=0)
            all_local_two_n = np.concatenate([e_states_n[2], p_states_n[2]],
                                             axis=0)
            all_agent_stats_n = np.concatenate([e_states_n[3], p_states_n[3]],
                                               axis=0)
            all_target_stats_n = np.concatenate([e_states_n[4], p_states_n[4]],
                                                axis=0)

            expert_probs = []
            for (index, state) in enumerate(expert_obs):
                _, probs = self.select_action(state)
                expert_probs.append(probs[expert_acts[index]])

            policy_probs = []
            for (index, state) in enumerate(policy_obs):
                _, probs = self.select_action(state)
                policy_probs.append(probs[policy_acts[index]])

            expert_probs = np.asarray(expert_probs)
            policy_probs = np.asarray(policy_probs)

            probs = np.concatenate([expert_probs, policy_probs], axis=0)
            probs = np.expand_dims(probs, axis=1)

            feed_dict = {
                self.global_state: all_global,
                self.local_state: all_local,
                self.local_two_state: all_local_two,
                self.agent_stats: all_agent_stats,
                self.target_stats: all_target_stats,
                self.global_state_n: all_global_n,
                self.local_state_n: all_local_n,
                self.local_two_state_n: all_local_two_n,
                self.agent_stats_n: all_agent_stats_n,
                self.target_stats_n: all_target_stats_n,
            }

            if self.use_vairl:
                feed_dict[self.use_noise] = [0]

            if self.with_action:
                all_acts = np.concatenate([expert_acts, policy_acts], axis=0)
                all_acts = np.expand_dims(all_acts, axis=1)

                feed_dict[self.acts] = all_acts

            feed_dict[self.global_state] = all_global
            feed_dict[self.global_state_n] = all_global_n

            f = self.sess.run([self.f], feed_dict=feed_dict)
            f -= self.entropy_weight * np.log(probs)
            f = np.squeeze(f)
            self.push_reward(f)

        # Update Dynamic Running Stat
        if isinstance(self.r_norm, DynamicRunningStat):
            self.r_norm.reset()

        return np.mean(losses), 0

    # Eval without discriminator - only reward function
    def eval(self, obs, obs_n, acts=None, probs=None):

        states = self.obs_to_state(obs)

        feed_dict = {
            self.global_state: states[0],
            self.local_state: states[1],
            self.local_two_state: states[2],
            self.agent_stats: states[3],
            self.target_stats: states[4],
            self.use_noise: [0]
        }

        if self.with_action and acts is not None:
            acts = np.expand_dims(acts, axis=1)
            feed_dict[self.acts] = acts

        reward = self.sess.run([self.reward], feed_dict=feed_dict)
        if probs != None:
            reward -= self.entropy_weight * np.log(probs)

        # Normalize the reward
        #self.r_norm.push(reward[0][0])
        #reward = [[self.normalize_rewards(reward[0][0])]]
        #if self.r_norm.n == 0:
        #    reward = [[0]]

        return reward[0][0]

    # Eval with the discriminator - it returns an entropy regularized objective
    def eval_discriminator(self, obs, obs_n, probs, acts=None):
        states = self.obs_to_state(obs)
        states_n = self.obs_to_state(obs_n)

        probs = np.expand_dims(probs, axis=1)

        feed_dict = {
            self.global_state: states[0],
            self.local_state: states[1],
            self.local_two_state: states[2],
            self.agent_stats: states[3],
            self.target_stats: states[4],
            self.global_state_n: states[0],
            self.local_state_n: states_n[1],
            self.local_two_state_n: states_n[2],
            self.agent_stats_n: states_n[3],
            self.target_stats_n: states_n[4],
            self.use_noise: [0]
        }

        if self.with_action and acts is not None:
            acts = np.expand_dims(acts, axis=1)
            feed_dict[self.acts] = acts

        f = self.sess.run([self.f], feed_dict=feed_dict)
        f -= self.entropy_weight * np.log(probs)
        f = self.normalize_rewards(f)
        return f

    # Transform a DeepCrawl obs to state
    def obs_to_state(self, obs):

        global_batch = np.stack(
            [np.asarray(state['global_in']) for state in obs])
        local_batch = np.stack(
            [np.asarray(state['local_in']) for state in obs])
        local_two_batch = np.stack(
            [np.asarray(state['local_in_two']) for state in obs])
        agent_stats_batch = np.stack(
            [np.asarray(state['agent_stats']) for state in obs])
        target_stats_batch = np.stack(
            [np.asarray(state['target_stats']) for state in obs])

        return global_batch, local_batch, local_two_batch, agent_stats_batch, target_stats_batch

    # For V-AIRL
    def make_beta_update(self):

        new_beta = tf.maximum(
            self.beta + self.alpha * (self.kl_loss - self.mutual_information),
            eps)
        with tf.control_dependencies([self.step]):
            self.update_beta = tf.compat.v1.assign(self.beta, new_beta)

    # Normalize the reward for each frame of the sequence
    def push_reward(self, rewards):
        for r in rewards:
            self.r_norm.push(r)

    def normalize_rewards(self, rewards):
        rewards -= self.r_norm.mean
        rewards /= (self.r_norm.std + 1e-12)
        rewards *= 0.05

        return rewards

    # Select action from the policy and fetch the probability distribution over the action space
    def select_action(self, state):

        act, _, probs = self.policy.eval([state])

        return (act, probs[0])

    # Update demonstrations
    def set_demonstrations(self, demonstrations, validations):
        self.expert_traj = demonstrations

        if validations is not None:
            self.validation_traj = validations

    # Create the replay buffer with which train the discriminator
    def create_buffer(self):
        self.buffer['obs'] = []
        self.buffer['obs_n'] = []
        self.buffer['acts'] = []

    # Add a transition to the buffer
    def add_to_buffer(self, obs, obs_n, acts, buffer_length=100000):

        if len(self.buffer['obs']) >= buffer_length:
            random_index = np.random.randint(0, len(self.buffer['obs']))
            del self.buffer['obs'][random_index]
            del self.buffer['obs_n'][random_index]
            del self.buffer['acts'][random_index]

        self.buffer['obs'].append(obs)
        self.buffer['obs_n'].append(obs_n)
        self.buffer['acts'].append(acts)

    # Create and return some demonstrations [(states, actions, frames)]. The order of the demonstrations must be from
    # best to worst. The number of demonstrations is given by the user
    def create_demonstrations(self,
                              env,
                              save_demonstrations=True,
                              inference=False,
                              verbose=False,
                              with_policy=False,
                              num_episode=31,
                              max_timestep=20):
        end = False

        # Initialize trajectories buffer
        expert_traj = {
            'obs': [],
            'obs_n': [],
            'acts': [],
        }

        val_traj = {'obs': [], 'obs_n': [], 'acts': []}

        if with_policy is None:
            num_episode = None

        episode = 1

        while not end:
            # Make another demonstration
            print('Demonstration n° ' + str(episode))
            # Reset the environment
            state = env.reset()
            states = [state]
            actions = []
            done = False
            step = 0
            cum_reward = 0
            # New sequence of states and actions
            while not done:
                try:
                    # Input the action and save the new state and action
                    step += 1
                    print("Timestep: " + str(step))
                    if verbose:
                        env.print_observation(state)
                    if not with_policy:
                        action = input('action: ')
                        if action == "f":
                            done = True
                            continue
                        while env.command_to_action(action) >= 99:
                            action = input('action: ')
                    else:
                        action, probs = self.select_action(state)
                    print(action)
                    state_n, done, reward = env.execute(action)

                    cum_reward += reward
                    if not with_policy:
                        action = env.command_to_action(action)
                    # If inference is true, print the reward
                    if inference:
                        _, probs = self.select_action(state)
                        reward = self.eval([state], [state_n], [action])
                        # print('Discriminator probability: ' + str(disc))
                        print('Unnormalize reward: ' + str(reward))
                        reward = self.normalize_rewards(reward)
                        print('Normalize reward: ' + str(reward))
                        print('Probability of state space: ')
                        print(probs)
                    state = state_n
                    states.append(state)
                    actions.append(action)
                    if step >= max_timestep:
                        done = True
                except Exception as e:
                    print(e)
                    continue

            if not inference:
                y = None

                print('Demonstration number: ' + str(episode))
                if with_policy:
                    if episode < num_episode:
                        print(state_n['target_stats'][0])
                        if True:
                            y = 'y'
                        else:
                            y = 'n'
                while y != 'y' and y != 'n':
                    y = input('Do you want to save this demonstration? [y/n] ')

                if y == 'y':
                    # Update expert trajectories
                    expert_traj['obs'].extend(np.array(states[:-1]))
                    expert_traj['obs_n'].extend(np.array(states[1:]))
                    expert_traj['acts'].extend(np.array(actions))
                    episode += 1
                else:

                    if with_policy:
                        if episode > num_episode - 1:
                            y = input(
                                'Do you want to save this demonstration as validation? [y/n] '
                            )
                        else:
                            y = 'n'
                    else:
                        y = input(
                            'Do you want to save this demonstration as validation? [y/n] '
                        )

                    if y == 'y':
                        val_traj['obs'].extend(np.array(states[:-1]))
                        val_traj['obs_n'].extend(np.array(states[1:]))
                        val_traj['acts'].extend(np.array(actions))
                        episode += 1

            y = None
            if num_episode is None:
                while y != 'y' and y != 'n':

                    if not inference:
                        y = input(
                            'Do you want to create another demonstration? [y/n] '
                        )
                    else:
                        y = input('Do you want to try another episode? [y/n] ')

                    if y == 'n':
                        end = True
            else:
                if episode >= num_episode + 1:
                    end = True

        if len(val_traj['obs']) <= 0:
            val_traj = None

        # Save demonstrations to file
        if save_demonstrations and not inference:
            print('Saving the demonstrations...')
            self.save_demonstrations(expert_traj, val_traj)
            print('Demonstrations saved!')

        if not inference:
            self.set_demonstrations(expert_traj, val_traj)

        return expert_traj, val_traj

    # Save demonstrations dict to file
    @staticmethod
    def save_demonstrations(demonstrations,
                            validations=None,
                            name='dems_potions.pkl'):
        with open('reward_model/dems/' + name, 'wb') as f:
            pickle.dump(demonstrations, f, pickle.HIGHEST_PROTOCOL)
        if validations is not None:
            with open('reward_model/dems/vals.pkl', 'wb') as f:
                pickle.dump(validations, f, pickle.HIGHEST_PROTOCOL)

    # Load demonstrations from file
    def load_demonstrations(self, name='dems.pkl'):
        with open('reward_model/dems/' + name, 'rb') as f:
            expert_traj = pickle.load(f)

        with open('reward_model/dems/vals.pkl', 'rb') as f:
            val_traj = pickle.load(f)

        self.set_demonstrations(expert_traj, val_traj)

        return expert_traj, val_traj

    # Save the entire model
    def save_model(self, name=None):
        self.saver.save(self.sess, 'reward_model/models/{}'.format(name))
        return

    # Load entire model
    def load_model(self, name=None):
        self.saver = tf.compat.v1.train.import_meta_graph(
            'reward_model/models/' + name + '.meta')
        self.saver.restore(self.sess, 'reward_model/models/' + name)
        return