コード例 #1
0
ファイル: gnn.py プロジェクト: nguyenducnhaty/tf2-gnn
    def __init__(self, params: Dict[str, Any]):
        """Initialise the layer."""
        super().__init__()
        self._params = params
        self._hidden_dim = params["hidden_dim"]
        self._num_layers = params["num_layers"]
        self._dense_every_num_layers = params["dense_every_num_layers"]
        self._residual_every_num_layers = params["residual_every_num_layers"]
        self._use_inter_layer_layernorm = params["use_inter_layer_layernorm"]
        self._initial_node_representation_activation_fn = get_activation_function(
            params["initial_node_representation_activation"]
        )
        self._dense_intermediate_layer_activation_fn = get_activation_function(
            params["dense_intermediate_layer_activation"]
        )
        self._message_passing_class = get_message_passing_class(
            params["message_calculation_class"]
        )

        if not params["global_exchange_mode"].lower() in {"mean", "mlp", "gru"}:
            raise ValueError(
                f"Unknown global_exchange_mode mode {params['global_exchange_mode']} - has to be one of 'mean', 'mlp', 'gru'!"
            )
        self._global_exchange_mode = params["global_exchange_mode"]
        self._global_exchange_every_num_layers = params["global_exchange_every_num_layers"]
        self._global_exchange_weighting_fun = params["global_exchange_weighting_fun"]
        self._global_exchange_num_heads = params["global_exchange_num_heads"]
        self._global_exchange_dropout_rate = params["global_exchange_dropout_rate"]

        # Layer member variables. To be filled in in the `build` method.
        self._initial_projection_layer: tf.keras.layers.Layer = None
        self._mp_layers: List[MessagePassing] = []
        self._inter_layer_layernorms: List[tf.keras.layers.Layer] = []
        self._dense_layers: Dict[str, tf.keras.layers.Layer] = {}
        self._global_exchange_layers: Dict[str, GraphGlobalExchange] = {}
コード例 #2
0
    def __init__(self, params: Dict[str, Any], **kwargs):
        super().__init__(**kwargs)
        self._hidden_dim = int(params["hidden_dim"])

        aggregation_fn_name = params["aggregation_function"]
        self._aggregation_fn = get_aggregation_function(aggregation_fn_name)

        activation_fn_name = params["message_activation_function"]
        self._activation_fn = get_activation_function(activation_fn_name)
コード例 #3
0
    def __init__(self, params: Dict[str, Any], **kwargs):
        super().__init__(**kwargs)
        self._hidden_dim = int(params["hidden_dim"])

        aggregation_fn_name = params["aggregation_function"]
        self._aggregation_fn = get_aggregation_function(aggregation_fn_name)

        activation_fn_name = params["message_activation_function"]
        self._activation_fn = get_activation_function(activation_fn_name)

        self._hyperedge_type_mlps: List[tf.keras.layers.Layer, ...] = []

        self._num_edge_MLP_hidden_layers = params["num_edge_MLP_hidden_layers"]