optim_alg="SGD",
    learning_rate=0.1,
    lr_scheduler="MultiStepLR",
    lr_milestones=[60, 120, 160],
    lr_gamma=0.2,
    weight_decay=0.0005,
    momentum=0.9,
    nesterov_momentum=True,
    # ---- optimizer related
    hebbian_prune_perc=None,
    weight_prune_perc=0.2,
    pruning_early_stop=2,
    hebbian_grow=False,
)

# run
tune_config = dict(
    name="comparison_iterative_pruning_2",
    num_samples=1,
    local_dir=os.path.expanduser("~/nta/results"),
    checkpoint_freq=0,
    checkpoint_at_end=False,
    resources_per_trial={
        "cpu": 1,
        "gpu": 1
    },
    verbose=2,
)

run_ray(tune_config, exp_config, fix_seed=True)
    lr_scheduler="StepLR",
    lr_step_size=1,
    lr_gamma=0.9825,

    # ---- Model ----
    model=ray.tune.grid_search(["StochasticSynapsesModel"]),
    # debug:
    use_tqdm=True,
    test_noise=False,
    debug_weights=True,
    debug_sparse=True,
)

# ray configurations
tune_config = dict(
    name=os.path.basename(__file__).replace(".py", ""),
    num_samples=1,
    local_dir=os.path.expanduser("~/nta/results"),
    checkpoint_freq=0,
    checkpoint_at_end=False,
    stop={"training_iteration": 100},
    resources_per_trial={
        # 1 GPU per trial
        "cpu": os.cpu_count() / cuda_device_count,
        "gpu": 1},
    loggers=DEFAULT_LOGGERS,
    verbose=0,
)

run_ray(tune_config, base_exp_config)
Ejemplo n.º 3
0
    model="SparseModel",
    data_dir="~/nta/data",
    on_perc=0.2,
    batch_size_train=10,
    batch_size_test=10,
    debug_sparse=True,
    name="test2",
)

exp_config["elasticsearch_index"] = __file__.replace(".py", "") + "_eval"

# run
tune_config = dict(
    name="test",
    num_samples=3,
    local_dir=os.path.expanduser("~/nta/results"),
    # checkpoint_freq=0,
    # checkpoint_at_end=True,
    stop={"training_iteration": 1},
    resources_per_trial={
        "cpu": 1,
        "gpu": 1
    },
    verbose=2,
    loggers=DEFAULT_LOGGERS + (ElasticsearchLogger, ),
)

run_ray(tune_config, exp_config)

# df = mlflow.search_runs([experiment_id])