Exemplo n.º 1
0
    def define_prediction_with_loss(self, node_embeddings):
        node_embeddings = tf.identity(node_embeddings)
        graph_embeddings = self.define_pooling(node_embeddings)
        domain_nodes_embeddings = tf.gather(
            params=node_embeddings, indices=self.placeholders['domain'])
        graph_embeddings_copied = tf.gather(
            params=graph_embeddings,
            indices=self.placeholders['domain_node_graph_ids_list'])

        mlp_domain_nodes = MLP(in_size=domain_nodes_embeddings.get_shape()[1] +
                               graph_embeddings_copied.get_shape()[1],
                               out_size=2,
                               hid_sizes=self.classifier_hidden_dims)

        domain_node_logits = mlp_domain_nodes(
            tf.concat([domain_nodes_embeddings, graph_embeddings_copied],
                      axis=-1))
        individual_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=domain_node_logits,
            labels=self.placeholders['domain_labels'])
        loss = self.ops['loss'] = tf.reduce_mean(individual_loss)
        probs = self.ops['probabilities'] = tf.nn.softmax(domain_node_logits)
        flat_correct_predictions = tf.cast(
            tf.equal(tf.argmax(probs, -1, output_type=tf.int32),
                     self.placeholders['domain_labels']), tf.float32)
        correct_predictions = tf.unsorted_segment_prod(
            data=flat_correct_predictions,
            segment_ids=self.placeholders['domain_node_graph_ids_list'],
            num_segments=self.placeholders['num_graphs'])
        self.ops['accuracy'] = tf.reduce_mean(correct_predictions)
Exemplo n.º 2
0
    def define_prediction_with_loss(self, node_embeddings):
        node_embeddings = tf.identity(node_embeddings)
        graph_embeddings = self.define_pooling(node_embeddings)
        domain_nodes_embeddings = tf.gather(
            params=node_embeddings, indices=self.placeholders['domain'])
        graph_embeddings_copied = tf.gather(
            params=graph_embeddings,
            indices=self.placeholders['domain_node_graph_ids_list'])

        domain_node_score_calculator = MLP(
            in_size=domain_nodes_embeddings.get_shape()[1] +
            graph_embeddings_copied.get_shape()[1],
            out_size=1,
            hid_sizes=self.classifier_hidden_dims)

        domain_node_logits = domain_node_score_calculator(
            tf.concat([domain_nodes_embeddings, graph_embeddings_copied],
                      axis=-1))
        domain_node_logits = tf.reshape(domain_node_logits, [-1])
        probs, log_probs = SegmentBasedSoftmax(
            data=domain_node_logits,
            segment_ids=self.placeholders['domain_node_graph_ids_list'],
            num_segments=self.placeholders['num_graphs'],
            return_log=True)
        self.ops['probabilities'] = probs

        loss_per_domain_node = -tf.cast(self.placeholders['domain_labels'],
                                        tf.float32) * log_probs
        loss_per_graph = tf.unsorted_segment_sum(
            data=loss_per_domain_node,
            segment_ids=self.placeholders['domain_node_graph_ids_list'],
            num_segments=self.placeholders['num_graphs'])
        self.ops['loss'] = tf.reduce_mean(loss_per_graph)

        domain_node_max_scores = tf.unsorted_segment_max(
            data=domain_node_logits,
            segment_ids=self.placeholders['domain_node_graph_ids_list'],
            num_segments=self.placeholders['num_graphs'])
        copied_domain_node_max_scores = tf.gather(
            params=domain_node_max_scores,
            indices=self.placeholders['domain_node_graph_ids_list'])

        selected_domain_nodes = tf.cast(tf.equal(copied_domain_node_max_scores,
                                                 domain_node_logits),
                                        dtype=tf.int32)
        correct_prediction_per_node = tf.cast(
            tf.equal(selected_domain_nodes,
                     self.placeholders['domain_labels']), tf.float32)
        correct_prediction = tf.unsorted_segment_prod(
            data=correct_prediction_per_node,
            segment_ids=self.placeholders['domain_node_graph_ids_list'],
            num_segments=self.placeholders['num_graphs'])

        self.ops['accuracy'] = tf.reduce_mean(correct_prediction)
