def _Build(self) -> BatchData: """Construct and return a batch, resetting mutable state.""" gt = self.builder.Build() # Expand node labels to 1-hot. indices = np.arange(len(self.node_labels)) node_labels_1hot = np.zeros((len(self.node_labels), 2), dtype=np.int32) node_labels_1hot[indices, self.node_labels] = 1 batch = BatchData( graph_count=gt.graph_size, model_data=GgnnBatchData( graph_tuple=gt, vocab_ids=np.array(self.vocab_ids, dtype=np.int32), selector_ids=np.array(self.selector_ids, dtype=np.int32), node_labels=node_labels_1hot, ), ) # Reset mutable state. self.vocab_ids = [] self.selector_ids = [] self.node_labels = [] self.node_size = 0 return batch
def _Build(self) -> BatchData: # A batch may contain fewer graphs than the required batch_size. # If so, pad with empty "graphs". These padding graphs will be discarded # once processed. if len(self.graph_node_sizes) < self.batch_size: pad_count = self.batch_size - len(self.graph_node_sizes) self.vocab_ids += [np.array([self._vocab_id_pad], dtype=np.int32) ] * pad_count self.selector_vectors += [self._selector_vector_pad] * pad_count self.targets += [self._node_label_pad] * pad_count batch = BatchData( graph_count=len(self.graph_node_sizes), model_data=LstmBatchData( graph_node_sizes=np.array(self.graph_node_sizes, dtype=np.int32), encoded_sequences=tf.compat.v1.keras.preprocessing.sequence. pad_sequences( self.vocab_ids, maxlen=self.padded_sequence_length, dtype="int32", padding="pre", truncating="post", value=self._vocab_id_pad, ), selector_vectors=tf.compat.v1.keras.preprocessing.sequence. pad_sequences( self.selector_vectors, maxlen=self.padded_sequence_length, dtype="float32", padding="pre", truncating="post", value=np.zeros(2, dtype=np.float32), ), node_labels=tf.compat.v1.keras.preprocessing.sequence. pad_sequences( self.targets, maxlen=self.padded_sequence_length, dtype="float32", padding="pre", truncating="post", value=np.zeros(self.node_y_dimensionality, dtype=np.float32), ), # We don't pad or truncate targets. targets=self.targets, ), ) # Reset mutable state. self.graph_node_sizes = [] self.vocab_ids = [] self.selector_vectors = [] self.targets = [] return batch
def test_RollingResults_iteration_count(weight: float): """Test aggreation of model iteration count and convergence.""" rolling_results = RollingResults() data = BatchData(graph_count=1, model_data=None) results = BatchResults.Create( targets=np.array([[0, 1, 2]]), predictions=np.array([[0, 1, 2]]), iteration_count=1, model_converged=True, ) for _ in range(10): rolling_results.Update(data, results, weight=weight) assert rolling_results.iteration_count == 1 assert rolling_results.model_converged == 1