コード例 #1
0
class CustomTFRPGModel(TFModelV2):
    """Example of interpreting repeated observations."""
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        super().__init__(obs_space, action_space, num_outputs, model_config,
                         name)
        self.model = TFFCNet(obs_space, action_space, num_outputs,
                             model_config, name)
        self.register_variables(self.model.variables())

    def forward(self, input_dict, state, seq_lens):
        # The unpacked input tensors, where M=MAX_PLAYERS, N=MAX_ITEMS:
        # {
        #   'items', <tf.Tensor shape=(?, M, N, 5)>,
        #   'location', <tf.Tensor shape=(?, M, 2)>,
        #   'status', <tf.Tensor shape=(?, M, 10)>,
        # }
        print("The unpacked input tensors:", input_dict["obs"])
        print()
        print("Unbatched repeat dim", input_dict["obs"].unbatch_repeat_dim())
        print()
        if tf.executing_eagerly():
            print("Fully unbatched", input_dict["obs"].unbatch_all())
            print()
        return self.model.forward(input_dict, state, seq_lens)

    def value_function(self):
        return self.model.value_function()
コード例 #2
0
class MaskedActionsMLP(DistributionalQModel, TFModelV2):
    """Tensorflow model for Envs that provide action masks with observations."""
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name, **kwargs):
        super().__init__(obs_space, action_space, num_outputs, model_config,
                         name, **kwargs)

        # DictFlatteningPreprocessor, combines all obs components together
        # obs.shape for MLP should be a flattened game board obs
        original_space = obs_space.original_space['board']
        flat_obs_space = spaces.Box(low=np.min(original_space.low),
                                    high=np.max(original_space.high),
                                    shape=(np.prod(original_space.shape), ))
        self.mlp = FullyConnectedNetwork(flat_obs_space, action_space,
                                         num_outputs, model_config, name)
        self.register_variables(self.mlp.variables())

    def forward(self, input_dict, state, seq_lens):
        obs = flatten(input_dict['obs']['board'])
        action_mask = tf.maximum(tf.log(input_dict['obs']['action_mask']),
                                 tf.float32.min)
        model_out, _ = self.mlp({'obs': obs})
        return action_mask + model_out, state

    def value_function(self):
        return self.mlp.value_function()
コード例 #3
0
ファイル: rl_utils.py プロジェクト: grossmann-group/or-gym
class VMActionMaskModel(TFModelV2):
    def __init__(self,
                 obs_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name,
                 true_obs_shape=(51, 3),
                 action_embed_size=50,
                 *args,
                 **kwargs):
        super(VMActionMaskModel,
              self).__init__(obs_space, action_space, num_outputs,
                             model_config, name, *args, **kwargs)
        self.action_embed_model = FullyConnectedNetwork(
            spaces.Box(0, 1, shape=true_obs_shape), action_space,
            action_embed_size, model_config, name + "_action_embedding")
        self.register_variables(self.action_embed_model.variables())

    def forward(self, input_dict, state, seq_lens):
        avail_actions = input_dict["obs"]["avail_actions"]
        action_mask = input_dict["obs"]["action_mask"]
        action_embedding, _ = self.action_embed_model(
            {"obs": input_dict["obs"]["state"]})
        intent_vector = tf.expand_dims(action_embedding, 1)
        action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=1)
        inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min)
        return action_logits + inf_mask, state

    def value_function(self):
        return self.action_embed_model.value_function()
コード例 #4
0
ファイル: centralized_critic.py プロジェクト: zuoxiaolei/ray
class CentralizedCriticModel(TFModelV2):
    """Multi-agent model that implements a centralized VF."""

    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        super(CentralizedCriticModel, self).__init__(
            obs_space, action_space, num_outputs, model_config, name)
        # Base of the model
        self.model = FullyConnectedNetwork(obs_space, action_space,
                                           num_outputs, model_config, name)
        self.register_variables(self.model.variables())

        # Central VF maps (obs, opp_ops, opp_act) -> vf_pred
        obs = tf.keras.layers.Input(shape=(6, ), name="obs")
        opp_obs = tf.keras.layers.Input(shape=(6, ), name="opp_obs")
        opp_act = tf.keras.layers.Input(shape=(2, ), name="opp_act")
        concat_obs = tf.keras.layers.Concatenate(axis=1)(
            [obs, opp_obs, opp_act])
        central_vf_dense = tf.keras.layers.Dense(
            16, activation=tf.nn.tanh, name="c_vf_dense")(concat_obs)
        central_vf_out = tf.keras.layers.Dense(
            1, activation=None, name="c_vf_out")(central_vf_dense)
        self.central_vf = tf.keras.Model(
            inputs=[obs, opp_obs, opp_act], outputs=central_vf_out)
        self.register_variables(self.central_vf.variables)

    def forward(self, input_dict, state, seq_lens):
        return self.model.forward(input_dict, state, seq_lens)

    def central_value_function(self, obs, opponent_obs, opponent_actions):
        return tf.reshape(
            self.central_vf(
                [obs, opponent_obs,
                 tf.one_hot(opponent_actions, 2)]), [-1])