Exemplo n.º 3
0
    def define_prediction_with_loss(self, node_embeddings):
        node_embeddings = tf.identity(node_embeddings)
        graph_embeddings = self.define_pooling(node_embeddings)
        domain_nodes_embeddings = tf.gather(
            params=node_embeddings, indices=self.placeholders['domain'])
        domain_nodes_pooled = tf.unsorted_segment_sum(
            data=domain_nodes_embeddings,
            segment_ids=self.placeholders['domain_node_graph_ids_list'],
            num_segments=self.placeholders['num_graphs'])

        #  Shape is (batch_size, H) for both the tensors below
        graph_embeddings_copied = tf.gather(
            params=graph_embeddings,
            indices=self.placeholders['domain_node_graph_ids_list'])
        domain_nodes_pooled_copied = tf.gather(
            params=domain_nodes_pooled,
            indices=self.placeholders['domain_node_graph_ids_list'])

        #  Shape is now (batch_size, max_length, H) for both the tensors below
        tiling = [1, self.placeholders['max_length'], 1]
        tiled_graph_embeddings_copied = tf.tile(
            tf.expand_dims(graph_embeddings_copied, 1), tiling)
        tiled_domain_nodes_pooled_copied = tf.tile(
            tf.expand_dims(domain_nodes_pooled_copied, 1), tiling)

        #  Shape is (num-nodes-in-batch, max_length, H)
        tiled_domain_nodes_embeddings = tf.tile(
            tf.expand_dims(domain_nodes_embeddings, 1), tiling)

        #  Shape is (batch_size, max_length, 2H)
        rnn_input = tf.concat(
            [tiled_domain_nodes_pooled_copied, tiled_graph_embeddings_copied],
            axis=-1)
        #  Shape is (batch_size, max_length, 2H)
        rnn_output = tf.keras.layers.LSTM(self.classifier_hidden_dims[0],
                                          return_sequences=True)(rnn_input)
        #  Shape is (num-nodes-in-batch, max_length, H)
        rnn_output_copied = tf.gather(
            params=rnn_output,
            indices=self.placeholders['domain_node_graph_ids_list'])

        mlp_domain_nodes = MLP(in_size=rnn_output_copied.get_shape()[-1] +
                               tiled_domain_nodes_embeddings.get_shape()[-1],
                               out_size=1,
                               hid_sizes=self.classifier_hidden_dims)

        #  Shape is (num-nodes-in-batch, max_length, 1)
        domain_node_logits = mlp_domain_nodes(
            tf.concat([rnn_output_copied, tiled_domain_nodes_embeddings],
                      axis=-1))
        #  Shape is (num-nodes-in-batch, max_length)
        domain_node_logits = tf.squeeze(domain_node_logits, [-1])
        #  Shape is (num-nodes-in-batch, max_length) for both
        probs, log_probs = SegmentBasedSoftmax(
            data=domain_node_logits,
            segment_ids=self.
            placeholders['domain_node_timestep_graph_ids_list'],
            num_segments=self.placeholders['num_graphs'] *
            self.placeholders['max_length'],
            return_log=True)
        self.ops['probabilities'] = probs

        loss_per_domain_node = -tf.cast(self.placeholders['domain_labels'],
                                        tf.float32) * log_probs
        loss_per_domain_node *= self.placeholders['loss_mask']
        loss_per_graph = tf.unsorted_segment_sum(
            data=loss_per_domain_node,
            segment_ids=self.
            placeholders['domain_node_timestep_graph_ids_list_unshifted'],
            num_segments=self.placeholders['num_graphs'])
        self.ops['loss'] = tf.reduce_mean(loss_per_graph)

        domain_node_max_scores = tf.unsorted_segment_max(
            data=domain_node_logits,
            segment_ids=self.
            placeholders['domain_node_timestep_graph_ids_list'],
            num_segments=self.placeholders['num_graphs'] *
            self.placeholders['max_length'])
        copied_domain_node_max_scores = tf.gather(
            params=domain_node_max_scores,
            indices=self.placeholders['domain_node_timestep_graph_ids_list'])

        selected_domain_nodes = tf.cast(tf.equal(copied_domain_node_max_scores,
                                                 domain_node_logits),
                                        dtype=tf.int32)
        correct_prediction_per_node = tf.cast(
            tf.equal(selected_domain_nodes,
                     self.placeholders['domain_labels']), tf.float32)
        correct_prediction_per_node = tf.maximum(correct_prediction_per_node,
                                                 self.placeholders['acc_mask'])
        correct_prediction = tf.unsorted_segment_prod(
            data=correct_prediction_per_node,
            segment_ids=self.
            placeholders['domain_node_timestep_graph_ids_list_unshifted'],
            num_segments=self.placeholders['num_graphs'])

        self.ops['accuracy'] = tf.reduce_mean(correct_prediction)