Example #1
0
    def _get_select_action_probs(self, pi, selected_spatial_action_flat):
        action_id = select_from_each_row(
            pi.action_id_log_probs, self.placeholders.selected_action_id
        )
        spatial = select_from_each_row(
            pi.spatial_action_log_probs, selected_spatial_action_flat
        )
        total = spatial + action_id

        return SelectedLogProbs(action_id, spatial, total)
Example #2
0
    def get_selected_action_probability(self, theta, selected_spatial_action):
        """get_selected_action_probability

        :param theta: The current state of the policy.
        :param selected_spatial_action: The current spatial action to evaluate.
        """

        action_id = select_from_each_row(theta.action_id_log_probs,
                                         self.placeholders.selected_action_id)

        spatial_coord = select_from_each_row(theta.spatial_action_log_probs,
                                             selected_spatial_action)

        total = spatial_coord + action_id

        return SelectedLogProbs(action_id, spatial_coord, total)
Example #3
0
    def _get_select_action_probs(self, pi):
        action_id = select_from_each_row(pi.action_id_log_probs,
                                         self.placeholders.selected_action_id)

        total = action_id

        return SelectedLogProbs(action_id, total)
Example #4
0
    def test_select_from_each_row(self):
        x = np.random.rand(4, 5)
        col_idx = np.array([0, 0, 1, 2])
        with self.test_session():
            selection = select_from_each_row(x, col_idx).eval()

        assert selection.shape == (4,)

        for i, s in enumerate(selection):
            assert s == x[i, col_idx[i]]
Example #5
0
    def build_model(self):
        self._define_input_placeholders()

        spatial_action_probs, action_id_probs, value_estimate = \
            self._build_fullyconv_network()

        selected_spatial_action_flat = ravel_index_pairs(
            self.ph_selected_spatial_action, self.spatial_dim
        )

        def logclip(x):
            return tf.log(tf.clip_by_value(x, 1e-12, 1.0))

        spatial_action_log_probs = (
            logclip(spatial_action_probs)
            * tf.expand_dims(self.ph_is_spatial_action_available, axis=1)
        )

        # non-available actions get log(1e-10) value but that's ok because it's never used        
        action_id_log_probs = logclip(action_id_probs)

        selected_spatial_action_log_prob = select_from_each_row(
            spatial_action_log_probs, selected_spatial_action_flat
        )
        selected_action_id_log_prob = select_from_each_row(
            action_id_log_probs, self.ph_selected_action_id
        )
        selected_action_total_log_prob = (
            selected_spatial_action_log_prob
            + selected_action_id_log_prob
        )

        # maximum is to avoid 0 / 0 because this is used to calculate some means
        sum_spatial_action_available = tf.maximum(
            1e-10, tf.reduce_sum(self.ph_is_spatial_action_available)
        )
        neg_entropy_spatial = tf.reduce_sum(
            spatial_action_probs * spatial_action_log_probs
        ) / sum_spatial_action_available
        neg_entropy_action_id = tf.reduce_mean(tf.reduce_sum(
            action_id_probs * action_id_log_probs, axis=1
        ))

        advantage = tf.stop_gradient(self.ph_value_target - value_estimate)
        policy_loss = -tf.reduce_mean(selected_action_total_log_prob * advantage)
        value_loss = tf.losses.mean_squared_error(self.ph_value_target, value_estimate)

        loss = (
            policy_loss
            + value_loss * self.loss_value_weight
            + neg_entropy_spatial * self.entropy_weight_spatial
            + neg_entropy_action_id * self.entropy_weight_action_id
        )

        scalar_summary_collection_name = "scalar_summaries"
        s_collections = [scalar_summary_collection_name, tf.GraphKeys.SUMMARIES]
        tf.summary.scalar("loss/policy", policy_loss, collections=s_collections)
        tf.summary.scalar("loss/value", value_loss, s_collections)
        tf.summary.scalar("loss/neg_entropy_spatial", neg_entropy_spatial, s_collections)
        tf.summary.scalar("loss/neg_entropy_action_id", neg_entropy_action_id, s_collections)
        tf.summary.scalar("loss/total", loss, s_collections)
        tf.summary.scalar("value/advantage", tf.reduce_mean(advantage), s_collections)
        tf.summary.scalar("value/estimate", tf.reduce_mean(value_estimate), s_collections)
        tf.summary.scalar("value/target", tf.reduce_mean(self.ph_value_target), s_collections)
        tf.summary.scalar("action/is_spatial_action_available",
            tf.reduce_mean(self.ph_is_spatial_action_available), s_collections)
        tf.summary.scalar("action/is_spatial_action_available",
            tf.reduce_mean(self.ph_is_spatial_action_available), s_collections)
        tf.summary.scalar("action/selected_id_log_prob",
            tf.reduce_mean(selected_action_id_log_prob))
        tf.summary.scalar("action/selected_total_log_prob",
            tf.reduce_mean(selected_action_total_log_prob))
        tf.summary.scalar("action/selected_spatial_log_prob",
            tf.reduce_sum(selected_spatial_action_log_prob) / sum_spatial_action_available
        )

        self.sampled_action_id = weighted_random_sample(action_id_probs)
        self.sampled_spatial_action = weighted_random_sample(spatial_action_probs)
        self.value_estimate = value_estimate

        self.train_op = layers.optimize_loss(
            loss=loss,
            global_step=framework.get_global_step(),
            optimizer=self.optimiser,
            clip_gradients=self.max_gradient_norm,
            summaries=OPTIMIZER_SUMMARIES,
            learning_rate=None,
            name="train_op"
        )

        self.init_op = tf.global_variables_initializer()
        self.saver = tf.train.Saver(max_to_keep=2)
        self.all_summary_op = tf.summary.merge_all(tf.GraphKeys.SUMMARIES)
        self.scalar_summary_op = tf.summary.merge(tf.get_collection(scalar_summary_collection_name))