def init_nvp(rng, dim, flip, init_batch=None): net_init, net_apply = stax.serial(Dense(512), Relu, Dense(512), Relu, Dense(dim)) in_shape = (-1, dim // 2) _, net_params = net_init(rng, in_shape) def shift_and_log_scale_fn(net_params, x1): s = net_apply(net_params, x1) return np.split(s, 2, axis=1) def nvp_forward(net_params, prev_sample, prev_logp=0.): d = dim // 2 x1, x2 = prev_sample[:, :d], prev_sample[:, d:] if flip: x2, x1 = x1, x2 shift, log_scale = shift_and_log_scale_fn(net_params, x1) y2 = x2 * np.exp(log_scale) + shift if flip: x1, y2 = y2, x1 y = np.concatenate([x1, y2], axis=-1) return y, prev_logp + np.sum(log_scale, axis=-1) def nvp_reverse(net_params, next_sample, next_logp=0.): d = dim // 2 y1, y2 = next_sample[:, :d], next_sample[:, d:] if flip: y1, y2 = y2, y1 shift, log_scale = shift_and_log_scale_fn(net_params, y1) x2 = (y2 - shift) * np.exp(-log_scale) if flip: y1, x2 = x2, y1 x = np.concatenate([y1, x2], axis=-1) return x, next_logp - np.sum(log_scale, axis=-1) return net_params, nvp_forward, nvp_reverse
def DenseReluNetwork(out_dim: int, hidden_layers: int, hidden_dim: int) -> Tuple[Callable, Callable]: """Create a dense neural network with Relu after hidden layers. Parameters ---------- out_dim : int The output dimension. hidden_layers : int The number of hidden layers hidden_dim : int The dimension of the hidden layers Returns ------- init_fun : function The function that initializes the network. Note that this is the init_function defined in the Jax stax module, which is different from the functions of my InitFunction class. forward_fun : function The function that passes the inputs through the neural network. """ init_fun, forward_fun = serial( *(Dense(hidden_dim), Relu) * hidden_layers, Dense(out_dim), ) return init_fun, forward_fun
def create_surrogate(self): surrogate_init, surrogate = stax.serial( Dense(200), Relu, Dense(200), Relu, Dense(200), Relu, Dense(1) ) return surrogate, surrogate_init
def network(activation): # Use stax to set up network initialization and evaluation functions net_init, net_apply = stax.serial( Dense(40), activation, Dense(40), activation, Dense(1) ) return net_init, net_apply
def generate_network(out_features, hidden_size): # Use stax to set up network initialization and evaluation functions net_init, net_apply = stax.serial( Dense(hidden_size), Relu, Dense(hidden_size), Relu, Dense(out_features), Sigmoid ) return net_init, net_apply
def DeepQNetwork(): init_fun, predict_fun = stax.serial( Conv(16, (8, 8), strides=(4, 4)), Relu, Conv(32, (4, 4), strides=(2, 2)), Relu, Conv(64, (3, 3)), Relu, Flatten, Dense(256), Relu, Dense(6) ) return init_fun, predict_fun
def PolicyNetwork(): """Policy network for the experiments in: https://arxiv.org/abs/2102.12425""" return serial( helx.nn.rnn.LSTM(256), Dense(256), Relu, FanOut(2), parallel(Dense(1), Dense(1)), )
def prepare_single_layer_model(input_size, output_size, width, key): init_random_params, predict = stax.serial(Dense(width), Relu, Dense(output_size), LogSoftmax) key, split = random.split(key) _, params = init_random_params(split, (-1, input_size)) cast = lambda x: x.astype(canonicalize_dtype(onp.float64)) params = tree_util.tree_map(cast, params) return predict, params, key
def create_model_params(self): """ Random Weights for Autoencoder / InverseNet. These parameters are trained in the models. Returns: model_params (list): Contains numpy.arrays. These are different layers and activation functions. """ if (self.hyper_params['model_name'] == 'AE'): model_params = [[ Dense(64, b_init=zeros), Sigmoid, Dense(32, b_init=zeros), Sigmoid, Dense(self.hyper_params['z_latent'], b_init=zeros) ], [ Dense(32, b_init=zeros), Sigmoid, Dense(64, b_init=zeros), Sigmoid, Dense(self.hyper_params['x_dim'], b_init=zeros) ]] elif (self.hyper_params['model_name'] == 'IV'): model_params = [ Dense(32, b_init=zeros), Sigmoid, Dense(64, b_init=zeros), Sigmoid, Dense(self.hyper_params['x_dim'], b_init=zeros) ] else: raise NameError('Wrong model name') return model_params
def init_nvp(D_in, D_out, rng): net_init, net_apply = stax.serial(Dense(256), Relu, Dense(256), Relu, Dense(D_out * 2)) # 2 for scale & shift in_shape = (-1, D_in) out_shape, net_params = net_init(rng, in_shape) def shift_and_log_scale_fn(net_params, x1): s = net_apply(net_params, x1) return np.split(s, 2, axis=1) return net_params, shift_and_log_scale_fn
def Lpg(hparams): phi = serial(Dense(16), Dense(1)) return serial( # FanOut(6), parallel(Identity, Identity, Identity, Identity, phi, phi), FanInConcat(), LSTMCell(hparams.hidden_size)[0:2], DiscardHidden(), Relu, FanOut(2), parallel(phi, phi), )
def state_encoder(output_num): return serial( Conv(4, (3, 3), (1, 1), "SAME"), Tanh, # BatchNorm(), Conv(4, (3, 3), (1, 1), "SAME"), Tanh, # BatchNorm(), Conv(4, (3, 3), (1, 1), "SAME"), Tanh, # BatchNorm(), Flatten, Dense(128), Tanh, # BatchNormつけるとなぜか出力が固定値になる, Dense(output_num))
def feed_forward(): init_fun, predict = stax.serial( Dense(1024), Relu, Dense(1024), Relu, Dense(10), ) def init_params(rng): return init_fun(rng, (-1, 28 * 28))[1] return init_params, predict
def Cnn(n_actions: int, hidden_size) -> Module: return serial( Conv(32, (8, 8), (4, 4), "VALID"), Relu, Conv(64, (4, 4), (2, 2), "VALID"), Relu, Conv(64, (3, 3), (1, 1), "VALID"), Relu, Flatten, Dense(hidden_size), Relu, Dense(n_actions), )
def JaxDeepConvNN(hilbert, hamiltonian, alpha=1, optimizer='Sgd', lr=0.1, sampler='Local'): """Complex deep convolutional Neural Network Machine implemented in Jax. Conv1d, complexReLU, Conv1d, complexReLU, Conv1d, complexReLU, Conv1d, complexReLU, Dense, complexReLU, Dense Args: hilbert (netket.hilbert) : hilbert space hamiltonian (netket.hamiltonian) : hamiltonian alpha (int) : hidden layer density optimizer (str) : possible choices are 'Sgd', 'Adam', or 'AdaMax' lr (float) : learning rate sampler (str) : possible choices are 'Local', 'Exact', 'VBS', 'Inverse' Returns: ma (netket.machine) : machine op (netket.optimizer) : optimizer sa (netket.sampler) : sampler machine_name (str) : name of the machine, see get_operator """ print('JaxDeepConvNN is used') input_size = hilbert.size init_fun, apply_fun = stax.serial(FixSrLayer, InputForConvLayer, Conv1d(alpha, (3,)), ComplexReLu, Conv1d(alpha, (3,)), ComplexReLu, Conv1d(alpha, (3,)), ComplexReLu, Conv1d(alpha, (3,)), ComplexReLu, stax.Flatten, Dense(input_size * alpha), ComplexReLu, Dense(1), FormatLayer) ma = nk.machine.Jax( hilbert, (init_fun, apply_fun), dtype=complex ) ma.init_random_parameters(seed=12, sigma=0.01) # Optimizer if (optimizer == 'Sgd'): op = Wrap(ma, SgdJax(lr)) elif (optimizer == 'Adam'): op = Wrap(ma, AdamJax(lr)) else: op = Wrap(ma, AdaMaxJax(lr)) # Sampler if (sampler == 'Local'): sa = nk.sampler.MetropolisLocal(machine=ma) elif (sampler == 'Exact'): sa = nk.sampler.ExactSampler(machine=ma) elif (sampler == 'VBS'): sa = my_sampler.getVBSSampler(machine=ma) elif (sampler == 'Inverse'): sa = my_sampler.getInverseSampler(machine=ma) else: sa = nk.sampler.MetropolisHamiltonian(machine=ma, hamiltonian=hamiltonian, n_chains=16) machine_name = 'JaxDeepConvNN' return ma, op, sa, machine_name
def JaxTransformedFFNN(hilbert, hamiltonian, alpha=1, optimizer='Sgd', lr=0.1, sampler='Local'): """Complex Feed Forward Neural Network (fully connected) Machine implemented in Jax. One hidden layer. The input data is transformed in the beginning by the transformation 10.1103/physrevb.46.3486 Dense, ComplexReLU, Dense Args: hilbert (netket.hilbert) : hilbert space hamiltonian (netket.hamiltonian) : hamiltonian alpha (int) : hidden layer density optimizer (str) : possible choices are 'Sgd', 'Adam', or 'AdaMax' lr (float) : learning rate sampler (str) : possible choices are 'Local', 'Exact', 'VBS', 'Inverse' Returns: ma (netket.machine) : machine op (netket.optimizer) : optimizer sa (netket.sampler) : sampler machine_name (str) : name of the machine, see get_operator """ print('JaxTransformedFFNN is used') input_size = hilbert.size init_fun, apply_fun = stax.serial(FixSrLayer, TransformedLayer, Dense(input_size * alpha), ComplexReLu, Dense(1), FormatLayer) ma = nk.machine.Jax( hilbert, (init_fun, apply_fun), dtype=complex ) ma.init_random_parameters(seed=12, sigma=0.01) # Optimizer if (optimizer == 'Sgd'): op = Wrap(ma, SgdJax(lr)) elif (optimizer == 'Adam'): op = Wrap(ma, AdamJax(lr)) else: op = Wrap(ma, AdaMaxJax(lr)) # Sampler if (sampler == 'Local'): sa = nk.sampler.MetropolisLocal(machine=ma) elif (sampler == 'Exact'): sa = nk.sampler.ExactSampler(machine=ma) elif(sampler == 'VBS'): sa = my_sampler.getVBSSampler(machine=ma) elif (sampler == 'Inverse'): sa = my_sampler.getInverseSampler(machine=ma) else: sa = nk.sampler.MetropolisHamiltonian(machine=ma, hamiltonian=hamiltonian, n_chains=16) machine_name = 'JaxTransformedFFNN' return ma, op, sa, machine_name
def _create_networks(): encoder1_init, encode1 = stax.serial(Dense(200), Sigmoid) encoder2_init, encode2 = stax.serial(Dense(200), Sigmoid) decoder2_init, decode2 = stax.serial(Dense(200), Sigmoid) decoder1_init, decode1 = stax.serial(Dense(28 * 28), Sigmoid) encoder = (encode1, encode2) encoder_init = (encoder1_init, encoder2_init) decoder = (decode1, decode2) decoder_init = (decoder1_init, decoder2_init) return encoder, encoder_init, decoder, decoder_init
def init_NN(Q): layers = [] num_layers = len(Q) for i in range(0, num_layers - 2): layers.append( Dense(Q[i + 1], W_init=glorot_normal(dtype=np.float64), b_init=normal(dtype=np.float64))) layers.append(Tanh) layers.append( Dense(Q[-1], W_init=glorot_normal(dtype=np.float64), b_init=normal(dtype=np.float64))) net_init, net_apply = stax.serial(*layers) return net_init, net_apply
def SyntheticReturn(features_network): """Synthetic return module as described in: https://arxiv.org/abs/2102.12425, Raposo, D., Synthetic Returns for Long-Term Credit Assignment, 2021.""" # sigmoid gate g = lambda: serial(Dense(256), Relu, Dense(1), Relu, Dense(1), Sigmoid) # state utility contribution c = lambda: serial(Dense(256), Relu, Dense(256), Relu, Dense(1)) # state utility baseline b = lambda: serial(Dense(256), Relu, Dense(256), Relu, Dense(1)) return serial(features_network, Flatten, FanOut(3), parallel(g(), c(), b()))
def __init__(self, rng, learning_rate=0.001, nplayers=2, nparams=5, hidden_size=1000, name='Network'): self.key = rng self.init_fun, self.apply_fun = stax.serial( Flatten, Dense(hidden_size), Relu, Dense(hidden_size), Relu, Dense(hidden_size), Relu, Dense(nplayers) ) self.in_shape = (-1, nplayers, nparams) _, self.net_params = self.init_fun(self.key, self.in_shape) self.opt_init, self.opt_update, self.get_params = optimizers.adam(step_size=learning_rate) self.opt_state = self.opt_init(self.net_params) self.loss = np.inf
def mlstm64(): """Return mLSTM64 model's initialization and forward pass functions. The initializer function returned will give us random weights as a starting point. The model forward pass function will accept any weights compatible with those generated by the initializer function. The model implemented here has a trainable embedding, four consecutive mLSTM layers each with 64 nodes, and a single dense layer to predict the next amino acid identity. This is the simplest model published by the original UniRep authors. """ model_layers = ( AAEmbedding(10), mLSTM(64), mLSTMHiddenStates(), mLSTM(64), mLSTMHiddenStates(), mLSTM(64), mLSTMHiddenStates(), mLSTM(64), mLSTMHiddenStates(), Dense(25), Softmax, ) init_fun, apply_fun = serial(*model_layers) return init_fun, apply_fun
def mlstm256(): """Return mLSTM256 model's initialization and forward pass functions. The initializer function returned will give us random weights as a starting point. The model forward pass function will accept any weights compatible with those generated by the initializer function. The model implemented here has a trainable embedding, four consecutive mLSTM layers each with 256 nodes, and a single dense layer to predict the next amino acid identity. It's a simpler but nonetheless still complex version of the UniRep model that can be trained to generate protein representations. """ model_layers = ( AAEmbedding(10), mLSTM(256), mLSTMHiddenStates(), mLSTM(256), mLSTMHiddenStates(), mLSTM(256), mLSTMHiddenStates(), mLSTM(256), mLSTMHiddenStates(), Dense(25), Softmax, ) init_fun, apply_fun = serial(*model_layers) return init_fun, apply_fun
def ResNet50(num_classes): return stax.serial( GeneralConv(("HWCN", "OIHW", "NHWC"), 64, (7, 7), (2, 2), "SAME"), BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [64, 64, 256], strides=(1, 1)), IdentityBlock(3, [64, 64]), IdentityBlock(3, [64, 64]), ConvBlock(3, [128, 128, 512]), IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), ConvBlock(3, [256, 256, 1024]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), ConvBlock(3, [512, 512, 2048]), IdentityBlock(3, [512, 512]), IdentityBlock(3, [512, 512]), AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax, )
def feature_extractor(rng, dim): """Feature extraction network.""" init_params, forward = stax.serial( Conv(16, (8, 8), padding='SAME', strides=(2, 2)), Relu, MaxPool((2, 2), (1, 1)), Conv(32, (4, 4), padding='VALID', strides=(2, 2)), Relu, MaxPool((2, 2), (1, 1)), Flatten, Dense(dim), Relu, Dense(dim), ) temp, rng = random.split(rng) params = init_params(temp, (-1, 28, 28, 1))[1] return params, forward
def _get_model(self): """ Returns policy network """ layers = [] # inner / hidden network layers + non-linearities for l in self.network_layers: layers.append(Dense(l)) layers.append(Relu) # output layer (no non-linearity) layers.append(Dense(self.output_dsimension)) return stax.serial(*layers) raise NotImplementedError
def create_q_net( obs_dim, action_dim, rngkey=jax.random.PRNGKey(0) ) -> TT.Tuple[RT.NNParams, RT.NNParamsFn]: q_init, q_fn = serial( Dense(64, he_normal(), zeros), Relu, Dense(64, he_normal(), zeros), Relu, Dense(action_dim, he_normal(), zeros), ) output_shape, q_params = q_init(rngkey, (1, obs_dim + action_dim)) @jit def q_fn2(q, S, A): return q_fn(q, jnp.hstack([S, A])) return q_params, q_fn2
def conv(): init_fun, predict = stax.serial( Conv(16, (8, 8), padding='SAME', strides=(2, 2)), Relu, MaxPool((2, 2), (1, 1)), Conv(32, (4, 4), padding='VALID', strides=(2, 2)), Relu, MaxPool((2, 2), (1, 1)), Flatten, Dense(32), Relu, Dense(10), ) def init_params(rng): return init_fun(rng, (-1, 28, 28, 1))[1] return init_params, predict
def LeNet5(num_classes): return stax.serial( GeneralConv(('HWCN','OIHW','NHWC'), 64, (7,7), (2,2), 'SAME'), BatchNorm(), Relu, AvgPool((3,3)), Conv(16, (5,5), strides = (1,1),padding="SAME"), BatchNorm(), Relu, AvgPool((3,3)), Flatten, Dense(num_classes*10), Dense(num_classes*5), Dense(num_classes), LogSoftmax )
def value_appoximator(): """ Approximator for value of the input state. Returns ------- (init, apply) tuple """ init, apply = serial( Flatten, Dense(2048), # 1024 Relu, Dense(1024), # 512 Relu, Dense(512), # 256 Relu, Dense(1)) return init, apply
def _get_model(self): """ Return jax network initialisation and forward method. """ layers = [] # inner / hidden network layers + non-linearities for l in self.network_layers: layers.append(Dense(l)) layers.append(Relu) # output layer (no non-linearity) layers.append(Dense(self.output_dimension)) # make jax stax object model = stax.serial(*layers) return model