Exemplo n.º 1
0
    def _build_graph(self, network):
        self.feature_net = FeatureNet(self.state_shape, network,
                                      self.special.get("feature_get", None))

        self.hidden_state = LinearHiddenState(
            self.feature_net.feature_state,
            self.special.get("hidden_size", 512),
            self.special.get("hidden_activation", tf.nn.elu))

        self.policy_net = PolicyNet(self.hidden_state.state, self.n_actions,
                                    self.special.get("policy_net", None))
        self.value_net = ValueNet(self.hidden_state.state,
                                  self.special.get("value_net", None))

        build_model_optimization(
            self.policy_net, self.special.get("policy_net_optimization", None))
        build_model_optimization(
            self.value_net, self.special.get("value_net_optimization", None))
        build_model_optimization(
            self.hidden_state,
            self.special.get("hidden_state_optimization", None),
            loss=0.5 * (self.policy_net.loss + self.value_net.loss))
        build_model_optimization(
            self.feature_net,
            self.special.get("feature_net_optimization", None),
            loss=0.5 * (self.policy_net.loss + self.value_net.loss))
Exemplo n.º 2
0
    def _build_graph(self, network):
        self.feature_net = FeatureNet(
            self.state_shape, network,
            self.special.get("feature_net", {}))

        self.hidden_state = LinearHiddenState(
            self.feature_net.feature_state,
            self.special.get("hidden_size", 512),
            self.special.get("hidden_activation", tf.nn.elu))

        if self.special.get("dueling_network", False):
            self.qvalue_net = QvalueNet(
                self.hidden_state.state, self.n_actions,
                dict(**self.special.get("qvalue_net", {}), **{"advantage": True}))
            self.value_net = ValueNet(
                self.hidden_state.state,
                self.special.get("value_net", {}))

            # a bit hacky way
            self.predicted_qvalues = self.value_net.predicted_values + \
                                     self.qvalue_net.predicted_qvalues
            self.predicted_qvalues_for_action = self.value_net.predicted_values_for_actions + \
                                                self.qvalue_net.predicted_qvalues_for_actions
            self.agent_loss = tf.losses.mean_squared_error(
                labels=self.qvalue_net.td_target,
                predictions=self.predicted_qvalues_for_action)

            build_model_optimization(
                self.value_net,
                self.special.get("value_net_optimization", None),
                loss=self.agent_loss)
        else:
            self.qvalue_net = QvalueNet(
                self.hidden_state.state, self.n_actions,
                self.special.get("qvalue_net", {}))
            self.predicted_qvalues = self.qvalue_net.predicted_qvalues
            self.predicted_qvalues_for_action = self.qvalue_net.predicted_qvalues_for_actions
            self.agent_loss = self.qvalue_net.loss

        build_model_optimization(
            self.qvalue_net,
            self.special.get("qvalue_net_optimization", None))

        build_model_optimization(
            self.hidden_state,
            self.special.get("hidden_state_optimization", None),
            loss=self.agent_loss)
        build_model_optimization(
            self.feature_net,
            self.special.get("feature_net_optimization", None),
            loss=self.agent_loss)
Exemplo n.º 3
0
    def __init__(self,
                 vocab_size, embedding_size,
                 encoder_args, decoder_args,
                 embeddings_optimization_args=None,
                 encoder_optimization_args=None,
                 decoder_optimization_args=None):
        self.embeddings = Embeddings(
            vocab_size,
            embedding_size,
            scope="embeddings")

        self.encoder = DynamicRnnEncoder(
            embedding_matrix=self.embeddings.embedding_matrix,
            **encoder_args)

        self.decoder = DynamicRnnDecoder(
            encoder_state=self.encoder.state,
            encoder_outputs=self.encoder.outputs,
            encoder_inputs_length=self.encoder.inputs_length,
            embedding_matrix=self.embeddings.embedding_matrix,
            **decoder_args)

        build_model_optimization(self.encoder, encoder_optimization_args, self.decoder.loss)
        build_model_optimization(self.decoder, decoder_optimization_args)
        build_model_optimization(self.embeddings, embeddings_optimization_args, self.decoder.loss)