示例#1
0
def main(argv):
    # Allow running multiple at once
    set_gpu_memory(FLAGS.gpumem)
    # Figure out the log and model directory filenames
    assert FLAGS.uid != "", "uid cannot be an empty string"
    model_dir, log_dir = get_directory_names()

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # Write config file about what dataset we're using, sources, target, etc.
    file_utils.write_config_from_args(log_dir)

    # Load datasets
    source_datasets, target_dataset = load_datasets.load_da(FLAGS.dataset,
        FLAGS.sources, FLAGS.target, test=FLAGS.test)
    # for x in source_datasets:
    #     print (x)
    # source_train_iterators = [iter(x.train) for x in source_datasets]
    # print (len(source_train_iterators))
    # for x in source_train_iterators:
    #     a = next(x)
    #     print (a)
    # data_sources = [next(x) for x in source_train_iterators]
    # data_sources = [next(x) for x in source_train_iterators]
    # data_sources = [next(x) for x in source_train_iterators]

    # Need to know which iteration for learning rate schedule
    global_step = tf.Variable(0, name="global_step", trainable=False)

    # Load the method, model, etc.
    method = methods.get_method(FLAGS.method,
        source_datasets=source_datasets,
        target_dataset=target_dataset,
        model_name=FLAGS.model,
        global_step=global_step,
        total_steps=FLAGS.steps,
        ensemble_size=FLAGS.ensemble,
        moving_average=FLAGS.moving_average,
        share_most_weights=FLAGS.share_most_weights)

    # Check that this method is supposed to be trainable. If not, we're done.
    # (Basically, we just wanted to write the config file for non-trainable
    # models.)


    if not method.trainable:
        print("Method not trainable. Exiting now.")
        return

    # Checkpoints
    checkpoint = tf.train.Checkpoint(
        global_step=global_step, **method.checkpoint_variables)
    checkpoint_manager = CheckpointManager(checkpoint, model_dir, log_dir)
    checkpoint_manager.restore_latest()

    # Metrics
    has_target_domain = target_dataset is not None
    metrics = Metrics(log_dir, method, source_datasets, target_dataset,
        has_target_domain)

    # Start training
    #
    # TODO maybe eventually rewrite this in the more-standard Keras way
    # See: https://www.tensorflow.org/guide/keras/train_and_evaluate
    for i in range(int(global_step), FLAGS.steps+1):
        t = time.time()
        data_sources, data_target = method.train_step()
        global_step.assign_add(1)
        t = time.time() - t

        if FLAGS.time_training:
            print(int(global_step), t, sep=",")
            continue  # skip evaluation, checkpointing, etc. when timing

        if i%1000 == 0:
            print("step %d took %f seconds"%(int(global_step), t))
            sys.stdout.flush()  # otherwise waits till the end to flush on Kamiak

        # Metrics on training/validation data
        if FLAGS.log_train_steps != 0 and i%FLAGS.log_train_steps == 0:
            metrics.train(data_sources, data_target, global_step, t)

        # Evaluate every log_val_steps but also at the last step
        validation_accuracy_source = None
        validation_accuracy_target = None
        if (FLAGS.log_val_steps != 0 and i%FLAGS.log_val_steps == 0) \
                or i == FLAGS.steps:
            validation_accuracy_source, validation_accuracy_target \
                = metrics.test(global_step)
            print(validation_accuracy_source,validation_accuracy_target)

        # Checkpoints -- Save either if at the right model step or if we found
        # a new validation accuracy. If this is better than the previous best
        # model, we need to make a new checkpoint so we can restore from this
        # step with the best accuracy.


        if (FLAGS.model_steps != 0 and i%FLAGS.model_steps == 0) \
                or validation_accuracy_source is not None:
            checkpoint_manager.save(int(global_step-1),
                validation_accuracy_source, validation_accuracy_target)

        # Plots
        if FLAGS.log_plots_steps != 0 and i%FLAGS.log_plots_steps == 0:
            metrics.plots(global_step)

    # We're done -- used for hyperparameter tuning
    file_utils.write_finished(log_dir)
示例#2
0
import torch
from datasets import get_ds
from cfg import get_cfg
from methods import get_method

from eval.sgd import eval_sgd
from eval.knn import eval_knn
from eval.lbfgs import eval_lbfgs
from eval.get_data import get_data


