Exemplo n.º 1
0
def save_weights(file_name,
                 step,
                 params_vec,
                 batch_state,
                 saveonlylast=True,
                 savebatch_state=True):
    """Save params and step in appropiate format."""
    filetag = '_step' + str(step)

    np.save('weights/' + file_name + filetag + '.npy', params_vec)
    if savebatch_state:
        with open('weights/' + file_name + filetag + '.pkl', 'wb') as f:
            pickle.dump(batch_state, f)

    if not saveonlylast:
        return
    for el in sorted(os.listdir('weights/')):
        if el.startswith(file_name) and (el != file_name + filetag + '.npy' and
                                         el != file_name + filetag + '.pkl'):
            ll = [
                el2 for el2 in os.listdir('weights/')
                if el2.startswith(file_name)
            ]
            if len(ll) > 4:
                os.remove('weights/' + el)
                print('Deleted', el)
    return
Exemplo n.º 2
0
def main():
    surface = Surface("../initFiles/axes/ellipticalAxis4Rotate.txt", 128, 32,
                      1.0)

    radii = np.linspace(0.0, 1.11111111, 15)

    start = time.time()

    N = 100
    coil_data, coil_params = CoilSet.get_initial_data(
        surface,
        input_file="../../tests/postresaxis/triple_comparison/fb.hdf5")
    rs, zs = Poincare.getPoincarePoints(N, 0.0, radii, surface, False,
                                        coil_data, coil_params)

    end = time.time()
    print(end - start)
    """
	font = {'family' : 'serif',
        'weight' : 'normal',
        'size'   : 12}

	plt.rc('font', **font)

	plt.plot(rs,zs,'ko', markersize=0.5, color='blue')
	plt.xlabel("R [m]")
	plt.ylabel("Z [m]")
	plt.show()
	"""

    np.save("rs.npy", rs)
    np.save("zs.npy", zs)
Exemplo n.º 3
0
def plot_points():
    onp.random.seed(0)
    params = np.array([0, 1, -1, 0.5, 2, 0, -1, -2])
    x, y = generate_data(params, 0.1)
    np.save('x.npy', x)
    np.save('y.npy', y)
    plt.plot(x, y, 'k+', markersize=5)
    plt.show()
Exemplo n.º 4
0
    def _checkpoint_model(self, step_count: int,
                          network_parameters: List) -> None:
        """
        Save a copy of the network parameters up to this point in training

        :param step_count: iteration number of training (meta-steps)
        :param network_parameters: parameters of network to save
        """
        os.makedirs(self.checkpoint_path, exist_ok=True)
        timestamp = datetime.datetime.fromtimestamp(
            time.time()).strftime('%H-%M-%S')
        # format of model chekcpoint path: timestamp _ step_count
        PATH = '{}model_checkpoint_{}_{}.npy'.format(self.checkpoint_path,
                                                     timestamp,
                                                     str(step_count))
        np.save(PATH, {
            'step': step_count,
            'network_parameters': network_parameters
        })
Exemplo n.º 5
0
    def write_model(self):

        np.save(self.prefix + "pcacomponents" + self.suffix, self.pcacomponents)
        np.save(
            self.prefix + "components_prior_params" + self.suffix,
            self.components_prior_params,
        )
        np.save(
            self.prefix + "polynomials_prior_mean" + self.suffix,
            self.polynomials_prior_mean,
        )
        np.save(
            self.prefix + "polynomials_prior_loginvvar" + self.suffix,
            self.polynomials_prior_loginvvar,
        )
def data_preprocessing():
    """ Seperates data (spin configurations) into test and training set and generates labels"""
    rng = random.PRNGKey(0)

    temperatures = jnp.linspace(1.0, 4.0, 7)
    temperatures1 = [1.0, 1.5, 3.0, 3.5, 4.0]
    temperatures2 = [2.0, 2.5]

    x_train = []
    y_train = []
    x_test = []
    y_test = []
    for T in temperatures:
        configs = jnp.load('data/spins_T%s.npy' % T)
        magnetization_density = jnp.abs(
            jnp.array([jnp.sum(config) / config.size for config in configs]))
        labels = jnp.where(magnetization_density < 0.5, 0, 1)
        if T in temperatures2:
            x_test.append(configs)
            y_test.append(labels)
        else:
            indices = random.permutation(rng, labels.size)
            y_test.append(labels[indices[:int(0.2 * labels.size)]])
            y_train.append(labels[indices[int(0.2 * labels.size):]])
            x_test.append(configs[indices[:int(0.2 * labels.size)]])
            x_train.append(configs[indices[int(0.2 * labels.size):]])

    y_test_new = jnp.array(y_test[0])
    x_test_new = jnp.array(x_test[0])
    for i in range(len(y_test) - 1):
        y_test_new = jnp.concatenate((y_test_new, y_test[i + 1]))
        x_test_new = jnp.concatenate((x_test_new, x_test[i + 1]))

    L = jnp.array(x_train).shape[2]
    x_test = jnp.array(x_test_new).reshape((-1, L, L, 1)).astype(jnp.float64)
    y_test = jnp.array(y_test_new).reshape((-1, 1))
    x_train = jnp.array(x_train).reshape((-1, L, L, 1)).astype(jnp.float64)
    y_train = jnp.array(y_train).reshape((-1, 1))

    jnp.save('data/x_test.npy', x_test)
    jnp.save('data/y_test.npy', y_test)
    jnp.save('data/x_train.npy', x_train)
    jnp.save('data/y_train.npy', y_train)

    return x_train, y_train, x_test, y_test
Exemplo n.º 7
0
def train(env,
          batch_size=128,
          num_epochs=5,
          num_iterations=21,
          num_samples=101,
          print_every=10,
          episodes=1000,
          k_min=1,
          k_max=25,
          verbose=False,
          params_save_path=""):
    """
    Train the model function by generating simulations of random-play.
    On every epoch generate a new simulation and run multiple iterations.
    On every iteration evaluate the targets using the most recent model parameters
    and run multiple times through the dataset.
    At the end of every epoch check the performance and store the best performing params.
    If the performance drops then decay the step size parameter.

    @param env (Cube Object): A Cube object representing the environment.
    @param batch_size (int): Size of minibatches used to compute loss and gradient during training.         [optional]
    @param num_epochs (int): The number of epochs to run for during training.                               [optional]
    @param num_iterations (int): The number of iterations through the generated episodes.                   [optional]
    @param num_samples (int): The number of times the dataset is reused.                                    [optional]
    @param print_every (int): An integer. Training progress will be printed every `print_every` iterations. [optional]
    @param episodes (int): Number of episodes to be created.                                                [optional]
    @param k_min (int): Minimum length of sequence of backward moves.                                       [optional]
    @param k_max (int): Maximum length of sequence of backward moves.                                       [optional]
    @param clip_norm (float): A scalar for gradient clipping.                                               [optional]
    @param verbose (bool): If set to false then no output will be printed during training.                  [optional]
    @param params_save_path (str): File path to save the model parameters.                                  [optional]
    @returns params (pytree): The best model parameters obatained during training.
    @returns loss_history (List): Loss history of iter_mean_loss and fisrt_loss computed during training.
    """
    loss_history = {"iter_loss": [], "first_loss": []}
    trust_range = 0

    # Initialize model parameters.
    params = list(jnp.load(params_save_path + "params_cnn_0.npy", allow_pickle=True)) \
                if params_save_path is not None \
                else init_fun(jax.random.PRNGKey(0), env.terminal_state.shape)[1]

    # Begin training.
    for e in range(num_epochs):
        tic = time.time()

        # Initialize the optimizer state at the begining of each epoch.
        opt_state = opt_init(params)

        # Generate data from random-play using the environment.
        states, w, children, rewards = generate_episodes(
            env, episodes, k_max, trust_range)

        # Train the model on the generated data. Periodically recompute the target values.
        epoch_mean_loss = 0.0
        for it in range(num_iterations):
            tic_it = time.time()

            # Make targets for the generated episodes using the most recent params and build a batch generator.
            params = get_params(opt_state)
            tgt_vals = make_targets(children, rewards, params)
            data = {"X": states, "y": tgt_vals, "w": w}
            train_batches = batch_generator(data, batch_size)

            # Run through the dataset and update model params.
            total_loss = 0.0
            for i in range(num_samples):
                batch = next(train_batches)
                loss, opt_state = update(e, opt_state, batch)
                total_loss += loss
                if it == 0 and i == 0:
                    loss_history["first_loss"].append(loss)

            # Book-keeping.
            iter_mean_loss = total_loss / num_samples
            epoch_mean_loss = (it * epoch_mean_loss + iter_mean_loss) / (it +
                                                                         1)
            loss_history["iter_loss"].append(iter_mean_loss)

            # Printout results.
            toc_it = time.time()
            if print_every != 0 and it % print_every == 0 and verbose:
                print(
                    "\t(Iteration({}/{}) took {:.3f} seconds) iter_mean_loss = {:.3f}"
                    .format(it + 1, num_iterations, (toc_it - tic_it),
                            iter_mean_loss))

        # Recompute the trust range using latest model params.
        trust_range = compute_trust_range(env, params)

        # Store the model parameters.
        if params_save_path is not None:
            jnp.save(params_save_path + "params_cnn_%d" % (e + 1), params)

        # Record the time needed for a single epoch.
        toc = time.time()

        # Printout results.
        if verbose:
            print(
                "(Epoch ({}/{}) took {:.3f} seconds), epoch_mean_loss: {:.3f}, trust_range: {}"
                .format(e + 1, num_epochs, (toc - tic), epoch_mean_loss,
                        trust_range))

    # Save loss history.
    json.dump(loss_history,
              open(params_save_path + "loss_history.json", "w"),
              indent=2)

    return params, loss_history
