def cov_estimate(*, optimization_path: Sequence[jnp.ndarray], optimization_path_grads: Sequence[jnp.ndarray], history: int): """Estimate covariance from an optimization path.""" dim = optimization_path[0].shape[0] position_diffs = jnp.empty((dim, 0)) gradient_diffs = jnp.empty((dim, 0)) approximations: List[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]] = [] diagonal_estimate = jnp.ones(dim) for j in range(len(optimization_path) - 1): _, thin_factors, scaling_outer_product = bfgs_inverse_hessian( updates_of_position_differences=position_diffs, updates_of_gradient_differences=gradient_diffs, ) position_diff = optimization_path[j + 1] - optimization_path[j] gradient_diff = optimization_path_grads[j] - optimization_path_grads[ j + 1] b = position_diff @ gradient_diff gradient_diff_norm = gradient_diff**2 new_diagonal_estimate = diagonal_estimate if b < 1e-12 * jnp.sum(gradient_diff_norm): position_diffs = jnp.column_stack( (position_diffs[:, -history + 1:], position_diff)) gradient_diffs = jnp.column_stack( (gradient_diffs[:, -history + 1:], gradient_diff)) a = gradient_diff @ (diagonal_estimate * gradient_diff) c = position_diff @ (position_diff / diagonal_estimate) new_diagonal_estimate = 1.0 / (a / (b * diagonal_estimate) + gradient_diff_norm / b - (a * position_diff**2) / (b * c * diagonal_estimate**2)) approximations.append( (diagonal_estimate, thin_factors, scaling_outer_product)) diagonal_estimate = new_diagonal_estimate return approximations
def lbfgs( *, log_target_density: Callable[[jnp.ndarray], jnp.ndarray], initial_value: jnp.ndarray, # theta_init inverse_hessian_history: int = 6, # J relative_tolerance: float = 1e-13, # tau_rel max_iters: int = 1000, # L wolfe_bounds: Tuple[float, float] = (1e-4, 0.9), positivity_threshold: float = 2.2e-16, ): """LBFGS implementation which returns the optimization path and gradients.""" dim = initial_value.shape[0] grad_log_density = jax.grad(log_target_density) optimization_path = [initial_value] current_lp = log_target_density(initial_value) grad_optimization_path = [grad_log_density(initial_value)] position_diffs = jnp.empty((dim, 0)) gradient_diffs = jnp.empty((dim, 0)) for _ in range(max_iters): diagonal_estimate, thin_factors, scaling_outer_product = bfgs_inverse_hessian( updates_of_position_differences=position_diffs, updates_of_gradient_differences=gradient_diffs, ) grad_lp = grad_optimization_path[-1] search_direction = diagonal_estimate * grad_lp + thin_factors @ ( scaling_outer_product @ (jnp.transpose(thin_factors) @ grad_lp)) step_size = 1.0 while step_size > 1e-8: proposed = optimization_path[-1] + step_size * search_direction proposed_lp = log_target_density(proposed) if proposed_lp >= current_lp + (wolfe_bounds[0] * grad_lp) @ ( step_size * search_direction): proposed_grad = grad_log_density(proposed) if (proposed_grad @ search_direction <= wolfe_bounds[1] * grad_lp @ search_direction): break step_size = 0.5 * step_size optimization_path.append(proposed) grad_optimization_path.append(proposed_grad) if (proposed_lp - current_lp) / jnp.abs(current_lp) < relative_tolerance: return optimization_path, grad_optimization_path current_lp = proposed_lp position_diff: jnp.ndarray = optimization_path[-1] - optimization_path[ -2] grad_diff = -grad_optimization_path[-1] + grad_optimization_path[-2] if position_diff @ grad_diff > positivity_threshold * jnp.sum(grad_diff **2): position_diffs = jnp.column_stack( (position_diffs[:, -inverse_hessian_history + 1:], position_diff)) gradient_diffs = jnp.column_stack( (gradient_diffs[:, -inverse_hessian_history + 1:], grad_diff)) return optimization_path, grad_optimization_path
def main(): smc = np.load('mc_ddpip_3d_smeared.npy') print(smc.shape) data = np.column_stack(get_vars(smc)[:3]) print(data.shape) do_fit(data, 100)
def _indices(key): if not sparse_shape: return jnp.empty((nse, n_sparse), dtype=int) flat_ind = random.choice(key, sparse_size, shape=(nse, ), replace=not unique_indices) return jnp.column_stack(jnp.unravel_index(flat_ind, sparse_shape))
def momentum_from_cluster(clu: Cluster) -> (Momentum): """ Assuming photon (massless particle) """ sinth = clu.sinth return Momentum.from_ndarray( np.column_stack([ clu.energy * sinth * np.sin(clu.phi), clu.energy * sinth * np.cos(clu.phi), clu.energy * clu.costh, ]))
def test_cluster_from_ndarray(): """ """ N = 100 energy = rjax.uniform(rng, (N,), minval=0., maxval=3.) costh = rjax.uniform(rng, (N,), minval=-1., maxval=1.) phi = rjax.uniform(rng, (N,), minval=-np.pi, maxval=np.pi) clu = Cluster.from_ndarray(np.column_stack([energy, costh, phi])) assert np.allclose(clu.energy, energy) assert np.allclose(clu.costh, costh) assert np.allclose(clu.phi, phi) assert clu.as_array.shape == (N, 3)
def sample(events: np.ndarray) -> (np.ndarray): """ Resolution sampler for MC events Args: - events: [E (MeV), m^2(DD) (GeV^2), m^2(Dpi) (GeV^2)] """ e = events[:, 0] * 10**-3 tdd = jnp.sqrt(events[:, 1]) - 2 * mdn mdpi = jnp.sqrt(events[:, 2]) offsets = jax.random.normal(jax.random.PRNGKey(1), events.shape) return jnp.column_stack([ (e + smddpi(e, tdd) * offsets[:, 0]) * 10**3, (tdd + stdd(tdd) * offsets[:, 1] + 2 * mdn)**2, (mdpi + smdstp() * offsets[:, 2])**2, ])
def test_linear_regression(in_dim=5, out_dim=2, shape=(1000, )): key = jr.PRNGKey(time.time_ns()) key1, key2 = jr.split(key, 2) covariates = jr.normal(key1, shape + (in_dim, )) covariates = np.column_stack([covariates, np.ones(shape)]) data = jr.normal(key2, shape + (out_dim, )) lr = regrs.GaussianLinearRegression.fit( dict(data=data, covariates=covariates)) # compare to least squares fit. note that the covariance matrix is only the # covariance of the residuals if we fit the intercept term what = np.linalg.lstsq(covariates, data)[0].T assert np.allclose(lr.weights, what) resid = data - covariates @ what.T assert np.allclose(lr.covariance_matrix, np.cov(resid, rowvar=False, bias=True), atol=1e-6)
def real_fourier_basis(n: int) -> Tuple[np.ndarray, np.ndarray]: """Construct a real unitary fourier basis of size (n). Args: n: The basis size: Returns: A tuple of the basis, and the frequencies at which to evaluate the spectral distribution function to get the variances for the Fourier-domain coefficients. """ assert n > 1 dc = np.ones((n, )) dc_freq = 0 cosine_basis_vectors = [] cosine_freqs = [] sine_basis_vectors = [] sine_freqs = [] ts = np.arange(n) for w in range(1, 1 + (n - 1) // 2): x = w * (2 * np.pi / n) * ts cosine_basis_vectors.append(np.sqrt(2) * np.cos(x)) cosine_freqs.append(w) sine_basis_vectors.append(-np.sqrt(2) * np.sin(x)) sine_freqs.append(w) if n % 2 == 0: w = n // 2 x = w * 2 * np.pi * ts / n cosine_basis_vectors.append(np.cos(x)) cosine_freqs.append(w) basis = np.column_stack( (dc, *cosine_basis_vectors, *sine_basis_vectors[::-1])) freqs = np.concatenate([ np.array([dc_freq]), np.array(cosine_freqs), np.array(sine_freqs[::-1]) ]) return basis / np.sqrt(n), freqs / n
def __call__(self, task_features, node_features, neighbor_relations, neighbor_features): """Maps task, node, edge and neighbors to logits. We use both multiplicative and additive dependency between representations of task, node, neighbors, neighbor_features. Args: task_features: Not used. node_features: Integer tensor of size `num_neighbors x num_node_features` with (duplicated) features of the current node. neighbor_relations: Integer tensor of size `num_neighbors`. neighbor_features: Integer tensor of size `num_neighbors x num_node_features` of neighbor features. Returns: A float tensor of size `num_neighbors` with logits associated with the probability of transitioning to that neighbor. """ del task_features num_neighbors = len(neighbor_relations) # Embed all inputs inputs = (node_features, neighbor_relations, neighbor_features) embeddings = [emb(x) for emb, x in zip(self.input_embedding, inputs)] # Reshape node_embs, relation_embs, neighbor_embs = embeddings node_embs = node_embs.reshape(num_neighbors, -1) relation_embs = relation_embs.reshape(num_neighbors, -1) neighbor_embs = neighbor_embs.reshape(num_neighbors, -1) local_feature_embs = jnp.column_stack((node_embs, neighbor_embs)) hidden_local_features = self.node_hidden(local_feature_embs) hidden_relation_embs = self.relation_hidden_mul(relation_embs) hidden_relation_bias = self.relation_hidden_sum(relation_embs) # FiLM-like layer to mix representations of neighborhood (node+neighbor) # and edge mixed_representation = hidden_local_features * hidden_relation_embs + hidden_relation_bias mixed_representation = nn.relu(mixed_representation) neighbor_logits = self.output_layer(mixed_representation) return neighbor_logits.squeeze(-1)
def __call__(self, param_name, param): """Shuffles the weight matrix/mask for a given parameter, per-neuron. This is to be used with mask_map, and accepts the standard mask_map function parameters. Args: param_name: The parameter's name. param: The parameter's weight or mask matrix. Returns: A shuffled weight/mask matrix, with each neuron shuffled independently. """ del param_name # Unused. incoming_connections = jnp.prod(jnp.array(param.shape[:-1])) num_neurons = param.shape[-1] # Ensure each input neuron has at least one connection unmasked. mask = _fill_diagonal_wrap((incoming_connections, num_neurons), 1, dtype=jnp.uint8) # Randomly shuffle which of the neurons have these connections. mask = jax.random.shuffle(self._get_rng(), mask, axis=0) # Add extra required random connections to mask to satisfy sparsity. mask_cols = [] for col in range(mask.shape[-1]): neuron_mask = mask[:, col] off_diagonal_count = max( round((1 - self._sparsity) * incoming_connections) - jnp.count_nonzero(neuron_mask), 0) zero_indices = jnp.flatnonzero(neuron_mask == 0) random_entries = _random_neuron_mask(len(zero_indices), off_diagonal_count, self._get_rng()) neuron_mask = neuron_mask.at[zero_indices].set(random_entries) mask_cols.append(neuron_mask) return jnp.column_stack(mask_cols).reshape(param.shape)
def table_ambisonics_order_vs_rE(max_order=20): """Return a dataframe with rE as a function of order.""" order = np.arange(1, max_order + 1, dtype=np.int32) rE3 = np.array(list(map(shelf.max_rE_3d, order))) drE3 = np.append(np.nan, rE3[1:] - rE3[:-1]) rE2 = np.array(list(map(shelf.max_rE_2d, order))) drE2 = np.append(np.nan, rE2[1:] - rE2[:-1]) df = pd.DataFrame(np.column_stack(( order, rE2, 100 * drE2 / rE2, 2 * np.arccos(rE2) * 180 / π, rE3, 100 * drE3 / rE3, 2 * np.arccos(rE3) * 180 / π, )), columns=('order', '2D', '% change', 'asw', '3D', '% change', 'asw')) return df
def do_fit(data, steps=100): """ """ norm_sample_size = 10**6 norm_space = ( (-2.5, 15), # energy range (MeV) (0, 22), # T(DD) range (MeV) (2.004, 2.026), # m(Dpi) range (GeV) ) _, initial_params = NN.init(rjax.PRNGKey(0), data) print('Model initialized:') print(jax.tree_map(np.shape, initial_params)) rng = rjax.PRNGKey(10) rng, key1, key2, key3 = rjax.split(rng, 4) keys = [key1, key2, key3] norm_sample = np.column_stack([ rjax.uniform(rjax.PRNGKey(key[0]), (norm_sample_size,), minval=lo, maxval=hi)\ for key, (lo, hi) in zip(keys, norm_space) ]) def loglh_loss(model): """ loss function for the unbinned maximum likelihood fit """ return -np.sum(np.log(model(data))) +\ data.shape[0] * np.log(np.sum(model(norm_sample))) model = nn.Model(NN, initial_params) adam = optim.Adam(learning_rate=0.03) optimizer = adam.create(model) for i in range(steps): loss, grad = jax.value_and_grad(loglh_loss)(optimizer.target) optimizer = optimizer.apply_gradient(grad) print(f'{i}/{steps}: loss: {loss:.3f}') with open('nn_model_3d.dat', 'wb') as ofile: state = TrainState(optimizer=optimizer) data = flax.serialization.to_bytes(state) print(f'Model serialized, num bytes: {len(data)}') ofile.write(data)
def append(self, state: Tuple) -> None: """Append a trace or new elements to the current trace. This is useful when performing repeated inference on the same dataset, or using the generator runtime. Sequential inference should use different traces for each sequence. Parameter --------- state A tuple that contains the chain state and the corresponding sampling info. """ sample, sample_info = state concatenate = lambda cur, new: jnp.concatenate((cur, new), axis=1) concatenate_1d = lambda cur, new: jnp.column_stack((cur, new)) if self.raw.samples is None: stacked_chain = sample else: try: stacked_chain = jax.tree_multimap(concatenate, self.raw.samples, sample) except TypeError: stacked_chain = jax.tree_multimap(concatenate_1d, self.raw.samples, sample) if self.raw.sampling_info is None: stacked_info = sample_info else: try: stacked_info = jax.tree_multimap(concatenate, self.raw.sampling_info, sample_info) except TypeError: stacked_info = jax.tree_multimap(concatenate_1d, self.raw.sampling_info, sample_info) self.raw = replace(self.raw, samples=stacked_chain, sampling_info=stacked_info)
def _sparse_bcoo_todense(state, jit: bool = False, compile: bool = False): shape = (2000, 2000) nse = 10000 size = np.prod(shape) rng = np.random.RandomState(1701) data = rng.randn(nse) indices = np.unravel_index(rng.choice(size, size=nse, replace=False), shape=shape) mat = sparse.BCOO((jnp.array(data), jnp.column_stack(indices)), shape=shape) f = lambda mat: mat.todense() if jit or compile: f = jax.jit(f) if compile: while state: f.lower(mat).compile() else: f(mat).block_until_ready() while state: f(mat).block_until_ready()
def cartesian_to_cluster(pos: Position, mom: Momentum, R=1000) -> (Cluster): """ Make cluster from Position and Momentum """ # Find intersection with calorimeter (see the geometry sketch) pos_mom_scalar_product = pos.x * mom.px + pos.y * mom.py + pos.z * mom.pz mom_total = mom.ptot mom_total_squared = mom_total**2 full_computation = False if full_computation: alpha = (-pos_mom_scalar_product + np.sqrt(pos_mom_scalar_product**2 + mom_total_squared * (R**2 - pos.r**2))) / mom_total_squared else: # takes into account that R >> particle flight length alpha = R / mom_total - pos_mom_scalar_product / mom_total_squared cluster_position = Position.from_ndarray(pos.as_array + alpha.reshape(-1, 1) * mom.as_array) return Cluster.from_ndarray( np.column_stack( [mom_total, cluster_position.costh, cluster_position.phi]))
def _add_supernode(self, node_features, dense_submat, dense_q): """Adds supernode with full incoming and outgoing connectivity. Adds a row and column of 1s to `dense_submat`, and normalizes. Also adds a row to `node_features`, containing the average of the other node features. Adds a weight of 1 at the end of `dense_q`. Args: node_features: Shape (num_nodes, feature_dim) Matrix of node features. dense_submat: Shape (num_nodes, num_nodes) Adjacency matrix. dense_q: Shape (num_nodes,) Node weights. Returns: node_features: Shape (num_nodes + 1, feature_dim) Matrix of node features. dense_submat: Shape (num_nodes + 1, num_nodes + 1) Adjacency matrix. dense_q: Shape (num_nodes + 1,) Node weights. """ dense_submat = jnp.row_stack( (dense_submat, jnp.ones(dense_submat.shape[1]))) dense_submat = jnp.column_stack( (dense_submat, jnp.ones(dense_submat.shape[0]))) # Normalize nonzero elements # The sum is bounded away from 0, so this is always differentiable # TODO(gnegiar): Do we want this? It means the supernode gets half the # outgoing weights dense_submat = dense_submat / dense_submat.sum(axis=-1, keepdims=True) # Add a weight to the supernode dense_q = jnp.append(dense_q, jnp.mean(dense_q)) # We embed the supernode using a distinct value. # TODO(gnegiar): Should we use another embedding? node_features = jnp.append(node_features, jnp.full((1, node_features.shape[1]), 2, dtype=int), axis=0) return node_features, dense_submat, dense_q
def as_array(self) -> (np.ndarray): return np.column_stack([self.x, self.y, self.z])
def column_stack(arrays): arrays = [a.value if isinstance(a, JaxArray) else a for a in arrays] return JaxArray(jnp.column_stack(arrays))
def main_opt(N, l, i0, nn_arq, act_fun, n_epochs, lr, w_decay, rho_g): start_time = time.time() str_nn_arq = '' for item in nn_arq: str_nn_arq = str_nn_arq + '_{}'.format(item) f_job = 'nn_arq{}_N_{}_i0_{}_l_{}_batch'.format(str_nn_arq, N, i0, l) f_out = '{}/out_opt_{}.txt'.format(r_dir, f_job) f_w_nn = '{}/W_{}.npy'.format(r_dir, f_job) file_results = '{}/data_nh3_{}.npy'.format(r_dir, f_job) # -------------------------------------- # Data n_atoms = 4 batch_size = 768 #1024#768#512#256#128#64#32 Dtr, Dval, Dt = load_data(file_results, N, l) Xtr, gXtr, gXctr, ytr = Dtr Xval, gXval, gXcval, yval = Dval Xt, gXt, gXct, yt = Dt print(gXtr.shape, gXtr.shape, gXctr.shape, ytr.shape) # -------------------------------- # BATCHES n_complete_batches, leftover = divmod(N, batch_size) n_batches = n_complete_batches + bool(leftover) def data_stream(): rng = onpr.RandomState(0) while True: perm = rng.permutation(N) for i in range(n_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield Xtr[batch_idx], gXtr[batch_idx], gXctr[batch_idx], ytr[ batch_idx] batches = data_stream() # -------------------------------- f = open(f_out, 'a+') print('-----------------------------------', file=f) print('Starting time', file=f) print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M"), file=f) print('-----------------------------------', file=f) print(f_out, file=f) print('N = {}, n_atoms = {}, data_random = {}, NN_random = {}'.format( N, n_atoms, l, i0), file=f) print(nn_arq, file=f) print('lr = {}, w decay = {}'.format(lr, w_decay), file=f) print('Activation function = {}'.format(act_fun), file=f) print('N Epoch = {}'.format(n_epochs), file=f) print('rho G = {}'.format(rho_g), file=f) print('-----------------------------------', file=f) f.close() # -------------------------------------- # initialize NN nn_arq.append(3) tuple_nn_arq = tuple(nn_arq) nn_model = NN_adiab(n_atoms, tuple_nn_arq) def get_init_NN_params(key): x = Xtr[0, :] x = x[None, :] # x = jnp.ones((1,Xtr.shape[1])) variables = nn_model.init(key, x) return variables # Initilialize parameters rng = random.PRNGKey(i0) rng, subkey = jax.random.split(rng) params = get_init_NN_params(subkey) f = open(f_out, 'a+') if os.path.isfile(f_w_nn): print('Reading NN parameters from prev calculation!', file=f) print('-----------------------', file=f) nn_dic = jnp.load(f_w_nn, allow_pickle=True) params = unfreeze(params) params['params'] = nn_dic.item()['params'] params = freeze(params) # print(params) f.close() init_params = params # -------------------------------------- # Phys functions @jit def nn_adiab(params, x): y_ad_pred = nn_model.apply(params, x) return y_ad_pred @jit def jac_nn_adiab(params, x): g_y_pred = jacrev(nn_adiab, argnums=1)(params, x[None, :]) return jnp.reshape(g_y_pred, (2, g_y_pred.shape[-1])) # -------------------------------------- # training loss functions @jit def f_loss_ad_energy(params, batch): X_inputs, _, _, y_true = batch y_pred = nn_adiab(params, X_inputs) diff_y = y_pred - y_true #Ha2cm* return jnp.linalg.norm(diff_y, axis=0) @jit def f_loss_jac(params, batch): X_inputs, gX_inputs, _, y_true = batch gX_pred = vmap(jac_nn_adiab, (None, 0))(params, X_inputs) diff_g_X = gX_pred - gX_inputs # jnp.linalg.norm(diff_g_X,axis=0) diff_g_X0 = diff_g_X[:, 0, :] diff_g_X1 = diff_g_X[:, 1, :] l0 = jnp.linalg.norm(diff_g_X0) l1 = jnp.linalg.norm(diff_g_X1) return jnp.stack([l0, l1]) # ------ @jit def f_loss(params, rho_g, batch): rho_g = jnp.exp(rho_g) loss_ad_energy = f_loss_ad_energy(params, batch) loss_jac_energy = f_loss_jac(params, batch) loss = jnp.vdot(jnp.ones_like(loss_ad_energy), loss_ad_energy) + jnp.vdot(rho_g, loss_jac_energy) return loss # -------------------------------------- # Optimization and Training # Perform a single training step. @jit def train_step(optimizer, rho_g, batch): #, learning_rate_fn, model grad_fn = jax.value_and_grad(f_loss) loss, grad = grad_fn(optimizer.target, rho_g, batch) optimizer = optimizer.apply_gradient(grad) #, {"learning_rate": lr} return optimizer, (loss, grad) # @jit def train(rho_g, nn_params): optimizer = optim.Adam(learning_rate=lr, weight_decay=w_decay).create(nn_params) optimizer = jax.device_put(optimizer) train_loss = [] loss0 = 1E16 loss0_tot = 1E16 itercount = itertools.count() f_params = init_params for epoch in range(n_epochs): for _ in range(n_batches): optimizer, loss_and_grad = train_step(optimizer, rho_g, next(batches)) loss, grad = loss_and_grad # f = open(f_out,'a+') # print(i,loss,file=f) # f.close() train_loss.append(loss) # params = optimizer.target # loss_tot = f_validation(params) nn_params = optimizer.target return nn_params, loss_and_grad, train_loss @jit def val_step(optimizer, nn_params): #, learning_rate_fn, model rho_g_prev = optimizer.target nn_params, loss_and_grad_train, train_loss_iter = train( rho_g_prev, nn_params) loss_train, grad_loss_train = loss_and_grad_train grad_fn_val = jax.value_and_grad(f_loss, argnums=1) loss_val, grad_val = grad_fn_val(nn_params, optimizer.target, Dval) optimizer = optimizer.apply_gradient( grad_val) #, {"learning_rate": lr} return optimizer, nn_params, (loss_val, loss_train, train_loss_iter), (grad_loss_train, grad_val) # Initilialize rho_G rng = random.PRNGKey(0) rng, subkey = jax.random.split(rng) rho_G0 = random.uniform(subkey, shape=(2, ), minval=5E-4, maxval=0.025) rho_G0 = jnp.log(rho_G0) print('Initial lambdas', rho_G0) init_G = rho_G0 # optimizer_out = optim.Adam(learning_rate=2E-4, weight_decay=0.).create(init_G) optimizer_out = jax.device_put(optimizer_out) f_params = init_params for i in range(50000): start_va_time = time.time() optimizer_out, f_params, loss_all, grad_all = val_step( optimizer_out, f_params) rho_g = optimizer_out.target loss_val, loss_train, train_loss_iter = loss_all grad_loss_train, grad_val = grad_all loss0_tot = f_loss(f_params, rho_g, Dt) dict_output = serialization.to_state_dict(f_params) jnp.save(f_w_nn, dict_output) #unfreeze() f = open(f_out, 'a+') # print(i,rho_g, loss0, loss0_tot, (time.time() - start_va_time),file=f) print(i, loss_val, loss_train, (time.time() - start_va_time), file=f) print(jnp.exp(rho_g), file=f) print(grad_val, file=f) # print(train_loss_iter ,file=f) # print(grad_val,file=f) # print(grad_loss_train,file=f) f.close() # -------------------------------------- # Prediction f = open(f_out, 'a+') print('Prediction of the entire data set', file=f) print('N = {}, n_atoms = {}, random = {}'.format(N, n_atoms, i0), file=f) print('NN : {}'.format(nn_arq), file=f) print('lr = {}, w decay = {}, rho G = {}'.format(lr, w_decay, rho_g), file=f) print('Activation function = {}'.format(act_fun), file=f) print('Total points = {}'.format(yt.shape[0]), file=f) y_pred = nn_adiab(f_params, Xt) gX_pred = vmap(jac_nn_adiab, (None, 0))(f_params, Xt) diff_y = y_pred - yt rmse_Ha = jnp.linalg.norm(diff_y) rmse_cm = jnp.linalg.norm(Ha2cm * diff_y) mae_Ha = jnp.linalg.norm(diff_y, ord=1) mae_cm = jnp.linalg.norm(Ha2cm * diff_y, ord=1) print('RMSE = {} [Ha]'.format(rmse_Ha), file=f) print('RMSE(tr) = {} [cm-1]'.format(loss0), file=f) print('RMSE = {} [cm-1]'.format(rmse_cm), file=f) print('MAE = {} [Ha]'.format(mae_Ha), file=f) print('MAE = {} [cm-1]'.format(mae_cm), file=f) Dpred = jnp.column_stack((Xt, y_pred)) data_dic = { 'Dtr': Dtr, 'Dpred': Dpred, 'gXpred': gX_pred, 'loss_tr': loss0, 'error_full': rmse_cm, 'N': N, 'l': l, 'i0': i0, 'rho_g': rho_g } jnp.save(file_results, data_dic) print('---------------------------------', file=f) print('Total time = %.6f seconds ---' % ((time.time() - start_time)), file=f) print('---------------------------------', file=f) f.close()
def cbind(x, y, backend="cpu"): # if len(x.shape) == 1 or len(y.shape) == 1: sys_platform = platform.system() if backend in ("gpu", "tpu") and (sys_platform in ("Linux", "Darwin")): return jnp.column_stack((x, y)) return np.column_stack((x, y))
ring = Ring(10, .1) gauss = Gaussian([0, 0], 9) ring_target = Setup(ring, gauss) ring_proposal = Setup(gauss, ring) target = Ring(10, .1) proposal = Ring(15, .1) double_ring = Setup(target, proposal) target = Squiggle([0, 0], [1, .1]) proposal = Gaussian([-2, 0], [1, 1]) squiggle_target = Setup(target, proposal) means = np.array([np.exp(2j * np.pi * x) for x in np.linspace(0, 1, 6)[:-1]]) means = np.column_stack((means.real, means.imag)) target = GaussianMixture(means, .03, np.ones(5)) proposal = Gaussian([-2, 0], [.5, .5]) mix_of_gauss = Setup(target, proposal) target = Gaussian([0, 0], [1e-4, 9]) proposal = Gaussian([0, 0], [1, 1]) thin_target = Setup(target, proposal) d = 50 variances = np.logspace(-2, 0, num=d) target = Gaussian(np.zeros(d), variances) proposal = Gaussian(np.zeros(d), np.ones(d)) high_d_gaussian = Setup(target, proposal) # rotated gaussian
def main_opt(N, l, i0, nn_arq, act_fun, n_epochs, lr, w_decay, rho_g): start_time = time.time() str_nn_arq = '' for item in nn_arq: str_nn_arq = str_nn_arq + '_{}'.format(item) f_job = 'nn_arq{}_N_{}_i0_{}_l_{}_batch'.format(str_nn_arq, N, i0, l) f_out = '{}/out_opt_{}.txt'.format(r_dir, f_job) f_w_nn = '{}/W_{}.npy'.format(r_dir, f_job) file_results = '{}/data_nh3_{}.npy'.format(r_dir, f_job) # -------------------------------------- # Data n_atoms = 4 batch_size = 768 #1024#768#512#256#128#64#32 Dtr, Dt = load_data(file_results, N, l) Xtr, gXtr, gXctr, ytr = Dtr Xt, gXt, gXct, yt = Dt print(gXtr.shape, gXtr.shape, gXctr.shape, ytr.shape) # -------------------------------- # BATCHES n_complete_batches, leftover = divmod(N, batch_size) n_batches = n_complete_batches + bool(leftover) def data_stream(): rng = onpr.RandomState(0) while True: perm = rng.permutation(N) for i in range(n_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield Xtr[batch_idx], gXtr[batch_idx], gXctr[batch_idx], ytr[ batch_idx] batches = data_stream() # -------------------------------- f = open(f_out, 'a+') print('-----------------------------------', file=f) print('Starting time', file=f) print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M"), file=f) print('-----------------------------------', file=f) print(f_out, file=f) print('N = {}, n_atoms = {}, data_random = {}, NN_random = {}'.format( N, n_atoms, l, i0), file=f) print(nn_arq, file=f) print('lr = {}, w decay = {}'.format(lr, w_decay), file=f) print('Activation function = {}'.format(act_fun), file=f) print('N Epoch = {}'.format(n_epochs), file=f) print('rho G = {}'.format(rho_g), file=f) print('-----------------------------------', file=f) f.close() # -------------------------------------- # initialize NN nn_arq.append(3) tuple_nn_arq = tuple(nn_arq) nn_model = NN_adiab(n_atoms, tuple_nn_arq) def get_init_NN_params(key): x = Xtr[0, :] x = x[None, :] # x = jnp.ones((1,Xtr.shape[1])) variables = nn_model.init(key, x) return variables # Initilialize parameters rng = random.PRNGKey(i0) rng, subkey = jax.random.split(rng) params = get_init_NN_params(subkey) f = open(f_out, 'a+') if os.path.isfile(f_w_nn): print('Reading NN parameters from prev calculation!', file=f) print('-----------------------', file=f) nn_dic = jnp.load(f_w_nn, allow_pickle=True) params = unfreeze(params) params['params'] = nn_dic.item()['params'] params = freeze(params) f.close() init_params = params # -------------------------------------- # Phys functions @jit def nn_adiab(params, x): y_ad_pred = nn_model.apply(params, x) return y_ad_pred @jit def jac_nn_adiab(params, x): g_y_pred = jacrev(nn_adiab, argnums=1)(params, x[None, :]) return jnp.reshape(g_y_pred, (2, g_y_pred.shape[-1])) ''' # WRONG @jit def f_nac_coup_i(gH_diab,eigvect_): #for a single cartesian dimension temp = jnp.dot(gH_diab,eigvect_[:,0]) return jnp.vdot(eigvect_[:,1],temp) @jit def f_nac_coup(params,x): eigval_, eigvect_ = f_adiab(params,x) gy_diab = jac_nn_diab(params,x) gy_diab = jnp.reshape(gy_diab.T,(12,2,2)) g_coup = vmap(f_nac_coup_i,(0,None))(gy_diab,eigvect_) return g_coup ''' # -------------------------------------- # Validation loss functions @jit def f_validation(params): y_pred = nn_adiab(params, Xt) diff_y = y_pred - yt z = jnp.linalg.norm(diff_y) return z @jit def f_jac_validation(params): gX_pred = vmap(jac_nn_adiab, (None, 0))(params, Xt) diff_y = gX_pred - gXt z = jnp.linalg.norm(diff_y) return z ''' @jit def f_nac_validation(params): g_nac_coup = vmap(f_nac_coup,(None,0))(params,Xt) diff_y = g_nac_coup - gXct z = jnp.linalg.norm(diff_y) return z ''' # -------------------------------------- # training loss functions @jit def f_loss_ad_energy(params, batch): X_inputs, _, _, y_true = batch y_pred = nn_adiab(params, X_inputs) diff_y = y_pred - y_true #Ha2cm* loss = jnp.linalg.norm(diff_y) return loss @jit def f_loss_jac(params, batch): X_inputs, gX_inputs, _, y_true = batch gX_pred = vmap(jac_nn_adiab, (None, 0))(params, X_inputs) diff_g_X = gX_pred - gX_inputs return jnp.linalg.norm(diff_g_X) ''' @jit def f_loss_nac(params,batch): X_inputs, _,gXc_inputs,y_true = batch g_nac_coup = vmap(f_nac_coup,(None,0))(params,x) diff_y = g_nac_coup - gXc_inputs z = jnp.linalg.norm(diff_y) return z ''' # ------ @jit def f_loss(params, batch): loss_ad_energy = f_loss_ad_energy(params, batch) # loss_jac_energy = f_loss_jac(params,batch) loss = loss_ad_energy #+ rho_g*loss_jac_energy return loss # -------------------------------------- # Optimization and Training # Perform a single training step. @jit def train_step(optimizer, batch): #, learning_rate_fn, model grad_fn = jax.value_and_grad(f_loss) loss, grad = grad_fn(optimizer.target, batch) optimizer = optimizer.apply_gradient(grad) #, {"learning_rate": lr} return optimizer, loss optimizer = optim.Adam(learning_rate=lr, weight_decay=w_decay).create(init_params) optimizer = jax.device_put(optimizer) loss0 = 1E16 loss0_tot = 1E16 itercount = itertools.count() f_params = init_params for epoch in range(n_epochs): for _ in range(n_batches): optimizer, loss = train_step(optimizer, next(batches)) params = optimizer.target loss_tot = f_validation(params) if epoch % 10 == 0: f = open(f_out, 'a+') print(epoch, loss, loss_tot, file=f) f.close() if loss < loss0: loss0 = loss f = open(f_out, 'a+') print(epoch, loss, loss_tot, file=f) f.close() if loss_tot < loss0_tot: loss0_tot = loss_tot f_params = params dict_output = serialization.to_state_dict(params) jnp.save(f_w_nn, dict_output) #unfreeze() f = open(f_out, 'a+') print('---------------------------------', file=f) print('Training time = %.6f seconds ---' % ((time.time() - start_time)), file=f) print('---------------------------------', file=f) f.close() # -------------------------------------- # Prediction f = open(f_out, 'a+') print('Prediction of the entire data set', file=f) print('N = {}, n_atoms = {}, random = {}'.format(N, n_atoms, i0), file=f) print('NN : {}'.format(nn_arq), file=f) print('lr = {}, w decay = {}, rho G = {}'.format(lr, w_decay, rho_g), file=f) print('Activation function = {}'.format(act_fun), file=f) print('Total points = {}'.format(yt.shape[0]), file=f) y_pred = nn_adiab(f_params, Xt) gX_pred = vmap(jac_nn_adiab, (None, 0))(f_params, Xt) diff_y = y_pred - yt rmse_Ha = jnp.linalg.norm(diff_y) rmse_cm = jnp.linalg.norm(Ha2cm * diff_y) mae_Ha = jnp.linalg.norm(diff_y, ord=1) mae_cm = jnp.linalg.norm(Ha2cm * diff_y, ord=1) print('RMSE = {} [Ha]'.format(rmse_Ha), file=f) print('RMSE(tr) = {} [cm-1]'.format(loss0), file=f) print('RMSE = {} [cm-1]'.format(rmse_cm), file=f) print('MAE = {} [Ha]'.format(mae_Ha), file=f) print('MAE = {} [cm-1]'.format(mae_cm), file=f) Dpred = jnp.column_stack((Xt, y_pred)) data_dic = { 'Dtr': Dtr, 'Dpred': Dpred, 'gXpred': gX_pred, 'loss_tr': loss0, 'error_full': rmse_cm, 'N': N, 'l': l, 'i0': i0, 'rho_g': rho_g } jnp.save(file_results, data_dic) print('---------------------------------', file=f) print('Total time = %.6f seconds ---' % ((time.time() - start_time)), file=f) print('---------------------------------', file=f) f.close()
def as_array(self) -> (np.ndarray): return np.column_stack([self.energy, self.costh, self.phi])
def as_array(self) -> (np.ndarray): """ Helix parameters as np.ndarray """ return np.column_stack([ self.d0, self.phi0, self.omega, self.z0, self.tanl])