def RunBatch( self, epoch_type: epoch_pb2.EpochType, batch_data: BatchData, ctx: ProgressContext = NullContext, ) -> BatchResults: """Run a batch of data through the model. Args: epoch_type: The type of the current epoch. batch: A batch of graphs and model data. This requires that batch data has 'x' and 'y' properties that return lists of model inputs, a `targets` property that returns a flattened list of targets, a `GetPredictions()` method that recieves as input the data generated by model and returns a flattened array of the same shape as `targets`. ctx: A logging context. """ model_data: LstmBatchData = batch_data.model_data assert model_data.encoded_sequences.shape == ( self.batch_size, self.padded_sequence_length, ), model_data.encoded_sequences.shape assert model_data.selector_vectors.shape == ( self.batch_size, self.padded_sequence_length, 2, ), model_data.selector_vectors.shape x = [model_data.encoded_sequences, model_data.selector_vectors] y = [model_data.node_labels] if epoch_type == epoch_pb2.TRAIN: loss, *_ = self.model.train_on_batch(x, y) else: loss = None padded_predictions = self.model.predict_on_batch(x) # Reshape the outputs. predictions = self.ReshapePaddedModelOutput(batch_data, padded_predictions) # Flatten the targets and predictions lists so that we can compare them. # Shape (batch_node_count, node_y_dimensionality). targets = np.concatenate(model_data.targets) predictions = np.concatenate(predictions) return BatchResults.Create( targets=targets, predictions=predictions, loss=loss, )
def test_RollingResults_iteration_count(weight: float): """Test aggreation of model iteration count and convergence.""" rolling_results = RollingResults() data = BatchData(graph_count=1, model_data=None) results = BatchResults.Create( targets=np.array([[0, 1, 2]]), predictions=np.array([[0, 1, 2]]), iteration_count=1, model_converged=True, ) for _ in range(10): rolling_results.Update(data, results, weight=weight) assert rolling_results.iteration_count == 1 assert rolling_results.model_converged == 1
def RunBatch( self, epoch_type: epoch_pb2.EpochType, batch: BatchData, ctx: ProgressContext = NullContext, ) -> BatchResults: """Process a mini-batch of data through the GGNN. Args: epoch_type: The type of epoch being run. batch: The batch data returned by MakeBatch(). ctx: A logging context. Returns: A batch results instance. """ model_inputs = self.PrepareModelInputs(epoch_type, batch) unroll_steps = np.array( GetUnrollSteps(epoch_type, batch, FLAGS.unroll_strategy), dtype=np.int64, ) # Set the model into the correct mode and feed through the batch data. if epoch_type == epoch_pb2.TRAIN: if not self.model.training: self.model.train() outputs = self.model(**model_inputs) else: if self.model.training: self.model.eval() self.model.opt.zero_grad() # Inference only, don't trace the computation graph. with torch.no_grad(): outputs = self.model(**model_inputs) ( targets, logits, graph_features, *unroll_stats, ) = outputs loss = self.model.loss((logits, graph_features), targets) if epoch_type == epoch_pb2.TRAIN: loss.backward() # TODO(github.com/ChrisCummins/ProGraML/issues/27): NB, pytorch clips by # norm of the gradient of the model, while tf clips by norm of the grad # of each tensor separately. Therefore we change default from 1.0 to 6.0. # TODO(github.com/ChrisCummins/ProGraML/issues/27): Anyway: Gradients # shouldn't really be clipped if not necessary? if self.clip_gradient_norm > 0.0: nn.utils.clip_gradient_norm_(self.model.parameters(), self.clip_gradient_norm) self.model.opt.step() self.model.opt.zero_grad() # check for LR scheduler stepping if self.opt_step_count % FLAGS.lr_decay_steps == 0: # If scheduler exists, then step it after every epoch if self.model.scheduler is not None: old_learning_rate = self.model.learning_rate self.model.scheduler.step() app.Log( 1, "LR Scheduler step. New learning rate is %s (was %s)", self.model.learning_rate, old_learning_rate, ) model_converged = unroll_stats[1] if unroll_stats else False iteration_count = unroll_stats[0] if unroll_stats else unroll_steps return BatchResults.Create( targets=batch.model_data.node_labels, predictions=logits.detach().cpu().numpy(), model_converged=model_converged, learning_rate=self.model.learning_rate, iteration_count=iteration_count, loss=loss.item(), )