Exemplo n.º 8
0
def save_array(filename, arr, upload_to_wandb=True):
    """ Save jax array. """
    filepath = str(get_result_path(filename))
    jnp.save(filepath, arr)
    if upload_to_wandb:
        safe_wandb_save(filepath)
#         errs_a.append(np.stack(erra).mean())
#         errs_std_a.append(np.stack(erra).std())
#         errs_b.append(np.stack(errb).mean())
#         errs_std_b.append(np.stack(errb).std())

# Compute generalization error, rapid but heuristic
errs_a = []
errs_b = []
for ii, a in enumerate(range(K)):
    Xa = np.array(manifolds[a])
    for b in range(a + 1, K):
        Xb = np.array(manifolds[b])
        erra = []
        errb = []

        key, _ = random.split(key)
        erra, errb = mshot_err_fast(key, Xa, Xb)

        errs_a.append(erra)
        errs_b.append(errb)

    print('Manifold {} of {}. Avg. acc: {}'.format(ii, K,
                                                   1 - errs_a[-1].mean()))

# Combine errs_a and errs_b into K x K matrix
errs_full = np.triu(squareform(errs_a)) + np.tril(squareform(errs_b))

# Save
np.save(save_path, errs_full)
print('Finished with acc. ' + str(1 - np.mean(errs_full)) + '. Saved.')
Exemplo n.º 10
0
    def load_spectrophotometry(
        self,
        input_dir="./",
        write_subset=False,
        use_subset=False,
        subsampling=1,
        spec=True,
        phot=True,
    ):

        if use_subset:
            suffix = "2.npy"
        else:
            suffix = ".npy"

        self.input_dir = input_dir

        self.lamgrid = onp.load(self.input_dir + "lamgrid.npy")
        self.lam_phot_eff = onp.load(self.input_dir + "lam_phot_eff.npy")
        self.lam_phot_size_eff = onp.load(self.input_dir +
                                          "lam_phot_size_eff.npy")

        self.redshifts = onp.load(self.input_dir + "redshifts" + suffix)
        n_obj = self.redshifts.size
        self.n_obj = n_obj
        assert_shape(self.redshifts, (n_obj, ))

        if phot:
            self.transferfunctions = (
                onp.load(self.input_dir + "transferfunctions.npy") * 1e-16)
            self.transferfunctions_zgrid = onp.load(
                self.input_dir + "transferfunctions_zgrid.npy")

            assert self.transferfunctions.shape[
                0] == self.transferfunctions_zgrid.size
            assert self.transferfunctions.shape[1] == self.lamgrid.size
            self.index_transfer_redshift = onp.load(self.input_dir +
                                                    "index_transfer_redshift" +
                                                    suffix)
            self.interprightindices_transfer = onp.load(
                self.input_dir + "interprightindices_transfer" + suffix)
            self.interpweights_transfer = onp.load(self.input_dir +
                                                   "interpweights_transfer" +
                                                   suffix)
            self.phot = fluxes = onp.load(self.input_dir + "phot" + suffix)
            self.phot_invvar = flux_ivars = onp.load(self.input_dir +
                                                     "phot_invvar" + suffix)
            assert_shape(self.index_transfer_redshift, (n_obj, ))
            assert_shape(self.interprightindices_transfer, (n_obj, ))
            assert_shape(self.interpweights_transfer, (n_obj, ))

            n_pix_phot = self.phot.shape[1]
            assert_shape(self.phot, (n_obj, n_pix_phot))
            assert_shape(self.phot_invvar, (n_obj, n_pix_phot))
            self.n_pix_phot = self.phot.shape[1]

        if spec:
            self.chi2s_sdss = onp.load(self.input_dir + "chi2s_sdss" + suffix)
            self.lamspec_waveoffset = int(
                onp.load(self.input_dir + "lamspec_waveoffset" + suffix))
            self.index_wave = onp.load(self.input_dir + "index_wave" + suffix)
            self.interprightindices = onp.load(self.input_dir +
                                               "interprightindices" + suffix)
            self.interpweights = onp.load(self.input_dir + "interpweights" +
                                          suffix)

            self.specmod_sdss = onp.load(self.input_dir + "spec_mod" + suffix)
            if True:
                self.spec = onp.load(self.input_dir + "spec" + suffix)
                self.spec_invvar = onp.load(self.input_dir + "spec_invvar" +
                                            suffix)
            else:
                self.spec = onp.load(self.input_dir + "spec_mod" + suffix)
                self.spec_invvar = (
                    onp.load(self.input_dir + "spec_invvar" + suffix) * 0 + 1)

            self.n_pix_spec = self.spec.shape[1]

            assert_shape(self.chi2s_sdss, (n_obj, ))
            assert_shape(self.index_wave, (n_obj, ))
            n_pix_spec = self.spec.shape[1]
            assert_shape(self.spec, (n_obj, n_pix_spec))
            assert_shape(self.specmod_sdss, (n_obj, n_pix_spec))
            assert_shape(self.spec_invvar, (n_obj, n_pix_spec))

        if write_subset:

            M = 50000
            suffix = "2.npy"

            self.index_wave = self.index_wave[:M]
            self.redshifts = self.redshifts[:M]
            self.chi2s_sdss = self.chi2s_sdss[:M]
            self.phot_invvar = self.phot_invvar[:M, :]
            self.index_transfer_redshift = self.index_transfer_redshift[:M]

            np.save(self.input_dir + "index_wave" + suffix,
                    self.index_wave[:M])
            np.save(
                self.input_dir + "interprightindices_transfer" + suffix,
                self.interprightindices_transfer[:M, :],
            )
            np.save(
                self.input_dir + "interpweights_transfer" + suffix,
                self.interpweights_transfer[:M, :],
            )
            np.save(
                self.input_dir + "index_transfer_redshift2.npy",
                self.index_transfer_redshift,
            )
            np.save(self.input_dir + "redshifts" + suffix, self.redshifts)
            np.save(self.input_dir + "spec" + suffix, self.spec)
            np.save(self.input_dir + "chi2s_sdss" + suffix, self.chi2s_sdss)
            np.save(self.input_dir + "spec_invvar" + suffix, self.spec_invvar)
            np.save(self.input_dir + "phot" + suffix, self.phot)
            np.save(self.input_dir + "phot_invvar" + suffix, self.phot_invvar)
            np.save(self.input_dir + "spec_mod" + suffix, self.specmod_sdss)

        if subsampling > 1:

            self.lamgrid = self.lamgrid[::subsampling]
            self.transferfunctions = self.transferfunctions[:, ::
                                                            subsampling, :][::
                                                                            subsampling, :, :]
            self.transferfunctions_zgrid = self.transferfunctions_zgrid[::
                                                                        subsampling]
            self.lamspec_waveoffset = self.lamspec_waveoffset // subsampling
            self.spec = self.spec[:, ::subsampling]
            self.specmod_sdss = self.specmod_sdss[:, ::subsampling]
            self.spec_invvar = self.spec_invvar[:, ::subsampling]
            self.index_wave = self.index_wave // subsampling
            self.index_transfer_redshift = self.index_transfer_redshift // subsampling
            self.interprightindices = (
                self.interprightindices[:, ::subsampling] // subsampling)
            self.interpweights = (self.interpweights[:, ::subsampling] /
                                  subsampling)  # dilution
            self.interprightindices_transfer = (
                self.interprightindices_transfer // subsampling
            )  # is it correct?
            self.interpweights_transfer = (self.interpweights_transfer /
                                           subsampling)  # is it correct?
Exemplo n.º 11
0
x_plot_grid = jnp.array([x_meshgrid[0].ravel(),
                         x_meshgrid[1].ravel()]).transpose()

#Fit copula obj
copula_classification_obj = fit_copula_classification(jnp.array(y),
                                                      jnp.array(x),
                                                      single_x_bandwidth=False,
                                                      n_perm_optim=10)
print('Bandwidth is {}'.format(copula_classification_obj.rho_opt))
print('Bandwidth is {}'.format(copula_classification_obj.rho_x_opt))
print('Preq loglik is {}'.format(copula_classification_obj.preq_loglik / n))

#Predict Yplot
logpmf1 = predict_copula_classification(copula_classification_obj, x_plot_grid)
pmf1 = jnp.exp(logpmf1)
jnp.save('plot_files/ccopula_moon_pmf', pmf1)

#Predictive Resample
B = 1000
T = 5000
logpmf_ytest_samp, logpmf_yn_samp, y_samp, x_samp, pdiff = predictive_resample_classification(
    copula_classification_obj, y, x, x_plot_grid, B, T)

jnp.save('plot_files/ccopula_moon_logpmf_ytest_pr', logpmf_ytest_samp)
jnp.save('plot_files/ccopula_moon_logpmf_yn_pr', logpmf_yn_samp)

#Convergence
T = 10000  #T = 10000, seed = 50 for i = 30
seed = 200
_, _, _, _, pdiff = predictive_resample_classification(
    copula_classification_obj, y, x, x_test[0:1], 1, T, seed=seed)
Exemplo n.º 12
0
y = jnp.array([y])[0]

#Fit copula obj
copula_density_obj = fit_copula_density(
    y, seed=50, single_bandwidth=False)  #or seed = 200?
print('Bandwidth is {}'.format(copula_density_obj.rho_opt))
print('Preq loglik is {}'.format(copula_density_obj.preq_loglik))

#Predict on yplot
logcdf_conditionals, logpdf_joints = predict_copula_density(
    copula_density_obj, y_plot)

#Predictive resample
T_fwdsamples = 5000  #T = N-n
B_postsamples = 1000

logcdf_conditionals_pr, logpdf_joints_pr = predictive_resample_density(
    copula_density_obj, y_plot, B_postsamples, T_fwdsamples, seed=50)

jnp.save('plot_files/copula_ozone_logpdf_pr', logpdf_joints_pr)
jnp.save('plot_files/copula_ozone_logcdf_pr', logcdf_conditionals_pr)

#Convergence plots
seed = 20
T_fwdsamples = 10000
logcdf_pr_conv, logpdf_pr_conv, pdiff, cdiff = check_convergence_pr(
    copula_density_obj, y_plot, 1, T_fwdsamples, seed)
pdf_pr_conv = jnp.exp(logpdf_pr_conv[0, :, -1])
jnp.save('plot_files/copula_ozone_pr_pdf_samp', pdf_pr_conv)
jnp.save('plot_files/copula_ozone_pr_pdiff', pdiff)
jnp.save('plot_files/copula_ozone_pr_cdiff', cdiff)
Exemplo n.º 13
0
    print('Preq loglik is {}'.format(copula_density_obj.preq_loglik))

    #Predict on yplot
    logcdf_conditionals, logpdf_joints = predict_copula_density(
        copula_density_obj, y_plot)
    pdf_cop = jnp.exp(logpdf_joints[:, -1])
    cdf_cop = jnp.exp(logcdf_conditionals[:, -1])

    #Predictive resample
    T_fwdsamples = 5000  #T = N-n
    B_postsamples = 1000
    logcdf_conditionals_pr, logpdf_joints_pr = predictive_resample_density(
        copula_density_obj, y_plot, B_postsamples, T_fwdsamples,
        seed=200)  #(seed = 200)

    jnp.save('plot_files/copula_gmm_logpdf_pr_n{}'.format(n), logpdf_joints_pr)
    jnp.save('plot_files/copula_gmm_logcdf_pr_n{}'.format(n),
             logcdf_conditionals_pr)

    #Convergence plots
    seed = 200
    T_fwdsamples = 10000
    _, _, pdiff, cdiff = check_convergence_pr(copula_density_obj, y_plot, 1,
                                              T_fwdsamples, seed)
    jnp.save('plot_files/copula_gmm_pr_pdiff_n{}'.format(n), pdiff)
    jnp.save('plot_files/copula_gmm_pr_cdiff_n{}'.format(n), cdiff)
### ###

### GALAXY ###
#Load data
print('Dataset: Galaxy')
Exemplo n.º 14
0
    net_params = get_params(opt_state)
    xL3 = get_xL(net_params, inputs, xL_0, xL_f[i])
    Overlap3 = Transport_Majorana(N, Vmax, sigma, xL3, xR, mu_phys, delta, t,
                                  b, alpha, dt)
    print('compare different overlaps')
    print(neurons1, Overlap1)
    print(neurons1, Overlap2)
    print(neurons1, Overlap3)

    #Postprocessing saving and plotting
    final_fidelity_neurons2.append(Overlap3)
    #Qloss_inbetween_final_wall = Transport_Majorana_all_fidelities(N,Vmax,sigma,xL3,xR,mu_phys,delta,t,b,alpha,dt)

    #save the data
    #np.save('Data2/Qloss(t)_{}'.format(filename_parameters),Qloss_inbetween_final_wall)
    np.save('Data2/Final_wall(t)_{}'.format(filename_parameters), xL3)
    np.save('Data2/Learning_curve_{}'.format(filename_parameters),
            learning_curve)
    np.save('Data2/NN_weight_params_{}'.format(filename_parameters),
            net_params)

    tlist = np.linspace(0, T_total, len(xL3))
    plt.figure(4)
    plt.xlabel('time t/dt')
    plt.ylabel('Left wall position x_L(t)')
    plt.savefig(
        'figures2/Path_search_space_N110_mu_03_Delta_1_t_1_T_5_dt_01_Vmax_301_sigma_1_xL0_5_xlF_10_layer_1_neurons_{}_layer_2_neurons_{}_activation1_Relu_activation2_sigmoid_neurons3_{}'
        .format(neurons1, neurons2, neurons3))

    #plot the final wall profile
    plt.figure(1)
Exemplo n.º 15
0
for i in range(1,len(nums)):
	epsilon = eps[i]
	n_iter = n_iters[i]
	theta_i = np.tile(np.linspace(0, 2 * np.pi, NS+1)[:-1], (NC,1))
	_, params_new = get_all_coil_data("../../../tests/w7x/scanold2/w7x_l{}.hdf5".format(nums[i]))
	fc_new, _ = params_new
	theta_i = find_minimum_theta_all_coils(fc_new, r_fil, theta_i)


	print("Size is {}".format(sizes[i]))
	print("The original delta r 1 is:")
	print(np.mean(np.linalg.norm(r_fil - filament_real_space(fc_new, np.tile(np.linspace(0, 2 * np.pi, NS+1)[:-1], (NC,1))), axis=-1)))
	print("The new delta r 1 is (with minimization):")
	print(np.mean(np.linalg.norm(r_fil - filament_real_space(fc_new, theta_i), axis=-1)))
	mean_delta_rs.append(np.mean(np.linalg.norm(r_fil - filament_real_space(fc_new, theta_i), axis=-1)))
	print("The max distance is")
	print(np.max(np.linalg.norm(r_fil - filament_real_space(fc_new, theta_i), axis=-1)))
	max_delta_rs.append(np.max(np.linalg.norm(r_fil - filament_real_space(fc_new, theta_i), axis=-1)))


	difference = np.linalg.norm(r_fil - filament_real_space(fc_new, np.tile(np.linspace(0, 2 * np.pi, NS+1)[:-1], (NC,1))), axis=-1) - np.linalg.norm(r_fil - filament_real_space(fc_new, theta_i), axis=-1)
	new_diff = difference[difference > 0]
	larger = np.ravel(difference).shape[0] - new_diff.shape[0]
	if larger > 0:
		print(larger)

np.save("w7x_sizes.npy", np.asarray(sizes))
np.save("w7x_mean_delta_rs.npy", np.asarray(mean_delta_rs))
np.save("w7x_max_delta_rs.npy", np.asarray(max_delta_rs))

Exemplo n.º 16
0
        with tf.Session() as sess:
            meta_path = "gpt-2/models/" + WHICH + "/model.ckpt.meta"
            saver = tf.train.import_meta_graph(meta_path)

            saver.restore(sess, "gpt-2/models/" + WHICH + "/model.ckpt")

            r = []
            for v in [
                    x for x in tf.global_variables()
                    if x.name.startswith("model/")
            ]:
                print(v.name, v.shape)
                r.append(v)
            # TODO CLEAN ME UP
            np.save("/tmp/" + WHICH + ".npy", sess.run(r))

    loaded = onp.load("/tmp/" + WHICH + ".npy", allow_pickle=True)
    for (n, v), a in zip(model.vars().items(), loaded):
        if v.value.shape != a.shape:
            raise ValueError(
                f'Mismatched shapes between v {v.value.shape} and a {a.shape}')
        if isinstance(v, objax.variable.TrainVar):
            v._value = a
        else:
            v.value = a

    print("Enter a prefix to complete:")
    while True:
        print(">>>", end=" ")
        prefix = input()
Exemplo n.º 17
0
#res = scipy.optimize.minimize(m_log_h, x0_BFGS, jac=True, args=(ssigma_prior, ssigma_likelihood),options={'disp': True})

# Calling MALA sampler
# Initialiazation:
#x0 = mu_t
x_prev = np.load('Results/Samples_X0_MAP05_05_2020.npy')
x0 = x_prev[-1, :]
#x0=res.x
# Parameters of the proposal:
dt = 2e-10  ## Notice here step size is the varience so it will be sqrt.
#dt=1e-08 ## check
# Number of steps:
n = 30000
#X, acceptance_rate = mala(x0, n, dt)
X, acceptance_rate = mala(x0,
                          log_h,
                          n,
                          dt,
                          args=(ssigma_prior, ssigma_likelihood))

print('acceptanc rate=', acceptance_rate)
np.save('Results/Samples_X0_MAP_v3' + date, X)

# Forward Solve for velocity

u_pred = onp.ndarray((X.shape[0], mesh.y.size))
# posterior samples
for n in range(X.shape[0]):
    u_pred[n, :] = solveRANS(X[-n - 1, :])
np.save('Results/U_pred_X0_MAP' + date, u_pred)
Exemplo n.º 18
0
    for i in range(args.max_iter):
        # update parameters
        model.step()
        loss = model.objective()

        if not args.silent:
            print(
                "step: {:3d}, loss: {:7.4f}, ||Ax* - b||_2^2: {:6.4f}".format(
                    i + 1, loss, model.feval(model.x)))
        prox_optval.append(loss)

        if i > 1 and np.abs(prox_optval[-1] - prox_optval[-2]) < args.tol:
            break

    print("-" * 40)
    print('Parameters: gamma={}, solver={}'.format(args.gamma, args.opt))
    print('||Ax* - b||_2^2: {:6.3f}, Obj: {:6.3f}'.format(
        model.feval(model.x), prox_optval[-1]))
    if args.opt == 'ADMM':
        print("nnz of x*: ", jnp.count_nonzero(model.z))
    else:
        print("nnz of x*: ", jnp.count_nonzero(model.x))

    print('Writing the results...')
    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)
    output_file = os.path.join(args.output_dir,
                               '{}_{}.npy'.format(args.opt, args.gamma))
    jnp.save(output_file, jnp.array(prox_optval))
    print('Done')
