예제 #1
0
    def __init__(self,
                 env,
                 trpo: TRPO,
                 tensorboard_path,
                 no_encoder=False,
                 feature_dim=10,
                 forward_weight=0.8,
                 external_reward_weight=0.01,
                 forward_cos=False,
                 init_learning_rate=1e-4,
                 icm_batch_size=128,
                 replay_pool_size=1000000,
                 min_pool_size=200,
                 n_updates_per_iter=10,
                 obs_dtype='float32',
                 normalize_input=False,
                 gpu_fraction=0.95,
                 pretrained_icm=False,
                 pretrained_icm_path=None,
                 freeze_icm=False,
                 **kwargs):
        """
		:param env: Environment
		:param algo: Algorithm that will be used with ICM
		:param encoder: State encoder that maps s to f
		:param inverse_model: Inverse dynamics model that maps (f1, f2) to actions
		:param forward_model: Forward dynamics model that maps (f1, a) to f2
		:param forward_weight: Weight from 0 to 1 that balances forward loss and inverse loss
		:param external_reward_weight: Weight that balances external reward and internal reward
		:param init_learning_rate: Initial learning rate of optimizer
		"""
        self.trpo = trpo
        self.freeze_icm = freeze_icm
        # Replace sampler to inject intrinsic reward
        self.trpo.sampler = self.get_sampler(self.trpo)
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=gpu_fraction)
        self.sess = tf.get_default_session() or tf.Session(
            config=tf.ConfigProto(gpu_options=gpu_options))
        self.external_reward_weight = external_reward_weight
        self.summary_writer = tf.summary.FileWriter(
            tensorboard_path, graph=tf.get_default_graph())
        self.n_updates_per_iter = n_updates_per_iter
        self.icm_batch_size = icm_batch_size
        self.act_space = env.action_space
        self.obs_space = env.observation_space

        self.pool = TRPOReplayPool(replay_pool_size,
                                   self.obs_space.flat_dim,
                                   self.act_space.flat_dim,
                                   obs_dtype=obs_dtype)

        self.min_pool_size = min_pool_size
        # Setup ICM models
        self.s1 = tf.placeholder(tf.float32,
                                 [None] + list(self.obs_space.shape))
        self.s2 = tf.placeholder(tf.float32,
                                 [None] + list(self.obs_space.shape))
        if normalize_input:
            s1 = self.s1 / 255.0 - 0.5
            s2 = self.s2 / 255.0 - 0.5
        else:
            s1 = self.s1
            s2 = self.s2

        self.asample = tf.placeholder(tf.float32,
                                      [None, self.act_space.flat_dim])
        self.external_rewards = tf.placeholder(tf.float32, (None, ))

        # Hack
        temp_vars = set(tf.all_variables())

        if pretrained_icm:
            with self.sess.as_default():
                icm_data = joblib.load(pretrained_icm_path)
                _encoder = icm_data['encoder']
                _forward_model = icm_data['forward_model']
                _inverse_model = icm_data['inverse_model']

            icm_vars = set(tf.all_variables()) - temp_vars
        else:
            icm_vars = set([])

        if pretrained_icm:
            self._encoder = _encoder
            # raise NotImplementedError("Currently only supports flat observation input!")
        else:
            if len(self.obs_space.shape) == 1:
                if no_encoder:
                    self._encoder = NoEncoder(self.obs_space.flat_dim,
                                              env_spec=env.spec)
                else:
                    self._encoder = FullyConnectedEncoder(feature_dim,
                                                          env_spec=env.spec)
            else:
                self._encoder = ConvEncoder(feature_dim,
                                            env.spec.observation_space.shape)

        self._encoder.sess = self.sess

        if not pretrained_icm:
            # Initialize variables for get_copy to work
            self.sess.run(tf.initialize_all_variables())

        with self.sess.as_default():
            self.encoder1 = self._encoder.get_weight_tied_copy(
                observation_input=s1)
            self.encoder2 = self._encoder.get_weight_tied_copy(
                observation_input=s2)

        if not pretrained_icm:
            self._inverse_model = InverseModel(feature_dim, env_spec=env.spec)
            self._forward_model = ForwardModel(feature_dim, env_spec=env.spec)
        else:
            self._inverse_model = _inverse_model
            self._forward_model = _forward_model

        self._inverse_model.sess = self.sess
        self._forward_model.sess = self.sess

        if not pretrained_icm:
            # Initialize variables for get_copy to work
            self.sess.run(tf.initialize_all_variables())

        # Clip actions to make sure it is consistent with what get input in env
        clipped_asample = tf.clip_by_value(self.asample, -1.0, 1.0)

        with self.sess.as_default():
            self.inverse_model = self._inverse_model.get_weight_tied_copy(
                feature_input1=self.encoder1.output,
                feature_input2=self.encoder2.output)
            self.forward_model = self._forward_model.get_weight_tied_copy(
                feature_input=self.encoder1.output,
                action_input=clipped_asample)

        # Define losses, by default it uses L2 loss
        if forward_cos:
            self.forward_loss = cos_loss(self.encoder2.output,
                                         self.forward_model.output)
        else:
            self.forward_loss = tf.reduce_mean(
                tf.square(self.encoder2.output - self.forward_model.output))
        if isinstance(self.act_space, Box):
            self.inverse_loss = tf.reduce_mean(
                tf.square(clipped_asample - self.inverse_model.output))
        elif isinstance(self.act_space, Discrete):
            # TODO: Implement softmax loss
            raise NotImplementedError
        else:
            raise NotImplementedError

        if forward_cos:
            self.internal_rewards = cos_loss(self.encoder2.output,
                                             self.forward_model.output,
                                             mean=False)
        else:
            self.internal_rewards = tf.reduce_sum(
                tf.square(self.encoder2.output - self.forward_model.output),
                axis=1)
        self.mean_internal_rewards = tf.reduce_mean(self.internal_rewards)
        self.mean_external_rewards = tf.reduce_mean(self.external_rewards)

        self.total_loss = forward_weight * self.forward_loss + \
            (1. - forward_weight) * self.inverse_loss
        self.icm_opt = tf.train.AdamOptimizer(init_learning_rate).\
            minimize(self.total_loss)

        # Setup summaries
        inverse_loss_summ = tf.summary.scalar("icm_inverse_loss",
                                              self.inverse_loss)
        forward_loss_summ = tf.summary.scalar("icm_forward_loss",
                                              self.forward_loss)
        total_loss_summ = tf.summary.scalar("icm_total_loss", self.total_loss)
        internal_rewards = tf.summary.scalar("mean_internal_rewards",
                                             self.mean_internal_rewards)
        external_rewards = tf.summary.scalar("mean_external_rewards",
                                             self.mean_external_rewards)
        # Setup env_info logs
        var_summ = []

        self.summary = tf.summary.merge([
            inverse_loss_summ, forward_loss_summ, total_loss_summ,
            internal_rewards, external_rewards
        ])
        # self.summary = tf.summary.merge([inverse_loss_summ, forward_loss_summ, total_loss_summ] + var_summ)

        ## Initialize uninitialized variables
        self.sess.run(
            tf.initialize_variables(set(tf.all_variables()) - icm_vars))
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('env_name', type=str, help="name of gym env")
    parser.add_argument('dataset_path',
                        type=str,
                        help="path of training and validation dataset")
    parser.add_argument('val_random_data',
                        type=str,
                        help="path of training and validation dataset")
    parser.add_argument('val_contact_data',
                        type=str,
                        help="path of training and validation dataset")
    parser.add_argument('tfboard_path', type=str, default='/tmp/tfboard')
    parser.add_argument('tfmodel_path', type=str, default='/tmp/tfmodels')
    parser.add_argument('--restore', action='store_true')
    # Training parameters
    parser.add_argument('--num_itr', type=int, default=10000000)
    parser.add_argument('--val_freq', type=int, default=200)
    parser.add_argument('--log_freq', type=int, default=50)
    parser.add_argument('--save_freq', type=int, default=5000)

    # ICM parameters
    parser.add_argument('--init_lr', type=float, default=2e-3)
    parser.add_argument('--forward_weight',
                        type=float,
                        default=0.5,
                        help="the ratio of forward loss vs inverse loss")
    parser.add_argument('--cos_forward',
                        action='store_true',
                        help="whether to use cosine forward loss")

    args = parser.parse_args()

    # Get dataset
    train_set_names = list(
        map(lambda file_name: osp.join(args.dataset_path, file_name),
            listdir(args.dataset_path)))
    val_random_set_names = list(
        map(lambda file_name: osp.join(args.val_random_data, file_name),
            listdir(args.val_random_data)))
    val_contact_set_names = list(
        map(lambda file_name: osp.join(args.val_contact_data, file_name),
            listdir(args.val_contact_data)))
    # import pdb; pdb.set_trace()

    obs_shape = OBS_SHAPE_MAP[args.env_name]
    action_dim = ACTION_DIM_MAP[args.env_name]
    train_obs, train_next_obs, train_action = inputs(train_set_names,
                                                     obs_shape,
                                                     train=True)
    val_random_obs, val_random_next_obs, val_random_action = inputs(
        val_random_set_names, obs_shape, train=False)
    val_contact_obs, val_contact_next_obs, val_contact_action = inputs(
        val_contact_set_names, obs_shape, train=False)

    if args.restore:
        models_dict = joblib.load(args.tfmodel_path)
        _encoder = models_dict['encoder']
        _inverse_model = model.dict['inverse_model']
        _forward_model = model.dict['forward_model']
    else:
        _encoder = NoEncoder(obs_shape, observation_dim=[obs_shape])
        _inverse_model = InverseModel(
            feature_dim=obs_shape,
            action_dim=action_dim,
            hidden_sizes=(256, 256),
            hidden_activation=tf.nn.elu,
            output_activation=tf.nn.tanh,
        )
        _forward_model = ForwardModel(
            feature_dim=obs_shape,
            action_dim=action_dim,
            hidden_sizes=(256, 257),
            hidden_activation=tf.nn.elu,
        )
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.1)

    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    _encoder.sess = sess
    _inverse_model.sess = sess
    _forward_model.sess = sess

    with sess.as_default():
        # Initialize variables for get_copy to work
        sess.run(tf.initialize_all_variables())

        train_encoder1 = _encoder.get_weight_tied_copy(
            observation_input=train_obs)
        train_encoder2 = _encoder.get_weight_tied_copy(
            observation_input=train_next_obs)
        # import pdb; pdb.set_trace()
        train_inverse_model = _inverse_model.get_weight_tied_copy(
            feature_input1=train_encoder1.output,
            feature_input2=train_encoder2.output)
        train_forward_model = _forward_model.get_weight_tied_copy(
            feature_input=train_encoder1.output, action_input=train_action)

        val_random_encoder1 = _encoder.get_weight_tied_copy(
            observation_input=val_random_obs)
        val_random_encoder2 = _encoder.get_weight_tied_copy(
            observation_input=val_random_next_obs)
        val_random_inverse_model = _inverse_model.get_weight_tied_copy(
            feature_input1=val_random_encoder1.output,
            feature_input2=val_random_encoder2.output)
        val_random_forward_model = _forward_model.get_weight_tied_copy(
            feature_input=val_random_encoder1.output,
            action_input=val_random_action)

        val_contact_encoder1 = _encoder.get_weight_tied_copy(
            observation_input=val_contact_obs)
        val_contact_encoder2 = _encoder.get_weight_tied_copy(
            observation_input=val_contact_next_obs)
        val_contact_inverse_model = _inverse_model.get_weight_tied_copy(
            feature_input1=val_contact_encoder1.output,
            feature_input2=val_contact_encoder2.output)
        val_contact_forward_model = _forward_model.get_weight_tied_copy(
            feature_input=val_contact_encoder1.output,
            action_input=val_contact_action)

        if args.cos_forward:
            train_forward_loss = cos_loss(train_encoder2.output,
                                          train_forward_model.output)
            val_forward_loss = cos_loss(val_encoder2.output,
                                        val_forward_model.output)
        else:
            train_forward_loss = tf.reduce_mean(
                tf.square(train_encoder2.output - train_forward_model.output))
            #use only if in state space!!!!!!!
            train_forward_loss_arm = tf.reduce_mean(
                tf.square(train_encoder2.output[:, :4] -
                          train_forward_model.output[:, :4]))
            train_forward_loss_box = tf.reduce_mean(
                tf.square(train_encoder2.output[:, 4:7] -
                          train_forward_model.output[:, 4:7]))

            val_random_forward_loss = tf.reduce_mean(
                tf.square(val_random_encoder2.output -
                          val_random_forward_model.output))
            val_random_forward_loss_arm = tf.reduce_mean(
                tf.square(val_random_encoder2.output[:, :4] -
                          val_random_forward_model.output[:, :4]))
            val_random_forward_loss_box = tf.reduce_mean(
                tf.square(val_random_encoder2.output[:, 4:7] -
                          val_random_forward_model.output[:, 4:7]))

            val_contact_forward_loss = tf.reduce_mean(
                tf.square(val_contact_encoder2.output -
                          val_contact_forward_model.output))
            val_contact_forward_loss_arm = tf.reduce_mean(
                tf.square(val_contact_encoder2.output[:, :4] -
                          val_contact_forward_model.output[:, :4]))
            val_contact_forward_loss_box = tf.reduce_mean(
                tf.square(val_contact_encoder2.output[:, 4:7] -
                          val_contact_forward_model.output[:, 4:7]))

        train_inverse_losses = tf.reduce_mean(
            tf.square(train_action - train_inverse_model.output), axis=0)
        val_random_inverse_losses = tf.reduce_mean(
            tf.square(val_random_action - val_random_inverse_model.output),
            axis=0)
        val_contact_inverse_losses = tf.reduce_mean(
            tf.square(val_contact_action - val_contact_inverse_model.output),
            axis=0)

        train_inverse_separate_summ = []
        val_random_inverse_separate_summ = []
        val_contact_inverse_separate_summ = []
        for joint_idx in range(action_dim):
            train_inverse_separate_summ.append(
                tf.summary.scalar(
                    "train/icm_inverse_loss/joint_{}".format(joint_idx),
                    train_inverse_losses[joint_idx]))
            val_random_inverse_separate_summ.append(
                tf.summary.scalar(
                    "random_val/icm_inverse_random_loss/joint_{}".format(
                        joint_idx), val_random_inverse_losses[joint_idx]))
            val_contact_inverse_separate_summ.append(
                tf.summary.scalar(
                    "contact_val/icm_inverse_random_loss/joint_{}".format(
                        joint_idx), val_contact_inverse_losses[joint_idx]))

        train_inverse_loss = tf.reduce_mean(train_inverse_losses)
        val_random_inverse_loss = tf.reduce_mean(val_random_inverse_losses)
        val_contact_inverse_loss = tf.reduce_mean(val_contact_inverse_losses)

        train_total_loss = args.forward_weight * train_forward_loss + (
            1. - args.forward_weight) * train_inverse_loss
        val_random_total_loss = args.forward_weight * val_random_forward_loss + (
            1. - args.forward_weight) * val_random_inverse_loss
        val_contact_total_loss = args.forward_weight * val_contact_forward_loss + (
            1. - args.forward_weight) * val_contact_inverse_loss

        icm_opt = tf.train.AdamOptimizer(
            args.init_lr).minimize(train_total_loss)
        _, train_data_forward_var = tf.nn.moments(train_obs, axes=[1])
        _, train_data_box_var = tf.nn.moments(train_obs[:, 4:7], axes=[1])

        # Setup summaries
        summary_writer = tf.summary.FileWriter(args.tfboard_path,
                                               graph=tf.get_default_graph())
        train_forward_loss_arm_summ = tf.summary.scalar(
            "train/forward_loss_arm", train_forward_loss_arm)
        train_forward_loss_box_summ = tf.summary.scalar(
            "train/forward_loss_box", train_forward_loss_box)
        train_inverse_loss_summ = tf.summary.scalar(
            "train/icm_inverse_loss/total_mean", train_inverse_loss)
        train_forward_loss_summ = tf.summary.scalar("train/icm_forward_loss",
                                                    train_forward_loss)
        train_total_loss_summ = tf.summary.scalar("train/icm_total_loss",
                                                  train_total_loss)

        random_val_forward_loss_arm_summ = tf.summary.scalar(
            "random_val/forward_loss_arm", val_random_forward_loss_arm)
        random_val_forward_loss_box_summ = tf.summary.scalar(
            "random_val/forward_loss_box", val_random_forward_loss_box)
        random_val_inverse_loss_summ = tf.summary.scalar(
            "random_val/icm_inverse_loss/total_mean", val_random_inverse_loss)
        random_val_forward_loss_summ = tf.summary.scalar(
            "random_val/icm_forward_loss", val_random_forward_loss)
        random_val_total_loss_summ = tf.summary.scalar(
            "random_val/icm_total_loss", val_random_total_loss)

        contact_val_forward_loss_arm_summ = tf.summary.scalar(
            "contact_val/forward_loss_arm", val_contact_forward_loss_arm)
        contact_val_forward_loss_box_summ = tf.summary.scalar(
            "contact_val/forward_loss_box", val_contact_forward_loss_box)
        contact_val_inverse_loss_summ = tf.summary.scalar(
            "contact_val/icm_inverse_loss/total_mean",
            val_contact_inverse_loss)
        contact_val_forward_loss_summ = tf.summary.scalar(
            "contact_val/icm_forward_loss", val_contact_forward_loss)
        contact_val_total_loss_summ = tf.summary.scalar(
            "contact_val/icm_total_loss", val_contact_total_loss)

        forward_data_variance_summ = tf.summary.scalar("training_data_forward_variance", \
                     tf.reduce_mean(train_data_forward_var))
        forward_data_box_variance_summ = tf.summary.scalar("training_data_forward_box_variance", \
                     tf.reduce_mean(train_data_box_var))

        train_summary_op = tf.summary.merge([
            train_inverse_loss_summ,
            train_forward_loss_summ,
            train_forward_loss_arm_summ,
            train_forward_loss_box_summ,
            train_total_loss_summ,
            forward_data_variance_summ,
            forward_data_box_variance_summ,
        ] + train_inverse_separate_summ)

        val_summary_op = tf.summary.merge([
            random_val_forward_loss_arm_summ,
            random_val_forward_loss_box_summ,
            random_val_inverse_loss_summ,
            random_val_forward_loss_summ,
            random_val_total_loss_summ,
            contact_val_forward_loss_arm_summ,
            contact_val_forward_loss_box_summ,
            contact_val_inverse_loss_summ,
            contact_val_forward_loss_summ,
            contact_val_total_loss_summ,
        ] + val_random_inverse_separate_summ +
                                          val_contact_inverse_separate_summ)

        logger.log("Finished creating ICM model")

        sess.run(tf.initialize_all_variables())

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            for timestep in range(args.num_itr):
                # print(timestep)
                # print(sess.run(train_action))
                # print("wow")
                if timestep % args.log_freq == 0:
                    logger.log("Start itr {}".format(timestep))
                    _, train_summary = sess.run([icm_opt, train_summary_op])
                    summary_writer.add_summary(train_summary, timestep)
                else:
                    sess.run(icm_opt)

                if timestep % args.save_freq == 0:
                    save_snapshot(_encoder, _inverse_model, _forward_model,
                                  args.tfmodel_path)

                if timestep % args.val_freq == 0:
                    val_summary = sess.run(val_summary_op)
                    summary_writer.add_summary(val_summary, timestep)

        except KeyboardInterrupt:
            print("End training...")
            pass

        coord.join(threads)
        sess.close()
	def __init__(
			self,
			env,
			algo: OnlineAlgorithm,
			no_encoder=False,
			feature_dim=10,
			forward_weight=0.8,
			external_reward_weight=0.01,
			inverse_tanh=False,
			init_learning_rate=1e-4,
			algo_update_freq=1,
			**kwargs
	):
		"""
		:param env: Environment
		:param algo: Algorithm that will be used with ICM
		:param encoder: State encoder that maps s to f
		:param inverse_model: Inverse dynamics model that maps (f1, f2) to actions
		:param forward_model: Forward dynamics model that maps (f1, a) to f2
		:param forward_weight: Weight from 0 to 1 that balances forward loss and inverse loss
		:param external_reward_weight: Weight that balances external reward and internal reward
		:param init_learning_rate: Initial learning rate of optimizer
		"""
		self.algo = algo
		gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
		self.sess = self.algo.sess or tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
		self.external_reward_weight = external_reward_weight
		self.summary_writer = self.algo.summary_writer
		self.algo_update_freq = algo_update_freq
		act_space = env.action_space
		obs_space = env.observation_space
		
		# Setup ICM models
		self.s1 = tf.placeholder(tf.float32, [None] + list(obs_space.shape))
		self.s2 = tf.placeholder(tf.float32, [None] + list(obs_space.shape))
		self.asample = tf.placeholder(tf.float32, [None, act_space.flat_dim])
		self.external_rewards = tf.placeholder(tf.float32, (None,))

		if len(obs_space.shape) == 1:
			if no_encoder:
				self._encoder = NoEncoder(obs_space.flat_dim, env_spec=env.spec)
			else:
				self._encoder = FullyConnectedEncoder(feature_dim, env_spec=env.spec)
		else:
			# TODO: implement conv encoder
			raise NotImplementedError("Currently only supports flat observation input!")

		self._encoder.sess = self.sess
		# Initialize variables for get_copy to work
		self.sess.run(tf.initialize_all_variables())
		with self.sess.as_default():
			self.encoder1 = self._encoder.get_weight_tied_copy(observation_input=self.s1)
			self.encoder2 = self._encoder.get_weight_tied_copy(observation_input=self.s2)

		self._inverse_model = InverseModel(feature_dim, env_spec=env.spec)
		self._forward_model = ForwardModel(feature_dim, env_spec=env.spec)
		self._inverse_model.sess = self.sess
		self._forward_model.sess = self.sess
		# Initialize variables for get_copy to work
		self.sess.run(tf.initialize_all_variables())
		with self.sess.as_default():
			self.inverse_model = self._inverse_model.get_weight_tied_copy(feature_input1=self.encoder1.output, 
																		  feature_input2=self.encoder2.output)
			self.forward_model = self._forward_model.get_weight_tied_copy(feature_input=self.encoder1.output,
																	  	  action_input=self.asample)

		# Define losses
		self.forward_loss = tf.reduce_mean(tf.square(self.encoder2.output - self.forward_model.output))
		# self.forward_loss = tf.nn.l2_loss(self.encoder2.output - self.forward_model.output)
		if isinstance(act_space, Box):
			self.inverse_loss = tf.reduce_mean(tf.square(self.asample - self.inverse_model.output))
		elif isinstance(act_space, Discrete):
			# TODO: Implement softmax loss
			raise NotImplementedError
		else:
			raise NotImplementedError
		self.internal_rewards = tf.reduce_sum(tf.square(self.encoder2.output - self.forward_model.output), axis=1)
		self.mean_internal_rewards = tf.reduce_mean(self.internal_rewards)
		self.mean_external_rewards = tf.reduce_mean(self.external_rewards)
		
		self.total_loss = forward_weight * self.forward_loss + \
						(1. - forward_weight) * self.inverse_loss
		self.icm_opt = tf.train.AdamOptimizer(init_learning_rate).\
						minimize(self.total_loss)


		# Setup summaries
		inverse_loss_summ = tf.summary.scalar("icm_inverse_loss", self.inverse_loss)
		forward_loss_summ = tf.summary.scalar("icm_forward_loss", self.forward_loss)
		total_loss_summ = tf.summary.scalar("icm_total_loss", self.total_loss)
		internal_rewards = tf.summary.scalar("mean_internal_rewards", self.mean_internal_rewards)
		external_rewards = tf.summary.scalar("mean_external_rewards_training", self.mean_external_rewards)
		var_summ = []
		for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
			var_summ.append(tf.summary.histogram(var.op.name, var))
		self.summary = tf.summary.merge([inverse_loss_summ, forward_loss_summ, total_loss_summ,\
							internal_rewards, external_rewards])
		# Initialize variables
		
		self.sess.run(tf.initialize_all_variables())
