コード例 #1
0
ファイル: dsn_util.py プロジェクト: cunningham-lab/dsn
def initialize_gauss_nf(D, arch_dict, random_seed, gauss_initdir, bounds=None):
    mu_init = arch_dict["mu_init"]
    Sigma_init = arch_dict["Sigma_init"]
    if bounds is not None:
        # make this more flexible for single bounds
        fam_class = family_from_str("truncated_normal")
        family = fam_class(D, a=bounds[0], b=bounds[1])
    else:
        fam_class = family_from_str("normal")
        family = fam_class(D)

    if mu_init is None:
        mu_init = np.zeros((D, ))

    params = {"mu": mu_init, "Sigma": Sigma_init, "dist_seed": 0}
    n = 1000
    lr_order = -3
    check_rate = 100
    min_iters = 20000
    max_iters = 50000
    converged = False
    while not converged:
        converged = train_nf(
            family,
            params,
            arch_dict,
            n,
            lr_order,
            random_seed,
            min_iters,
            max_iters,
            check_rate,
            None,
            profile=False,
            savedir=gauss_initdir,
        )
        if converged:
            print("done initializing gaussian NF")
        else:
            max_iters = 4 * max_iters
    return converged
コード例 #2
0
give_inverse_hint = int(sys.argv[3]) == 1
random_seed = int(sys.argv[4])
dir_str = str(sys.argv[5])

profile = True

TIF_flow_type, nlayers, scale_layer, lr_order = model_opt_hps(exp_fam, D)
nlayers = 30
lr_order = -3

flow_dict = {'latent_dynamics':None, \
    'scale_layer':False, \
             'TIF_flow_type':TIF_flow_type, \
             'repeats':nlayers}

fam_class = family_from_str(exp_fam)
family = fam_class(D)

family.load_data()
family.select_train_test_sets(500)

param_net_input_type = 'eta'
cost_type = 'KL'
K_eta = 100
M_eta = 200
stochastic_eta = True
dist_seed = 0

min_iters = 25
max_iters = 25
check_rate = 100
コード例 #3
0
ファイル: plot_util.py プロジェクト: t-rutten/efn
def load_dim_sweep(exp_fam,
                   model,
                   datadir,
                   Ds,
                   K,
                   M,
                   give_hint,
                   max_iters,
                   num_rs=10):
    num_Ds = len(Ds)
    if (model == 'EFN'):
        num_dists = K
    else:
        num_dists = num_rs
    elbos = np.zeros((num_Ds, num_dists))
    R2s = np.zeros((num_Ds, num_dists))
    KLs = np.zeros((num_Ds, num_dists))

    not_started = []
    in_progress = []
    unstable = []
    status_lists = [not_started, in_progress, unstable]

    if (give_hint):
        give_inv_str = 'giveInv_'
    else:
        give_inv_str = ''

    for i in range(num_Ds):
        D = Ds[i]
        fam_class = family_from_str(exp_fam)
        family = fam_class(D)
        D_Z, ncons, num_param_net_inputs, num_Tx_inputs = family.get_efn_dims(
            'eta', give_hint)
        #planar_flows = D;
        planar_flows = max(D, 20)
        flow_dict = get_flowdict(0, planar_flows, 0, 0)
        flowstring = get_flowstring(flow_dict)
        L = max(int(np.ceil(np.sqrt(D_Z))), 4)
        if (model == 'EFN'):
            fname = datadir + 'EFN_%s_stochasticEta_%sD=%d_K=%d_M=%d_flow=%s_L=%d_rs=%d/opt_info.npz' \
                                   % (exp_fam, give_inv_str, D, K, M, flowstring, L, 0)
            elbos[i, :], R2s[i, :], KLs[i, :], status = get_latest_diagnostics(
                fname, max_iters)
            log_fname(fname, status, status_lists)
        else:
            for rs in range(num_rs):
                if (model == 'NF1'):
                    fname = datadir + 'NF1/NF1_%s_D=%d_flow=%s_rs=%d/opt_info.npz' % (
                        exp_fam, D, flowstring, rs + 1)
                elif (model == 'EFN1'):
                    fname = datadir + 'EFN1/EFN_%s_fixedEta_%sD=%d_K=%d_M=%d_flow=%s_L=%d_rs=%d/opt_info.npz' \
                           % (exp_fam, give_inv_str, D, 1, M, flowstring, L, rs+1)
                elbos[i,
                      rs], R2s[i,
                               rs], KLs[i,
                                        rs], status = get_latest_diagnostics(
                                            fname, max_iters)
                log_fname(fname, status, status_lists)

    print_file_statuses(status_lists)

    return elbos, R2s, KLs