def define_prediction_with_loss(self, node_embeddings): node_embeddings = tf.identity(node_embeddings) graph_embeddings = self.define_pooling(node_embeddings) domain_nodes_embeddings = tf.gather( params=node_embeddings, indices=self.placeholders['domain']) graph_embeddings_copied = tf.gather( params=graph_embeddings, indices=self.placeholders['domain_node_graph_ids_list']) mlp_domain_nodes = MLP(in_size=domain_nodes_embeddings.get_shape()[1] + graph_embeddings_copied.get_shape()[1], out_size=2, hid_sizes=self.classifier_hidden_dims) domain_node_logits = mlp_domain_nodes( tf.concat([domain_nodes_embeddings, graph_embeddings_copied], axis=-1)) individual_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=domain_node_logits, labels=self.placeholders['domain_labels']) loss = self.ops['loss'] = tf.reduce_mean(individual_loss) probs = self.ops['probabilities'] = tf.nn.softmax(domain_node_logits) flat_correct_predictions = tf.cast( tf.equal(tf.argmax(probs, -1, output_type=tf.int32), self.placeholders['domain_labels']), tf.float32) correct_predictions = tf.unsorted_segment_prod( data=flat_correct_predictions, segment_ids=self.placeholders['domain_node_graph_ids_list'], num_segments=self.placeholders['num_graphs']) self.ops['accuracy'] = tf.reduce_mean(correct_predictions)
def define_prediction_with_loss(self, node_embeddings): node_embeddings = tf.identity(node_embeddings) graph_embeddings = self.define_pooling(node_embeddings) domain_nodes_embeddings = tf.gather( params=node_embeddings, indices=self.placeholders['domain']) graph_embeddings_copied = tf.gather( params=graph_embeddings, indices=self.placeholders['domain_node_graph_ids_list']) domain_node_score_calculator = MLP( in_size=domain_nodes_embeddings.get_shape()[1] + graph_embeddings_copied.get_shape()[1], out_size=1, hid_sizes=self.classifier_hidden_dims) domain_node_logits = domain_node_score_calculator( tf.concat([domain_nodes_embeddings, graph_embeddings_copied], axis=-1)) domain_node_logits = tf.reshape(domain_node_logits, [-1]) probs, log_probs = SegmentBasedSoftmax( data=domain_node_logits, segment_ids=self.placeholders['domain_node_graph_ids_list'], num_segments=self.placeholders['num_graphs'], return_log=True) self.ops['probabilities'] = probs loss_per_domain_node = -tf.cast(self.placeholders['domain_labels'], tf.float32) * log_probs loss_per_graph = tf.unsorted_segment_sum( data=loss_per_domain_node, segment_ids=self.placeholders['domain_node_graph_ids_list'], num_segments=self.placeholders['num_graphs']) self.ops['loss'] = tf.reduce_mean(loss_per_graph) domain_node_max_scores = tf.unsorted_segment_max( data=domain_node_logits, segment_ids=self.placeholders['domain_node_graph_ids_list'], num_segments=self.placeholders['num_graphs']) copied_domain_node_max_scores = tf.gather( params=domain_node_max_scores, indices=self.placeholders['domain_node_graph_ids_list']) selected_domain_nodes = tf.cast(tf.equal(copied_domain_node_max_scores, domain_node_logits), dtype=tf.int32) correct_prediction_per_node = tf.cast( tf.equal(selected_domain_nodes, self.placeholders['domain_labels']), tf.float32) correct_prediction = tf.unsorted_segment_prod( data=correct_prediction_per_node, segment_ids=self.placeholders['domain_node_graph_ids_list'], num_segments=self.placeholders['num_graphs']) self.ops['accuracy'] = tf.reduce_mean(correct_prediction)
def define_prediction_with_loss(self, node_embeddings): node_embeddings = tf.identity(node_embeddings) graph_embeddings = self.define_pooling(node_embeddings) domain_nodes_embeddings = tf.gather( params=node_embeddings, indices=self.placeholders['domain']) domain_nodes_pooled = tf.unsorted_segment_sum( data=domain_nodes_embeddings, segment_ids=self.placeholders['domain_node_graph_ids_list'], num_segments=self.placeholders['num_graphs']) # Shape is (batch_size, H) for both the tensors below graph_embeddings_copied = tf.gather( params=graph_embeddings, indices=self.placeholders['domain_node_graph_ids_list']) domain_nodes_pooled_copied = tf.gather( params=domain_nodes_pooled, indices=self.placeholders['domain_node_graph_ids_list']) # Shape is now (batch_size, max_length, H) for both the tensors below tiling = [1, self.placeholders['max_length'], 1] tiled_graph_embeddings_copied = tf.tile( tf.expand_dims(graph_embeddings_copied, 1), tiling) tiled_domain_nodes_pooled_copied = tf.tile( tf.expand_dims(domain_nodes_pooled_copied, 1), tiling) # Shape is (num-nodes-in-batch, max_length, H) tiled_domain_nodes_embeddings = tf.tile( tf.expand_dims(domain_nodes_embeddings, 1), tiling) # Shape is (batch_size, max_length, 2H) rnn_input = tf.concat( [tiled_domain_nodes_pooled_copied, tiled_graph_embeddings_copied], axis=-1) # Shape is (batch_size, max_length, 2H) rnn_output = tf.keras.layers.LSTM(self.classifier_hidden_dims[0], return_sequences=True)(rnn_input) # Shape is (num-nodes-in-batch, max_length, H) rnn_output_copied = tf.gather( params=rnn_output, indices=self.placeholders['domain_node_graph_ids_list']) mlp_domain_nodes = MLP(in_size=rnn_output_copied.get_shape()[-1] + tiled_domain_nodes_embeddings.get_shape()[-1], out_size=1, hid_sizes=self.classifier_hidden_dims) # Shape is (num-nodes-in-batch, max_length, 1) domain_node_logits = mlp_domain_nodes( tf.concat([rnn_output_copied, tiled_domain_nodes_embeddings], axis=-1)) # Shape is (num-nodes-in-batch, max_length) domain_node_logits = tf.squeeze(domain_node_logits, [-1]) # Shape is (num-nodes-in-batch, max_length) for both probs, log_probs = SegmentBasedSoftmax( data=domain_node_logits, segment_ids=self. placeholders['domain_node_timestep_graph_ids_list'], num_segments=self.placeholders['num_graphs'] * self.placeholders['max_length'], return_log=True) self.ops['probabilities'] = probs loss_per_domain_node = -tf.cast(self.placeholders['domain_labels'], tf.float32) * log_probs loss_per_domain_node *= self.placeholders['loss_mask'] loss_per_graph = tf.unsorted_segment_sum( data=loss_per_domain_node, segment_ids=self. placeholders['domain_node_timestep_graph_ids_list_unshifted'], num_segments=self.placeholders['num_graphs']) self.ops['loss'] = tf.reduce_mean(loss_per_graph) domain_node_max_scores = tf.unsorted_segment_max( data=domain_node_logits, segment_ids=self. placeholders['domain_node_timestep_graph_ids_list'], num_segments=self.placeholders['num_graphs'] * self.placeholders['max_length']) copied_domain_node_max_scores = tf.gather( params=domain_node_max_scores, indices=self.placeholders['domain_node_timestep_graph_ids_list']) selected_domain_nodes = tf.cast(tf.equal(copied_domain_node_max_scores, domain_node_logits), dtype=tf.int32) correct_prediction_per_node = tf.cast( tf.equal(selected_domain_nodes, self.placeholders['domain_labels']), tf.float32) correct_prediction_per_node = tf.maximum(correct_prediction_per_node, self.placeholders['acc_mask']) correct_prediction = tf.unsorted_segment_prod( data=correct_prediction_per_node, segment_ids=self. placeholders['domain_node_timestep_graph_ids_list_unshifted'], num_segments=self.placeholders['num_graphs']) self.ops['accuracy'] = tf.reduce_mean(correct_prediction)