コード例 #5
0
class CentralizedCriticModel(TFModelV2):
    """Multi-agent model that implements a centralized VF.

    It assumes the observation is a dict with 'own_obs' and 'opponent_obs', the
    former of which can be used for computing actions (i.e., decentralized
    execution), and the latter for optimization (i.e., centralized learning).

    This model has two parts:
    - An action model that looks at just 'own_obs' to compute actions
    - A value model that also looks at the 'opponent_obs' / 'opponent_action'
      to compute the value (it does this by using the 'obs_flat' tensor).
    """
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        super(CentralizedCriticModel,
              self).__init__(obs_space, action_space, num_outputs,
                             model_config, name)

        self.action_model = FullyConnectedNetwork(
            Box(low=0, high=1, shape=(6, )),  # one-hot encoded Discrete(6)
            action_space,
            num_outputs,
            model_config,
            name + "_action",
        )
        self.register_variables(self.action_model.variables())

        self.value_model = FullyConnectedNetwork(
            obs_space,
            action_space,
            1,
            model_config,
            name + "_vf",
        )
        self.register_variables(self.value_model.variables())

    def forward(self, input_dict, state, seq_lens):
        self._value_out, _ = self.value_model({"obs": input_dict["obs_flat"]},
                                              state, seq_lens)
        return self.action_model({"obs": input_dict["obs"]["own_obs"]}, state,
                                 seq_lens)

    def value_function(self):
        return tf.reshape(self._value_out, [-1])
コード例 #6
0
class ParametricActionsModel(DistributionalQTFModel):
    """Parametric action model that handles the dot product and masking.
    This assumes the outputs are logits for a single Categorical action dist.
    Getting this to work with a more complex output (e.g., if the action space
    is a tuple of several distributions) is also possible but left as an
    exercise to the reader.
    """

    def __init__(self,
                 obs_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name,
                 true_obs_shape=(4, ),
                 action_embed_size=6,
                 **kw):
        super(ParametricActionsModel, self).__init__(
            obs_space, action_space, num_outputs, model_config, name, **kw)
        if model_config['custom_options']['spy']:
            true_obs_space = make_spy_space(model_config['custom_options']['parties'], model_config['custom_options']['blocks'])
        else:
            true_obs_space = make_blind_space(model_config['custom_options']['parties'], model_config['custom_options']['blocks'])
        if model_config['custom_options']['extended']:
            action_embed_size = 6
        else:
            action_embed_size = 4
        total_dim = 0
        for space in true_obs_space:
            total_dim += get_preprocessor(space)(space).size
        self.action_embed_model = FullyConnectedNetwork(
            Box(-1, 1, shape = (total_dim,)), action_space, action_embed_size,
            model_config, name + "_action_embed")
        self.register_variables(self.action_embed_model.variables())

    def forward(self, input_dict, state, seq_lens):
        # Extract the available actions tensor from the observation.        
        avail_actions = input_dict["obs"]["avail_actions"]
        action_mask = input_dict["obs"]["action_mask"]
        # Compute the predicted action embedding
        action_embed, _ = self.action_embed_model({
            "obs": input_dict["obs"]["bitcoin"]
        })

        # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the
        # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE].
        intent_vector = tf.expand_dims(action_embed, 1)

        # Batch dot product => shape of logits is [BATCH, MAX_ACTIONS].
        action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=2)

        # Mask out invalid actions (use tf.float32.min for stability)
        inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min)
        return action_logits + inf_mask, state
    def value_function(self):
        return self.action_embed_model.value_function()
