예제 #1
0
def get_best_archis_for_seed(
    seed,
    test_params,
    data_factory,
    gpus,
    methods_variant_params,
    checkpoint_dir,
):
    best_archis = {}
    for method_name, method_params in sorted(methods_variant_params.items()):
        method = archis.Method(method_params["method"])
        method_params = method_params.copy()
        del method_params["method"]

        _, trained_archi = xp.train_model(
            method,
            seed=seed,
            data_factory=data_factory,
            gpus=gpus,
            method_params=method_params,
            method_name=method_name,
            checkpoint_dir=checkpoint_dir,
            try_to_resume=True,
            **test_params,
        )
        best_archis[method_name] = trained_archi
    return best_archis
예제 #2
0
def get_archi_or_file_for_seed(
    seed,
    methods,
    file_prefix,
    test_params,
    data_factory,
    checkpoint_dir,
    gpus,
    fig_names,
):
    best_archis = {}
    st.write("Generating models for plots.")
    st.text(f"Looking for {file_prefix}_{seed}_<method>_*.png")
    howtext = st.text("")
    pgbar = st.progress(0)
    for i, method in enumerate(methods):
        if (data_factory.is_semi_supervised()
                and not archis.Method(method).is_fewshot_method()):
            logging.warning(
                f"Skipping {method}: not suited for the semi-supervised setting."
            )
            continue
        png_files = glob.glob(f"{file_prefix}_{seed}_{method}_*.png")
        recompute = False
        for figname in fig_names:
            png_figs = glob.glob(
                f"{file_prefix}_{seed}_{method}_*{figname}*.png")
            if len(png_figs) == 0:
                recompute = True

        if recompute:
            howtext.text(f"Retraining {method}")
            trainee, trained_archi = xp.train_model(
                method,
                seed=seed,
                data_factory=data_factory,
                checkpoint_dir=checkpoint_dir,
                try_to_resume=True,
                gpus=gpus,
                **test_params,
            )
            best_archis[method] = trained_archi
        else:
            howtext.text("Restoring images")
            best_archis[method] = png_files
        pgbar.progress((i + 1) / len(methods))
    return best_archis
예제 #3
0
def get_best_archis_for_seed(
    seed,
    test_params,
    data_factory,
    gpus,
    methods_variant_params,
    mlflow_uri,
    tensorboard_dir,
    checkpoint_dir,
):
    best_archis = {}
    for method_name, method_params in sorted(methods_variant_params.items()):
        method = archis.Method(method_params["method"])
        method_params = method_params.copy()
        del method_params["method"]

        if data_factory.is_semi_supervised(
        ) and not method.is_fewshot_method():
            logging.warning(
                f"Skipping {method_name}: not suited for the semi-supervised setting."
            )
            continue

        _, trained_archi = xp.train_model(
            method,
            seed=seed,
            data_factory=data_factory,
            gpus=gpus,
            method_params=method_params,
            method_name=method_name,
            mlflow_uri=mlflow_uri,
            tensorboard_dir=tensorboard_dir,
            checkpoint_dir=checkpoint_dir,
            try_to_resume=True,
            **test_params,
        )
        best_archis[method_name] = trained_archi
    return best_archis
예제 #4
0
    xp.record_hashes(hash_file, params_hash, record_params)
    output_file_prefix = os.path.join(output_dir, params_hash)

    test_csv_file = f"{output_file_prefix}.csv"
    checkpoint_dir = os.path.join(output_dir, "checkpoints", params_hash)

    results = xpr.XpResults.from_file(
        ["source acc", "target acc", "domain acc"], test_csv_file
    )
    do_plots = False

    methods_variant_params = xp.load_json_dict(args.method)

    archis_res = {}
    for method_name, method_params in sorted(methods_variant_params.items()):
        method = archis.Method(method_params["method"])
        method_params = method_params.copy()
        del method_params["method"]
        domain_archi = xp.loop_train_test_model(
            method,
            results,
            nseeds,
            test_csv_file,
            test_params=test_params,
            data_factory=data_factory,
            gpus=args.gpu,
            method_name=method_name,
            method_params=method_params,
            checkpoint_dir=checkpoint_dir,
        )
        archis_res[method_name] = domain_archi
