Esempio n. 1
0
def test_is_done_with_max_iters():
    state = State(iteration=100,
                  epoch=1,
                  max_epochs=3,
                  epoch_length=100,
                  max_iters=250)
    assert not Engine._is_done(state)

    state = State(iteration=250,
                  epoch=1,
                  max_epochs=3,
                  epoch_length=100,
                  max_iters=250)
    assert Engine._is_done(state)
Esempio n. 2
0
def test__is_done():
    state = State(iteration=10, epoch=1, max_epochs=100, epoch_length=100)
    assert not Engine._is_done(state)

    state = State(iteration=1000, max_epochs=10, epoch_length=100)
    assert Engine._is_done(state)
Esempio n. 3
0
def train(to_save,
          train_spec: RunSpec,
          eval_spec: RunSpec,
          eval_event=Events.EPOCH_COMPLETED,
          save_event=Events.EPOCH_COMPLETED,
          n_saved=10,
          mlflow_enable=True,
          mlflow_tracking_uri=None,
          mlflow_experiment_name=None,
          mlflow_run_name=None,
          model_dir='output',
          checkpoint_dir='output',
          output_dir='output',
          parameters=None,
          device=None,
          max_epochs=None):
    """
    Train a model
    """
    if max_epochs:
        train_spec.max_epochs = max_epochs
    if mlflow_tracking_uri is not None:
        mlflow.set_tracking_uri(mlflow_tracking_uri)
    if 'MLFLOW_RUN_ID' in os.environ:
        run_id = os.environ['MLFLOW_RUN_ID']
        output_dir = os.path.join(output_dir, run_id)
        model_dir = os.path.join(model_dir, run_id)
        checkpoint_dir = os.path.join(checkpoint_dir, run_id)

    ctx = mlflow_ctx(output_dir=output_dir,
                     checkpoint_dir=checkpoint_dir,
                     mlflow_enable=mlflow_enable,
                     experiment_name=mlflow_experiment_name,
                     run_name=mlflow_run_name,
                     parameters=parameters)
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(model_dir, exist_ok=True)
    os.makedirs(checkpoint_dir, exist_ok=True)
    with ctx:
        mlflow_logger = get_mlflow_logger(output_dir=output_dir,
                                          checkpoint_dir=checkpoint_dir,
                                          mlflow_enable=mlflow_enable)
        # Create trainer
        trainer = build_engine(spec=train_spec,
                               output_dir=output_dir,
                               mlflow_logger=mlflow_logger,
                               tag='train',
                               device=device)
        to_save = {'trainer': trainer, **to_save}

        # Saver
        checkpoint_handler = ModelCheckpoint(checkpoint_dir,
                                             filename_prefix="",
                                             n_saved=n_saved,
                                             require_empty=False)

        def safe_checkpoint_handler(engine, to_save):
            if engine.state.iteration and engine.state.iteration > 0:
                _, last_iteration = get_last_checkpoint(
                    checkpoint_handler=checkpoint_handler)
                if last_iteration is None or last_iteration < engine.state.iteration:
                    checkpoint_handler(engine=engine, to_save=to_save)

        trainer.add_event_handler(event_name=save_event,
                                  handler=safe_checkpoint_handler,
                                  to_save=to_save)

        # Optional evaluation
        if eval_spec is not None:
            assert eval_event is not None
            if not isinstance(eval_spec, dict):
                eval_spec = {'eval': eval_spec}
            # Build evaluators
            evaluators = [(build_engine(spec=spec,
                                        output_dir=output_dir,
                                        mlflow_logger=mlflow_logger,
                                        tag=tag,
                                        trainer=trainer,
                                        metric_cls=SafeAverage,
                                        is_training=False,
                                        device=device), spec)
                          for tag, spec in eval_spec.items()]

            # Add evaluation hook to trainer

            def evaluation(engine):
                for evaluator, spec in evaluators:
                    evaluator.run(spec.loader,
                                  max_epochs=spec.max_epochs,
                                  epoch_length=spec.epoch_length)

            trainer.add_event_handler(event_name=eval_event,
                                      handler=evaluation)

        # Handle ctrl-C or other exceptions
        def exception_callback(engine):
            # Save on exit
            safe_checkpoint_handler(engine=engine, to_save=to_save)

        trainer.add_event_handler(event_name=Events.EXCEPTION_RAISED,
                                  handler=handle_exception,
                                  callback=exception_callback)

        # Get last checkpoint
        checkpoint_file, _ = get_last_checkpoint(checkpoint_handler)
        with capture_signals():
            if checkpoint_file:
                # Load checkpoint
                checkpoint_data = torch.load(checkpoint_file)
                for key, value in to_save.items():
                    value.load_state_dict(checkpoint_data[key])
                tqdm.write(
                    LOADED.format(checkpoint_file, trainer.state.epoch,
                                  trainer.state.iteration))
                if Engine._is_done(trainer.state):
                    # Training complete
                    tqdm.write(COMPLETE)
                else:
                    # Continue training
                    trainer.run(train_spec.loader)
            else:
                # Start training
                trainer.run(train_spec.loader,
                            max_epochs=train_spec.max_epochs,
                            epoch_length=train_spec.epoch_length)
        safe_checkpoint_handler(engine=trainer, to_save=to_save)
        if model_dir:
            os.makedirs(model_dir, exist_ok=True)
            torch.save({k: v.state_dict()
                        for k, v in to_save.items()},
                       os.path.join(model_dir, 'model.pt'))
        return get_metrics(engine=trainer)
