Example #1
0
class GraphGlobalMLPExchange(GraphGlobalExchange):
    def __init__(
        self,
        hidden_dim: int,
        weighting_fun: str = "softmax",
        num_heads: int = 4,
        dropout_rate: float = 0.0,
    ):
        """Initialise the layer."""
        super().__init__(hidden_dim, weighting_fun, num_heads, dropout_rate)

    def build(self, tensor_shapes: GraphGlobalExchangeInput):
        with tf.name_scope(self._name):
            self._mlp = MLP(out_size=self._hidden_dim)
            self._mlp.build(tf.TensorShape((None, 2 * self._hidden_dim)))
            super().build(tensor_shapes)

    def call(self, inputs: GraphGlobalExchangeInput, training: bool = False):
        per_node_graph_representations = self._compute_per_node_graph_representations(
            inputs, training
        )
        cur_node_representations = self._mlp(
            tf.concat([per_node_graph_representations, inputs.node_embeddings], axis=-1),
            training=training,
        )
        return cur_node_representations
Example #2
0
    def build(self, input_shapes: MessagePassingInput):
        node_embedding_shapes = input_shapes.node_embeddings
        adjacency_list_shapes = input_shapes.adjacency_lists
        num_edge_types = len(adjacency_list_shapes)

        for i in range(num_edge_types):
            with tf.name_scope(f"edge_type_{i}-FiLM"):
                film_mlp = MLP(
                    out_size=2 * self._hidden_dim,
                    hidden_layers=self._film_parameter_MLP_hidden_layers,
                )
                film_mlp.build(
                    tf.TensorShape((None, node_embedding_shapes[-1])))
                self._edge_type_film_layer_computations.append(film_mlp)

        super().build(input_shapes)
Example #3
0
    def build(self, input_shapes: MessagePassingInput):
        node_embedding_shapes = input_shapes.node_embeddings
        adjacency_list_shapes = input_shapes.adjacency_lists
        num_edge_types = len(adjacency_list_shapes)

        if self._use_target_state_as_input:
            edge_layer_input_size = 2 * node_embedding_shapes[-1]
        else:
            edge_layer_input_size = node_embedding_shapes[-1]

        for i in range(num_edge_types):
            with tf.name_scope(f"edge_type_{i}"):
                mlp = MLP(out_size=self._hidden_dim,
                          hidden_layers=self._num_edge_MLP_hidden_layers)
                mlp.build(tf.TensorShape((None, edge_layer_input_size)))
            self._edge_type_mlps.append(mlp)

        super().build(input_shapes)