예제 #5
0
def loop_train_test_model(
    method,
    results,
    nseeds,
    backup_file,
    test_params,
    data_factory,
    gpus,
    force_run=False,
    progress_callback=lambda percent: None,
    method_name=None,
    method_params=None,
    mlflow_uri=None,
    tensorboard_dir=None,
    checkpoint_dir=None,
):
    init_seed = 34875
    seeds = np.random.RandomState(init_seed).randint(100, 100000, size=nseeds)
    if type(method) is str:
        method = archis.Method(method)
    if method_name is None:
        method_name = method.value
    if method_params is None:
        method_params = {}

    if data_factory.is_semi_supervised() and not method.is_fewshot_method():
        logging.warning(
            f"Skipping {method_name}: not suited for the semi-supervised setting."
        )
        return None

    res_archis = {}
    for i, seed in enumerate(tqdm(seeds)):
        if results.already_computed(method_name, seed) and not force_run:
            progress_callback((i + 1) / nseeds)
            continue

        trainee, trained_archi = train_model(
            method,
            seed=seed,
            data_factory=data_factory,
            gpus=gpus,
            method_name=method_name,
            method_params=method_params,
            mlflow_uri=mlflow_uri,
            tensorboard_dir=tensorboard_dir,
            checkpoint_dir=checkpoint_dir,
            try_to_resume=not force_run,
            **test_params,
        )
        # validation scores
        results.update(
            is_validation=True,
            method_name=method_name,
            seed=seed,
            metric_values=trainee.callback_metrics,
        )
        # test scores
        trainee.test()
        results.update(
            is_validation=False,
            method_name=method_name,
            seed=seed,
            metric_values=trainee.callback_metrics,
        )
        results.to_csv(backup_file)
        results.print_scores(
            method_name, stdout=True, fdout=None, print_func=tqdm.write,
        )
        res_archis[seed] = trained_archi
        progress_callback((i + 1) / nseeds)

    best_archi_seed = results.get_best_archi_seed()
    if best_archi_seed not in res_archis:
        return None
    return res_archis[best_archi_seed]