Esempio n. 4
0
def train(
    to_save,
    model,
    train_spec: RunSpec,
    eval_spec: RunSpec = None,
    eval_event=Events.EPOCH_COMPLETED,
    save_event=Events.EPOCH_COMPLETED,
    n_saved=10,
    mlflow_enable=True,
    mlflow_tracking_uri=None,
    mlflow_tracking_username=None,
    mlflow_tracking_password=None,
    mlflow_tracking_secret_name=None,
    mlflow_tracking_secret_profile=None,
    mlflow_tracking_secret_region=None,
    mlflow_experiment_name=None,
    mlflow_run_name=None,
    model_dir='output',
    checkpoint_dir='output',
    output_dir='output',
    parameters=None,
    device=None,
    max_epochs=None,
    is_sagemaker=False,
    sagemaker_job_name=None,
    inference_spec=None,
    inference_args=None,
    eval_pbar=None,
    train_pbar=None,
    train_print_event=None,
    eval_print_event=None,
    eval_log_event=None,
    train_log_event=None
):
    """
    Train a model
    """
    save_event = event_argument(save_event)
    eval_event = event_argument(eval_event)
    if eval_spec is not None:
        if eval_pbar is not None:
            eval_spec.enable_pbar = eval_pbar
        if eval_print_event is not None:
            eval_spec.print_event = event_argument(eval_print_event)
        if eval_log_event is not None:
            eval_spec.log_event = event_argument(eval_log_event)
    if train_pbar is not None:
        train_spec.enable_pbar = train_pbar
    if train_print_event is not None:
        train_spec.print_event = event_argument(train_print_event)
    if train_log_event is not None:
        train_spec.log_event = event_argument(train_log_event)
    if max_epochs:
        train_spec.max_epochs = max_epochs
    if mlflow_tracking_uri:
        mlflow.set_tracking_uri(mlflow_tracking_uri)
    if mlflow_tracking_username:
        os.environ['MLFLOW_TRACKING_USERNAME'] = mlflow_tracking_username
    if mlflow_tracking_password:
        os.environ['MLFLOW_TRACKING_PASSWORD'] = mlflow_tracking_password
    if mlflow_tracking_secret_name:
        secret = get_secret(
            profile_name=mlflow_tracking_secret_profile,
            secret_name=mlflow_tracking_secret_name,
            region_name=mlflow_tracking_secret_region)
        if not secret:
            raise ValueError("Could not get secret [{}]. Check secret name, region, and role permissions".format(
                mlflow_tracking_secret_name))
        uri = secret.get('uri', None)
        username = secret.get('username', None)
        password = secret.get('password', None)
        if uri:
            # print("Set uri from secret: [{}]".format(uri))
            mlflow.set_tracking_uri(uri)
        if username:
            # print("Set username from secret")
            os.environ['MLFLOW_TRACKING_USERNAME'] = username
        if password:
            # print("Set password from secret")
            os.environ['MLFLOW_TRACKING_PASSWORD'] = password
    if 'MLFLOW_RUN_ID' in os.environ:
        run_id = os.environ['MLFLOW_RUN_ID']
        # output_dir = os.path.join(output_dir, run_id)
        # model_dir = os.path.join(model_dir, run_id)
        # checkpoint_dir = os.path.join(checkpoint_dir, run_id)

    ctx = mlflow_ctx(
        output_dir=output_dir, checkpoint_dir=checkpoint_dir, mlflow_enable=mlflow_enable,
        experiment_name=mlflow_experiment_name, run_name=mlflow_run_name,
        parameters=parameters, is_sagemaker=is_sagemaker, sagemaker_job_name=sagemaker_job_name)
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(model_dir, exist_ok=True)
    os.makedirs(checkpoint_dir, exist_ok=True)
    with ctx:
        mlflow_logger = get_mlflow_logger(
            output_dir=output_dir,
            checkpoint_dir=checkpoint_dir,
            mlflow_enable=mlflow_enable
        )
        # Create trainer
        trainer = build_engine(
            spec=train_spec,
            output_dir=output_dir,
            mlflow_logger=mlflow_logger,
            tag='train',
            device=device
        )
        to_save = {'trainer': trainer, **to_save}

        # Saver
        checkpoint_handler = ModelCheckpoint(
            checkpoint_dir, filename_prefix="", n_saved=n_saved, require_empty=False)

        def safe_checkpoint_handler(engine, to_save):
            if engine.state.iteration and engine.state.iteration > 0:
                _, last_iteration = get_last_checkpoint(
                    checkpoint_handler=checkpoint_handler)
                if last_iteration is None or last_iteration < engine.state.iteration:
                    checkpoint_handler(
                        engine=engine,
                        to_save=to_save
                    )

        trainer.add_event_handler(
            event_name=save_event,
            handler=safe_checkpoint_handler,
            to_save=to_save
        )

        # Optional evaluation
        if eval_spec is not None:
            assert eval_event is not None
            if not isinstance(eval_spec, dict):
                eval_spec = {
                    'eval': eval_spec
                }
            # Build evaluators
            evaluators = [
                (
                    build_engine(
                        spec=spec,
                        output_dir=output_dir,
                        mlflow_logger=mlflow_logger,
                        tag=tag,
                        trainer=trainer,
                        metric_cls=SafeAverage,
                        is_training=False,
                        device=device
                    ),
                    spec
                )
                for tag, spec in eval_spec.items()
            ]
            # Add evaluation hook to trainer

            def evaluation(engine):
                for evaluator, spec in evaluators:
                    evaluator.run(
                        spec.loader,
                        max_epochs=spec.max_epochs,
                        epoch_length=spec.epoch_length)
            trainer.add_event_handler(
                event_name=eval_event,
                handler=evaluation)

        # Handle ctrl-C or other exceptions

        def exception_callback(engine):
            # Save on exit
            safe_checkpoint_handler(engine=engine, to_save=to_save)
        trainer.add_event_handler(
            event_name=Events.EXCEPTION_RAISED,
            handler=handle_exception,
            callback=exception_callback
        )

        # Get last checkpoint
        checkpoint_file, _ = get_last_checkpoint(checkpoint_handler)
        with capture_signals():
            if checkpoint_file:
                # Load checkpoint
                checkpoint_data = torch.load(checkpoint_file)
                print(checkpoint_file)
                for key, value in to_save.items():
                    value.load_state_dict(checkpoint_data[key])
                tqdm.write(LOADED.format(
                    checkpoint_file, trainer.state.epoch, trainer.state.iteration),
                    file=sys.stdout)
                if Engine._is_done(trainer.state):
                    # Training complete
                    tqdm.write(
                        COMPLETE,
                        file=sys.stdout)
                else:
                    # Continue training
                    trainer.run(train_spec.loader)
            else:
                # Start training
                trainer.run(
                    train_spec.loader,
                    max_epochs=train_spec.max_epochs,
                    epoch_length=train_spec.epoch_length)
        safe_checkpoint_handler(engine=trainer, to_save=to_save)
        if model_dir:
            export_all(
                model_dir=model_dir,
                model=model,
                inference_args=inference_args,
                inference_spec=inference_spec
            )

        return get_metrics(engine=trainer)