Example #4
0
class GraphRegressionTask(GraphTaskModel):
    @classmethod
    def get_default_hyperparameters(cls,
                                    mp_style: Optional[str] = None
                                    ) -> Dict[str, Any]:
        super_params = super().get_default_hyperparameters(mp_style)
        these_hypers: Dict[str, Any] = {
            "use_intermediate_gnn_results": True,
            "graph_aggregation_output_size": 32,
            "graph_aggregation_num_heads": 4,
            "graph_aggregation_layers": [32, 32],
            "graph_aggregation_dropout_rate": 0.1,
            "regression_mlp_layers": [64, 32],
            "regression_mlp_dropout": 0.1,
        }
        super_params.update(these_hypers)
        return super_params

    def __init__(self,
                 params: Dict[str, Any],
                 dataset: GraphDataset,
                 name: str = None,
                 **kwargs):
        super().__init__(params, dataset=dataset, name=name, **kwargs)
        self._node_to_graph_aggregation = None

        # Construct sublayers:
        self._weighted_avg_of_nodes_to_graph_repr = WeightedSumGraphRepresentation(
            graph_representation_size=self.
            _params["graph_aggregation_output_size"],
            num_heads=self._params["graph_aggregation_num_heads"],
            weighting_fun="softmax",
            scoring_mlp_layers=self._params["graph_aggregation_layers"],
            scoring_mlp_dropout_rate=self.
            _params["graph_aggregation_dropout_rate"],
            scoring_mlp_activation_fun="elu",
            transformation_mlp_layers=self._params["graph_aggregation_layers"],
            transformation_mlp_dropout_rate=self.
            _params["graph_aggregation_dropout_rate"],
            transformation_mlp_activation_fun="elu",
        )
        self._weighted_sum_of_nodes_to_graph_repr = WeightedSumGraphRepresentation(
            graph_representation_size=self.
            _params["graph_aggregation_output_size"],
            num_heads=self._params["graph_aggregation_num_heads"],
            weighting_fun="sigmoid",
            scoring_mlp_layers=self._params["graph_aggregation_layers"],
            scoring_mlp_dropout_rate=self.
            _params["graph_aggregation_dropout_rate"],
            scoring_mlp_activation_fun="elu",
            transformation_mlp_layers=self._params["graph_aggregation_layers"],
            transformation_mlp_dropout_rate=self.
            _params["graph_aggregation_dropout_rate"],
            transformation_mlp_activation_fun="elu",
        )

        self._regression_mlp = MLP(
            out_size=1,
            hidden_layers=self._params["regression_mlp_layers"],
            dropout_rate=self._params["regression_mlp_dropout"],
            use_biases=True,
            activation_fun=tf.nn.relu,
        )

    def build(self, input_shapes):
        if self._params["use_intermediate_gnn_results"]:
            # We get the initial GNN input + results for all layers:
            node_repr_size = (input_shapes["node_features"][-1] +
                              self._params["gnn_hidden_dim"] *
                              self._params["gnn_num_layers"])
        else:
            node_repr_size = (input_shapes["node_features"][-1] +
                              self._params["gnn_hidden_dim"])

        node_to_graph_repr_input = NodesToGraphRepresentationInput(
            node_embeddings=tf.TensorShape((None, node_repr_size)),
            node_to_graph_map=tf.TensorShape((None)),
            num_graphs=tf.TensorShape(()),
        )

        with tf.name_scope(self.__class__.__name__):
            with tf.name_scope("graph_representation_computation"):
                with tf.name_scope("weighted_avg"):
                    self._weighted_avg_of_nodes_to_graph_repr.build(
                        node_to_graph_repr_input)
                with tf.name_scope("weighted_sum"):
                    self._weighted_sum_of_nodes_to_graph_repr.build(
                        node_to_graph_repr_input)

            self._regression_mlp.build(
                tf.TensorShape(
                    (None, 2 * self._params["graph_aggregation_output_size"])))

        super().build(input_shapes)

    def compute_task_output(
        self,
        batch_features: Dict[str, tf.Tensor],
        final_node_representations: Union[tf.Tensor, Tuple[tf.Tensor,
                                                           List[tf.Tensor]]],
        training: bool,
    ) -> Any:
        if self._params["use_intermediate_gnn_results"]:
            _, intermediate_node_representations = final_node_representations
            # We want to skip the first "intermediate" representation, which is the output of
            # the initial feature -> GNN input layer:
            node_representations = tf.concat(
                (batch_features["node_features"], ) +
                intermediate_node_representations[1:],
                axis=-1,
            )
        else:
            node_representations = tf.concat(
                [batch_features["node_features"], final_node_representations],
                axis=-1)

        graph_representation_layer_input = NodesToGraphRepresentationInput(
            node_embeddings=node_representations,
            node_to_graph_map=batch_features["node_to_graph_map"],
            num_graphs=batch_features["num_graphs_in_batch"],
        )
        weighted_avg_graph_repr = self._weighted_avg_of_nodes_to_graph_repr(
            graph_representation_layer_input, training=training)
        weighted_sum_graph_repr = self._weighted_sum_of_nodes_to_graph_repr(
            graph_representation_layer_input, training=training)

        graph_representations = tf.concat(
            [weighted_avg_graph_repr, weighted_sum_graph_repr],
            axis=-1)  # shape: [G, GD]

        per_graph_results = self._regression_mlp(
            graph_representations, training=training)  # shape: [G, 1]

        return tf.squeeze(per_graph_results, axis=-1)

    def compute_task_metrics(
        self,
        batch_features: Dict[str, tf.Tensor],
        task_output: Any,
        batch_labels: Dict[str, tf.Tensor],
    ) -> Dict[str, tf.Tensor]:
        mse = tf.losses.mean_squared_error(batch_labels["target_value"],
                                           task_output)
        mae = tf.losses.mean_absolute_error(batch_labels["target_value"],
                                            task_output)
        num_graphs = tf.cast(batch_features["num_graphs_in_batch"], tf.float32)
        return {
            "loss": mse,
            "batch_squared_error": mse * num_graphs,
            "batch_absolute_error": mae * num_graphs,
            "num_graphs": num_graphs,
        }

    def compute_epoch_metrics(self,
                              task_results: List[Any]) -> Tuple[float, str]:
        total_num_graphs = sum(batch_task_result["num_graphs"]
                               for batch_task_result in task_results)
        total_absolute_error = sum(batch_task_result["batch_absolute_error"]
                                   for batch_task_result in task_results)
        total_squared_error = sum(batch_task_result["batch_squared_error"]
                                  for batch_task_result in task_results)
        epoch_mse = (total_squared_error / total_num_graphs).numpy()
        epoch_mae = (total_absolute_error / total_num_graphs).numpy()
        return epoch_mae, f" MSE = {epoch_mse:.3f} | MAE = {epoch_mae:.3f}"

    def evaluate_model(self, dataset: tf.data.Dataset) -> Dict[str, float]:
        import sklearn.metrics as metrics

        predictions = self.predict(dataset).numpy()
        labels = []
        for _, batch_labels in dataset:
            labels.append(batch_labels["target_value"])
        labels = tf.concat(labels, axis=0).numpy()

        metrics = dict(
            mae=metrics.mean_absolute_error(y_true=labels, y_pred=predictions),
            mse=metrics.mean_squared_error(y_true=labels, y_pred=predictions),
            max_err=metrics.max_error(y_true=labels, y_pred=predictions),
            expl_var=metrics.explained_variance_score(y_true=labels,
                                                      y_pred=predictions),
            r2_score=metrics.r2_score(y_true=labels, y_pred=predictions),
        )

        return metrics
