Exemplo n.º 1
0
def launch_job(args):
    dir = build_log_dir(args)
    try:
        os.makedirs(dir)
    except:
        pass
    with open(os.path.join(dir, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)
    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        './data',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        './data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              **kwargs)

    model = Net().to(device)

    if args.optim == "sgd":
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
    elif args.optim == "adam":
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    elif args.optim == "amsgrad":
        optimizer = optim.Adam(model.parameters(), lr=args.lr, amsgrad=True)
    elif args.optim == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
    elif args.optim == 'adagrad':
        optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
    else:
        import fisher.optim as fisher_optim
        if args.optim == 'ngd':
            optimizer = fisher_optim.NGD(model.parameters(),
                                         lr=args.lr,
                                         shrunk=args.shrunk,
                                         lanczos_iters=args.lanczos_iters,
                                         batch_size=args.batch_size)
        elif args.optim == 'natural_adam':
            optimizer = fisher_optim.NaturalAdam(
                model.parameters(),
                lr=args.lr,
                shrunk=args.shrunk,
                lanczos_iters=args.lanczos_iters,
                batch_size=args.batch_size,
                betas=(args.beta1, args.beta2),
                assume_locally_linear=args.approx_adaptive)
        elif args.optim == 'natural_amsgrad':
            optimizer = fisher_optim.NaturalAmsgrad(
                model.parameters(),
                lr=args.lr,
                shrunk=args.shrunk,
                lanczos_iters=args.lanczos_iters,
                batch_size=args.batch_size,
                betas=(args.beta1, args.beta2),
                assume_locally_linear=args.approx_adaptive)
        elif args.optim == 'natural_adagrad':
            optimizer = fisher_optim.NaturalAdagrad(
                model.parameters(),
                lr=args.lr,
                shrunk=args.shrunk,
                lanczos_iters=args.lanczos_iters,
                batch_size=args.batch_size,
                assume_locally_linear=args.approx_adaptive)
        else:
            raise NotImplementedError

    accuracies = []
    losses = []
    times = [0.0]

    if args.decay_lr:
        lambda_lr = lambda epoch: 1.0 / np.sqrt(epoch + 1)
        scheduler = LambdaLR(optimizer, lr_lambda=[lambda_lr])
    for epoch in range(1, args.epochs + 1):
        if args.decay_lr:
            scheduler.step()
        train(args, model, device, train_loader, test_loader, optimizer, epoch,
              [accuracies, losses, times])

    log_stats(accuracies, losses, times, args, model, device, test_loader,
              epoch, 'inf')
Exemplo n.º 2
0
        # optA = fisher_optim.NGD([fac.A, fac.B],
        #                        lr=0.001,
        #                        curv_type='gauss_newton',
        #                        shrinkage_method=None, #'cg', #'lanzcos',
        #                        lanczos_iters=0,
        #                        batch_size=bs)
        # optB = fisher_optim.NGD([fac.B],
        #                        lr=0.001,
        #                        curv_type='gauss_newton',
        #                        shrinkage_method=None, #'cg', #'lanzcos',
        #                        lanczos_iters=0,
        #                        batch_size=bs)
        optA = fisher_optim.NaturalAdam([fac.A, fac.B],
                                         lr=0.001,
                                         curv_type='gauss_newton',
                                         shrinkage_method=None,
                                         batch_size=bs,
                                         betas=(0.9, 0.99),
                                         assume_locally_linear=True)
        optB = fisher_optim.NaturalAdam([fac.B],
                                         lr=0.001,
                                         curv_type='gauss_newton',
                                         shrinkage_method=None,
                                         batch_size=bs,
                                         betas=(0.9, 0.9),
                                         assume_locally_linear=True)
    else:
        optA = optim.Adam([fac.A, fac.B], lr=0.001)
        # optB = optim.Adam([fac.B], lr=0.001)

    P = fac(ids=Wind)
Exemplo n.º 3
0
def make_optimizer(args, model):
    if args.optim == "sgd":
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
    elif args.optim == "adam":
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    elif args.optim == "amsgrad":
        optimizer = optim.Adam(model.parameters(), lr=args.lr, amsgrad=True)
    elif args.optim == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
    elif args.optim == 'adagrad':
        optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
    else:
        import fisher.optim as fisher_optim

        common_kwargs = dict(
            lr=args.lr,
            curv_type=args.curv_type,
            cg_iters=args.cg_iters,
            cg_residual_tol=args.cg_residual_tol,
            cg_prev_init_coef=args.cg_prev_init_coef,
            cg_precondition_empirical=args.cg_precondition_empirical,
            cg_precondition_regu_coef=args.cg_precondition_regu_coef,
            cg_precondition_exp=args.cg_precondition_exp,
            shrinkage_method=args.shrinkage_method,
            lanczos_amortization=args.lanczos_amortization,
            lanczos_iters=args.lanczos_iters,
            batch_size=args.batch_size)
        if args.optim == 'ngd_bd':
            raise NotImplementedError
            # optimizer = fisher_optim.NGD_BD([{'params': model.fc1.parameters()},
            #                                  {'params': model.fc2.parameters()}],
            #                                 lr=args.lr,
            #                                 curv_type='gauss_newton',
            #                                 shrinkage_method=None,
            #                                 lanczos_iters=args.lanczos_iters,
            #                                 batch_size=args.batch_size)
            # optimizer = fisher_optim.NGD_BD([
            #                                  {'params': model.conv1.parameters()},
            #                                  {'params': model.conv2.parameters()},
            #                                  {'params': model.fc1.parameters()},
            #                                  {'params': model.fc2.parameters()}],
            #                                 lr=args.lr,
            #                                 curv_type='gauss_newton',
            #                                 shrinkage_method='cg',
            #                                 lanczos_iters=args.lanczos_iters,
            #                                 batch_size=args.batch_size)
        elif args.optim == 'ngd':
            optimizer = fisher_optim.NGD(model.parameters(), **common_kwargs)
        elif args.optim == 'natural_adam':
            optimizer = fisher_optim.NaturalAdam(
                model.parameters(),
                **common_kwargs,
                betas=(args.beta1, args.beta2),
                assume_locally_linear=args.approx_adaptive)
        elif args.optim == 'natural_amsgrad':
            optimizer = fisher_optim.NaturalAmsgrad(
                model.parameters(),
                **common_kwargs,
                betas=(args.beta1, args.beta2),
                assume_locally_linear=args.approx_adaptive)
        elif args.optim == 'natural_adagrad':
            optimizer = fisher_optim.NaturalAdagrad(
                model.parameters(),
                **common_kwargs,
                assume_locally_linear=args.approx_adaptive)
        else:
            raise NotImplementedError

    return optimizer
Exemplo n.º 4
0
def fit(data):
    model = Model()
    algos = ['natural_adam', 'natural_amsgrad', 'natural_adagrad', 'ngd']
    # algos = ['ngd']
    # algos = ['natural_adagrad', 'natural_adam', 'natural_amsgrad']
    trace_dict = {}
    for algo in algos:

        if algo in [
                'ngd', 'natural_amsgrad', 'natural_adagrad', 'natural_adam'
        ]:
            import fisher.optim as fisher_optim
            fisher_lr = 0.002
            if algo == 'ngd':
                opt = fisher_optim.NGD(model.parameters(),
                                       lr=fisher_lr,
                                       shrunk=False,
                                       lanczos_iters=1,
                                       batch_size=1000)
            elif algo == 'natural_adam':
                opt = fisher_optim.NaturalAdam(model.parameters(),
                                               lr=fisher_lr,
                                               shrunk=False,
                                               lanczos_iters=0,
                                               batch_size=1000,
                                               betas=(0.1, 0.1),
                                               assume_locally_linear=False)
            elif algo == 'natural_amsgrad':
                opt = fisher_optim.NaturalAmsgrad(model.parameters(),
                                                  lr=fisher_lr,
                                                  shrunk=False,
                                                  lanczos_iters=0,
                                                  batch_size=1000,
                                                  betas=(0.1, 0.1),
                                                  assume_locally_linear=False)
            elif algo == 'natural_adagrad':
                opt = fisher_optim.NaturalAdagrad(model.parameters(),
                                                  lr=fisher_lr,
                                                  shrunk=False,
                                                  lanczos_iters=0,
                                                  batch_size=1000,
                                                  assume_locally_linear=False)
        else:
            if algo == 'sgd':
                opt = optim.SGD(model.parameters(), lr=0.05, momentum=0.9)
            elif algo == 'adam':
                opt = optim.Adam(model.parameters(), lr=0.1)

        loss_fn = torch.nn.BCELoss()
        model.fc.weight.data = torch.FloatTensor([[-12.5, 2.0]])
        X, y = data

        trace = [tuple(model.fc.weight.data.numpy().squeeze())]

        for iter in range(20):
            Xvar = Variable(torch.from_numpy(X)).float()
            yvar = Variable(torch.from_numpy(y)).float()

            opt.zero_grad()

            output = model(Xvar)
            loss = loss_fn(output, yvar)
            loss.backward()

            if algo in [
                    'ngd', 'natural_amsgrad', 'natural_adagrad', 'natural_adam'
            ]:
                Fvp_fn = build_Fvp(model, Xvar, yvar, mean_kl_multinomial)
                opt.step(Fvp_fn)
            else:
                opt.step()

            trace.append(tuple(model.fc.weight.data.numpy().squeeze()))
            print(loss)

        trace_dict[algo] = trace

    plot_contours(X, y, model, loss_fn, traces=trace_dict)