コード例 #7
0
class CustomLossModel(TFModelV2):
    """Custom model that adds an imitation loss on top of the policy loss."""

    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        super().__init__(obs_space, action_space, num_outputs, model_config,
                         name)

        self.fcnet = FullyConnectedNetwork(
            self.obs_space,
            self.action_space,
            num_outputs,
            model_config,
            name="fcnet")
        self.register_variables(self.fcnet.variables())

    @override(ModelV2)
    def forward(self, input_dict, state, seq_lens):
        # Delegate to our FCNet.
        return self.fcnet(input_dict, state, seq_lens)

    @override(ModelV2)
    def custom_loss(self, policy_loss, loss_inputs):
        # Create a new input reader per worker.
        reader = JsonReader(
            self.model_config["custom_model_config"]["input_files"])
        input_ops = reader.tf_input_ops()

        # Define a secondary loss by building a graph copy with weight sharing.
        obs = restore_original_dimensions(
            tf.cast(input_ops["obs"], tf.float32), self.obs_space)
        logits, _ = self.forward({"obs": obs}, [], None)

        # You can also add self-supervised losses easily by referencing tensors
        # created during _build_layers_v2(). For example, an autoencoder-style
        # loss can be added as follows:
        # ae_loss = squared_diff(
        #     loss_inputs["obs"], Decoder(self.fcnet.last_layer))
        print("FYI: You can also use these tensors: {}, ".format(loss_inputs))

        # Compute the IL loss.
        action_dist = Categorical(logits, self.model_config)
        self.policy_loss = policy_loss
        self.imitation_loss = tf.reduce_mean(
            -action_dist.logp(input_ops["actions"]))
        return policy_loss + 10 * self.imitation_loss

    def custom_stats(self):
        return {
            "policy_loss": self.policy_loss,
            "imitation_loss": self.imitation_loss,
        }
コード例 #8
0
class CustomModel(TFModelV2):
    """Example of a custom model that just delegates to a fc-net."""
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        super(CustomModel, self).__init__(obs_space, action_space, num_outputs,
                                          model_config, name)
        self.model = FullyConnectedNetwork(obs_space, action_space,
                                           num_outputs, model_config, name)
        self.register_variables(self.model.variables())

    def forward(self, input_dict, state, seq_lens):
        return self.model.forward(input_dict, state, seq_lens)

    def value_function(self):
        return self.model.value_function()
コード例 #9
0
ファイル: rl_utils.py プロジェクト: grossmann-group/or-gym
class FCModel(TFModelV2):
    '''Fully Connected Model'''
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        super(FCModel, self).__init__(obs_space, action_space, num_outputs,
                                      model_config, name)
        self.model = FullyConnectedNetwork(obs_space, action_space,
                                           num_outputs, model_config, name)
        self.register_variables(self.model.variables())

    def forward(self, input_dict, state, seq_lens):
        return self.model.forward(input_dict, state, seq_lens)

    def value_function(self):
        return self.model.value_function()
コード例 #10
0
class HanabiFullyConnected(LegalActionsDistributionalQModel):
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name, **kwargs):
        super(HanabiFullyConnected,
              self).__init__(obs_space, action_space, num_outputs,
                             model_config, name, **kwargs)
        self.fc = FullyConnectedNetwork(obs_space.original_space["board"],
                                        action_space, num_outputs,
                                        model_config, name + "fc")
        self.register_variables(self.fc.variables())

    def forward(self, input_dict, state, seq_lens):
        model_out, state = self.fc({"obs": input_dict["obs"]["board"]}, state,
                                   seq_lens)
        self.calculate_and_store_q(input_dict, model_out)
        return model_out, state
