Exemplo n.º 1
0
    def RunBatch(
        self,
        epoch_type: epoch_pb2.EpochType,
        batch_data: BatchData,
        ctx: ProgressContext = NullContext,
    ) -> BatchResults:
        """Run a batch of data through the model.

        Args:
          epoch_type: The type of the current epoch.
          batch: A batch of graphs and model data. This requires that batch data has
            'x' and 'y' properties that return lists of model inputs, a `targets`
            property that returns a flattened list of targets, a `GetPredictions()`
            method that recieves as input the data generated by model and returns
            a flattened array of the same shape as `targets`.
          ctx: A logging context.
        """
        model_data: LstmBatchData = batch_data.model_data

        assert model_data.encoded_sequences.shape == (
            self.batch_size,
            self.padded_sequence_length,
        ), model_data.encoded_sequences.shape
        assert model_data.selector_vectors.shape == (
            self.batch_size,
            self.padded_sequence_length,
            2,
        ), model_data.selector_vectors.shape

        x = [model_data.encoded_sequences, model_data.selector_vectors]
        y = [model_data.node_labels]

        if epoch_type == epoch_pb2.TRAIN:
            loss, *_ = self.model.train_on_batch(x, y)
        else:
            loss = None

        padded_predictions = self.model.predict_on_batch(x)

        # Reshape the outputs.
        predictions = self.ReshapePaddedModelOutput(batch_data,
                                                    padded_predictions)

        # Flatten the targets and predictions lists so that we can compare them.
        # Shape (batch_node_count, node_y_dimensionality).
        targets = np.concatenate(model_data.targets)
        predictions = np.concatenate(predictions)

        return BatchResults.Create(
            targets=targets,
            predictions=predictions,
            loss=loss,
        )
Exemplo n.º 2
0
def test_RollingResults_iteration_count(weight: float):
    """Test aggreation of model iteration count and convergence."""
    rolling_results = RollingResults()

    data = BatchData(graph_count=1, model_data=None)
    results = BatchResults.Create(
        targets=np.array([[0, 1, 2]]),
        predictions=np.array([[0, 1, 2]]),
        iteration_count=1,
        model_converged=True,
    )

    for _ in range(10):
        rolling_results.Update(data, results, weight=weight)

    assert rolling_results.iteration_count == 1
    assert rolling_results.model_converged == 1
Exemplo n.º 3
0
    def RunBatch(
        self,
        epoch_type: epoch_pb2.EpochType,
        batch: BatchData,
        ctx: ProgressContext = NullContext,
    ) -> BatchResults:
        """Process a mini-batch of data through the GGNN.

    Args:
      epoch_type: The type of epoch being run.
      batch: The batch data returned by MakeBatch().
      ctx: A logging context.

    Returns:
      A batch results instance.
    """
        model_inputs = self.PrepareModelInputs(epoch_type, batch)
        unroll_steps = np.array(
            GetUnrollSteps(epoch_type, batch, FLAGS.unroll_strategy),
            dtype=np.int64,
        )

        # Set the model into the correct mode and feed through the batch data.
        if epoch_type == epoch_pb2.TRAIN:
            if not self.model.training:
                self.model.train()
            outputs = self.model(**model_inputs)
        else:
            if self.model.training:
                self.model.eval()
                self.model.opt.zero_grad()
            # Inference only, don't trace the computation graph.
            with torch.no_grad():
                outputs = self.model(**model_inputs)

        (
            targets,
            logits,
            graph_features,
            *unroll_stats,
        ) = outputs

        loss = self.model.loss((logits, graph_features), targets)

        if epoch_type == epoch_pb2.TRAIN:
            loss.backward()
            # TODO(github.com/ChrisCummins/ProGraML/issues/27): NB, pytorch clips by
            # norm of the gradient of the model, while tf clips by norm of the grad
            # of each tensor separately. Therefore we change default from 1.0 to 6.0.
            # TODO(github.com/ChrisCummins/ProGraML/issues/27): Anyway: Gradients
            # shouldn't really be clipped if not necessary?
            if self.clip_gradient_norm > 0.0:
                nn.utils.clip_gradient_norm_(self.model.parameters(),
                                             self.clip_gradient_norm)
            self.model.opt.step()
            self.model.opt.zero_grad()
            # check for LR scheduler stepping
            if self.opt_step_count % FLAGS.lr_decay_steps == 0:
                # If scheduler exists, then step it after every epoch
                if self.model.scheduler is not None:
                    old_learning_rate = self.model.learning_rate
                    self.model.scheduler.step()
                    app.Log(
                        1,
                        "LR Scheduler step. New learning rate is %s (was %s)",
                        self.model.learning_rate,
                        old_learning_rate,
                    )

        model_converged = unroll_stats[1] if unroll_stats else False
        iteration_count = unroll_stats[0] if unroll_stats else unroll_steps

        return BatchResults.Create(
            targets=batch.model_data.node_labels,
            predictions=logits.detach().cpu().numpy(),
            model_converged=model_converged,
            learning_rate=self.model.learning_rate,
            iteration_count=iteration_count,
            loss=loss.item(),
        )