Exemplo n.º 1
0
def init(config, run):
    # general init
    args = SimpleNamespace(**config)
    args = assertions.validate_args(args)
    mlh.seed_all(args.seed)
    args._run = run
    args.wandb = wandb

    # init scheduler
    args.partition_scheduler = schedules.get_partition_scheduler(args)
    args.partition = util.get_partition(args)

    # init data
    train_data_loader, test_data_loader = get_data(args)
    args.train_data_loader = train_data_loader
    args.test_data_loader = test_data_loader

    # init model
    model = get_model(train_data_loader, args)

    # init optimizer
    model.init_optimizer()

    return model, args
Exemplo n.º 2
0
def train(args):
    # read data
    train_data_loader, test_data_loader = get_data(args)

    # attach data to args
    args.train_data_loader = train_data_loader
    args.test_data_loader = test_data_loader

    # Make models
    model = get_model(train_data_loader, args)

    # Make optimizer
    if args.loss in DUAL_OBJECTIVES:
        optimizer_phi = torch.optim.Adam(
            (params for name, params in model.named_parameters()
             if args.phi_tag in name),
            lr=args.lr)
        optimizer_theta = torch.optim.Adam(
            (params for name, params in model.named_parameters()
             if args.theta_tag in name),
            lr=args.lr)

    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    #for epoch in range(args.epochs):
    for epoch in tqdm(range(args.epochs)):
        if mlh.is_schedule_update_time(epoch, args):
            args.partition = args.partition_scheduler(model, args)
            if len(args.Y_ori) % args.increment_update_frequency == 0 and len(
                    args.Y_ori) > 1:
                args.schedule_update_frequency = args.schedule_update_frequency + 1
                print("args.schedule_update_frequency=",
                      args.schedule_update_frequency)

        if args.loss in DUAL_OBJECTIVES:
            train_logpx, train_elbo, train_tvo_log_evidence = model.train_epoch_dual_objectives(
                train_data_loader, optimizer_phi, optimizer_theta, epoch=epoch)
        else:
            # addl recording within model.base
            train_logpx, train_elbo, train_tvo_log_evidence = model.train_epoch_single_objective(
                train_data_loader, optimizer, epoch=epoch)

        log_scalar(train_elbo=train_elbo,
                   train_logpx=train_logpx,
                   train_tvo_log_evidence=train_tvo_log_evidence,
                   step=epoch)

        # store the information
        args.betas_all = np.vstack((args.betas_all,
                                    np.reshape(format_input(args.partition),
                                               (1, args.K + 1))))
        args.logtvopx_all = np.append(args.logtvopx_all,
                                      train_tvo_log_evidence)

        if mlh.is_gradient_time(epoch, args):
            # Save grads
            grad_variance = util.calculate_grad_variance(model, args)
            log_scalar(grad_variance=grad_variance, step=epoch)

        if mlh.is_test_time(epoch, args):
            test_logpx, test_kl = model.evaluate_model_and_inference_network(
                test_data_loader, epoch=epoch)
            log_scalar(test_logpx=test_logpx, test_kl=test_kl, step=epoch)

        if mlh.is_checkpoint_time(epoch, args):
            opt = [optimizer_phi, optimizer_theta
                   ] if args.loss in DUAL_OBJECTIVES else [optimizer]
            save_checkpoint(model, epoch, train_elbo, train_logpx, opt, args)

        # ------ end of training loop ---------
    opt = [optimizer_phi, optimizer_theta
           ] if args.loss in DUAL_OBJECTIVES else [optimizer]
    save_checkpoint(model, args.epochs, train_elbo, train_logpx, opt, args)

    if args.train_only:
        test_logpx, test_kl = 0, 0

    results = {
        "test_logpx": test_logpx,
        "test_kl": test_kl,
        "train_logpx": train_logpx,
        "train_elbo": train_elbo,
        "train_tvo_px": train_tvo_log_evidence,
        "average_y":
        args.average_y,  # average tvo_logpx within this bandit iteration
        "X": args.X_ori,  # this is betas
        # this is utility score y=f(betas)= ave_y[-1] - ave_y[-2]
        "Y": args.Y_ori
    }

    return results, model
def _get_model(config):
  model = get_model(config)
  return model
Exemplo n.º 4
0
flags = tf.flags
# flags.DEFINE_string("model", "siamese", "Model name")
flags.DEFINE_string("model", "mAP", "Model name")
# flags.DEFINE_string("dataset", "omniglot", "Dataset name")
flags.DEFINE_string("dataset", "mini_imagenet", "Dataset name")

FLAGS = tf.flags.FLAGS

OUTDIR = os.path.join("results", FLAGS.dataset)

if __name__ == "__main__":
    configs = Configs()
    config = configs.get_config(FLAGS.dataset, FLAGS.model)

    test_dataset = get_dataset(FLAGS.dataset, config, "test")
    model = get_model(config)
    saver = tf.train.Saver(tf.global_variables())

    with tf.Session() as sess:

        ckpt = tf.train.latest_checkpoint(config.saveloc)
        if ckpt:
            saver.restore(sess, ckpt)
            print("Restored weights from {}".format(config.saveloc))

            # Find out the uidx that we are restoring from
            with open(os.path.join(config.saveloc, "checkpoint"), "r") as f:
                lines = f.readlines()
            model_checkpoint_line = lines[0].strip()
            dash_ind = model_checkpoint_line.rfind('-')
            uidx = int(model_checkpoint_line[dash_ind + 1:-1])
Exemplo n.º 5
0
def train(args):
    # read data
    train_data_loader, test_data_loader = get_data(args)

    # attach data to args
    args.train_data_loader = train_data_loader
    args.test_data_loader = test_data_loader

    # Make models
    model = get_model(train_data_loader, args)

    args.train_data_loader = train_data_loader
    args.test_data_loader = test_data_loader
    # Make optimizer
    if args.loss in DUAL_LIST:
        optimizer_phi = torch.optim.Adam(
            (params for name, params in model.named_parameters()
             if 'encoder' in name),
            lr=args.lr)
        optimizer_theta = torch.optim.Adam(
            (params for name, params in model.named_parameters()
             if 'decoder' in name),
            lr=args.lr)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    for epoch in range(args.epochs):

        _record = mlh.is_record_time(epoch, args)

        if args.loss in DUAL_LIST:
            train_logpx, train_elbo = model.train_epoch_dual_objectives(
                train_data_loader,
                optimizer_phi,
                optimizer_theta,
                record=_record)
        else:
            train_logpx, train_elbo = model.train_epoch_single_objective(
                train_data_loader, optimizer, record=_record)

        log_scalar(train_elbo=train_elbo, train_logpx=train_logpx, step=epoch)

        if mlh.is_test_time(epoch, args):
            args.test_time = True
            test_logpx, test_kl = model.evaluate_model_and_inference_network(
                test_data_loader)
            log_scalar(test_logpx=test_logpx, test_kl=test_kl, step=epoch)
            args.test_time = False

        if mlh.is_checkpoint_time(epoch, args):
            opt = [optimizer_phi, optimizer_theta
                   ] if args.loss in DUAL_LIST else [optimizer]
            save_checkpoint(model, epoch, train_elbo, train_logpx, opt, args)

        if mlh.is_schedule_update_time(epoch, args):
            args.partition = args.partition_scheduler(model, args)

        # ------ end of training loop ---------

    if args.train_only:
        test_logpx, test_kl = 0, 0

    results = {
        "test_logpx": test_logpx,
        "test_kl": test_kl,
        "train_logpx": train_logpx,
        "train_elbo": train_elbo
    }

    return results, model