Ejemplo n.º 1
0
 def get_default_hyperparameters(cls,
                                 mp_style: Optional[str] = None
                                 ) -> Dict[str, Any]:
     """Get the default hyperparameter dictionary for the class."""
     params = {
         f"gnn_{name}": value
         for name, value in GNN.get_default_hyperparameters(
             mp_style).items()
     }
     these_hypers: Dict[str, Any] = {
         "optimizer": "Adam",  # One of "SGD", "RMSProp", "Adam"
         "learning_rate": 0.001,
         "learning_rate_warmup_steps": None,
         "learning_rate_decay_steps": None,
         "momentum": 0.85,
         "rmsprop_rho":
         0.98,  # decay of gradients in RMSProp (unused otherwise)
         "gradient_clip_value":
         None,  # Set to float value to clip each gradient separately
         "gradient_clip_norm":
         None,  # Set to value to clip gradients by their norm
         "gradient_clip_global_norm":
         None,  # Set to value to clip gradients by their global norm
         "use_intermediate_gnn_results": False,
     }
     params.update(these_hypers)
     return params
Ejemplo n.º 2
0
 def get_default_hyperparameters(cls, mp_style: Optional[str] = None) -> Dict[str, Any]:
     """Get the default hyperparameter dictionary for the class."""
     params = {f"gnn_{name}": value for name, value in GNN.get_default_hyperparameters(mp_style).items()}
     these_hypers: Dict[str, Any] = {
         "optimizer": "Adam",  # One of "SGD", "RMSProp", "Adam"
         "learning_rate": 0.001,
         "learning_rate_decay": 0.98,
         "momentum": 0.85,
         "gradient_clip_value": 1.0,
         "use_intermediate_gnn_results": False,
     }
     params.update(these_hypers)
     return params
Ejemplo n.º 3
0
 def get_default_hyperparameters(cls, mp_style: Optional[str] = None) -> Dict[str, Any]:
     """Get the default hyperparameter dictionary for the class."""
     params = {f"gnn_{name}": value for name, value in GNN.get_default_hyperparameters(mp_style).items()}
     these_hypers: Dict[str, Any] = {
         "graph_aggregation_size": 256,
         "graph_aggregation_num_heads": 16,
         "graph_aggregation_hidden_layers": [128],
         "graph_aggregation_dropout_rate": 0.2,
         "token_embedding_size":  64,
         "gnn_message_calculation_class": "gnn_edge_mlp",
         "gnn_hidden_dim": 64,
         "gnn_global_exchange_mode": "mlp",
         "gnn_global_exchange_every_num_layers": 10,
         "gnn_num_layers": 4,
         "graph_encoding_size": 256,
     }
     params.update(these_hypers)
     return params