def testRegistryAdditions(self): class MyRunner(Runner): def run(): pass def staging_required(): return False class MyMetric(Metric): pass register_metric(MyMetric) register_runner(MyRunner) experiment = get_experiment_with_batch_and_single_trial() experiment.runner = MyRunner() experiment.add_tracking_metric(MyMetric(name="my_metric")) with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as f: save_experiment( experiment, f.name, encoder_registry=DEPRECATED_ENCODER_REGISTRY, class_encoder_registry=DEPRECATED_CLASS_ENCODER_REGISTRY, ) loaded_experiment = load_experiment( f.name, decoder_registry=DEPRECATED_DECODER_REGISTRY, class_decoder_registry=DEPRECATED_CLASS_DECODER_REGISTRY, ) self.assertEqual(loaded_experiment, experiment) os.remove(f.name)
def testSaveAndLoad(self): with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as f: save_experiment(self.experiment, f.name) loaded_experiment = load_experiment(f.name) self.assertEqual(loaded_experiment, self.experiment) os.remove(f.name)
def testSaveAndLoad(self): with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as f: save_experiment( self.experiment, f.name, encoder_registry=DEPRECATED_ENCODER_REGISTRY, class_encoder_registry=DEPRECATED_CLASS_ENCODER_REGISTRY, ) loaded_experiment = load_experiment( f.name, decoder_registry=DEPRECATED_DECODER_REGISTRY, class_decoder_registry=DEPRECATED_CLASS_DECODER_REGISTRY, ) self.assertEqual(loaded_experiment, self.experiment) os.remove(f.name)