示例#1
0
# Create experiment
ex = Experiment(
    name='Dummy Experiment',
    ingredients=[dataset, model, training]
)

# Runtime options
save_folder = '../../data/sims/test/'
ex.add_config({
    'save': save_folder,
    'no_cuda': False,
})

# Add dependencies
ex.add_package_dependency('torch', torch.__version__)

# Add observer
ex.observers.append(
    FileStorageObserver.create(save_folder))
# ex.observers.append(
#     MongoObserver.create(url='127.0.0.1:27017',
#                         db_name='MY_DB')
# )

@ex.capture
def log_training(tracer):
    ex.log_scalar('training_loss', tracer.trace[-1])
    tracer.trace.clear()

@ex.capture
示例#2
0
def create_experiment(task,
                      name,
                      dataset_configs,
                      training_configs,
                      model_configs,
                      observers,
                      experiment_configs=None):

    dataset, load_dataset = get_dataset_ingredient(task)

    # Create experiment
    ex = Experiment(name=name, ingredients=[dataset, model, training])

    update_configs_(dataset, dataset_configs)
    update_configs_(training, training_configs)
    update_configs_(model, model_configs)

    if experiment_configs is not None:
        update_configs_(ex, experiment_configs)

    # Runtime options
    save_folder = '../../data/sims/deladd/temp/'
    ex.add_config({
        'no_cuda': False,
    })

    # Add dependencies
    ex.add_source_file('../../src/model/subLSTM/nn.py')
    ex.add_source_file('../../src/model/subLSTM/functional.py')
    ex.add_package_dependency('torch', torch.__version__)
    ex.observers.extend(observers)

    def _log_training(tracer):
        ex.log_scalar('training_loss', tracer.trace[-1])
        tracer.trace.clear()

    def _log_validation(engine):
        for metric, value in engine.state.metrics.items():
            ex.log_scalar('val_{}'.format(metric), value)

    def _run_experiment(_config, seed):
        no_cuda = _config['no_cuda']
        batch_size = _config['training']['batch_size']

        device = set_seed_and_device(seed, no_cuda)
        training_set, test_set, validation_set = load_dataset(
            batch_size=batch_size)
        model = init_model(device=device)

        trainer, validator, checkpoint, metrics = setup_training(
            model,
            validation_set,
            save=save_folder,
            device=device,
            trace=False,
            time=False)[:4]

        tracer = Tracer().attach(trainer)
        trainer.add_event_handler(Events.EPOCH_COMPLETED,
                                  lambda e: _log_training(tracer))
        validator.add_event_handler(Events.EPOCH_COMPLETED, _log_validation)

        test_metrics = run_training(model=model,
                                    train_data=training_set,
                                    trainer=trainer,
                                    test_data=test_set,
                                    metrics=metrics,
                                    model_checkpoint=checkpoint,
                                    device=device)

        # save best model performance and state
        for metric, value in test_metrics.items():
            ex.log_scalar('test_{}'.format(metric), value)

        ex.add_artifact(str(checkpoint._saved[-1][1][0]), 'trained-model')

    return ex, _run_experiment
示例#3
0
from comms_rl.agents.random_agent import RandomAgent
from comms_rl.agents.round_robin_agent import *
from comms_rl.agents.proportional_fair import *
path_abs = "../../"

# Load agent parameters
with open(path_abs + 'config/config_agent.json') as f:
    ac = json.load(f)

# Configure experiment
with open(path_abs + 'config/config_sacred.json') as f:
    sc = json.load(f)  # Sacred Configuration
    ns = sc["sacred"][
        "n_metrics_points"]  # Number of points per episode to log in Sacred
    ex = Experiment(ac["agent"]["agent_type"])
    ex.add_package_dependency("comms_rl", "0.1")
    ex.add_config(sc)
    ex.add_config(ac)
#mongo_db_url = f'mongodb://{sc["sacred"]["sacred_user"]}:{sc["sacred"]["sacred_pwd"]}@' +\
#               f'{sc["sacred"]["sacred_host"]}:{sc["sacred"]["sacred_port"]}/{sc["sacred"]["sacred_db"]}'
# ex.observers.append(MongoObserver(url=mongo_db_url, db_name=sc["sacred"]["sacred_db"]))  # Uncomment to save to DB

# Load environment parameters
with open(path_abs + 'config/config_environment.json') as f:
    ec = json.load(f)
    ex.add_config(ec)


@ex.automain
def main(_run):
    n_eps = _run.config["agent"]["n_episodes"]