예제 #4
0
    def __init__(self,
                 env,
                 trpo: TRPO,
                 tensorboard_path,
                 no_encoder=False,
                 feature_dim=10,
                 forward_weight=0.8,
                 external_reward_weight=0.01,
                 inverse_tanh=True,
                 forward_cos=False,
                 init_learning_rate=1e-4,
                 icm_batch_size=128,
                 replay_pool_size=1000000,
                 min_pool_size=1000,
                 n_updates_per_iter=500,
                 rel_curiosity=False,
                 clip_curiosity=0.0,
                 debug_save_data=False,
                 debug_log_weights=False,
                 **kwargs):
        """
		:param env: Environment
		:param algo: Algorithm that will be used with ICM
		:param encoder: State encoder that maps s to f
		:param inverse_model: Inverse dynamics model that maps (f1, f2) to actions
		:param forward_model: Forward dynamics model that maps (f1, a) to f2
		:param forward_weight: Weight from 0 to 1 that balances forward loss and inverse loss
		:param external_reward_weight: Weight that balances external reward and internal reward
		:param init_learning_rate: Initial learning rate of optimizer
		"""
        self.trpo = trpo
        # Replace sampler to inject intrinsic reward
        self.trpo.sampler = self.get_sampler(self.trpo)
        self.sess = tf.get_default_session() or tf.Session()
        self.external_reward_weight = external_reward_weight
        self.summary_writer = tf.summary.FileWriter(
            tensorboard_path, graph=tf.get_default_graph())
        self.n_updates_per_iter = n_updates_per_iter
        self.forward_cos = forward_cos
        self.inverse_tanh = inverse_tanh
        self.icm_batch_size = icm_batch_size
        self.rel_curiosity = rel_curiosity
        self.clip_curiosity = clip_curiosity
        self.debug_save_data = debug_save_data
        self.debug_log_weights = debug_log_weights

        # Debug purpose: Save (ob1, a, ob2, if_contact)
        if self.debug_save_data:
            self.DEBUG_DATA_PATH = "/home/dianchen/icm_data.csv"
            with open(self.DEBUG_DATA_PATH, 'w+') as csvfile:
                fieldnames = ['obs', 'a', 'next_obs', 'contact']
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                writer.writeheader()

        act_space = env.action_space
        obs_space = env.observation_space

        self.pool = TRPOReplayPool(replay_pool_size, obs_space.flat_dim,
                                   act_space.flat_dim)
        self.min_pool_size = min_pool_size
        # Setup ICM models
        self.s1 = tf.placeholder(tf.float32, [None] + list(obs_space.shape))
        self.s2 = tf.placeholder(tf.float32, [None] + list(obs_space.shape))
        self.asample = tf.placeholder(tf.float32, [None, act_space.flat_dim])

        self.external_rewards = tf.placeholder(tf.float32, (None, ))
        self.contact_rewards = tf.placeholder(tf.float32, (None, ))

        if len(obs_space.shape) == 1:
            if no_encoder:
                self._encoder = NoEncoder(obs_space.flat_dim,
                                          env_spec=env.spec)
            else:
                self._encoder = FullyConnectedEncoder(feature_dim,
                                                      env_spec=env.spec)
        else:
            # TODO: implement conv encoder
            raise NotImplementedError(
                "Currently only supports flat observation input!")

        self._encoder.sess = self.sess
        # Initialize variables for get_copy to work
        self.sess.run(tf.initialize_all_variables())
        with self.sess.as_default():
            self.encoder1 = self._encoder.get_weight_tied_copy(
                observation_input=self.s1)
            self.encoder2 = self._encoder.get_weight_tied_copy(
                observation_input=self.s2)

        if self.inverse_tanh:
            self._inverse_model = InverseModel(feature_dim, env_spec=env.spec)
        else:
            self._inverse_model = InverseModel(feature_dim,
                                               env_spec=env.spec,
                                               output_activation=None)
        self._forward_model = ForwardModel(feature_dim, env_spec=env.spec)
        self._inverse_model.sess = self.sess
        self._forward_model.sess = self.sess
        # Initialize variables for get_copy to work
        self.sess.run(tf.initialize_all_variables())
        with self.sess.as_default():
            self.inverse_model = self._inverse_model.get_weight_tied_copy(
                feature_input1=self.encoder1.output,
                feature_input2=self.encoder2.output)
            self.forward_model = self._forward_model.get_weight_tied_copy(
                feature_input=self.encoder1.output, action_input=self.asample)

        # Define losses
        if self.forward_cos:
            self.forward_loss = cosine_loss(self.encoder2.output,
                                            self.forward_model.output)
        else:
            self.forward_loss = tf.reduce_mean(
                tf.square(self.encoder2.output - self.forward_model.output))
        # self.forward_loss = tf.nn.l2_loss(self.encoder2.output - self.forward_model.output)
        if isinstance(act_space, Box):
            self.inverse_loss = tf.reduce_mean(
                tf.square(self.asample - self.inverse_model.output))
        elif isinstance(act_space, Discrete):
            # TODO: Implement softmax loss
            raise NotImplementedError
        else:
            raise NotImplementedError
        if self.forward_cos:
            self.internal_rewards = 1.0 - tf.reduce_sum(
                tf.multiply(tf.nn.l2_normalize(self.forward_model.output, 1),
                            tf.nn.l2_normalize(self.encoder2.output, 1)), 1)
        else:
            self.internal_rewards = tf.reduce_sum(
                tf.square(self.encoder2.output - self.forward_model.output),
                axis=1)
        self.mean_internal_rewards = tf.reduce_mean(self.internal_rewards)
        self.mean_external_rewards = tf.reduce_mean(self.external_rewards)
        self.mean_contact_rewards = tf.reduce_mean(self.contact_rewards)

        self.total_loss = forward_weight * self.forward_loss + \
            (1. - forward_weight) * self.inverse_loss
        self.icm_opt = tf.train.AdamOptimizer(init_learning_rate).\
            minimize(self.total_loss)

        # Setup summaries
        inverse_loss_summ = tf.summary.scalar("icm_inverse_loss",
                                              self.inverse_loss)
        forward_loss_summ = tf.summary.scalar("icm_forward_loss",
                                              self.forward_loss)
        total_loss_summ = tf.summary.scalar("icm_total_loss", self.total_loss)
        internal_rewards = tf.summary.scalar("mean_internal_rewards",
                                             self.mean_internal_rewards)
        external_rewards = tf.summary.scalar("mean_external_rewards",
                                             self.mean_external_rewards)
        # Setup env_info logs
        contact_summ = tf.summary.scalar("mean_contact_rewards",
                                         self.mean_contact_rewards)

        var_summ = []
        # for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
        # 	var_summ.append(tf.summary.histogram(var.op.name, var))
        if self.debug_log_weights:
            for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
                var_summ.append(tf.summary.histogram(var.op.name, var))

        self.summary_icm = tf.summary.merge(
            [inverse_loss_summ, forward_loss_summ, total_loss_summ] + var_summ)
        self.summary_env = tf.summary.merge(
            [internal_rewards, external_rewards, contact_summ])
        # self.summary = tf.summary.merge([inverse_loss_summ, forward_loss_summ, total_loss_summ] + var_summ)
        # Initialize variables

        self.sess.run(tf.initialize_all_variables())