Exemplo n.º 19
0
   c2_plp = MSA_2PtCorrelations(f1_plp, f2_plp)

   key = random.PRNGKey(seed)

   #h_bml, e_bml, corr_err = BMLearn_Convergence(f1_plp, f2_plp, c2_plp, epsilon, niter=Niter, nflip=Nflip, nseq=Nseq, key=key)
   h_bml, e_bml, corr_err, f1_mcmc, f2_mcmc = BMLearn_Persistent(f1_plp, f2_plp, c2_plp,
                                                                epsilon, niter=Niter, nflip=Nflip,
                                                                nseq=Nseq, errthresh=errthresh,
                                                                key=key, potts_init=potts_init)


   pm_bml = PottsModel(h = h_bml, e = e_bml, abc = abc, L = msa_in.L)

   Potts_ShiftGaugeZeroSum(pm_bml)

   print("Peak memory usage (MB): ", np.int32(getrusage(RUSAGE_SELF).ru_maxrss / 1024))

   if args.errfile is not None:
      np.save(args.errfile[0], corr_err)

   if args.f1file:
      np.save(args.f1file[0], f1_mcmc)

   if args.f2file:
      np.save(args.f2file[0], f2_mcmc)


   Potts_Write(pm_bml, potts_outpath)
   #print("Peak memory usage (MB): ", np.int32(getrusage(RUSAGE_SELF).ru_maxrss / 1024))
