Ejemplo n.º 1
0
def run(args):
    # set seed
    np.random.seed(0)
    torch.manual_seed(0)

    # generate initial params, set to model
    n_inputs = 784
    n_outputs = 10
    W = np.random.random((n_outputs, n_inputs))
    b = np.random.random((n_outputs, ))

    # print ("W, b: ", W, b)
    # input("")
    model = Model(n_inputs, n_outputs, param_inits=[W, b])

    # Load mnist by keras for consistency with tf
    (X_train, y_train), (X_test, y_test) = load_mnist()

    # fix batch size
    bs = 128

    # create optimizer
    import adacurv.torch.optim as fisher_optim
    common_kwargs = dict(
        lr=args.lr,
        curv_type=args.curv_type,
        cg_iters=args.cg_iters,
        cg_residual_tol=args.cg_residual_tol,
        cg_prev_init_coef=args.cg_prev_init_coef,
        cg_precondition_empirical=args.cg_precondition_empirical,
        cg_precondition_regu_coef=args.cg_precondition_regu_coef,
        cg_precondition_exp=args.cg_precondition_exp,
        shrinkage_method=args.shrinkage_method,
        lanczos_amortization=args.lanczos_amortization,
        lanczos_iters=args.lanczos_iters,
        batch_size=args.batch_size)

    optimizer = fisher_optim.NaturalAdam(
        model.parameters(),
        **common_kwargs,
        betas=(args.beta1, args.beta2),
        assume_locally_linear=args.approx_adaptive)

    # compute a few iterations of optimizer and log data
    for i in range(1):
        data = X_train[bs * i:bs * (i + 1)]
        target = y_train[bs * i:bs * (i + 1)]
        optimization_step(model, optimizer, data, target)
Ejemplo n.º 2
0
def make_optimizer(args, model):
    if args.optim == "sgd":
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
    elif args.optim == "adam":
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    elif args.optim == "amsgrad":
        optimizer = optim.Adam(model.parameters(), lr=args.lr, amsgrad=True)
    elif args.optim == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
    elif args.optim == 'adagrad':
        optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
    else:
        import adacurv.torch.optim as fisher_optim

        common_kwargs = dict(lr=args.lr,
                             curv_type=args.curv_type,
                             cg_iters=args.cg_iters,
                             cg_residual_tol=args.cg_residual_tol,
                             cg_prev_init_coef=args.cg_prev_init_coef,
                             cg_precondition_empirical=args.cg_precondition_empirical,
                             cg_precondition_regu_coef=args.cg_precondition_regu_coef,
                             cg_precondition_exp=args.cg_precondition_exp,
                             shrinkage_method=args.shrinkage_method,
                             lanczos_amortization=args.lanczos_amortization,
                             lanczos_iters=args.lanczos_iters,
                             batch_size=args.batch_size)
        if args.optim == 'ngd_bd':
            raise NotImplementedError
            # optimizer = fisher_optim.NGD_BD([{'params': model.fc1.parameters()},
            #                                  {'params': model.fc2.parameters()}],
            #                                 lr=args.lr,
            #                                 curv_type='gauss_newton',
            #                                 shrinkage_method=None,
            #                                 lanczos_iters=args.lanczos_iters,
            #                                 batch_size=args.batch_size)
            # optimizer = fisher_optim.NGD_BD([
            #                                  {'params': model.conv1.parameters()},
            #                                  {'params': model.conv2.parameters()},
            #                                  {'params': model.fc1.parameters()},
            #                                  {'params': model.fc2.parameters()}],
            #                                 lr=args.lr,
            #                                 curv_type='gauss_newton',
            #                                 shrinkage_method='cg',
            #                                 lanczos_iters=args.lanczos_iters,
            #                                 batch_size=args.batch_size)
        elif args.optim == 'ngd':
            optimizer = fisher_optim.NGD(model.parameters(), **common_kwargs)
        elif args.optim == 'natural_adam':
            optimizer = fisher_optim.NaturalAdam(model.parameters(),
                                                 **common_kwargs,
                                                 betas=(args.beta1, args.beta2),
                                                 assume_locally_linear=args.approx_adaptive)
        elif args.optim == 'natural_adam_bd':
            block_diag_params = []
            mods = model.children()
            for m in mods:
                print (m)
                block_diag_params.append({'params': m.parameters()})
            print (block_diag_params)
            optimizer = fisher_optim.NaturalAdam_BD(block_diag_params,
                                                        **common_kwargs,
                                                        betas=(args.beta1, args.beta2),
                                                        assume_locally_linear=args.approx_adaptive)
        elif args.optim == 'natural_amsgrad':
            optimizer = fisher_optim.NaturalAmsgrad(model.parameters(),
                                                    **common_kwargs,
                                                    betas=(args.beta1, args.beta2),
                                                    assume_locally_linear=args.approx_adaptive)
        elif args.optim == 'natural_adagrad':
            optimizer = fisher_optim.NaturalAdagrad(model.parameters(),
                                                    **common_kwargs,
                                                    assume_locally_linear=args.approx_adaptive)
        else:
            raise NotImplementedError

    print (optimizer)
    return optimizer