コード例 #11
0
class ParametricActionsModel(TFModelV2):
    """ Parametric model that handles varying action spaces"""
    def __init__(self,
                 obs_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name,
                 true_obs_shape=(24, ),
                 action_embed_size=None):
        super(ParametricActionsModel,
              self).__init__(obs_space, action_space, num_outputs,
                             model_config, name)

        if action_embed_size is None:
            action_embed_size = action_space.n  # this works for Discrete() action

        # we get the size of the output of the preprocessor automatically chosen by rllib for the real_obs space
        real_obs = obs_space.original_space['real_obs']
        true_obs_shape = get_preprocessor(real_obs)(
            real_obs).size  # this will we an integer
        # true_obs_shape = obs_space.original_space['real_obs']
        self.action_embed_model = FullyConnectedNetwork(
            obs_space=Box(-1, 1, shape=(true_obs_shape, )),
            action_space=action_space,
            num_outputs=action_embed_size,
            model_config=model_config,
            name=name + "_action_embed")
        self.base_model = self.action_embed_model.base_model
        self.register_variables(self.action_embed_model.variables())

    def forward(self, input_dict, state, seq_lens):
        # Compute the predicted action probabilties
        # input_dict["obs"]["real_obs"] is a list of 1d tensors if the observation space is a Tuple while
        # it should be a tensor. When it is a list we concatenate the various 1d tensors
        obs_concat = input_dict["obs"]["real_obs"]
        if isinstance(obs_concat, list):
            obs_concat = tf.concat(values=flatten_list(obs_concat), axis=1)
        action_embed, _ = self.action_embed_model({"obs": obs_concat})

        # Mask out invalid actions (use tf.float32.min for stability)
        action_mask = input_dict["obs"]["action_mask"]
        inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
        return action_embed + inf_mask, state

    def value_function(self):
        return self.action_embed_model.value_function()
コード例 #12
0
class CentralizedCriticModel(TFModelV2):
    """Multi-agent model that implements a centralized VF."""

    # TODO(@evinitsky) make this work with more than boxes

    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        super(CentralizedCriticModel,
              self).__init__(obs_space, action_space, num_outputs,
                             model_config, name)
        # Base of the model
        self.model = FullyConnectedNetwork(obs_space, action_space,
                                           num_outputs, model_config, name)
        self.register_variables(self.model.variables())

        # Central VF maps (obs, opp_ops, opp_act) -> vf_pred
        self.max_num_agents = model_config['custom_options']['max_num_agents']
        self.obs_space_shape = obs_space.shape[0]
        other_obs = tf.keras.layers.Input(shape=(obs_space.shape[0] *
                                                 self.max_num_agents, ),
                                          name="opp_obs")
        central_vf_dense = tf.keras.layers.Dense(
            model_config['custom_options']['central_vf_size'],
            activation=tf.nn.tanh,
            name="c_vf_dense")(other_obs)
        central_vf_out = tf.keras.layers.Dense(
            1, activation=None, name="c_vf_out")(central_vf_dense)
        self.central_vf = tf.keras.Model(inputs=[other_obs],
                                         outputs=central_vf_out)
        self.register_variables(self.central_vf.variables)

    def forward(self, input_dict, state, seq_lens):
        return self.model.forward(input_dict, state, seq_lens)

    def central_value_function(self, obs, opponent_obs):
        return tf.reshape(self.central_vf([opponent_obs]), [-1])

    def value_function(self):
        return self.model.value_function()  # not used
コード例 #13
0
ファイル: PaModel.py プロジェクト: nicofirst1/rl_werewolf
class ParametricActionsModel(TFModelV2):
    """
    Parametric action model used to filter out invalid action from environment
    """
    def import_from_h5(self, h5_file):
        pass

    def __init__(
        self,
        obs_space,
        action_space,
        num_outputs,
        model_config,
        name,
    ):
        name = "Pa_model"
        super(ParametricActionsModel,
              self).__init__(obs_space, action_space, num_outputs,
                             model_config, name)

        # get real obs space, discarding action mask
        real_obs_space = obs_space.original_space.spaces['array_obs']

        # define action embed model
        self.action_embed_model = FullyConnectedNetwork(
            real_obs_space, action_space, num_outputs, model_config,
            name + "_action_embed")
        self.register_variables(self.action_embed_model.variables())

    def forward(self, input_dict, state, seq_lens):
        """
        Override forward pass to mask out invalid actions

               Arguments:
                   input_dict (dict): dictionary of input tensors, including "obs",
                       "obs_flat", "prev_action", "prev_reward", "is_training"
                   state (list): list of state tensors with sizes matching those
                       returned by get_initial_state + the batch dimension
                   seq_lens (Tensor): 1d tensor holding input sequence lengths

               Returns:
                   (outputs, state): The model output tensor of size
                       [BATCH, num_outputs]

               """
        obs = input_dict['obs']

        # extract action mask  [batch size, num players]
        action_mask = obs['action_mask']
        # extract original observations [batch size, obs size]
        array_obs = obs['array_obs']

        # Compute the predicted action embedding
        # size [batch size, num players * num players]
        action_embed, _ = self.action_embed_model({"obs": array_obs})

        # Mask out invalid actions (use tf.float32.min for stability)
        # size [batch size, num players * num players]
        inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min)
        inf_mask = tf.cast(inf_mask, tf.float32)

        masked_actions = action_embed + inf_mask

        # return masked action embed and state
        return masked_actions, state

    def value_function(self):
        return self.action_embed_model.value_function()