Exemplo n.º 20
0
rng_key_train, rng_key_predict = random.split(random.PRNGKey(0))

num_warmup = 4000
num_samples = 8000
num_chains = 1
target_accept_prob = 0.85
settings = {'num_warmup': num_warmup,
            'num_samples': num_samples,
            'num_chains': num_chains,
            'target_accept_prob': target_accept_prob}
samples = model.train(settings, rng_key_train)  
print('True values: J0 = %f, k1 = %f, k2 = %f, k3 = %f, k4 = %f, k5 = %f, k6 = %f, k = %f, ka = %f, q = %f, KI = %f, phi = %f, Np = %f, A = %f, IC = %f' % (J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A, IC))

vmap_args = (samples['J0'], samples['k1'], samples['k2'], samples['k3'], samples['k4'], samples['k5'], samples['k6'], samples['k'], samples['ka'], samples['q'], samples['KI'], samples['phi'], samples['Np'], samples['A'], samples['IC'])

np.save('data/par_and_IC',np.array(vmap_args))
np.save('data/noise',np.array(samples['noise']))
np.save('data/hyp',np.array(samples['hyp']))
np.save('data/W',np.array(samples['W']))

def RBF(x1, x2, params):
    diffs = (x1 / params).T - x2 / params
    return np.exp(-0.5 * diffs**2)

Nt = N+1
N_fine = 100
t_test = np.linspace(0, Tf_test, 2*N_fine+1)
Nt_test = t_test.shape[0]