Ejemplo n.º 3
0
def run(rcp='rcp45', iters=500, batch=5000, use_gn=True, seed=0):
    np.random.seed(seed)
    torch.manual_seed(seed)

    # M in (m x n) of rank r
    # A is (m x r)
    # B is (n x r)
    # So A @ B.T is (m x n)

    M = np.load("climate_data/matrices/M_1900_2101_" + rcp + ".npy")
    W = np.load("climate_data/matrices/W_1900_2101_" + rcp + ".npy")

    bs = batch
    n_samples = np.count_nonzero(W)
    print("num sampes: ", n_samples)
    Wind = np.nonzero(W)
    Wind = randomize_windices(Wind)
    n_batches = int(np.ceil(n_samples / bs))
    print("n_batches: ", n_batches)

    r = 5
    m = M.shape[0]
    n = M.shape[1]
    M = torch.from_numpy(M).float()
    W = torch.from_numpy(W).float()

    fac = Factorization(m, n, r)

    print("r, m, n: ", r, m, n)
    print("A, B: ", fac.A.shape, fac.B.shape)

    if use_gn:
        optA = fisher_optim.NaturalAdam([fac.A, fac.B],
                                        lr=0.01,
                                        curv_type='gauss_newton',
                                        cg_prev_init_coef=0.0,
                                        cg_precondition_empirical=False,
                                        shrinkage_method=None,
                                        batch_size=bs,
                                        betas=(0.1, 0.1),
                                        assume_locally_linear=True)
    else:
        optA = optim.Adam([fac.A, fac.B], lr=0.01)

    P = fac(ids=Wind)
    print("P init: ", P.shape)

    init_error = mat_completion_loss(W, M, P, fac.A, fac.B, Wind)
    print("Init error: ", init_error / P.shape[0])

    # input("Start training?")
    best_error = float(init_error)

    from torch.optim.lr_scheduler import ReduceLROnPlateau
    scheduler = ReduceLROnPlateau(optA, 'min')

    for i in range(iters):
        Wind = randomize_windices(Wind)
        for j in range(n_batches):

            optA.zero_grad()

            ind1 = j * bs
            ind2 = (j + 1) * bs

            Windx, Windy = Wind
            batch_idx_idy = Windx[ind1:ind2], Windy[ind1:ind2]

            P = fac(ids=batch_idx_idy)

            error = mat_completion_loss(W, M, P, fac.A, fac.B, batch_idx_idy)
            error.backward()

            if use_gn:
                # import time
                # t1 = time.time()
                loss_closure = build_mat_completion_loss_closure_combined(
                    fac, W, M, batch_idx_idy)
                # t2 = time.time()
                # print ("Building loss closure time: ", (t2 - t1))
                optA.step(loss_closure)
            else:
                optA.step()

            if j % 10 == 0:
                print("Iter: ", i, ", batch: ", j, float(error) / P.shape[0])

        P = fac(Wind)
        error = float(mat_completion_loss(W, M, P, fac.A, fac.B, Wind))
        print("Iter: ", i, error / P.shape[0])
        scheduler.step(error)

        if error < best_error:
            P2 = fac()
            gn_str = 'gn' if use_gn else 'adam'
            np.save(
                'models/P_' + rcp + '_rank' + str(r) + '_' + gn_str + '.npy',
                P2.data.numpy())
            best_error = error
