예제 #1
0
def train(model, args):
    for epoch in range(args.epochs):
        if mlh.is_schedule_update_time(epoch, args):
            args.partition = args.partition_scheduler(model, args)

        train_logpx, train_elbo = model.step_epoch(args.train_data_loader, step=epoch)

        log_scalar(train_elbo=train_elbo, train_logpx=train_logpx, step=epoch)

        if mlh.is_gradient_time(epoch, args):
            # Save grads
            grad_variance = util.calculate_grad_variance(model, args)
            log_scalar(grad_variance=grad_variance, step=epoch)

        if mlh.is_test_time(epoch, args):
            test_logpx, test_kl = model.test(args.test_data_loader, step=epoch)
            log_scalar(test_logpx=test_logpx, test_kl=test_kl, step=epoch)

        # ------ end of training loop ---------

    if args.train_only:
        test_logpx, test_kl = 0, 0

    results = {
        "test_logpx": test_logpx,
        "test_kl": test_kl,
        "train_logpx": train_logpx,
        "train_elbo": train_elbo
    }

    return results, model
예제 #2
0
def train(args):
    # read data
    train_data_loader, test_data_loader = get_data(args)

    # attach data to args
    args.train_data_loader = train_data_loader
    args.test_data_loader = test_data_loader

    # Make models
    model = get_model(train_data_loader, args)

    # Make optimizer
    if args.loss in DUAL_OBJECTIVES:
        optimizer_phi = torch.optim.Adam(
            (params for name, params in model.named_parameters()
             if args.phi_tag in name),
            lr=args.lr)
        optimizer_theta = torch.optim.Adam(
            (params for name, params in model.named_parameters()
             if args.theta_tag in name),
            lr=args.lr)

    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    #for epoch in range(args.epochs):
    for epoch in tqdm(range(args.epochs)):
        if mlh.is_schedule_update_time(epoch, args):
            args.partition = args.partition_scheduler(model, args)
            if len(args.Y_ori) % args.increment_update_frequency == 0 and len(
                    args.Y_ori) > 1:
                args.schedule_update_frequency = args.schedule_update_frequency + 1
                print("args.schedule_update_frequency=",
                      args.schedule_update_frequency)

        if args.loss in DUAL_OBJECTIVES:
            train_logpx, train_elbo, train_tvo_log_evidence = model.train_epoch_dual_objectives(
                train_data_loader, optimizer_phi, optimizer_theta, epoch=epoch)
        else:
            # addl recording within model.base
            train_logpx, train_elbo, train_tvo_log_evidence = model.train_epoch_single_objective(
                train_data_loader, optimizer, epoch=epoch)

        log_scalar(train_elbo=train_elbo,
                   train_logpx=train_logpx,
                   train_tvo_log_evidence=train_tvo_log_evidence,
                   step=epoch)

        # store the information
        args.betas_all = np.vstack((args.betas_all,
                                    np.reshape(format_input(args.partition),
                                               (1, args.K + 1))))
        args.logtvopx_all = np.append(args.logtvopx_all,
                                      train_tvo_log_evidence)

        if mlh.is_gradient_time(epoch, args):
            # Save grads
            grad_variance = util.calculate_grad_variance(model, args)
            log_scalar(grad_variance=grad_variance, step=epoch)

        if mlh.is_test_time(epoch, args):
            test_logpx, test_kl = model.evaluate_model_and_inference_network(
                test_data_loader, epoch=epoch)
            log_scalar(test_logpx=test_logpx, test_kl=test_kl, step=epoch)

        if mlh.is_checkpoint_time(epoch, args):
            opt = [optimizer_phi, optimizer_theta
                   ] if args.loss in DUAL_OBJECTIVES else [optimizer]
            save_checkpoint(model, epoch, train_elbo, train_logpx, opt, args)

        # ------ end of training loop ---------
    opt = [optimizer_phi, optimizer_theta
           ] if args.loss in DUAL_OBJECTIVES else [optimizer]
    save_checkpoint(model, args.epochs, train_elbo, train_logpx, opt, args)

    if args.train_only:
        test_logpx, test_kl = 0, 0

    results = {
        "test_logpx": test_logpx,
        "test_kl": test_kl,
        "train_logpx": train_logpx,
        "train_elbo": train_elbo,
        "train_tvo_px": train_tvo_log_evidence,
        "average_y":
        args.average_y,  # average tvo_logpx within this bandit iteration
        "X": args.X_ori,  # this is betas
        # this is utility score y=f(betas)= ave_y[-1] - ave_y[-2]
        "Y": args.Y_ori
    }

    return results, model
예제 #3
0
파일: main.py 프로젝트: vmasrani/tvo_all_in
def train(args):
    # read data
    train_data_loader, test_data_loader = get_data(args)

    # attach data to args
    args.train_data_loader = train_data_loader
    args.test_data_loader = test_data_loader

    # Make models
    model = get_model(train_data_loader, args)

    args.train_data_loader = train_data_loader
    args.test_data_loader = test_data_loader
    # Make optimizer
    if args.loss in DUAL_LIST:
        optimizer_phi = torch.optim.Adam(
            (params for name, params in model.named_parameters()
             if 'encoder' in name),
            lr=args.lr)
        optimizer_theta = torch.optim.Adam(
            (params for name, params in model.named_parameters()
             if 'decoder' in name),
            lr=args.lr)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    for epoch in range(args.epochs):

        _record = mlh.is_record_time(epoch, args)

        if args.loss in DUAL_LIST:
            train_logpx, train_elbo = model.train_epoch_dual_objectives(
                train_data_loader,
                optimizer_phi,
                optimizer_theta,
                record=_record)
        else:
            train_logpx, train_elbo = model.train_epoch_single_objective(
                train_data_loader, optimizer, record=_record)

        log_scalar(train_elbo=train_elbo, train_logpx=train_logpx, step=epoch)

        if mlh.is_test_time(epoch, args):
            args.test_time = True
            test_logpx, test_kl = model.evaluate_model_and_inference_network(
                test_data_loader)
            log_scalar(test_logpx=test_logpx, test_kl=test_kl, step=epoch)
            args.test_time = False

        if mlh.is_checkpoint_time(epoch, args):
            opt = [optimizer_phi, optimizer_theta
                   ] if args.loss in DUAL_LIST else [optimizer]
            save_checkpoint(model, epoch, train_elbo, train_logpx, opt, args)

        if mlh.is_schedule_update_time(epoch, args):
            args.partition = args.partition_scheduler(model, args)

        # ------ end of training loop ---------

    if args.train_only:
        test_logpx, test_kl = 0, 0

    results = {
        "test_logpx": test_logpx,
        "test_kl": test_kl,
        "train_logpx": train_logpx,
        "train_elbo": train_elbo
    }

    return results, model