t_tr = t[:,None] /model.max_t
t_te = t_test[:,None] /model.max_t
Exemplo n.º 21
0
    def supervised_optimization(self,
                                sup_density_list,
                                wiring_str,
                                save_supervised_result_bool,
                                dataset_str,
                                EXPLOITATION_NUM_EPOCHS,
                                EXPLOITATION_BATCH_SIZE,
                                OPTIMIZER_STR,
                                STEP_SIZE,
                                REG,
                                W_initializers_str='glorot_normal()',
                                b_initializers_str='normal()',
                                init_weight_rescale_bool=False,
                                EXPLOITATION_VALIDATION_FRACTION=0.1,
                                EXPLOIT_TRAIN_DATASET_FRACTION=1.0,
                                RECORD_ACC_FREQ=100,
                                DROPOUT_LAYER_POS=[],
                                **kwargs):
        """ 
        Train a neural network with loaded wiring from scratch.

        Args: 
            sup_density_list: a list of network density levels
            wiring_str: a string that represents the network wiring, e.g., trans, rand, snip
            dataset_str: a string used to retreive the dataset
            EXPLOITATION_NUM_EPOCHS: the number of epochs used in supervsied training
            EXPLOITATION_BATCH_SIZE: the batch size used in supervsied training
            OPTIMIZER_STR: a string used to retreive the optimzier
            STEP_SIZE: step size of the optimizer
            REG: l2 regularization constant
            EXPLOITATION_VALIDATION_FRACTION: the fraction of training data held out for validation purpose
            EXPLOIT_TRAIN_DATASET_FRACTION: the fraction of training data used in evaluation. 
            RECORD_ACC_FREQ: the frequency for recording train and test results

        Returns:
            train_acc_list_runs: a list of training accuracy
            test_acc_list_runs: a list of testing accuracy
        """

        for density in sup_density_list:
            if density not in self.ntt_setup_dict['NN_DENSITY_LEVEL_LIST']:
                raise ValueError(
                    'The desired density level for supervised training is not used in NTT.'
                )

        dataset_info = Dataset(
            datasource=dataset_str,
            VALIDATION_FRACTION=EXPLOITATION_VALIDATION_FRACTION)

        dataset = dataset_info.dataset

        # configure the dataset
        gen_batches = dataset_info.data_stream(EXPLOITATION_BATCH_SIZE)

        batch_input_shape = [-1] + self.ntt_setup_dict['instance_input_shape']

        nr_training_samples = len(dataset['train']['input'])

        nr_training_samples_subset = int(nr_training_samples *
                                         EXPLOIT_TRAIN_DATASET_FRACTION)

        train_input = dataset['train'][
            'input'][:nr_training_samples_subset].reshape(batch_input_shape)
        train_label = dataset['train']['label'][:nr_training_samples_subset]

        test_input = dataset['test']['input'].reshape(batch_input_shape)
        test_label = dataset['test']['label']

        num_complete_batches, leftover = divmod(nr_training_samples,
                                                EXPLOITATION_BATCH_SIZE)

        num_mini_batches_per_epochs = num_complete_batches + bool(leftover)

        total_batch = EXPLOITATION_NUM_EPOCHS * num_mini_batches_per_epochs

        if len(DROPOUT_LAYER_POS) == 0:
            # in this case, dropout is NOT used
            init_fun_no_dropout, f_train = model_dict[self.model_str](
                W_initializers_str=W_initializers_str,
                b_initializers_str=b_initializers_str)
            f_test = f_train
            f_no_dropout = f_train
            key_dropout = None
            subkey_dropout = None

        else:
            # in this case, dropout is used
            _, f_train = model_dict[self.model_str + '_dropout'](
                mode='train',
                W_initializers_str=W_initializers_str,
                b_initializers_str=b_initializers_str)
            _, f_test = model_dict[self.model_str + '_dropout'](
                mode='test',
                W_initializers_str=W_initializers_str,
                b_initializers_str=b_initializers_str)

            init_fun_no_dropout, f_no_dropout = model_dict[self.model_str](
                W_initializers_str=W_initializers_str,
                b_initializers_str=b_initializers_str)

            key_dropout = random.PRNGKey(0)

        @jit
        def step(i, opt_state, x, y, masks, key):
            this_step_params = get_params(opt_state)
            masked_g = grad(softmax_cross_entropy_with_logits_l2_reg)(
                this_step_params,
                f_train,
                x,
                y,
                masks,
                L2_REG_COEFF=REG,
                key=key)
            return opt_update(i, masked_g, opt_state)

        train_results_dict = {}
        test_results_dict = {}
        trained_masked_dict = {}

        for handler in logging.root.handlers[:]:
            logging.root.removeHandler(handler)

        time.sleep(orig_random.uniform(1, 5))
        now_str = '__' + str(datetime.now().strftime("%D:%H:%M:%S")).replace(
            '/', ':')

        supervised_model_info = '[u]' + self.ntt_file_name + '_[s]' + dataset_str

        supervised_model_wiring_info = supervised_model_info + '_' + wiring_str

        supervised_model_wiring_dir = self.supervised_result_path + supervised_model_info + '/' + supervised_model_wiring_info + now_str

        if save_supervised_result_bool:

            while os.path.exists(supervised_model_wiring_dir):
                temp = supervised_model_wiring_dir + '_0'
                supervised_model_wiring_dir = temp
            # print(supervised_model_wiring_dir)
            os.makedirs(supervised_model_wiring_dir)

            logging.basicConfig(filename=supervised_model_wiring_dir +
                                "/supervised_learning_log.log",
                                format='%(asctime)s %(message)s',
                                filemode='w',
                                level=logging.DEBUG)
        else:
            logging.basicConfig(filename="supervised_learning_log.log",
                                format='%(asctime)s %(message)s',
                                filemode='w',
                                level=logging.DEBUG)

        for nn_density_level in sup_density_list:

            nn_density_level = onp.round(nn_density_level, 2)
            train_acc_list_runs = []
            test_acc_list_runs = []
            trained_masked_params_runs = []

            for run_index in range(1, self.ntt_setup_dict['NUM_RUNS'] + 1):

                if wiring_str == 'trans':
                    # load ntt masks and parameters
                    density_run_dir = '/' + 'density_' + str(
                        nn_density_level) + '/' + 'run_' + str(run_index)

                    transferred_masks_fileName = '/transferred_masks_' + self.model_str + density_run_dir.replace(
                        '/', '_') + '.npy'

                    transferred_param_fileName = '/transferred_params_' + self.model_str + density_run_dir.replace(
                        '/', '_') + '.npy'

                    masks = list(
                        np.load(self.ntt_result_path + density_run_dir +
                                transferred_masks_fileName,
                                allow_pickle=True))

                    masked_params = list(
                        np.load(self.ntt_result_path + density_run_dir +
                                transferred_param_fileName,
                                allow_pickle=True))

                elif wiring_str == 'rand':
                    # randomly initialize masks and parameters

                    _, params = init_fun_no_dropout(random.PRNGKey(run_index),
                                                    tuple(batch_input_shape))

                    masks = get_masks_from_jax_params(
                        params,
                        nn_density_level,
                        global_bool=self.ntt_setup_dict['GLOBAL_PRUNE_BOOL'],
                        magnitude_base_bool=False,
                        reshuffle_seed=run_index)

                    masked_params = get_sparse_params_filtered_by_masks(
                        params, masks)

                elif wiring_str == 'dense':
                    # randomly initialize masks and parameters

                    _, params = init_fun_no_dropout(random.PRNGKey(run_index),
                                                    tuple(batch_input_shape))

                    #                     masks = get_masks_from_jax_params(params, nn_density_level, global_bool = self.ntt_setup_dict['GLOBAL_PRUNE_BOOL'], magnitude_base_bool = False, reshuffle_seed = run_index)
                    logger.info("Dense net!!")

                    masks = None
                    masked_params = params

                elif wiring_str == 'snip':
                    # randomly initialize masks and parameters
                    if dataset_str == 'cifar-10':
                        num_examples_snip = 128
                    else:
                        num_examples_snip = 100

                    snip_input = dataset['train']['input'][:num_examples_snip]

                    snip_label = dataset['train']['label'][:num_examples_snip]

                    snip_batch = (snip_input, snip_label)

                    _, params = init_fun_no_dropout(random.PRNGKey(run_index),
                                                    tuple(batch_input_shape))

                    if not self.ntt_setup_dict['GLOBAL_PRUNE_BOOL']:
                        logger.info("Use layerwise snip")

                    masks = get_snip_masks(
                        params, nn_density_level, f_no_dropout, snip_batch,
                        batch_input_shape,
                        self.ntt_setup_dict['GLOBAL_PRUNE_BOOL'])

                    masked_params = get_sparse_params_filtered_by_masks(
                        params, masks)

                elif wiring_str == 'logit_snip':
                    # randomly initialize masks and parameters
                    if dataset_str == 'cifar-10':
                        num_examples_snip = 128
                    else:
                        num_examples_snip = 100

                    snip_input = dataset['train']['input'][:num_examples_snip]

                    _, params = init_fun_no_dropout(random.PRNGKey(run_index),
                                                    tuple(batch_input_shape))

                    masks = get_logit_snip_masks(
                        params, nn_density_level, f_no_dropout, snip_input,
                        batch_input_shape,
                        self.ntt_setup_dict['GLOBAL_PRUNE_BOOL'])
                    #                     get_snip_masks(params, nn_density_level, f_no_dropout, snip_batch, batch_input_shape)

                    masked_params = get_sparse_params_filtered_by_masks(
                        params, masks)

                else:
                    raise ValueError('The wiring string is undefined.')

            # optionally, add dropout layers #Test without dropout masks
                if len(DROPOUT_LAYER_POS) > 100:
                    dropout_masked_params = [
                        ()
                    ] * (len(masked_params) + len(DROPOUT_LAYER_POS))

                    dropout_masks = [[]] * (len(masked_params) +
                                            len(DROPOUT_LAYER_POS))

                    print(len(masked_params))  #check dropout position
                    #pprint(masked_params) # check

                    num_inserted = 0
                    for i in range(len(dropout_masked_params)):
                        if i in DROPOUT_LAYER_POS:
                            num_inserted += 1
                        else:
                            dropout_masked_params[i] = masked_params[
                                i - num_inserted]
                            dropout_masks[i] = masks[i - num_inserted]

                    masks = dropout_masks
                    masked_params = dropout_masked_params

                if init_weight_rescale_bool == True:
                    logger.info(
                        "Init weight rescaled: W_scaled = W/sqrt(nn_density_level)"
                    )
                    scaled_params = []

                    for i in range(len(masked_params)):
                        if len(masked_params[i]) == 2:
                            scaled_params.append(
                                (masked_params[i][0] *
                                 np.sqrt(1 / nn_density_level),
                                 masked_params[i][1]))
                        else:
                            scaled_params.append(masked_params[i])

                    masked_params = scaled_params

                optimizer_with_params = optimizer_dict[OPTIMIZER_STR](
                    step_size=STEP_SIZE)

                opt_init, opt_update, get_params = optimizer_with_params

                opt_state = opt_init(masked_params)

                train_acc_list = []

                test_acc_list = []

                itercount = itertools.count()

                for iteration in range(total_batch):

                    batch_xs, batch_ys = next(gen_batches)

                    batch_xs = batch_xs.reshape(batch_input_shape)

                    if key_dropout is not None:
                        key_dropout, subkey_dropout = random.split(key_dropout)

                    opt_state = step(next(itercount),
                                     opt_state,
                                     batch_xs,
                                     batch_ys,
                                     masks=masks,
                                     key=subkey_dropout)

                    if iteration % RECORD_ACC_FREQ == 0:

                        masked_trans_params = get_params(opt_state)

                        train_acc = accuracy(masked_trans_params, f_test,
                                             train_input, train_label,
                                             key_dropout)
                        test_acc = accuracy(masked_trans_params, f_test,
                                            test_input, test_label,
                                            key_dropout)

                        train_acc_list.append(train_acc)
                        test_acc_list.append(test_acc)

                        logger.info(
                            "NN density %.2f | Run %03d/%03d | Iteration %03d/%03d | Train acc %.2f%% | Test acc %.2f%%",
                            nn_density_level, run_index,
                            self.ntt_setup_dict['NUM_RUNS'], iteration + 1,
                            total_batch, train_acc * 100, test_acc * 100)

                trained_masked_trans_params = get_params(opt_state)

                train_acc_list_runs.append(train_acc_list)
                test_acc_list_runs.append(test_acc_list)
                trained_masked_params_runs.append(trained_masked_trans_params)

            train_acc_list_runs = np.array(train_acc_list_runs)
            test_acc_list_runs = np.array(test_acc_list_runs)

            train_results_dict[str(nn_density_level)] = train_acc_list_runs
            test_results_dict[str(nn_density_level)] = test_acc_list_runs
            trained_masked_dict[str(
                nn_density_level)] = trained_masked_params_runs

            if save_supervised_result_bool:

                supervised_model_wiring_dir_run = supervised_model_wiring_dir + '/density_' + str(
                    round(nn_density_level, 2)) + '/'

                while os.path.exists(supervised_model_wiring_dir_run):
                    temp = supervised_model_wiring_dir_run + '_0'
                    supervised_model_wiring_dir_run = temp

                os.makedirs(supervised_model_wiring_dir_run)

                model_summary_str = '[u]' + self.ntt_file_name + '_[s]' + dataset_str + '_density_' + str(
                    round(nn_density_level, 2))

                np.save(
                    supervised_model_wiring_dir_run + '/' +
                    'supervised_trained_' + model_summary_str, [
                        nn_density_level, train_acc_list_runs,
                        test_acc_list_runs, trained_masked_params_runs
                    ])

        output = dict(train_results=train_results_dict,
                      test_results=test_results_dict,
                      trained_params=trained_masked_dict)

        return output
