コード例 #1
0
 def test_pickle_serializable_experiment_success(self):
     experiment = PickleOnlySerializableExeperiment()
     runner = ExperimentRunner()
     specification = {"test": "test"}
     runner.run("test", [specification], experiment, specification_runner=MainRunner())
     self.assertIn("run.pkl", os.listdir(get_save_file_directory('test', specification)))
     self.assertIn("specification.json", os.listdir(get_save_file_directory('test', specification)))
コード例 #2
0
 def test_checkpoint_handler_rotates_checkpoints_properly(self):
     experiment = SerializableExperimentFailsAfter4Steps()
     runner = ExperimentRunner()
     specification = {"test": "test"}
     runner.run("test", [specification], experiment, specification_runner=MainRunner())
     self.assertEqual(3, len(os.listdir(get_partial_save_directory("test", specification))))
     partial_experiment = CheckpointedExperimentHandler().load_most_recent("test", specification)
     self.assertEqual(partial_experiment.j, 3)
コード例 #3
0
 def test_with_runner(self):
     experiment = SerializableExperiment()
     runner = ExperimentRunner()
     specification = {"test": "test"}
     runner.run("test", [specification],
                experiment,
                specification_runner=MainRunner())
     self.assertEqual(
         1, len(os.listdir(get_save_file_directory('test', specification))))
コード例 #4
0
    def testmain(self):

        # Same specification as before
        generation_specification = {"seed": [1, 2, 3, 4, 5, 6, 7, 8], "num_calls": [[10, 20, 30]]}
        specifications = SpecificationGenerator().generate(generation_specification)

        output_generation_specification = {"seed": [1, 2, 3, 4, 5, 6, 7, 8], "num_calls": [10, 20, 30]}
        output_specifications = SpecificationGenerator().generate(output_generation_specification)

        name = "test"
        # This time we will run them all in parallel
        runner = ExperimentRunner()
        runner.run(name, specifications, SimpleExperiment(), specification_runner=MainRunner(),
                   use_dashboard=False, propagate_exceptions=True)
        for result in experiment_iterator(name):
            if result["result"] != []:
                output_specifications.remove(result["specification"])
        self.assertEqual([],output_specifications)
コード例 #5
0
        f"Expt {name}:\t{len(specifications)/num_seeds} specs to run, over {num_seeds} seeds"
    )
    for spec in specifications:
        if spec["seed"] == 0:
            print(spec)

    runner = ExperimentRunner()
    map_memory(base_specs["file"], base_specs["state_space_dimensionality"])
    DEBUG = False

    if DEBUG:
        runner.run(name,
                   specifications,
                   PlanningExperiment(),
                   propagate_exceptions=True,
                   specification_runner=MainRunner(),
                   use_dashboard=False,
                   force_pickle=True,
                   context_type="fork")
    else:
        gpus = 4
        jobs_per_gpu = 2
        resources = list(product(list(range(gpus)), list(range(jobs_per_gpu))))
        runner.run(name,
                   specifications,
                   PlanningExperiment(),
                   propagate_exceptions=False,
                   specification_runner=MultiprocessingRunner(),
                   context_type="fork",
                   use_dashboard=True,
                   force_pickle=True)
コード例 #6
0
 def test_un_serializable_experiment_failure(self):
     experiment = UnserializableExperiment()
     runner = ExperimentRunner()
     specification = {"test": "test"}
     runner.run("test", [specification], experiment, specification_runner=MainRunner())
     self.assertEqual(0, len(os.listdir(get_save_file_directory('test', specification, runner.diff_namer))))
コード例 #7
0
        visualizer.tSNE(prefix_str+'tsne.png')
        logging.getLogger(self.get_logger_name()).info("Visualization complete.")

        return {"discriminative score mean": results[0], "predictive score mean": results[2]}

    def get_hash(self):
        return self.get_logger_name()


# In the generation specification keys that have lists as their values will be cross producted with other list valued keys to create many specifications
# in this instance there will be 8 * 3 = 24 specifications
generation_specification = {  # just trying out random values for testing
    "total_iterations": [1], 
    "sub_iterations": [2],
    "data_size": [300],
    "max_seq_length": [12],
    "iterations": [10001],
    "batch_size": [128],
    # "module_name": ['gru', 'lstm', 'lstmLN']
}

# Call the generate method. Will create the cross product.
specifications = SpecificationGenerator().generate(generation_specification)
print(specifications)

expt = TsganExperiment()
name = "tsgan_unseen_metrics" #+expt.get_hash()
runner = ExperimentRunner()
runner.run(name, specifications, expt, specification_runner=MainRunner(), propagate_exceptions=True)

コード例 #8
0
                                 "seed": list(range(5)),
                                 "alpha_param": 6,
                                 "beta_param":1,
                                 "epsilon": 10,
                                 "delta": 0.1,
                                 "plan_commitment_algorithm": "n_steps",
                                 "plan_threshold": [1],
                                 "sample_observations": False,
                                 "use_expected_improvement":False,
                                 "planning_steps": 200
                                 }



    ##Create shared memory
    map_memory(generation_specifications["file"], generation_specifications["state_space_dimensionality"])

    specifications = SpecificationGenerator().generate(generation_specifications)
    runner = ExperimentRunner()
    DEBUG = False


    if DEBUG:
        runner.run(name, specifications, PlanningExperiment(), propagate_exceptions=True,
                   specification_runner=MainRunner(), use_dashboard=False, force_pickle=True, context_type="fork")
    else:

        runner.run(name, specifications, PlanningExperiment(), propagate_exceptions=False,
                   specification_runner=MultiprocessingRunner(), context_type="fork", use_dashboard=True,
                   force_pickle=True)