Example #1
0
def init_nvp(rng, dim, flip, init_batch=None):
    net_init, net_apply = stax.serial(Dense(512), Relu, Dense(512), Relu,
                                      Dense(dim))
    in_shape = (-1, dim // 2)
    _, net_params = net_init(rng, in_shape)

    def shift_and_log_scale_fn(net_params, x1):
        s = net_apply(net_params, x1)
        return np.split(s, 2, axis=1)

    def nvp_forward(net_params, prev_sample, prev_logp=0.):
        d = dim // 2
        x1, x2 = prev_sample[:, :d], prev_sample[:, d:]
        if flip:
            x2, x1 = x1, x2
        shift, log_scale = shift_and_log_scale_fn(net_params, x1)
        y2 = x2 * np.exp(log_scale) + shift
        if flip:
            x1, y2 = y2, x1
        y = np.concatenate([x1, y2], axis=-1)
        return y, prev_logp + np.sum(log_scale, axis=-1)

    def nvp_reverse(net_params, next_sample, next_logp=0.):
        d = dim // 2
        y1, y2 = next_sample[:, :d], next_sample[:, d:]
        if flip:
            y1, y2 = y2, y1
        shift, log_scale = shift_and_log_scale_fn(net_params, y1)
        x2 = (y2 - shift) * np.exp(-log_scale)
        if flip:
            y1, x2 = x2, y1
        x = np.concatenate([y1, x2], axis=-1)
        return x, next_logp - np.sum(log_scale, axis=-1)

    return net_params, nvp_forward, nvp_reverse
Example #2
0
def DenseReluNetwork(out_dim: int, hidden_layers: int,
                     hidden_dim: int) -> Tuple[Callable, Callable]:
    """Create a dense neural network with Relu after hidden layers.

    Parameters
    ----------
    out_dim : int
        The output dimension.
    hidden_layers : int
        The number of hidden layers
    hidden_dim : int
        The dimension of the hidden layers

    Returns
    -------
    init_fun : function
        The function that initializes the network. Note that this is the
        init_function defined in the Jax stax module, which is different
        from the functions of my InitFunction class.
    forward_fun : function
        The function that passes the inputs through the neural network.
    """
    init_fun, forward_fun = serial(
        *(Dense(hidden_dim), Relu) * hidden_layers,
        Dense(out_dim),
    )
    return init_fun, forward_fun
Example #3
0
 def create_surrogate(self):
     surrogate_init, surrogate = stax.serial(
         Dense(200), Relu,
         Dense(200), Relu,
         Dense(200), Relu,
         Dense(1)
     )
     return surrogate, surrogate_init
Example #4
0
def network(activation):
    # Use stax to set up network initialization and evaluation functions
    net_init, net_apply = stax.serial(
        Dense(40), activation,
        Dense(40), activation,
        Dense(1)
    )
    return net_init, net_apply
def generate_network(out_features, hidden_size):
    # Use stax to set up network initialization and evaluation functions
    net_init, net_apply = stax.serial(
        Dense(hidden_size), Relu,
        Dense(hidden_size), Relu,
        Dense(out_features), Sigmoid
    )

    return net_init, net_apply
Example #6
0
def DeepQNetwork():
    init_fun, predict_fun = stax.serial(
        Conv(16, (8, 8), strides=(4, 4)), Relu,
        Conv(32, (4, 4), strides=(2, 2)), Relu,
        Conv(64, (3, 3)), Relu,
        Flatten,
        Dense(256), Relu,
        Dense(6)
    )
    return init_fun, predict_fun
Example #7
0
def PolicyNetwork():
    """Policy network for the experiments in:
    https://arxiv.org/abs/2102.12425"""
    return serial(
        helx.nn.rnn.LSTM(256),
        Dense(256),
        Relu,
        FanOut(2),
        parallel(Dense(1), Dense(1)),
    )
Example #8
0
def prepare_single_layer_model(input_size, output_size, width, key):
    init_random_params, predict = stax.serial(Dense(width), Relu,
                                              Dense(output_size), LogSoftmax)

    key, split = random.split(key)
    _, params = init_random_params(split, (-1, input_size))

    cast = lambda x: x.astype(canonicalize_dtype(onp.float64))
    params = tree_util.tree_map(cast, params)
    return predict, params, key
Example #9
0
    def create_model_params(self):
        """
            Random Weights for Autoencoder / InverseNet. These parameters are trained in the models.

            Returns:
                model_params (list): Contains numpy.arrays. These are different layers and activation functions.

         """
        if (self.hyper_params['model_name'] == 'AE'):
            model_params = [[
                Dense(64, b_init=zeros), Sigmoid,
                Dense(32, b_init=zeros), Sigmoid,
                Dense(self.hyper_params['z_latent'], b_init=zeros)
            ],
                            [
                                Dense(32, b_init=zeros), Sigmoid,
                                Dense(64, b_init=zeros), Sigmoid,
                                Dense(self.hyper_params['x_dim'], b_init=zeros)
                            ]]
        elif (self.hyper_params['model_name'] == 'IV'):
            model_params = [
                Dense(32, b_init=zeros), Sigmoid,
                Dense(64, b_init=zeros), Sigmoid,
                Dense(self.hyper_params['x_dim'], b_init=zeros)
            ]
        else:
            raise NameError('Wrong model name')
        return model_params
def init_nvp(D_in, D_out, rng):
    net_init, net_apply = stax.serial(Dense(256), Relu, Dense(256), Relu,
                                      Dense(D_out * 2))  # 2 for scale & shift
    in_shape = (-1, D_in)
    out_shape, net_params = net_init(rng, in_shape)

    def shift_and_log_scale_fn(net_params, x1):
        s = net_apply(net_params, x1)
        return np.split(s, 2, axis=1)

    return net_params, shift_and_log_scale_fn
def Lpg(hparams):
    phi = serial(Dense(16), Dense(1))
    return serial(
        # FanOut(6),
        parallel(Identity, Identity, Identity, Identity, phi, phi),
        FanInConcat(),
        LSTMCell(hparams.hidden_size)[0:2],
        DiscardHidden(),
        Relu,
        FanOut(2),
        parallel(phi, phi),
    )
Example #12
0
 def state_encoder(output_num):
     return serial(
         Conv(4, (3, 3), (1, 1), "SAME"),
         Tanh,  # BatchNorm(),
         Conv(4, (3, 3), (1, 1), "SAME"),
         Tanh,  # BatchNorm(),
         Conv(4, (3, 3), (1, 1), "SAME"),
         Tanh,  # BatchNorm(),
         Flatten,
         Dense(128),
         Tanh,  # BatchNormつけるとなぜか出力が固定値になる,
         Dense(output_num))
Example #13
0
def feed_forward():
    init_fun, predict = stax.serial(
        Dense(1024),
        Relu,
        Dense(1024),
        Relu,
        Dense(10),
    )

    def init_params(rng):
        return init_fun(rng, (-1, 28 * 28))[1]

    return init_params, predict
Example #14
0
def Cnn(n_actions: int, hidden_size) -> Module:
    return serial(
        Conv(32, (8, 8), (4, 4), "VALID"),
        Relu,
        Conv(64, (4, 4), (2, 2), "VALID"),
        Relu,
        Conv(64, (3, 3), (1, 1), "VALID"),
        Relu,
        Flatten,
        Dense(hidden_size),
        Relu,
        Dense(n_actions),
    )
Example #15
0
def JaxDeepConvNN(hilbert, hamiltonian, alpha=1, optimizer='Sgd', lr=0.1, sampler='Local'):
    """Complex deep convolutional Neural Network Machine implemented in Jax.
        Conv1d, complexReLU, Conv1d, complexReLU, Conv1d, complexReLU,
        Conv1d, complexReLU, Dense, complexReLU, Dense

            Args:
                hilbert (netket.hilbert) : hilbert space
                hamiltonian (netket.hamiltonian) : hamiltonian
                alpha (int) : hidden layer density
                optimizer (str) : possible choices are 'Sgd', 'Adam', or 'AdaMax'
                lr (float) : learning rate
                sampler (str) : possible choices are 'Local', 'Exact', 'VBS', 'Inverse'

            Returns:
                ma (netket.machine) : machine
                op (netket.optimizer) : optimizer
                sa (netket.sampler) : sampler
                machine_name (str) : name of the machine, see get_operator
                                                    """
    print('JaxDeepConvNN is used')
    input_size = hilbert.size
    init_fun, apply_fun = stax.serial(FixSrLayer, InputForConvLayer, Conv1d(alpha, (3,)), ComplexReLu,
                                      Conv1d(alpha, (3,)), ComplexReLu, Conv1d(alpha, (3,)), ComplexReLu,
                                      Conv1d(alpha, (3,)), ComplexReLu, stax.Flatten,
                                      Dense(input_size * alpha), ComplexReLu, Dense(1), FormatLayer)
    ma = nk.machine.Jax(
        hilbert,
        (init_fun, apply_fun), dtype=complex
    )
    ma.init_random_parameters(seed=12, sigma=0.01)
    # Optimizer
    if (optimizer == 'Sgd'):
        op = Wrap(ma, SgdJax(lr))
    elif (optimizer == 'Adam'):
        op = Wrap(ma, AdamJax(lr))
    else:
        op = Wrap(ma, AdaMaxJax(lr))
    # Sampler
    if (sampler == 'Local'):
        sa = nk.sampler.MetropolisLocal(machine=ma)
    elif (sampler == 'Exact'):
        sa = nk.sampler.ExactSampler(machine=ma)
    elif (sampler == 'VBS'):
        sa = my_sampler.getVBSSampler(machine=ma)
    elif (sampler == 'Inverse'):
        sa = my_sampler.getInverseSampler(machine=ma)
    else:
        sa = nk.sampler.MetropolisHamiltonian(machine=ma, hamiltonian=hamiltonian, n_chains=16)
    machine_name = 'JaxDeepConvNN'
    return ma, op, sa, machine_name
Example #16
0
def JaxTransformedFFNN(hilbert, hamiltonian, alpha=1, optimizer='Sgd', lr=0.1, sampler='Local'):
    """Complex Feed Forward Neural Network (fully connected) Machine implemented in Jax. One hidden layer.

        The input data is transformed in the beginning by the transformation 10.1103/physrevb.46.3486
        Dense, ComplexReLU, Dense

            Args:
                hilbert (netket.hilbert) : hilbert space
                hamiltonian (netket.hamiltonian) : hamiltonian
                alpha (int) : hidden layer density
                optimizer (str) : possible choices are 'Sgd', 'Adam', or 'AdaMax'
                lr (float) : learning rate
                sampler (str) : possible choices are 'Local', 'Exact', 'VBS', 'Inverse'

            Returns:
                ma (netket.machine) : machine
                op (netket.optimizer) : optimizer
                sa (netket.sampler) : sampler
                machine_name (str) : name of the machine, see get_operator
                                                """
    print('JaxTransformedFFNN is used')
    input_size = hilbert.size
    init_fun, apply_fun = stax.serial(FixSrLayer, TransformedLayer,
        Dense(input_size * alpha), ComplexReLu,
        Dense(1), FormatLayer)
    ma = nk.machine.Jax(
        hilbert,
        (init_fun, apply_fun), dtype=complex
    )
    ma.init_random_parameters(seed=12, sigma=0.01)
    # Optimizer
    if (optimizer == 'Sgd'):
        op = Wrap(ma, SgdJax(lr))
    elif (optimizer == 'Adam'):
        op = Wrap(ma, AdamJax(lr))
    else:
        op = Wrap(ma, AdaMaxJax(lr))
    # Sampler
    if (sampler == 'Local'):
        sa = nk.sampler.MetropolisLocal(machine=ma)
    elif (sampler == 'Exact'):
        sa = nk.sampler.ExactSampler(machine=ma)
    elif(sampler == 'VBS'):
        sa = my_sampler.getVBSSampler(machine=ma)
    elif (sampler == 'Inverse'):
        sa = my_sampler.getInverseSampler(machine=ma)
    else:
        sa = nk.sampler.MetropolisHamiltonian(machine=ma, hamiltonian=hamiltonian, n_chains=16)
    machine_name = 'JaxTransformedFFNN'
    return ma, op, sa, machine_name
Example #17
0
def _create_networks():
    encoder1_init, encode1 = stax.serial(Dense(200), Sigmoid)

    encoder2_init, encode2 = stax.serial(Dense(200), Sigmoid)

    decoder2_init, decode2 = stax.serial(Dense(200), Sigmoid)

    decoder1_init, decode1 = stax.serial(Dense(28 * 28), Sigmoid)

    encoder = (encode1, encode2)
    encoder_init = (encoder1_init, encoder2_init)
    decoder = (decode1, decode2)
    decoder_init = (decoder1_init, decoder2_init)

    return encoder, encoder_init, decoder, decoder_init
Example #18
0
def init_NN(Q):
    layers = []
    num_layers = len(Q)
    for i in range(0, num_layers - 2):
        layers.append(
            Dense(Q[i + 1],
                  W_init=glorot_normal(dtype=np.float64),
                  b_init=normal(dtype=np.float64)))
        layers.append(Tanh)
    layers.append(
        Dense(Q[-1],
              W_init=glorot_normal(dtype=np.float64),
              b_init=normal(dtype=np.float64)))
    net_init, net_apply = stax.serial(*layers)
    return net_init, net_apply
Example #19
0
def SyntheticReturn(features_network):
    """Synthetic return module as described in:
    https://arxiv.org/abs/2102.12425,
    Raposo, D., Synthetic Returns for Long-Term Credit Assignment, 2021."""
    #  sigmoid gate
    g = lambda: serial(Dense(256), Relu, Dense(1), Relu, Dense(1), Sigmoid)
    #  state utility contribution
    c = lambda: serial(Dense(256), Relu, Dense(256), Relu, Dense(1))
    #  state utility baseline
    b = lambda: serial(Dense(256), Relu, Dense(256), Relu, Dense(1))
    return serial(features_network, Flatten, FanOut(3),
                  parallel(g(), c(), b()))
Example #20
0
 def __init__(self, rng, learning_rate=0.001, nplayers=2,
              nparams=5, hidden_size=1000,
              name='Network'):
     self.key = rng
     self.init_fun, self.apply_fun = stax.serial(
         Flatten,
         Dense(hidden_size), Relu,
         Dense(hidden_size), Relu,
         Dense(hidden_size), Relu,
         Dense(nplayers)
     )
     self.in_shape = (-1, nplayers, nparams)
     _, self.net_params = self.init_fun(self.key, self.in_shape)
     self.opt_init, self.opt_update, self.get_params = optimizers.adam(step_size=learning_rate)
     self.opt_state = self.opt_init(self.net_params)
     self.loss = np.inf
Example #21
0
def mlstm64():
    """Return mLSTM64 model's initialization and forward pass functions.

    The initializer function returned will give us random weights as a starting point.

    The model forward pass function will accept any weights compatible with those
    generated by the initializer function.
    The model implemented here has a trainable embedding,
    four consecutive mLSTM layers each with 64 nodes,
    and a single dense layer to predict the next amino acid identity.

    This is the simplest model published by the original UniRep authors.
    """
    model_layers = (
        AAEmbedding(10),
        mLSTM(64),
        mLSTMHiddenStates(),
        mLSTM(64),
        mLSTMHiddenStates(),
        mLSTM(64),
        mLSTMHiddenStates(),
        mLSTM(64),
        mLSTMHiddenStates(),
        Dense(25),
        Softmax,
    )
    init_fun, apply_fun = serial(*model_layers)
    return init_fun, apply_fun
Example #22
0
def mlstm256():
    """Return mLSTM256 model's initialization and forward pass functions.

    The initializer function returned will give us random weights as a starting point.

    The model forward pass function will accept any weights compatible with those
    generated by the initializer function.
    The model implemented here has a trainable embedding,
    four consecutive mLSTM layers each with 256 nodes,
    and a single dense layer to predict the next amino acid identity.

    It's a simpler but nonetheless still complex version of the UniRep model
    that can be trained to generate protein representations.
    """
    model_layers = (
        AAEmbedding(10),
        mLSTM(256),
        mLSTMHiddenStates(),
        mLSTM(256),
        mLSTMHiddenStates(),
        mLSTM(256),
        mLSTMHiddenStates(),
        mLSTM(256),
        mLSTMHiddenStates(),
        Dense(25),
        Softmax,
    )
    init_fun, apply_fun = serial(*model_layers)
    return init_fun, apply_fun
Example #23
0
def ResNet50(num_classes):
    return stax.serial(
        GeneralConv(("HWCN", "OIHW", "NHWC"), 64, (7, 7), (2, 2), "SAME"),
        BatchNorm(),
        Relu,
        MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [64, 64, 256], strides=(1, 1)),
        IdentityBlock(3, [64, 64]),
        IdentityBlock(3, [64, 64]),
        ConvBlock(3, [128, 128, 512]),
        IdentityBlock(3, [128, 128]),
        IdentityBlock(3, [128, 128]),
        IdentityBlock(3, [128, 128]),
        ConvBlock(3, [256, 256, 1024]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        ConvBlock(3, [512, 512, 2048]),
        IdentityBlock(3, [512, 512]),
        IdentityBlock(3, [512, 512]),
        AvgPool((7, 7)),
        Flatten,
        Dense(num_classes),
        LogSoftmax,
    )
Example #24
0
def feature_extractor(rng, dim):
    """Feature extraction network."""
    init_params, forward = stax.serial(
        Conv(16, (8, 8), padding='SAME', strides=(2, 2)),
        Relu,
        MaxPool((2, 2), (1, 1)),
        Conv(32, (4, 4), padding='VALID', strides=(2, 2)),
        Relu,
        MaxPool((2, 2), (1, 1)),
        Flatten,
        Dense(dim),
        Relu,
        Dense(dim),
    )
    temp, rng = random.split(rng)
    params = init_params(temp, (-1, 28, 28, 1))[1]
    return params, forward
Example #25
0
    def _get_model(self):
        """
        Returns policy network
        """
        layers = []

        # inner / hidden network layers + non-linearities
        for l in self.network_layers:
            layers.append(Dense(l))
            layers.append(Relu)

        # output layer (no non-linearity)
        layers.append(Dense(self.output_dsimension))

        return stax.serial(*layers)

        raise NotImplementedError
def create_q_net(
    obs_dim, action_dim, rngkey=jax.random.PRNGKey(0)
) -> TT.Tuple[RT.NNParams, RT.NNParamsFn]:
    q_init, q_fn = serial(
        Dense(64, he_normal(), zeros),
        Relu,
        Dense(64, he_normal(), zeros),
        Relu,
        Dense(action_dim, he_normal(), zeros),
    )
    output_shape, q_params = q_init(rngkey, (1, obs_dim + action_dim))

    @jit
    def q_fn2(q, S, A):
        return q_fn(q, jnp.hstack([S, A]))

    return q_params, q_fn2
Example #27
0
def conv():
    init_fun, predict = stax.serial(
        Conv(16, (8, 8), padding='SAME', strides=(2, 2)),
        Relu,
        MaxPool((2, 2), (1, 1)),
        Conv(32, (4, 4), padding='VALID', strides=(2, 2)),
        Relu,
        MaxPool((2, 2), (1, 1)),
        Flatten,
        Dense(32),
        Relu,
        Dense(10),
    )

    def init_params(rng):
        return init_fun(rng, (-1, 28, 28, 1))[1]

    return init_params, predict
Example #28
0
def LeNet5(num_classes):
    return stax.serial(
        GeneralConv(('HWCN','OIHW','NHWC'), 64, (7,7), (2,2), 'SAME'),
        BatchNorm(),
        Relu,
        AvgPool((3,3)),

        Conv(16, (5,5), strides = (1,1),padding="SAME"),
        BatchNorm(),
        Relu,
        AvgPool((3,3)),

        Flatten,
        Dense(num_classes*10),
        Dense(num_classes*5),
        Dense(num_classes),
        LogSoftmax
    )
Example #29
0
def value_appoximator():
    """
    Approximator for value of the input state.

    Returns
    -------
    (init, apply) tuple
    """
    init, apply = serial(
        Flatten,
        Dense(2048),  # 1024
        Relu,
        Dense(1024),  # 512
        Relu,
        Dense(512),  # 256
        Relu,
        Dense(1))
    return init, apply
Example #30
0
    def _get_model(self):
        """
        Return jax network initialisation and forward method.
        """
        layers = []

        # inner / hidden network layers + non-linearities
        for l in self.network_layers:
            layers.append(Dense(l))
            layers.append(Relu)

        # output layer (no non-linearity)
        layers.append(Dense(self.output_dimension))

        # make jax stax object
        model = stax.serial(*layers)

        return model