Exemplo n.º 22
0
def train(rng,
          env,
          batch_size=128,
          num_epochs=5,
          num_iterations=21,
          num_samples=100,
          print_every=10,
          episodes=100,
          k_min=1,
          k_max=25,
          verbose=False,
          params_filepath=None,
          savepath=None):
    """
    Train the model function by generating simulations of random-play.
    On every epoch generate a new simulation and run multiple iterations.
    On every iteration evaluate the targets using the most recent model parameters
    and run multiple times through the dataset.
    At the end of every epoch check the performance and store the best performing params.
    If the performance drops then decay the step size parameter.

    @param rng (PRNGKey): A pseudo-random number generator.
    @param env (Cube Object): A Cube object representing the environment.
    @param batch_size (int): Size of minibatches used to compute loss and gradient during training.         [optional]
    @param num_epochs (int): The number of epochs to run for during training.                               [optional]
    @param num_iterations (int): The number of iterations through the generated episodes.                   [optional]
    @param num_samples (int): The number of times the dataset is reused.                                    [optional]
    @param print_every (int): An integer. Training progress will be printed every `print_every` iterations. [optional]
    @param episodes (int): Number of episodes to be created.                                                [optional]
    @param k_min (int): Minimum length of sequence of backward moves.                                       [optional]
    @param k_max (int): Maximum length of sequence of backward moves.                                       [optional]
    @param clip_norm (float): A scalar for gradient clipping.                                               [optional]
    @param verbose (bool): If set to false then no output will be printed during training.                  [optional]
    @param params_filepath (str): File path to save the model parameters.                                  [optional]
    @returns params (pytree): The best model parameters obatained during training.                          [optional]
    @returns loss_history (List): A list containing the mean loss computed during each iteration.           [optional]
    """
    # Initialize model parameters and optimizer state.
    rng, init_rng = jax.random.split(rng)
    input_shape = (-1, ) + env.terminal_state.shape
    params = None
    if params_filepath is None:
        _, params = init_fun(init_rng, input_shape)
    else:
        params = list(jnp.load(params_filepath, allow_pickle=True))

    _solved_state = np.expand_dims(env.terminal_state, axis=0)
    _solved_state = jnp.array(_solved_state)
    # Set Numpy seed
    np.random.seed(17)

    # Generate test states
    test_set = []
    x = env()
    for _ in range(1000):
        x.reset()
        x.shuffle(np.random.randint(5, 15))
        test_set.append(x.state.copy())
    del x
    # Generate test episodes
    test_episodes_shape = (1000, 20)
    test_episodes = generate_episodes(env, *test_episodes_shape)[0]

    loss_history = []
    progress = []
    p_iteration_fmt = 'Iteration, {}, {}, {:.1f}, {:.3f}\n'
    p_epoch_fmt = 'Epoch {}, {}, {:.1f}, {:.3f}\n\n'

    # Begin training.
    decays = np.hstack(
        [np.ones(20, dtype=np.float32),
         np.linspace(1.0, 0.2, 64)])
    for e in range(num_epochs):
        decay = decays[e] if e < len(decays) else 0.2
        tic = time.time()
        opt_state = opt_init(params)

        # Generate data from random-play using the environment.
        states, w, children, rewards = generate_episodes(
            env, episodes, k_max, decay)

        # Train the model on the generated data. Periodically recompute the target values.
        epoch_mean_loss = 0.0
        for it in range(num_iterations):
            tic_it = time.time()

            # Make targets for the generated episodes using the most recent params and build a batch generator.
            params = get_params(opt_state)
            tgt_vals = make_targets(children, rewards, params)
            data = {"X": states, "y": tgt_vals, "w": w}
            rng, sub_rng = jax.random.split(rng)
            train_batches = batch_generator(sub_rng, data, batch_size)

            # Run through the dataset and update model params.
            total_loss = 0.0
            for i in range(num_samples):
                batch = next(train_batches)
                loss, opt_state = update(it * num_samples + i, opt_state,
                                         batch)
                total_loss += loss

            # Book-keeping.
            iter_mean_loss = total_loss / num_samples
            epoch_mean_loss = (it * epoch_mean_loss + iter_mean_loss) / (it +
                                                                         1)
            loss_history.append(iter_mean_loss)

            # Iteration verbose
            if it % print_every == 0 and verbose:
                toc_it = time.time()
                progress.append(
                    p_iteration_fmt.format(it + 1, num_iterations,
                                           toc_it - tic_it, iter_mean_loss))
                print(progress[-1], end='')

        # Record the time needed for a single epoch.
        toc = time.time()
        progress.append(
            p_epoch_fmt.format(e + 1, num_epochs, toc - tic, epoch_mean_loss))
        # Do Value evaluation of TEST EPISIDES
        Vs = apply_fun(params, test_episodes).reshape(test_episodes_shape)
        Vmeans = np.mean(Vs, axis=0)
        Vsolved = float(apply_fun(params, _solved_state))
        print('Distance  0 states mean V:: {:.3f}'.format(Vsolved))
        for q, m in enumerate(Vmeans, 1):
            print('Distance {:2} states mean V: {:.3f}'.format(q, m))
        # Do Best First Search evaluation of TEST SET
        bs_solved = [beam_search(env, s, params, apply_fun) for s in test_set]
        bs_rate = sum(bs_solved) / len(bs_solved)
        print('GBFS solution rate: {:.2f}'.format(bs_rate))

        # Epoch verbose
        if verbose:
            print(progress[-1], end='')
        # Save parameters and output
        if savepath:
            pfilepath = os.path.join(savepath, 'params' + str(e))
            ofilepath = os.path.join(savepath, 'output' + str(e))
            jnp.save(pfilepath, params)
            with open(ofilepath, mode='w') as f:
                f.writelines(progress)

    return params, loss_history
Exemplo n.º 23
0
    def optimize(self, return_teacher_params_bool = False):    
        """  Carry out the optimization task.
        Arg:
            run_index: the index of independent run of the optimization.
            save_dir: the directory used to save the transferred results.
            
        Returns:
            nt_trans_params_all_runs: the transferred parameters
            nt_trans_masks_all_runs: the transferred masks
            nt_trans_vali_all_runs: a collection o fvalidation loss during training.
        """

        gen_batches = self.DATASET.data_stream(self.BATCH_SIZE)
        
        num_complete_batches, leftover = divmod(self.DATASET.num_example['train'], self.BATCH_SIZE)

        # number of minibatches per epoch
        num_mini_batches_per_epochs = num_complete_batches +  bool(leftover)

        # number of total iterations
        num_total_iters  = self.NUM_EPOCHS * num_mini_batches_per_epochs

        # number of time that the sparisty levels get updated
        num_sparsity_updates = num_total_iters // self.MASK_UPDATE_FREQ 
     
        mask_update_limit =  num_total_iters - self.MASK_UPDATE_FREQ
    
        if self.SAVE_BOOL == True:
            # save the transferred results in the desinated directory.

            trans_model_dir = self.unique_model_dir

#             while os.path.exists(trans_model_dir):
#                 trans_model_dir = trans_model_dir + '_0'
            
            os.makedirs(trans_model_dir)

            np.save(trans_model_dir + '/param_dict.npy', self.param_dict) 
            
            

        nt_trans_params_all_sparsities_all_runs = []
        nt_trans_masks_all_sparsities_all_runs = []
        nt_trans_vali_all_sparsities_all_runs = []
        teacher_params_all_sparsities_all_runs = []
        
        
        num_sparisty_levels = len(self.NN_DENSITY_LEVEL_LIST) 
        num_runs = len(range(self.INIT_RUN_INDEX, self.INIT_RUN_INDEX + self.NUM_RUNS ))
        all_density_all_run_num_total_iters = num_sparisty_levels * num_runs * num_total_iters
        
        
        for nn_density_level in self.NN_DENSITY_LEVEL_LIST:   
            
            
            nt_trans_params_all_runs = []
            nt_trans_masks_all_runs = []
            nt_trans_vali_all_runs = []
            teacher_params_all_runs = []


            for run_index in range(self.INIT_RUN_INDEX, self.INIT_RUN_INDEX + self.NUM_RUNS ):

                # do logging
                for handler in logging.root.handlers[:]:
                    logging.root.removeHandler(handler)

                # a string that summarizes the current ntt experiment
                model_summary_str =  self.model_str + '_density_' + str(round(nn_density_level, 2) ) + '_run_' + str(run_index)
                prof_model_summary_str = self.prof_model_str + '_density_' + str(round(nn_density_level, 2)) + '_run_' + str(run_index)

                if self.SAVE_BOOL == True:
                    model_dir_density_run = trans_model_dir + '/' + 'density_' + str(round(nn_density_level, 2) ) + '/' + 'run_' +  str(run_index) + '/'

                    os.makedirs(model_dir_density_run)
                    
                    logging.basicConfig(filename = model_dir_density_run + "/" + model_summary_str + "_log.log", format='%(asctime)s %(message)s', filemode='w', level=logging.DEBUG)

                else: 
                    logging.basicConfig(filename = model_summary_str + "_log.log" , format='%(asctime)s %(message)s', filemode='w', level=logging.DEBUG)
                
                
                # for different run indices, randomly draw teacher net's parameters
                _, teacher_net_params = self.init_fun(random.PRNGKey(run_index), tuple(self.batch_input_shape))
                _, prof_net_params = self.prof_init_fun(random.PRNGKey(run_index), tuple(self.prof_batch_input_shape))
                                
                # the prediction of the teacher net evaluated on validation samples
                #vali_teacher_prediction = self.apply_fn(teacher_net_params, self.vali_samples, rng=random.PRNGKey(run_index))
                vali_prof_prediction = self.prof_apply_fn(prof_net_params, self.prof_vali_samples, rng=random.PRNGKey(run_index))

                #vali_teacher_ntk_mat = self.emp_ntk_fn(self.vali_inputs_1, self.vali_inputs_2, teacher_net_params, keys=random.PRNGKey(run_index))
                vali_prof_ntk_mat = self.prof_emp_ntk_fn(self.prof_vali_inputs_1, self.prof_vali_inputs_2, prof_net_params, keys=random.PRNGKey(run_index))

                # the initial binary mask
                
                if self.PRUNE_METHOD == 'magnitude':                
                    masks = get_masks_from_jax_params(teacher_net_params, nn_density_level, global_bool = self.GLOBAL_PRUNE_BOOL)
                elif self.PRUNE_METHOD == 'logit_snip':
                    logger.info("Use logit snip method to get the initial mask")
                    num_examples_snip = 128