Ejemplo n.º 4
0
def launch_job(tag, variant):

    print(len(variant))
    seed, env, algo, optim, curv_type, lr, batch_size, cg_iters, cg_residual_tol, cg_prev_init_coef, \
        cg_precondition_empirical, cg_precondition_regu_coef, cg_precondition_exp,  \
        shrinkage_method, lanczos_amortization, lanczos_iters, approx_adaptive, betas, use_nn_policy, gn_vfn_opt, total_samples = variant
    beta1, beta2 = betas

    iters = int(total_samples / batch_size)

    # NN policy
    # ==================================
    e = GymEnv(env)
    if use_nn_policy:
        policy = MLP(e.spec, hidden_sizes=(64, ), seed=seed)
    else:
        policy = LinearPolicy(e.spec, seed=seed)
    vfn_batch_size = 256 if gn_vfn_opt else 64
    vfn_epochs = 2 if gn_vfn_opt else 2
    # baseline = MLPBaseline(e.spec, reg_coef=1e-3, batch_size=64, epochs=2, learn_rate=1e-3)
    baseline = MLPBaseline(e.spec,
                           reg_coef=1e-3,
                           batch_size=vfn_batch_size,
                           epochs=2,
                           learn_rate=1e-3,
                           use_gauss_newton=gn_vfn_opt)
    # agent = NPG(e, policy, baseline, normalized_step_size=0.005, seed=SEED, save_logs=True)

    common_kwargs = dict(lr=lr,
                         curv_type=curv_type,
                         cg_iters=cg_iters,
                         cg_residual_tol=cg_residual_tol,
                         cg_prev_init_coef=cg_prev_init_coef,
                         cg_precondition_empirical=cg_precondition_empirical,
                         cg_precondition_regu_coef=cg_precondition_regu_coef,
                         cg_precondition_exp=cg_precondition_exp,
                         shrinkage_method=shrinkage_method,
                         lanczos_amortization=lanczos_amortization,
                         lanczos_iters=lanczos_iters,
                         batch_size=batch_size)

    if optim == 'ngd':
        optimizer = fisher_optim.NGD(policy.trainable_params, **common_kwargs)
    elif optim == 'natural_adam':
        optimizer = fisher_optim.NaturalAdam(
            policy.trainable_params,
            **common_kwargs,
            betas=(beta1, beta2),
            assume_locally_linear=approx_adaptive)
    elif optim == 'natural_adagrad':
        optimizer = fisher_optim.NaturalAdagrad(
            policy.trainable_params,
            **common_kwargs,
            betas=(beta1, beta2),
            assume_locally_linear=approx_adaptive)
    elif optim == 'natural_amsgrad':
        optimizer = fisher_optim.NaturalAmsgrad(
            policy.trainable_params,
            **common_kwargs,
            betas=(beta1, beta2),
            assume_locally_linear=approx_adaptive)

    if algo == 'trpo':
        from mjrl.algos.trpo_delta import TRPO
        agent = TRPO(e, policy, baseline, optimizer, seed=seed, save_logs=True)
        # agent = TRPO(e, policy, baseline, seed=seed, save_logs=True)
    else:
        from mjrl.algos.npg_cg_delta import NPG
        agent = NPG(e, policy, baseline, optimizer, seed=seed, save_logs=True)

    save_dir = build_log_dir(tag, variant)
    try:
        os.makedirs(save_dir)
    except:
        pass

    # print ("Iters:", iters, ", num_traj: ", str(batch_size//1000))
    train_agent(job_name=save_dir,
                agent=agent,
                seed=seed,
                niter=iters,
                gamma=0.995,
                gae_lambda=0.97,
                num_cpu=1,
                sample_mode='samples',
                num_samples=batch_size,
                save_freq=5,
                evaluation_rollouts=5,
                verbose=False)  #True)