if __name__ == "__main__":
    cfg = get_cfg()

    model_full = get_method(cfg.method)(cfg)
    model_full.cuda().eval()
    if cfg.fname is None:
        print("evaluating random model")
    else:
        model_full.load_state_dict(torch.load(cfg.fname))

    ds = get_ds(cfg.dataset)(None, cfg, cfg.num_workers)
    device = "cpu" if cfg.clf == "lbfgs" else "cuda"
    if cfg.eval_head:
        model = lambda x: model_full.head(model_full.model(x))
        out_size = cfg.emb
    else:
        model = model_full.model
        out_size = model_full.out_size
    x_train, y_train = get_data(model, ds.clf, out_size, device)
    x_test, y_test = get_data(model, ds.test, out_size, device)
示例#3
0
            T_mult=cfg.Tmult,
            eta_min=cfg.eta_min,
        )
    elif cfg.lr_step == "step":
        m = [cfg.epoch - a for a in cfg.drop]
        return MultiStepLR(optimizer, milestones=m, gamma=cfg.drop_gamma)
    else:
        return None


if __name__ == "__main__":
    cfg = get_cfg()
    wandb.init(project=cfg.wandb, config=cfg)

    ds = get_ds(cfg.dataset)(cfg.bs, cfg, cfg.num_workers)
    model = get_method(cfg.method)(cfg)
    model.cuda().train()
    if cfg.fname is not None:
        model.load_state_dict(torch.load(cfg.fname))

    optimizer = optim.Adam(model.parameters(),
                           lr=cfg.lr,
                           weight_decay=cfg.adam_l2)
    scheduler = get_scheduler(optimizer, cfg)

    eval_every = cfg.eval_every
    lr_warmup = 0 if cfg.lr_warmup else 500
    cudnn.benchmark = True

    for ep in trange(cfg.epoch, position=0):
        loss_ep = []
示例#4
0
def process_model(log_dir, model_dir, config, gpumem, multi_gpu):
    """ Evaluate a model on the train/test data and compute the results """
    setup_gpu_for_process(gpumem, multi_gpu)

    dataset_name = config["dataset"]
    method_name = config["method"]
    model_name = config["model"]
    sources = config["sources"]
    target = config["target"]
    moving_average = config["moving_average"]
    ensemble_size = config["ensemble"]
    share_most_weights = config["share_most_weights"]

    # Load datasets
    source_datasets, target_dataset = load_datasets.load_da(dataset_name,
        sources, target, test=FLAGS.test)

    # Load the method, model, etc.
    # Note: {global,num}_step are for training, so it doesn't matter what
    # we set them to here
    method = methods.get_method(method_name,
        source_datasets=source_datasets,
        target_dataset=target_dataset,
        model_name=model_name,
        global_step=1, total_steps=1,
        moving_average=moving_average,
        ensemble_size=ensemble_size,
        share_most_weights=share_most_weights)

    # Load model from checkpoint (if there's anything in the checkpoint)
    if len(method.checkpoint_variables) > 0:
        checkpoint = tf.train.Checkpoint(**method.checkpoint_variables)
        checkpoint_manager = CheckpointManager(checkpoint, model_dir, log_dir)

        if FLAGS.selection == "last":
            checkpoint_manager.restore_latest()
            max_accuracy_step = checkpoint_manager.latest_step()
            max_accuracy = None  # We don't really care...
            found = checkpoint_manager.found_last
        elif FLAGS.selection == "best_source":
            checkpoint_manager.restore_best_source()
            max_accuracy_step = checkpoint_manager.best_step_source()
            max_accuracy = checkpoint_manager.best_validation_source
            found = checkpoint_manager.found_best_source
        elif FLAGS.selection == "best_target":
            checkpoint_manager.restore_best_target()
            max_accuracy_step = checkpoint_manager.best_step_target()
            max_accuracy = checkpoint_manager.best_validation_target
            found = checkpoint_manager.found_best_target
        else:
            raise NotImplementedError("unknown --selection argument")
    else:
        max_accuracy_step = None
        max_accuracy = None
        found = True

        # Metrics
    has_target_domain = target_dataset is not None
    metrics = Metrics(log_dir, method, source_datasets, target_dataset,
        has_target_domain)

    # If not found, give up
    if not found:
        return log_dir, model_dir, config, {}, None, None

    # Evaluate on both datasets
    metrics.train_eval()
    metrics.test(evaluation=True)

    # Get results
    results = metrics.results()

    return log_dir, model_dir, config, results, max_accuracy_step, max_accuracy