Example #5
0
class WeightedSumGraphRepresentation(NodesToGraphRepresentation):
    """Layer computing graph representations as weighted sum of node representations.
    The weights are either computed from the original node representations ("self-attentional")
    or by a softmax across the nodes of a graph.
    Supports splitting operation into parallely computed independent "heads" which can focus
    on different aspects.

    Throughout we use the following abbreviations in shape descriptions:
        * V: number of nodes (across all graphs)
        * VD: node representation dimension
        * G: number of graphs
        * GD: graph representation dimension
        * H: number of heads
    """
    def __init__(
        self,
        graph_representation_size: int,
        num_heads: int,
        weighting_fun: str = "softmax",  # One of {"softmax", "sigmoid"}
        scoring_mlp_layers: List[int] = [128],
        scoring_mlp_activation_fun: str = "ReLU",
        scoring_mlp_use_biases: bool = False,
        scoring_mlp_dropout_rate: float = 0.2,
        transformation_mlp_layers: List[int] = [128],
        transformation_mlp_activation_fun: str = "ReLU",
        transformation_mlp_use_biases: bool = False,
        transformation_mlp_dropout_rate: float = 0.2,
        transformation_mlp_result_lower_bound: Optional[float] = None,
        transformation_mlp_result_upper_bound: Optional[float] = None,
        **kwargs,
    ):
        """
        Args:
            graph_representation_size: Size of the computed graph representation.
            num_heads: Number of independent heads to use to compute weights.
            weighting_fun: "sigmoid" ([0, 1] weights for each node computed from its
                representation), "softmax" ([0, 1] weights for each node computed
                from all nodes in same graph), "average" (weight is fixed to 1/num_nodes),
                or "none" (weight is fixed to 1).
            scoring_mlp_layers: MLP layer structure for computing raw scores turned into
                weights.
            scoring_mlp_activation_fun: MLP activcation function for computing raw scores
                turned into weights.
            scoring_mlp_dropout_rate: MLP inter-layer dropout rate for computing raw scores
                turned into weights.
            transformation_mlp_layers: MLP layer structure for computing graph representations.
            transformation_mlp_activation_fun: MLP activcation function for computing graph
                representations.
            transformation_mlp_dropout_rate: MLP inter-layer dropout rate for computing graph
                representations.
            transformation_mlp_result_lower_bound: Lower bound that results of the transformation
                MLP will be clipped to before being scaled and summed up.
                This is particularly useful to limit the magnitude of results when using "sigmoid"
                or "none" as weighting function.
            transformation_mlp_result_upper_bound: Upper bound that results of the transformation
                MLP will be clipped to before being scaled and summed up.
        """
        super().__init__(graph_representation_size, **kwargs)
        assert (
            graph_representation_size % num_heads == 0
        ), f"Number of heads {num_heads} needs to divide final representation size {graph_representation_size}!"
        assert weighting_fun.lower() in {
            "none",
            "average",
            "softmax",
            "sigmoid",
        }, f"Weighting function {weighting_fun} unknown, {{'softmax', 'sigmoid', 'none', 'average'}} supported."

        self._num_heads = num_heads
        self._weighting_fun = weighting_fun.lower()
        self._transformation_mlp_activation_fun = get_activation_function_by_name(
            transformation_mlp_activation_fun)
        self._transformation_mlp_result_lower_bound = transformation_mlp_result_lower_bound
        self._transformation_mlp_result_upper_bound = transformation_mlp_result_upper_bound

        # Build sub-layers:
        if self._weighting_fun not in ("none", "average"):
            self._scoring_mlp = MLP(
                out_size=self._num_heads,
                hidden_layers=scoring_mlp_layers,
                use_biases=scoring_mlp_use_biases,
                activation_fun=get_activation_function_by_name(
                    scoring_mlp_activation_fun),
                dropout_rate=scoring_mlp_dropout_rate,
                name="ScoringMLP",
            )

        self._transformation_mlp = MLP(
            out_size=self._graph_representation_size,
            hidden_layers=transformation_mlp_layers,
            use_biases=transformation_mlp_use_biases,
            activation_fun=self._transformation_mlp_activation_fun,
            dropout_rate=transformation_mlp_dropout_rate,
            name="TransformationMLP",
        )

    def build(self, input_shapes: NodesToGraphRepresentationInput):
        with tf.name_scope("WeightedSumGraphRepresentation"):
            if self._weighting_fun not in ("none", "average"):
                self._scoring_mlp.build(
                    tf.TensorShape((None, input_shapes.node_embeddings[-1])))
            self._transformation_mlp.build(
                tf.TensorShape((None, input_shapes.node_embeddings[-1])))

            super().build(input_shapes)

    @tf.function(input_signature=(
        NodesToGraphRepresentationInput(
            node_embeddings=tf.TensorSpec(shape=tf.TensorShape((None, None)),
                                          dtype=tf.float32),
            node_to_graph_map=tf.TensorSpec(shape=tf.TensorShape((None, )),
                                            dtype=tf.int32),
            num_graphs=tf.TensorSpec(shape=(), dtype=tf.int32),
        ),
        tf.TensorSpec(shape=(), dtype=tf.bool),
    ))
    def call(self,
             inputs: NodesToGraphRepresentationInput,
             training: bool = False):
        # (1) compute weights for each node/head pair:
        if self._weighting_fun not in ("none", "average"):
            scores = self._scoring_mlp(inputs.node_embeddings,
                                       training=training)  # Shape [V, H]
            if self._weighting_fun == "sigmoid":
                weights = tf.nn.sigmoid(scores)  # Shape [V, H]
            elif self._weighting_fun == "softmax":
                weights_per_head = []
                for head_idx in range(self._num_heads):
                    head_scores = scores[:, head_idx]  # Shape [V]
                    head_weights = unsorted_segment_softmax(
                        logits=head_scores,
                        segment_ids=inputs.node_to_graph_map,
                        num_segments=inputs.num_graphs,
                    )  # Shape [V]
                    weights_per_head.append(tf.expand_dims(head_weights, -1))
                weights = tf.concat(weights_per_head, axis=1)  # Shape [V, H]
            else:
                raise ValueError()

        # (2) compute representations for each node/head pair:
        node_reprs = self._transformation_mlp_activation_fun(
            self._transformation_mlp(inputs.node_embeddings,
                                     training=training))  # Shape [V, GD]
        if self._transformation_mlp_result_lower_bound is not None:
            node_reprs = tf.maximum(
                node_reprs, self._transformation_mlp_result_lower_bound)
        if self._transformation_mlp_result_upper_bound is not None:
            node_reprs = tf.minimum(
                node_reprs, self._transformation_mlp_result_upper_bound)
        node_reprs = tf.reshape(
            node_reprs,
            shape=(-1, self._num_heads,
                   self._graph_representation_size // self._num_heads),
        )  # Shape [V, H, GD//H]

        # (3) if necessary, weight representations and aggregate by graph:
        if self._weighting_fun == "none":
            node_reprs = tf.reshape(
                node_reprs,
                shape=(-1, self._graph_representation_size))  # Shape [V, GD]
            graph_reprs = tf.math.segment_sum(
                data=node_reprs,
                segment_ids=inputs.node_to_graph_map)  # Shape [G, GD]
        elif self._weighting_fun == "average":
            node_reprs = tf.reshape(
                node_reprs,
                shape=(-1, self._graph_representation_size))  # Shape [V, GD]
            graph_reprs = tf.math.segment_mean(
                data=node_reprs,
                segment_ids=inputs.node_to_graph_map)  # Shape [G, GD]
        else:
            weights = tf.expand_dims(weights, -1)  # Shape [V, H, 1]
            weighted_node_reprs = weights * node_reprs  # Shape [V, H, GD//H]

            weighted_node_reprs = tf.reshape(
                weighted_node_reprs,
                shape=(-1, self._graph_representation_size))  # Shape [V, GD]
            graph_reprs = tf.math.segment_sum(
                data=weighted_node_reprs,
                segment_ids=inputs.node_to_graph_map)  # Shape [G, GD]

        return graph_reprs
Example #6
0
class RGIN(GNN_Edge_MLP):
    """Compute new graph states by neural message passing using MLPs for state updates
    and message computation.
    For this, we assume existing node states h^t_v and a list of per-edge-type adjacency
    matrices A_\ell.

    We compute new states as follows:
        h^{t+1}_v := \sigma(MLP_{aggr}(\sum_\ell \sum_{(u, v) \in A_\ell} MLP_\ell(h^t_u)))
    The learnable parameters of this are the MLPs MLP_\ell.
    This is derived from Cor. 6 of arXiv:1810.00826, instantiating the functions f, \phi
    with _separate_ MLPs. This is more powerful than the GIN formulation in Eq. (4.1) of
    arXiv:1810.00826, as we want to be able to distinguish graphs of the form
     G_1 = (V={1, 2, 3}, E_1={(1, 2)}, E_2={(3, 2)})
    and
     G_2 = (V={1, 2, 3}, E_1={(3, 2)}, E_2={(1, 2)})
    from each other. If we would treat all edges the same,
    G_1.E_1 \cup G_1.E_2 == G_2.E_1 \cup G_2.E_2 would imply that the two graphs
    become indistuingishable.
    Hence, we introduce per-edge-type MLPs, which also means that we have to drop
    the optimisation of modelling f \circ \phi by a single MLP used in the original
    GIN formulation.

    Note that RGIN is implemented as a special-case of GNN_Edge_MLP, setting some
    different default hyperparameters and adding a different message aggregation
    function, but re-using the message passing functionality.

    We use the following abbreviations in shape descriptions:
    * V: number of nodes
    * L: number of different edge types
    * E: number of edges of a given edge type
    * D: input node representation dimension
    * H: output node representation dimension (set as hidden_dim)

    >>> node_embeddings = tf.random.normal(shape=(5, 3))
    >>> adjacency_lists = (
    ...    tf.constant([[0, 1], [2, 4], [2, 4]], dtype=tf.int32),
    ...    tf.constant([[2, 3], [2, 4]], dtype=tf.int32),
    ...    tf.constant([[3, 1]], dtype=tf.int32),
    ... )
    ...
    >>> params = RGIN.get_default_hyperparameters()
    >>> params["hidden_dim"] = 12
    >>> layer = RGIN(params)
    >>> output = layer(MessagePassingInput(node_embeddings, adjacency_lists))
    >>> print(output)
    tf.Tensor(..., shape=(5, 12), dtype=float32)
    """
    @classmethod
    def get_default_hyperparameters(cls):
        these_hypers = {
            "use_target_state_as_input": False,
            "num_edge_MLP_hidden_layers": 1,
            "num_aggr_MLP_hidden_layers": None,
        }
        gnn_edge_mlp_hypers = super().get_default_hyperparameters()
        gnn_edge_mlp_hypers.update(these_hypers)
        return gnn_edge_mlp_hypers

    def __init__(self, params: Dict[str, Any], **kwargs):
        super().__init__(params, **kwargs)
        self._num_aggr_MLP_hidden_layers: Optional[int] = params[
            "num_aggr_MLP_hidden_layers"]
        self._aggregation_mlp: Optional[MLP] = None

    def build(self, input_shapes: MessagePassingInput):
        node_embedding_shapes = input_shapes.node_embeddings
        if self._num_aggr_MLP_hidden_layers is not None:
            with tf.name_scope("aggregation_MLP"):
                self._aggregation_mlp = MLP(
                    out_size=self._hidden_dim,
                    hidden_layers=[self._hidden_dim] *
                    self._num_aggr_MLP_hidden_layers,
                )
                self._aggregation_mlp.build(
                    tf.TensorShape((None, self._hidden_dim)))
        super().build(input_shapes)

    def _compute_new_node_embeddings(
        self,
        cur_node_embeddings: tf.Tensor,
        messages_per_type: List[tf.Tensor],
        edge_type_to_message_targets: List[tf.Tensor],
        num_nodes: tf.Tensor,
        training: bool,
    ):
        # Let M be the number of messages (sum of all E):
        message_targets = tf.concat(edge_type_to_message_targets,
                                    axis=0)  # Shape [M]
        messages = tf.concat(messages_per_type, axis=0)  # Shape [M, H]

        aggregated_messages = self._aggregation_fn(data=messages,
                                                   segment_ids=message_targets,
                                                   num_segments=num_nodes)
        if self._aggregation_mlp is not None:
            aggregated_messages = self._aggregation_mlp(
                aggregated_messages, training)

        return self._activation_fn(aggregated_messages)
Example #7
0
class QM9RegressionTask(GraphTaskModel):
    @classmethod
    def get_default_hyperparameters(cls,
                                    mp_style: Optional[str] = None
                                    ) -> Dict[str, Any]:
        super_params = super().get_default_hyperparameters(mp_style)
        these_hypers: Dict[str, Any] = {
            "use_intermediate_gnn_results": False,
            "out_layer_dropout_keep_prob": 1.0,
        }
        super_params.update(these_hypers)
        return super_params

    def __init__(self,
                 params: Dict[str, Any],
                 dataset: GraphDataset,
                 name: str = None,
                 **kwargs):
        super().__init__(params, dataset=dataset, name=name, **kwargs)
        assert isinstance(dataset, QM9Dataset)

        self._task_id = int(dataset._params["task_id"])

        self._regression_gate = MLP(
            out_size=1,
            hidden_layers=[],
            use_biases=True,
            dropout_rate=self._params["out_layer_dropout_keep_prob"],
            name="gate",
        )
        self._regression_transform = MLP(
            out_size=1,
            hidden_layers=[],
            use_biases=True,
            dropout_rate=self._params["out_layer_dropout_keep_prob"],
            name="transform")

    def build(self, input_shapes):
        with tf.name_scope(self.__class__.__name__):
            with tf.name_scope("node_gate"):
                self._regression_gate.build(
                    tf.TensorShape((
                        None,
                        input_shapes["node_features"][-1] +
                        self._params["gnn_hidden_dim"],
                    )))
            with tf.name_scope("node_transform"):
                self._regression_transform.build(
                    tf.TensorShape((None, self._params["gnn_hidden_dim"])))

        super().build(input_shapes)

    def compute_task_output(
        self,
        batch_features: Dict[str, tf.Tensor],
        final_node_representations: Union[tf.Tensor, Tuple[tf.Tensor,
                                                           List[tf.Tensor]]],
        training: bool,
    ) -> Any:
        if self._params["use_intermediate_gnn_results"]:
            final_node_representations, _ = final_node_representations

        # The per-node regression uses only final node representations:
        per_node_output = self._regression_transform(
            final_node_representations, training=training)  # Shape [V, 1]

        # The gating uses both initial and final node representations:
        per_node_weight = self._regression_gate(
            tf.concat(
                [batch_features["node_features"], final_node_representations],
                axis=-1),
            training=training,
        )  # Shape [V, 1]

        per_node_weighted_output = tf.squeeze(tf.nn.sigmoid(per_node_weight) *
                                              per_node_output,
                                              axis=-1)  # Shape [V]
        per_graph_output = tf.math.unsorted_segment_sum(
            data=per_node_weighted_output,
            segment_ids=batch_features["node_to_graph_map"],
            num_segments=batch_features["num_graphs_in_batch"],
        )  # Shape [G]

        return per_graph_output

    def compute_task_metrics(
        self,
        batch_features: Dict[str, tf.Tensor],
        task_output: Any,
        batch_labels: Dict[str, tf.Tensor],
    ) -> Dict[str, tf.Tensor]:
        mse = tf.losses.mean_squared_error(batch_labels["target_value"],
                                           task_output)
        mae = tf.losses.mean_absolute_error(batch_labels["target_value"],
                                            task_output)
        num_graphs = tf.cast(batch_features["num_graphs_in_batch"], tf.float32)
        return {
            "loss": mse,
            "batch_squared_error": mse * num_graphs,
            "batch_absolute_error": mae * num_graphs,
            "num_graphs": num_graphs,
        }

    def compute_epoch_metrics(self,
                              task_results: List[Any]) -> Tuple[float, str]:
        total_num_graphs = sum(batch_task_result["num_graphs"]
                               for batch_task_result in task_results)
        total_absolute_error = sum(batch_task_result["batch_absolute_error"]
                                   for batch_task_result in task_results)
        total_squared_error = sum(batch_task_result["batch_squared_error"]
                                  for batch_task_result in task_results)
        epoch_mse = (total_squared_error / total_num_graphs).numpy()
        epoch_mae = (total_absolute_error / total_num_graphs).numpy()
        return (
            epoch_mae,
            (f"Task {self._task_id} |"
             f" MSE = {epoch_mse:.3f} |"
             f" MAE = {epoch_mae:.3f} |"
             f" Error Ratio: {epoch_mae / CHEMICAL_ACC_NORMALISING_FACTORS[self._task_id]:.3f}"
             ),
        )