def _build_graph(self, network): self.feature_net = FeatureNet(self.state_shape, network, self.special.get("feature_get", None)) self.hidden_state = LinearHiddenState( self.feature_net.feature_state, self.special.get("hidden_size", 512), self.special.get("hidden_activation", tf.nn.elu)) self.policy_net = PolicyNet(self.hidden_state.state, self.n_actions, self.special.get("policy_net", None)) self.value_net = ValueNet(self.hidden_state.state, self.special.get("value_net", None)) build_model_optimization( self.policy_net, self.special.get("policy_net_optimization", None)) build_model_optimization( self.value_net, self.special.get("value_net_optimization", None)) build_model_optimization( self.hidden_state, self.special.get("hidden_state_optimization", None), loss=0.5 * (self.policy_net.loss + self.value_net.loss)) build_model_optimization( self.feature_net, self.special.get("feature_net_optimization", None), loss=0.5 * (self.policy_net.loss + self.value_net.loss))
def _build_graph(self, network): self.feature_net = FeatureNet( self.state_shape, network, self.special.get("feature_net", {})) self.hidden_state = LinearHiddenState( self.feature_net.feature_state, self.special.get("hidden_size", 512), self.special.get("hidden_activation", tf.nn.elu)) if self.special.get("dueling_network", False): self.qvalue_net = QvalueNet( self.hidden_state.state, self.n_actions, dict(**self.special.get("qvalue_net", {}), **{"advantage": True})) self.value_net = ValueNet( self.hidden_state.state, self.special.get("value_net", {})) # a bit hacky way self.predicted_qvalues = self.value_net.predicted_values + \ self.qvalue_net.predicted_qvalues self.predicted_qvalues_for_action = self.value_net.predicted_values_for_actions + \ self.qvalue_net.predicted_qvalues_for_actions self.agent_loss = tf.losses.mean_squared_error( labels=self.qvalue_net.td_target, predictions=self.predicted_qvalues_for_action) build_model_optimization( self.value_net, self.special.get("value_net_optimization", None), loss=self.agent_loss) else: self.qvalue_net = QvalueNet( self.hidden_state.state, self.n_actions, self.special.get("qvalue_net", {})) self.predicted_qvalues = self.qvalue_net.predicted_qvalues self.predicted_qvalues_for_action = self.qvalue_net.predicted_qvalues_for_actions self.agent_loss = self.qvalue_net.loss build_model_optimization( self.qvalue_net, self.special.get("qvalue_net_optimization", None)) build_model_optimization( self.hidden_state, self.special.get("hidden_state_optimization", None), loss=self.agent_loss) build_model_optimization( self.feature_net, self.special.get("feature_net_optimization", None), loss=self.agent_loss)
def __init__(self, vocab_size, embedding_size, encoder_args, decoder_args, embeddings_optimization_args=None, encoder_optimization_args=None, decoder_optimization_args=None): self.embeddings = Embeddings( vocab_size, embedding_size, scope="embeddings") self.encoder = DynamicRnnEncoder( embedding_matrix=self.embeddings.embedding_matrix, **encoder_args) self.decoder = DynamicRnnDecoder( encoder_state=self.encoder.state, encoder_outputs=self.encoder.outputs, encoder_inputs_length=self.encoder.inputs_length, embedding_matrix=self.embeddings.embedding_matrix, **decoder_args) build_model_optimization(self.encoder, encoder_optimization_args, self.decoder.loss) build_model_optimization(self.decoder, decoder_optimization_args) build_model_optimization(self.embeddings, embeddings_optimization_args, self.decoder.loss)