コード例 #14
0
class HanabiHandInference(LegalActionsDistributionalQModel):

    def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kwargs):
        super(HanabiHandInference, self).__init__(obs_space, action_space,
                                                         model_config["custom_options"]["q_module_hiddens"][-1],
                                                         model_config, name,
                                                         **kwargs)
        self.obs_module = FullyConnectedNetwork(obs_space.original_space["board"],
                                                None,
                                                model_config["custom_options"]["obs_module_hiddens"][-1],
                                                {
                                                    "fcnet_activation": model_config["fcnet_activation"],
                                                    "fcnet_hiddens": model_config["custom_options"][
                                                        "obs_module_hiddens"],
                                                    "no_final_linear": True,
                                                    "vf_share_layers": True},
                                                name + "obs_module")

        obs_module_output_dummy = numpy.zeros(model_config["custom_options"]["obs_module_hiddens"][-1])
        self.q_module = FullyConnectedNetwork(obs_module_output_dummy, None,
                                              model_config["custom_options"]["q_module_hiddens"][-1],
                                              {"fcnet_activation": model_config["fcnet_activation"],
                                               "fcnet_hiddens": model_config["custom_options"]["q_module_hiddens"],
                                               "no_final_linear": True,
                                               "vf_share_layers": True},
                                              name + "q_module")

        self.aux_module = FullyConnectedNetwork(obs_module_output_dummy, None,
                                                model_config["custom_options"]["aux_module_hiddens"][-1],
                                                {"fcnet_activation": model_config["fcnet_activation"],
                                                    "fcnet_hiddens": model_config["custom_options"][
                                                        "aux_module_hiddens"],
                                                    "no_final_linear": True,
                                                    "vf_share_layers": True},
                                                name + "aux_module")

        aux_head_input_dummy = numpy.zeros(model_config["custom_options"]["aux_module_hiddens"][-1])
        self.aux_head = FullyConnectedNetwork(aux_head_input_dummy, None,
                                              numpy.prod(obs_space.original_space["hidden_hand"].shape),
                                              {"fcnet_activation": model_config["fcnet_activation"],
                                                  "fcnet_hiddens": model_config["custom_options"][
                                                      "aux_head_hiddens"],
                                                  "no_final_linear": False,
                                                  "vf_share_layers": True},
                                              name + "aux_head")
        self.register_variables(self.obs_module.variables())
        self.register_variables(self.q_module.variables())
        self.register_variables(self.aux_module.variables())
        self.register_variables(self.aux_head.variables())
        self.aux_loss_formula = get_aux_loss_formula(model_config["custom_options"].get("aux_loss_formula", "sqrt"))

    def forward(self, input_dict, state, seq_lens):
        obs_module_out, state_1 = self.obs_module({"obs": input_dict["obs"]["board"]}, state, seq_lens)
        q_module_out, state_2 = self.q_module({"obs": obs_module_out}, state_1, seq_lens)
        aux_module_out, state_3 = self.aux_module({"obs": obs_module_out}, state_1, seq_lens)

        model_out = tf.multiply(q_module_out, tf.stop_gradient(aux_module_out))

        self.calculate_and_store_q(input_dict, model_out)

        return model_out, state_2

    def extra_loss(self, policy_loss, loss_inputs, stats):
        obs = restore_original_dimensions(loss_inputs["obs"], self.obs_space, self.framework)["board"]
        hidden_hand = restore_original_dimensions(loss_inputs["obs"], self.obs_space, self.framework)[
            "hidden_hand"]
        hidden_hand = tf.reshape(hidden_hand,[tf.shape(hidden_hand)[0],hidden_hand.shape[1]*hidden_hand.shape[2]])  # reshape so all hands are in one vector
        hidden_hand = tf.math.divide_no_nan(hidden_hand, tf.expand_dims(tf.reduce_sum(hidden_hand, 1), 1))  # normalize so sum of vector is 1
        obs_module_out, state_1 = self.obs_module({"obs": obs}, None, None)
        aux_module_out, state_2 = self.aux_module({"obs": obs_module_out}, state_1, None)
        aux_head_out, _ = self.aux_head({"obs": aux_module_out}, state_2, None)
        cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
            labels=tf.stop_gradient(hidden_hand),
            logits=aux_head_out)
        hand_inference_loss = tf.reduce_mean(cross_entropy)
        combined_loss = self.aux_loss_formula(policy_loss, hand_inference_loss)
        stats.update({
            "combined_loss": combined_loss,
            "hand_inference_loss": hand_inference_loss
        })
        return combined_loss