#                     gen_batches_logit_snip = self.DATASET.data_stream(num_examples_snip)
                    
                    snip_input = self.DATASET.dataset['train']['input'][:num_examples_snip, :]
                    
                    if self.GLOBAL_PRUNE_BOOL == False:
                        logger.warning("layerwise sparse net initialized with logit_snip")                        
                    masks = get_logit_snip_masks(teacher_net_params, nn_density_level, self.apply_fn, snip_input, self.batch_input_shape, GlOBAL_PRUNE_BOOL = self.GLOBAL_PRUNE_BOOL) 
                else:
                    raise NotImplementedError("not implemented")
    

                # the initial student parameters
                masked_student_net_params = get_sparse_params_filtered_by_masks(teacher_net_params, masks)

                # instantiate the optimizer triple 
                opt_init, opt_update, get_params = self.OPTIMIZER_WITH_PARAMS

                opt_state = opt_init(teacher_net_params) #optimize toward teacher
                #opt_state = opt_init(prof_net_params) #optimize toward professor

                # one step of NTK transfer
                @jit
                def nt_transfer_step(i, opt_state, x, prof_x, masks):

                    # parameters in the current optimizer state
                    student_net_params = get_params(opt_state)

                    # gradients that flow through the binary masks
                    masked_g = grad(self.nt_transfer_loss)(student_net_params, masks, teacher_net_params, x, prof_x, prof_net_params, nn_density_level, key=run_index)

                    return opt_update(i, masked_g, opt_state)

                # a list of validation loss
                vali_loss_list = []

                # calculate the initial validation loss. 
                vali_loss = self.eval_nt_transfer_loss_on_vali_data(masked_student_net_params, vali_prof_prediction, vali_prof_ntk_mat, nn_density_level, key=run_index)

                vali_loss_list.append(vali_loss)

                logger.info("Before transfer: trans dist %.3f | ntk dist %.3f | targ dist %.3f | l2 pentalty %.3f | nn density %.2f", vali_loss[0], vali_loss[1], vali_loss[2], vali_loss[3], nn_density_level)
                itercount = itertools.count()

                t = time.time()

                # loop through iterations
                for num_iter in range(1, num_total_iters + 1): 
                    
                    # a batch of input data
                    batch_xs, _ = next(gen_batches)                

                    # reshape the input to a proper format (2d array for MLP and 3d for CNN)
                    batch_xs = batch_xs.reshape(self.batch_input_shape)
                    prof_batch_xs = batch_xs.reshape(self.prof_batch_input_shape)

                    # update the optimizer state
                    opt_state = nt_transfer_step(next(itercount), opt_state, batch_xs, prof_batch_xs, masks )


                    if num_iter % 100 == 0:
                        elapsed_time = time.time() - t
                        
                        if (num_iter <= 500) and (run_index == self.INIT_RUN_INDEX) and (nn_density_level == self.NN_DENSITY_LEVEL_LIST[0]):  
                            # estimate the program end time.
                            remaining_iter_num = all_density_all_run_num_total_iters - num_iter
                            remaining_seconds = elapsed_time * ( remaining_iter_num / 100 )
                            expected_end_time = str(datetime.now() + timedelta(seconds = remaining_seconds))

                        # get parameters from the current optimizer state
                        student_net_params = get_params(opt_state) 

                        # filter the paramters by masks
                        masked_student_net_params = get_sparse_params_filtered_by_masks(student_net_params , masks)
                        
                        # validation loss
                        vali_loss = self.eval_nt_transfer_loss_on_vali_data(masked_student_net_params, vali_prof_prediction, vali_prof_ntk_mat, nn_density_level, key=run_index)

                        vali_loss_list.append(vali_loss)

                        logger.info('run: %02d/%02d | iter %04d/%04d | trans. dist %.3f | ntk dist %.3f | targ. dist %.3f | l2 %.3f | nn density %.2f | time %.2f [s] | expected finish time %s', run_index, self.NUM_RUNS + self.INIT_RUN_INDEX - 1, num_iter, num_total_iters, vali_loss[0], vali_loss[1], vali_loss[2], vali_loss[3], nn_density_level, elapsed_time, expected_end_time)
                        t = time.time()


                    if (num_iter % self.MASK_UPDATE_FREQ == 0) and (num_iter < mask_update_limit):
                        # get parameters from the current optimizer state
                        student_net_params = get_params(opt_state) 
                        
                        # update masks
                        masks = get_masks_from_jax_params(student_net_params, nn_density_level, global_bool = self.GLOBAL_PRUNE_BOOL)
                        
#                         if self.PRUNE_METHOD == 'logit_snip':
#                             logit_snip_batch_xs, _ = next(gen_batches_logit_snip)
#                             masks = get_logit_snip_masks(student_net_params, nn_density_level, self.apply_fn, snip_input, self.batch_input_shape, GlOBAL_PRUNE_BOOL = self.GLOBAL_PRUNE_BOOL) 
#                         else:
#                             masks = get_masks_from_jax_params(student_net_params, nn_density_level, global_bool = self.GLOBAL_PRUNE_BOOL)


                
                elapsed_time = time.time() - t
                
                student_net_params = get_params(opt_state) 
                
                # filter the paramters by masks
                masked_student_net_params = get_sparse_params_filtered_by_masks(student_net_params , masks)
                
                vali_loss = self.eval_nt_transfer_loss_on_vali_data(masked_student_net_params, vali_prof_prediction, vali_prof_ntk_mat, nn_density_level, key=run_index)

                vali_loss_list.append(vali_loss)
                
                logger.info('run: %02d/%02d | iter %04d/%04d | trans. dist %.3f | ntk dist %.3f | targ. dist %.3f | l2 %.3f | nn density %.2f | time %.2f [s]', run_index, self.NUM_RUNS + self.INIT_RUN_INDEX - 1, num_iter, num_total_iters, vali_loss[0], vali_loss[1], vali_loss[2], vali_loss[3], nn_density_level, elapsed_time )
                            
                vali_loss_array = np.array(vali_loss_list)

                nt_trans_params_all_runs.append(masked_student_net_params)
                nt_trans_masks_all_runs.append(masks)
                nt_trans_vali_all_runs.append(vali_loss_array)
                teacher_params_all_runs.append(teacher_net_params )

                if self.SAVE_BOOL == True:

                    model_summary_str =  self.model_str + '_density_' + str(round(nn_density_level, 2) ) + '_run_' + str(run_index)
                    prof_model_summary_str = self.prof_model_str + '_density_' + str(round(nn_density_level, 2)) + '_run_' + str(run_index)

                    prof_param_fileName = model_dir_density_run + 'prof_params_' + prof_model_summary_str
                    np.save(prof_param_fileName, prof_net_params)

                    teacher_param_fileName = model_dir_density_run + 'teacher_params_' + model_summary_str
                    np.save(teacher_param_fileName, teacher_net_params)

                    student_param_fileName = model_dir_density_run + 'transferred_params_' + model_summary_str
                    np.save(student_param_fileName, masked_student_net_params)

                    mask_fileName = model_dir_density_run + 'transferred_masks_' + model_summary_str
                    np.save(mask_fileName, masks)

                    loss_array_fileName = model_dir_density_run + 'loss_array_' + model_summary_str
                    np.save(loss_array_fileName, vali_loss_array)
            

        nt_trans_params_all_sparsities_all_runs.append( nt_trans_params_all_runs )
        nt_trans_masks_all_sparsities_all_runs.append( nt_trans_masks_all_runs )
        nt_trans_vali_all_sparsities_all_runs.append( nt_trans_vali_all_runs )
        teacher_params_all_sparsities_all_runs.append( teacher_params_all_runs )
                    
        if return_teacher_params_bool:
            return nt_trans_params_all_sparsities_all_runs, nt_trans_masks_all_sparsities_all_runs, nt_trans_vali_all_sparsities_all_runs, teacher_params_all_sparsities_all_runs

        else:
            return nt_trans_params_all_sparsities_all_runs, nt_trans_masks_all_sparsities_all_runs, nt_trans_vali_all_sparsities_all_runs