예제 #6
0
def train_model(
    method,
    data_factory,
    train_params=None,
    archi_params=None,
    method_name=None,
    method_params=None,
    seed=98347,
    fix_few_seed=0,
    gpus=None,
    mlflow_uri=None,
    tensorboard_dir=None,
    checkpoint_dir=None,
    fast=False,
    try_to_resume=True,
):
    """This is the main function where a single model is created and trained, for a single seed value.

    Args:
        method (archis.Method): type of method, used to decide which networks to build and
            how to use some parameters.
        data_factory (DataFactory): dataset description to get dataset loaders, as well as useful
            information for some networks.
        train_params (dict, optional): Hyperparameters for training (see network config). Defaults to None.
        archi_params (dict, optional): Parameters of the network (see network config). Defaults to None.
        method_name (string, optional): A unique name describing the method, with its parameters. Used for logging results.
            Defaults to None.
        method_params (dict, optional): Parameters to be fed to the model that are specific to `method`. Defaults to None.
        seed (int, optional): Global seed for reproducibility. Defaults to 98347.
        fix_few_seed (int, optional): See for semi-supervised setting, fixing which target samples are labeled. Defaults to 0.
        gpus (list of int, optional): Which GPU ids to use. Defaults to None.
        mlflow_uri (int|string, optional): if a string, must be formatted like <uri>:<port>. If a port, will try
            to log to a MLFlow server on localhost:port. If None, ignores MLFlow logging. Defaults to None.
        fast (bool, optional): Whether to activate the `fast_dev_run` option of PyTorch-Lightning,
            training only on 1 batch per epoch for debugging. Defaults to False.

    Returns:
        2-elements tuple containing:
        
            - pl.Trainer: object containing the resulting metrics, used for evaluation.
            - BaseAdaptTrainer: pl.LightningModule object (derived class depending on `method`), containing 
                both the dataset & trained networks.
    
    """
    if type(method) is str:
        method = archis.Method(method)
    if method_name is None:
        method_name = method.value
    train_params_local = deepcopy(train_params)

    set_all_seeds(seed)
    if fix_few_seed > 0:
        archi_params["random_state"] = fix_few_seed
    else:
        archi_params["random_state"] = seed

    dataset = data_factory.get_multi_domain_dataset(seed)
    n_classes, data_dim, args = data_factory.get_data_args()
    network_factory = NetworkFactory(archi_params)
    # setup feature extractor
    feature_network = network_factory.get_feature_extractor(data_dim, *args)
    # setup classifier
    feature_dim = feature_network.output_size()
    classifier_network = network_factory.get_task_classifier(feature_dim, n_classes)

    method_params = {} if method_params is None else method_params
    if method.is_mmd_method():
        model = archis.create_mmd_based(
            method=method,
            dataset=dataset,
            feature_extractor=feature_network,
            task_classifier=classifier_network,
            **method_params,
            **train_params_local,
        )
    else:
        critic_input_size = feature_dim
        # setup critic network
        if method.is_cdan_method():
            if method_params is not None and method_params.get("use_random", False):
                critic_input_size = method_params["random_dim"]
            else:
                critic_input_size = feature_dim * n_classes
        critic_network = network_factory.get_critic_network(critic_input_size)

        model = archis.create_dann_like(
            method=method,
            dataset=dataset,
            feature_extractor=feature_network,
            task_classifier=classifier_network,
            critic=critic_network,
            **method_params,
            **train_params_local,
        )

    data_name = data_factory.get_data_short_name()

    if checkpoint_dir is not None:
        path_method_name = re.sub(r"[^-/\w\.]", "_", method_name)
        full_checkpoint_dir = os.path.join(
            checkpoint_dir, path_method_name, f"seed_{seed}"
        )
        checkpoint_callback = ModelCheckpoint(
            filepath=os.path.join(full_checkpoint_dir, "{epoch}"),
            monitor="last_epoch",
            mode="max",
        )
        checkpoints = sorted(
            glob.glob(f"{full_checkpoint_dir}/*.ckpt"), key=os.path.getmtime
        )
        if len(checkpoints) > 0 and try_to_resume:
            last_checkpoint_file = checkpoints[-1]
            if method is archis.Method.WDGRL:
                # WDGRL doesn't resume training gracefully
                last_epoch = (
                    train_params_local["nb_init_epochs"]
                    + train_params_local["nb_adapt_epochs"]
                )
                if f"epoch={last_epoch - 1}" not in last_checkpoint_file:
                    last_checkpoint_file = None
        else:
            last_checkpoint_file = None
    else:
        checkpoint_callback = None
        last_checkpoint_file = None

    if mlflow_uri is not None:
        if mlflow_uri.isdecimal():
            mlflow_uri = f"http://127.0.0.1:{mlflow_uri}"
        mlf_logger = MLFlowLogger(
            experiment_name=data_name,
            tracking_uri=mlflow_uri,
            tags=dict(
                method=method_name,
                data_variant=data_factory.get_data_long_name(),
                script=__file__,
            ),
        )
    else:
        mlf_logger = None

    if tensorboard_dir is not None:
        tnb_logger = TensorBoardLogger(
            save_dir=tensorboard_dir, name=f"{data_name}_{method_name}",
        )
    else:
        tnb_logger = None

    loggers = [logger for logger in [mlf_logger, tnb_logger] if logger is not None]
    if len(loggers) == 0:
        logger = False
    else:
        logger = LoggerCollection(loggers)
        logger.log_hyperparams(
            {
                "seed": seed,
                "feature_network": archi_params["feature"]["name"],
                "method group": method.value,
                "method": method_name,
                "start time": create_timestamp_string("%Y-%m-%d %H:%M:%S"),
            }
        )

    max_nb_epochs = (
        train_params_local["nb_adapt_epochs"] * 5
        if method is archis.Method.WDGRLMod
        else train_params["nb_adapt_epochs"]
    )
    pb_refresh = 1 if len(dataset) < 1000 else 10
    row_log_interval = max(10, len(dataset) // train_params_local["batch_size"] // 10)

    if gpus is not None and len(gpus) > 1 and method is archis.Method.WDGRL:
        logging.warning("WDGRL is not compatible with multi-GPU.")
        gpus = [gpus[0]]

    trainer = pl.Trainer(
        progress_bar_refresh_rate=pb_refresh,  # in steps
        row_log_interval=row_log_interval,
        min_epochs=train_params_local["nb_init_epochs"],
        max_epochs=max_nb_epochs + train_params_local["nb_init_epochs"],
        early_stop_callback=False,
        num_sanity_val_steps=5,
        check_val_every_n_epoch=1,
        checkpoint_callback=checkpoint_callback,
        resume_from_checkpoint=last_checkpoint_file,
        gpus=gpus,
        logger=logger,
        weights_summary=None,  # 'full' is default
        fast_dev_run=fast,
    )

    if last_checkpoint_file is None:
        logging.info(f"Training model with {method.name} {param_to_str(method_params)}")
    else:
        logging.info(
            f"Resuming training with {method.name} {param_to_str(method_params)}, from {last_checkpoint_file}."
        )
    trainer.fit(model)
    if trainer.interrupted:
        raise KeyboardInterrupt("Trainer was interrupted and shutdown gracefully.")

    if logger:
        logger.log_hyperparams(
            {"finish time": create_timestamp_string("%Y-%m-%d %H:%M:%S")}
        )
    return trainer, model