Example #1
0
def test_experiment_restoration(nested_dict_config, tmpdir):

    experiments_dir = tmpdir.join("experiments").strpath

    # create an experiment
    experiment = Experiment(nested_dict_config,
                            experiments_dir=experiments_dir)
    experiment.register_directory("temp")

    with pytest.raises(ValueError):
        # since the experiment with the same identifier has been
        # already created, experiment raises an error
        experiment = Experiment(nested_dict_config,
                                experiments_dir=experiments_dir)

    # test restoration from identifier
    experiment = Experiment(resume_from=experiment.config.identifier,
                            experiments_dir=experiments_dir)

    assert experiment.config.to_dict() == nested_dict_config
    # test that `temp` is registered after restoration
    assert os.path.isdir(experiment.temp)

    # test restoration from directory
    experiment = Experiment(resume_from=os.path.join(
        experiments_dir, experiment.config.identifier))

    assert experiment.config.to_dict() == nested_dict_config
    # test that `temp` is registered after restoration
    assert os.path.isdir(experiment.temp)
Example #2
0
def test_experiment_register_directory(nested_dict_config, tmpdir):

    experiments_dir = tmpdir.join("experiments").strpath

    experiment = Experiment(nested_dict_config,
                            experiments_dir=experiments_dir)

    experiment.register_directory("temp")
    target = os.path.join(experiment.experiment_dir, "temp")

    assert os.path.isdir(target)
    assert experiment.temp == target
Example #3
0
    "batch_accumulation": args.batch_accumulation,
    "batch_size": args.batch_size,
    "warmup": args.warmup,
    "lr": args.lr,
    "folds": args.folds,
    "max_sequence_length": args.max_sequence_length,
    "max_title_length": args.max_title_length,
    "max_question_length": args.max_question_length,
    "max_answer_length": args.max_answer_length,
    "head_tail": args.head_tail,
    "label": args.label,
    "_pseudo_file": args.pseudo_file,
    "model_type": args.model_type,
}
experiment = Experiment(config, implicit_resuming=args.use_folds is not None)
experiment.register_directory("checkpoints")
experiment.register_directory("predictions")


def seed_everything(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


logging.getLogger("transformers").setLevel(logging.ERROR)
seed_everything(args.seed)
# load the data