Ejemplo n.º 1
0
def run(max_lr: float = 1e-1, steps_per_epoch: int = 1413, device: str = None, check: bool = False) -> dict:
    config = copy.deepcopy(experiment_config)
    device = device or utils.get_device()
    print(f"device: {device}")

    utils.set_global_seed(SEED)

    # convert parquet ot zip
    parquet_to_images(TRAIN, ZIP_TRAIN_FILE, SIZE)
    parquet_to_images(TEST, ZIP_TEST_FILE, SIZE)

    config['monitoring_params']['name'] = EXPERIMENT_NAME
    config['stages']['state_params']['checkpoint_data']['image_size'] = SIZE

    # add scheduler to config
    config["stages"]["scheduler_params"] = {
        "scheduler": "OneCycleLR",
        "max_lr": max_lr,
        "epochs": config["stages"]["state_params"]["num_epochs"],
        "steps_per_epoch": steps_per_epoch,
        "div_factor": 200,
        "final_div_factor": 1e5,
    }
    experiment = Experiment(config)

    # run experiment
    runner = SupervisedWandbRunner(
        device=device,
        input_key="images",
        output_key=["logit_" + c for c in output_classes.keys()],
        input_target_key=list(output_classes.keys()),)

    runner.run_experiment(experiment, check=check)

    return experiment, runner
Ejemplo n.º 2
0
    return schedulers[scheduler_name](**scheduler_params, optimizer=optimizer_)


# @TODO: add metrics support 
# (catalyst expects logits, rather than sigmoid outputs)
# metrics = [
#     smp.utils.metrics.IoUMetric(eps=1.),
#     smp.utils.metrics.FscoreMetric(eps=1.),
# ]

if __name__ == '__main__':
    args = parse_args()
    config = safitty.load(args.config_path)

    runner = SupervisedWandbRunner()

    model = get_model(
        model_name=safitty.get(config, 'model', 'name', default='unet'),
        model_params=safitty.get(config, 'model', 'params', default={}))

    criterion = get_criterion(
        criterion_name=safitty.get(config, 'criterion', 'name', default='bce_dice'),
        criterion_params=safitty.get(config, 'criterion', 'params', default={}))

    optimizer = get_optimizer(
        optimizer_name=safitty.get(config, 'optimizer', 'name', default='adam'),
        optimizer_params=safitty.get(config, 'optimizer', 'params', default={}),
        model_=model)

    scheduler = get_scheduler(