Exemplo n.º 24
0
def save_model(path, params, steps):
    with open(path, 'wb') as f:
        jnp.save(f, params)
    with open(str(path) + '.steps', 'wb') as f:
        jnp.save(f, steps)
Exemplo n.º 25
0
import jax.numpy as np
from jax import grad, value_and_grad, jit, vmap
from jax import random
import tensorflow as tf
import numpy as onp

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

np.save('mnist.npy', (x_train, y_train, x_test, y_test))
Exemplo n.º 26
0
    def solve_direct(self, states, controls, T, homotopy, boundaries):

        # sanity
        assert states.shape[0] == controls.shape[0]
        assert states.shape[1] == self.state_dim
        assert controls.shape[1] == self.control_dim

        # system parameters
        params = self.params.values()

        # number of collocation nodes
        n = states.shape[0]

        # decision vector bounds
        @jit
        def get_bounds():
            zl = np.hstack((self.state_lb, self.control_lb))
            zl = np.tile(zl, n)
            zl = np.hstack(([0.0], zl))
            zu = np.hstack((self.state_ub, self.control_ub))
            zu = np.tile(zu, n)
            zu = np.hstack(([np.inf], zu))
            return zl, zu

        # decision vector maker
        @jit
        def flatten(states, controls, T):
            z = np.hstack((states, controls)).flatten()
            z = np.hstack(([T], z))
            return z

        # decsision vector translator
        @jit
        def unflatten(z):
            T = z[0]
            z = z[1:].reshape(n, self.state_dim + self.control_dim)
            states = z[:, :self.state_dim]
            controls = z[:, self.state_dim:]
            return states, controls, T

        # fitness vector
        print('Compiling fitness...')

        @jit
        def fitness(z):

            # translate decision vector
            states, controls, T = unflatten(z)

            # time grid
            n = states.shape[0]
            times = np.linspace(0, T, n)

            # objective
            L = vmap(lambda state, control: self.lagrangian(
                state, control, homotopy, *params))
            L = L(states, controls)
            J = np.trapz(L, dx=T / (n - 1))

            # Lagrangian state dynamics constraints, and boundary constraints
            # e0 = self.collocate_lagrangian(states, controls, times, costs, homotopy, *params)
            e1 = self.collocate_state(states, controls, times, *params)
            e2, e3 = boundaries(states[0, :], states[-1, :])
            e = np.hstack((e1.flatten(), e2, e3))**2

            # fitness vector
            return np.hstack((J, e))

        # z = flatten(states, controls, T)
        # fitness(z)

        # sparse Jacobian
        print('Compiling Jacobian and its sparsity...')
        gradient = jit(jacfwd(fitness))
        z = flatten(states, controls, T)
        sparse_id = np.vstack((np.nonzero(gradient(z)))).T
        sparse_gradient = jit(lambda z: gradient(z)[[*sparse_id.T]])
        gradient_sparsity = jit(lambda: sparse_id)
        print('Jacobian has {} elements.'.format(sparse_id.shape[0]))

        # assign PyGMO problem methods
        self.fitness = fitness
        self.gradient = sparse_gradient
        self.gradient_sparsity = gradient_sparsity
        self.get_bounds = get_bounds
        self.get_nobj = jit(lambda: 1)
        nec = fitness(z).shape[0] - 1
        self.get_nec = jit(lambda: nec)

        # plot before
        states, controls, T = unflatten(z)
        self.plot('../img/direct_before.png', states, dpi=1000)

        # solve NLP with IPOPT
        print('Solving...')
        prob = pg.problem(udp=self)
        algo = pg.ipopt()
        algo.set_integer_option('max_iter', 1000)
        algo = pg.algorithm(algo)
        algo.set_verbosity(1)
        pop = pg.population(prob=prob, size=0)
        pop.push_back(z)
        pop = algo.evolve(pop)

        # save and plot solution
        z = pop.champion_x
        np.save('decision.npy', z)
        states, controls, T = unflatten(z)
        self.plot('../img/direct_after.png', states, dpi=1000)
Exemplo n.º 27
0
    def save_fake_data(n_obj, n_pix_sed, n_pix_spec, n_pix_phot,
                       n_pix_transfer):

        root = "data/fake/fake_"
        from jax.random import uniform, randint

        np.save(root + "lamgrid.npy", 8.1e2 + np.arange(n_pix_sed))
        np.save(root + "lam_phot_eff.npy", np.arange(n_pix_phot))
        np.save(root + "lam_phot_size_eff.npy", np.arange(n_pix_phot))
        np.save(
            root + "transferfunctions.npy",
            uniform(key, (n_pix_transfer, n_pix_sed, n_pix_phot)),
        )
        np.save(root + "transferfunctions_zgrid.npy",
                np.arange(n_pix_transfer))

        np.save(root + "chi2s_sdss.npy", uniform(key, (n_obj, )))
        np.save(
            root + "lamspec_waveoffset.npy",
            randint(key, (1, ), 0, n_pix_sed - n_pix_spec - 1),
        )
        np.save(
            root + "index_wave.npy",
            randint(key, (n_obj, ), 0, n_pix_sed - n_pix_spec - 1),
        )
        np.save(
            root + "index_transfer_redshift.npy",
            randint(key, (n_obj, ), 0, n_pix_transfer),
        )
        np.save(
            root + "interprightindices.npy",
            randint(key, (n_obj, n_pix_spec), 0, n_pix_transfer),
        )
        np.save(
            root + "interpweights.npy",
            uniform(key, (n_obj, n_pix_spec)),
        )
        np.save(
            root + "interprightindices_transfer.npy",
            randint(key, (n_obj, ), 0, n_pix_transfer),
        )
        np.save(
            root + "interpweights_transfer.npy",
            uniform(key, (n_obj, )),
        )
        np.save(root + "spec.npy", uniform(key, (n_obj, n_pix_spec)))
        np.save(root + "spec_mod.npy", uniform(key, (n_obj, n_pix_spec)))
        np.save(root + "spec_invvar.npy", uniform(key, (n_obj, n_pix_spec)))
        np.save(root + "phot.npy", uniform(key, (n_obj, n_pix_phot)))
        np.save(root + "phot_invvar.npy", uniform(key, (n_obj, n_pix_phot)))
        np.save(root + "redshifts.npy", uniform(key, (n_obj, )))
Exemplo n.º 28
0
def save_to_disk(f, params):
    return jnp.save(f, params)
Exemplo n.º 29
0
def save(file, arr, allow_pickle=True, fix_imports=True):
  arr = _remove_jaxarray(arr)
  jnp.save(file, arr, allow_pickle, fix_imports)
Exemplo n.º 30
0
    def train(self, num_epochs=1, batch_size=200):
        """Performs one full run of training"""

        #train_images, train_labels, test_images, test_labels = data_preprocessing()

        train_images = jnp.load('data/x_train.npy')
        train_labels = jnp.load('data/y_train.npy')
        test_images = jnp.load('data/x_test.npy')
        test_labels = jnp.load('data/y_test.npy')

        if not self.CNN:
            train_images = jnp.reshape(train_images, self.input_shape)
            test_images = jnp.reshape(test_images, self.input_shape)

        num_train = train_images.shape[0]
        num_complete_batches, leftover = divmod(num_train, batch_size)
        num_batches = num_complete_batches + bool(leftover)
        batches = self.data_stream(train_images, train_labels, num_train,
                                   num_batches, batch_size)

        train_acc_list = []
        test_acc_list = []
        loss_list = []

        train_acc_list.append(
            self.accuracy(self.init_params, (train_images, train_labels)))
        test_acc_list.append(
            self.accuracy(self.init_params, (test_images, test_labels)))

        opt_state = self.opt_init(self.init_params)
        itercount = 0

        print("\nStarting training...")
        for epoch in range(num_epochs):
            start_time = time.time()
            for i in range(num_batches):
                opt_state, losses = self.update(itercount, opt_state,
                                                next(batches))
                loss_list.append(losses)

                #params = self.get_params(opt_state)
                #train_acc = self.accuracy(params, (train_images, train_labels))
                #test_acc = self.accuracy(params, (test_images, test_labels))
                #train_acc_list.append(train_acc)
                #test_acc_list.append(test_acc)

                itercount += 1
            epoch_time = time.time() - start_time

            params = self.get_params(opt_state)
            train_acc = self.accuracy(params, (train_images, train_labels))
            test_acc = self.accuracy(params, (test_images, test_labels))
            train_acc_list.append(train_acc)
            test_acc_list.append(test_acc)
            print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
            print("Training set accuracy {}".format(train_acc))
            print("Test set accuracy {}".format(test_acc))

        self.opt_state = opt_state
        self.params = params
        jnp.save('data/params.npy',
                 params)  # save network parameters for later use
        return train_acc_list, test_acc_list, loss_list