コード例 #15
0
class CustomLossModel(TFModelV2):
    """Custom model that adds an imitation loss on top of the policy loss."""
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        super().__init__(obs_space, action_space, num_outputs, model_config,
                         name)

        self.fcnet = FullyConnectedNetwork(self.obs_space,
                                           self.action_space,
                                           num_outputs,
                                           model_config,
                                           name="fcnet")
        self.register_variables(self.fcnet.variables())

    @override(ModelV2)
    def forward(self, input_dict, state, seq_lens):
        # Delegate to our FCNet.
        return self.fcnet(input_dict, state, seq_lens)

    @override(ModelV2)
    def custom_loss(self, policy_loss, loss_inputs):
        # create a new input reader per worker
        reader = JsonReader(self.model_config["custom_options"]["input_files"])
        input_ops = reader.tf_input_ops(
            self.model_config["custom_options"].get("expert_size", 1))

        # define a secondary loss by building a graph copy with weight sharing
        obs = restore_original_dimensions(
            tf.cast(input_ops["obs"], tf.float32), self.obs_space)
        logits, _ = self.forward({"obs": obs}, [], None)

        # You can also add self-supervised losses easily by referencing tensors
        # created during _build_layers_v2(). For example, an autoencoder-style
        # loss can be added as follows:
        # ae_loss = squared_diff(
        #     loss_inputs["obs"], Decoder(self.fcnet.last_layer))
        # print("FYI: You can also use these tensors: {}, ".format(loss_inputs))

        # compute the IL loss
        self.policy_loss = policy_loss
        (action_scores, model_logits,
         dist) = self.get_q_value_distributions(logits)
        model_logits = tf.squeeze(model_logits)
        action_dist = Categorical(model_logits, self.model_config)

        expert_logits = tf.cast(input_ops["actions"], tf.int32)
        expert_action = tf.math.argmax(expert_logits)
        expert_action_one_hot = tf.one_hot(expert_action, self.num_outputs)
        model_action = action_dist.deterministic_sample()
        model_action_one_hot = tf.one_hot(model_action, self.num_outputs)
        model_expert = model_action_one_hot * expert_action_one_hot
        imitation_loss = 0
        loss_type = self.model_config["custom_options"].get("loss", "ce")
        if loss_type == "ce":
            imitation_loss = tf.reduce_mean(-action_dist.logp(expert_logits))
        elif loss_type == "kl":
            expert_dist = Categorical(tf.one_hot(expert_logits,\
                self.num_outputs), self.model_config)
            imitation_loss = tf.reduce_mean(-action_dist.kl(expert_dist))
        elif loss_type == "dqfd":
            max_value = float("-inf")
            Q_select = model_logits  #  TODO: difference in action_scores,dist and logits
            for a in range(self.num_outputs):
                max_value = tf.maximum(
                    Q_select[a] + 0.8 * tf.cast(model_expert[a], tf.float32),
                    max_value)
            imitation_loss = tf.reduce_mean(
                1 * (max_value - Q_select[tf.cast(expert_action, tf.int32)]))

        self.imitation_loss = imitation_loss
        total_loss = self.model_config["custom_options"]["lambda1"]*policy_loss\
                     + self.model_config["custom_options"]["lambda2"]\
            * self.imitation_loss
        return total_loss

    def custom_stats(self):
        return {
            "policy_loss": self.policy_loss,
            "imitation_loss": self.imitation_loss,
        }