def test_experiment_restoration(nested_dict_config, tmpdir): experiments_dir = tmpdir.join("experiments").strpath # create an experiment experiment = Experiment(nested_dict_config, experiments_dir=experiments_dir) experiment.register_directory("temp") with pytest.raises(ValueError): # since the experiment with the same identifier has been # already created, experiment raises an error experiment = Experiment(nested_dict_config, experiments_dir=experiments_dir) # test restoration from identifier experiment = Experiment(resume_from=experiment.config.identifier, experiments_dir=experiments_dir) assert experiment.config.to_dict() == nested_dict_config # test that `temp` is registered after restoration assert os.path.isdir(experiment.temp) # test restoration from directory experiment = Experiment(resume_from=os.path.join( experiments_dir, experiment.config.identifier)) assert experiment.config.to_dict() == nested_dict_config # test that `temp` is registered after restoration assert os.path.isdir(experiment.temp)
def test_experiment_register_directory(nested_dict_config, tmpdir): experiments_dir = tmpdir.join("experiments").strpath experiment = Experiment(nested_dict_config, experiments_dir=experiments_dir) experiment.register_directory("temp") target = os.path.join(experiment.experiment_dir, "temp") assert os.path.isdir(target) assert experiment.temp == target
"batch_accumulation": args.batch_accumulation, "batch_size": args.batch_size, "warmup": args.warmup, "lr": args.lr, "folds": args.folds, "max_sequence_length": args.max_sequence_length, "max_title_length": args.max_title_length, "max_question_length": args.max_question_length, "max_answer_length": args.max_answer_length, "head_tail": args.head_tail, "label": args.label, "_pseudo_file": args.pseudo_file, "model_type": args.model_type, } experiment = Experiment(config, implicit_resuming=args.use_folds is not None) experiment.register_directory("checkpoints") experiment.register_directory("predictions") def seed_everything(seed: int): random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True logging.getLogger("transformers").setLevel(logging.ERROR) seed_everything(args.seed) # load the data