Ejemplo n.º 1
0
    def __init__(self,
                 polymer_length: int,
                 polymer_dimensions: int,
                 n_recurrences: int,
                 mlp_sizes: Tuple[int, ...],
                 mlp_kwargs: Dict[str, Any] = None,
                 name: str = 'Energy'):
        super().__init__(name=name)

        if mlp_kwargs is None:
            mlp_kwargs = {
                'w_init': hk.initializers.VarianceScaling(),
                'b_init': hk.initializers.VarianceScaling(0.1),
                'activation': jax.nn.softplus
            }

        self._graph_net = nn.GraphNetEncoder(n_recurrences, mlp_sizes,
                                             mlp_kwargs)

        structure_size = polymer_length * polymer_dimensions
        self._decoder = hk.nets.MLP(output_sizes=mlp_sizes +
                                    (structure_size + 1, ),
                                    activate_final=False,
                                    name='GlobalDecoder',
                                    **mlp_kwargs)
Ejemplo n.º 2
0
  def __init__(self, n_recurrences, mlp_sizes, mlp_kwargs=None, name='Energy'):
    super(EnergyGraphNet, self).__init__(name=name)

    if mlp_kwargs is None:
      mlp_kwargs = {
        'w_init': hk.initializers.VarianceScaling(),
        'b_init': hk.initializers.VarianceScaling(0.1),
        'activation': jax.nn.softplus
      }

    self._graph_net = nn.GraphNetEncoder(n_recurrences, mlp_sizes, mlp_kwargs)
    self._decoder = hk.nets.MLP(output_sizes=mlp_sizes + (1,),
                                activate_final=False,
                                name='GlobalDecoder',
                                **mlp_kwargs)