Beispiel #1
0
def run(net, train_dataloader, test_dataloader, num_epochs, ctx, lr):
    '''
    Train a test sentiment model
    '''

    # Define trainer
    trainer = mx.gluon.Trainer(net.collect_params(), 'adam',
                               {'learning_rate': lr})
    # Define loss and evaluation metrics
    loss = gluon.loss.SoftmaxCrossEntropyLoss()
    metrics = mx.gluon.metric.CompositeEvalMetric()
    acc = mx.gluon.metric.Accuracy()
    nested_metrics = mx.gluon.metric.CompositeEvalMetric()
    metrics.add([acc, mx.gluon.metric.Loss()])
    nested_metrics.add([metrics, mx.gluon.metric.Accuracy()])

    # Define estimator
    est = estimator.Estimator(net=net,
                              loss=loss,
                              train_metrics=nested_metrics,
                              trainer=trainer,
                              context=ctx)
    # Begin training
    est.fit(train_data=train_dataloader,
            val_data=test_dataloader,
            epochs=num_epochs)
    return acc
Beispiel #2
0
def test_estimator_cpu():
    '''
    Test estimator by doing one pass over each model with synthetic data
    '''
    models = ['resnet18_v1',
              'FCN'
              ]
    context = mx.cpu()
    for model_name in models:
        net, input_shape, label_shape, loss_axis = get_net(model_name, context)
        train_dataset = gluon.data.dataset.ArrayDataset(mx.nd.random.uniform(shape=input_shape),
                                                        mx.nd.zeros(shape=label_shape))
        val_dataset = gluon.data.dataset.ArrayDataset(mx.nd.random.uniform(shape=input_shape),
                                                      mx.nd.zeros(shape=label_shape))
        loss = gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis)
        train_data = gluon.data.DataLoader(train_dataset, batch_size=1)
        val_data = gluon.data.DataLoader(val_dataset, batch_size=1)
        net.hybridize()
        trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
        # Define estimator
        est = estimator.Estimator(net=net,
                                  loss=loss,
                                  metrics=mx.metric.Accuracy(),
                                  trainer=trainer,
                                  context=context)
        # Call fit()
        est.fit(train_data=train_data,
                val_data=val_data,
                epochs=1)
Beispiel #3
0
def test_estimator_gpu():
    '''
    Test estimator by training resnet18_v1 for 5 epochs on MNIST and verify accuracy
    '''
    model_name = 'resnet18_v1'
    batch_size = 128
    num_epochs = 5
    context = mx.gpu(0)
    net, _, _, _ = get_net(model_name, context)
    train_data, test_data = load_data_mnist(batch_size, resize=224)
    loss = gluon.loss.SoftmaxCrossEntropyLoss()
    net.hybridize()
    acc = mx.metric.Accuracy()
    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
    # Define estimator
    est = estimator.Estimator(net=net,
                              loss=loss,
                              metrics=acc,
                              trainer=trainer,
                              context=context)
    # Call fit()
    est.fit(train_data=train_data,
            val_data=test_data,
            epochs=num_epochs)

    assert acc.get()[1] > 0.80
Beispiel #4
0
def gluon_model(model_data):
    train_data, train_label, _ = model_data
    train_data_loader = DataLoader(list(zip(train_data, train_label)),
                                   batch_size=128,
                                   last_batch="discard")
    model = HybridSequential()
    model.add(Dense(128, activation="relu"))
    model.add(Dense(64, activation="relu"))
    model.add(Dense(10))
    model.initialize()
    model.hybridize()
    trainer = Trainer(model.collect_params(),
                      "adam",
                      optimizer_params={
                          "learning_rate": .001,
                          "epsilon": 1e-07
                      })
    est = estimator.Estimator(net=model,
                              loss=SoftmaxCrossEntropyLoss(),
                              metrics=Accuracy(),
                              trainer=trainer)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        est.fit(train_data_loader, epochs=3)
    return model
def gluon_model(model_data):
    train_data, train_label, _ = model_data
    dataset = mx.gluon.data.ArrayDataset(train_data, train_label)
    train_data_loader = DataLoader(dataset,
                                   batch_size=128,
                                   last_batch="discard")
    model = HybridSequential()
    model.add(Dense(128, activation="relu"))
    model.add(Dense(64, activation="relu"))
    model.add(Dense(10))
    model.initialize()
    model.hybridize()
    trainer = Trainer(model.collect_params(),
                      "adam",
                      optimizer_params={
                          "learning_rate": 0.001,
                          "epsilon": 1e-07
                      })

    # `metrics` was renamed in mxnet 1.6.0: https://github.com/apache/incubator-mxnet/pull/17048
    arg_name = "metrics" if Version(
        mx.__version__) < Version("1.6.0") else "train_metrics"
    est = estimator.Estimator(net=model,
                              loss=SoftmaxCrossEntropyLoss(),
                              trainer=trainer,
                              **{arg_name: Accuracy()})
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        est.fit(train_data_loader, epochs=3)
    return model
Beispiel #6
0
def gluon_random_data_run():
    mlflow.gluon.autolog()

    with mlflow.start_run() as run:
        data = DataLoader(LogsDataset(), batch_size=128, last_batch="discard")
        validation = DataLoader(LogsDataset(),
                                batch_size=128,
                                last_batch="discard")

        model = HybridSequential()
        model.add(Dense(64, activation="relu"))
        model.add(Dense(64, activation="relu"))
        model.add(Dense(10))
        model.initialize()
        model.hybridize()
        trainer = Trainer(
            model.collect_params(),
            "adam",
            optimizer_params={
                "learning_rate": 0.001,
                "epsilon": 1e-07
            },
        )
        est = estimator.Estimator(net=model,
                                  loss=SoftmaxCrossEntropyLoss(),
                                  metrics=Accuracy(),
                                  trainer=trainer)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            est.fit(data, epochs=3, val_data=validation)
    client = mlflow.tracking.MlflowClient()
    return client.get_run(run.info.run_id)
Beispiel #7
0
def test_autolog_ends_auto_created_run():
    mlflow.gluon.autolog()

    data = DataLoader(LogsDataset(), batch_size=128, last_batch="discard")

    model = HybridSequential()
    model.add(Dense(64, activation="relu"))
    model.add(Dense(64, activation="relu"))
    model.add(Dense(10))
    model.initialize()
    model.hybridize()

    trainer = Trainer(model.collect_params(),
                      "adam",
                      optimizer_params={
                          "learning_rate": 0.001,
                          "epsilon": 1e-07
                      })
    est = estimator.Estimator(net=model,
                              loss=SoftmaxCrossEntropyLoss(),
                              trainer=trainer,
                              **get_metrics())

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        est.fit(data, epochs=3)

    assert mlflow.active_run() is None
Beispiel #8
0
def test_resume_checkpoint():
    with TemporaryDirectory() as tmpdir:
        model_prefix = 'test_net'
        file_path = os.path.join(tmpdir, model_prefix)
        test_data = _get_test_data()

        net = _get_test_network()
        ce_loss = loss.SoftmaxCrossEntropyLoss()
        acc = mx.metric.Accuracy()
        est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
        checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
                                                             model_prefix=model_prefix,
                                                             monitor=acc,
                                                             max_checkpoints=1)
        est.fit(test_data, event_handlers=[checkpoint_handler], epochs=2)
        assert os.path.isfile(file_path + '-epoch1batch8.params')
        assert os.path.isfile(file_path + '-epoch1batch8.states')
        checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
                                                             model_prefix=model_prefix,
                                                             monitor=acc,
                                                             max_checkpoints=1,
                                                             resume_from_checkpoint=True)
        est.fit(test_data, event_handlers=[checkpoint_handler], epochs=5)
        # should only continue to train 3 epochs and last checkpoint file is epoch4
        assert est.max_epoch == 3
        assert os.path.isfile(file_path + '-epoch4batch20.states')
Beispiel #9
0
def test_autolog_persists_manually_created_run():
    mlflow.gluon.autolog()

    data = DataLoader(LogsDataset(), batch_size=128, last_batch="discard")

    with mlflow.start_run() as run:

        model = HybridSequential()
        model.add(Dense(64, activation="relu"))
        model.add(Dense(64, activation="relu"))
        model.add(Dense(10))
        model.initialize()
        model.hybridize()
        trainer = Trainer(
            model.collect_params(),
            "adam",
            optimizer_params={
                "learning_rate": 0.001,
                "epsilon": 1e-07
            },
        )
        est = estimator.Estimator(net=model,
                                  loss=SoftmaxCrossEntropyLoss(),
                                  metrics=Accuracy(),
                                  trainer=trainer)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            est.fit(data, epochs=3)

        assert mlflow.active_run().info.run_id == run.info.run_id
Beispiel #10
0
def test_checkpoint_handler():
    with TemporaryDirectory() as tmpdir:
        model_prefix = 'test_epoch'
        file_path = os.path.join(tmpdir, model_prefix)
        test_data = _get_test_data()

        net = _get_test_network()
        ce_loss = loss.SoftmaxCrossEntropyLoss()
        acc = mx.gluon.metric.Accuracy()
        est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
        checkpoint_handler = event_handler.CheckpointHandler(
            model_dir=tmpdir,
            model_prefix=model_prefix,
            monitor=acc,
            save_best=True,
            epoch_period=1)
        est.fit(test_data, event_handlers=[checkpoint_handler], epochs=1)
        assert checkpoint_handler.current_epoch == 1
        assert checkpoint_handler.current_batch == 4
        assert os.path.isfile(file_path + '-best.params')
        assert os.path.isfile(file_path + '-best.states')
        assert os.path.isfile(file_path + '-epoch0batch4.params')
        assert os.path.isfile(file_path + '-epoch0batch4.states')

        model_prefix = 'test_batch'
        file_path = os.path.join(tmpdir, model_prefix)
        net = _get_test_network(nn.HybridSequential())
        net.hybridize()
        est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
        checkpoint_handler = event_handler.CheckpointHandler(
            model_dir=tmpdir,
            model_prefix=model_prefix,
            epoch_period=None,
            batch_period=2,
            max_checkpoints=2)
        est.fit(test_data, event_handlers=[checkpoint_handler], batches=10)
        assert checkpoint_handler.current_batch == 10
        assert checkpoint_handler.current_epoch == 3
        assert not os.path.isfile(file_path + 'best.params')
        assert not os.path.isfile(file_path + 'best.states')
        assert not os.path.isfile(file_path + '-epoch0batch0.params')
        assert not os.path.isfile(file_path + '-epoch0batch0.states')
        assert os.path.isfile(file_path + '-symbol.json')
        assert os.path.isfile(file_path + '-epoch1batch7.params')
        assert os.path.isfile(file_path + '-epoch1batch7.states')
        assert os.path.isfile(file_path + '-epoch2batch9.params')
        assert os.path.isfile(file_path + '-epoch2batch9.states')
Beispiel #11
0
def test_validation_handler():
    test_data = _get_test_data()

    net = _get_test_network()
    ce_loss = loss.SoftmaxCrossEntropyLoss()
    acc = mx.gluon.metric.Accuracy()
    est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
    val_handler = ValidationHandler(val_data=test_data,
                                    eval_fn=est.evaluate,
                                    event_handlers=Handler())

    est.fit(train_data=test_data, val_data=test_data,
            event_handlers=[val_handler], epochs=2)
    assert est.run_test_handler == True
Beispiel #12
0
def build_estimator(_ctx):
    _net = get_model('cifar_resnet20_v1', classes=10)
    _net.initialize(mx.init.Xavier(), ctx=ctx)
    _loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

    _optimizer = 'nag'
    _optimizer_params = {'learning_rate': 0.1, 'wd': 0.0001, 'momentum': 0.9}
    _trainer = gluon.Trainer(_net.collect_params(), _optimizer,
                             _optimizer_params)

    train_acc = mx.metric.Accuracy()
    _est = estimator.Estimator(net=_net,
                               loss=_loss_fn,
                               metrics=train_acc,
                               trainer=_trainer,
                               context=_ctx)

    return _est
Beispiel #13
0
def test_early_stopping():
    test_data = _get_test_data()

    net = _get_test_network()
    ce_loss = loss.SoftmaxCrossEntropyLoss()
    acc = mx.gluon.metric.Accuracy()
    est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
    early_stopping = event_handler.EarlyStoppingHandler(monitor=acc,
                                                        patience=0,
                                                        mode='min')
    est.fit(test_data, event_handlers=[early_stopping], epochs=5)
    assert early_stopping.current_epoch == 2
    assert early_stopping.stopped_epoch == 1

    early_stopping = event_handler.EarlyStoppingHandler(monitor=acc,
                                                        patience=2,
                                                        mode='auto')
    est.fit(test_data, event_handlers=[early_stopping], epochs=1)
    assert early_stopping.current_epoch == 1
Beispiel #14
0
def test_custom_handler():
    class CustomStopHandler(event_handler.TrainBegin,
                            event_handler.BatchEnd,
                            event_handler.EpochEnd):
        def __init__(self, batch_stop=None, epoch_stop=None):
            self.batch_stop = batch_stop
            self.epoch_stop = epoch_stop
            self.num_batch = 0
            self.num_epoch = 0
            self.stop_training = False

        def train_begin(self, estimator, *args, **kwargs):
            self.num_batch = 0
            self.num_epoch = 0

        def batch_end(self, estimator, *args, **kwargs):
            self.num_batch += 1
            if self.num_batch == self.batch_stop:
                self.stop_training = True
            return self.stop_training

        def epoch_end(self, estimator, *args, **kwargs):
            self.num_epoch += 1
            if self.num_epoch == self.epoch_stop:
                self.stop_training = True
            return self.stop_training

    # total data size is 32, batch size is 8
    # 4 batch per epoch
    test_data = _get_test_data()
    net = _get_test_network()
    ce_loss = loss.SoftmaxCrossEntropyLoss()
    acc = mx.gluon.metric.Accuracy()
    est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
    custom_handler = CustomStopHandler(3, 2)
    est.fit(test_data, event_handlers=[custom_handler], epochs=3)
    assert custom_handler.num_batch == 3
    assert custom_handler.num_epoch == 1
    custom_handler = CustomStopHandler(100, 5)
    est.fit(test_data, event_handlers=[custom_handler], epochs=10)
    assert custom_handler.num_batch == 5 * 4
    assert custom_handler.num_epoch == 5
def test_logging():
    with TemporaryDirectory() as tmpdir:
        test_data = _get_test_data()
        file_name = 'test_log'
        output_dir = os.path.join(tmpdir, file_name)

        net = _get_test_network()
        ce_loss = loss.SoftmaxCrossEntropyLoss()
        acc = mx.metric.Accuracy()
        est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
        train_metrics, val_metrics = est.prepare_loss_and_metrics()
        logging_handler = event_handler.LoggingHandler(
            file_name=file_name,
            file_location=tmpdir,
            train_metrics=train_metrics,
            val_metrics=val_metrics)
        est.fit(test_data, event_handlers=[logging_handler], epochs=3)
        assert logging_handler.batch_index == 0
        assert logging_handler.current_epoch == 3
        assert os.path.isfile(output_dir)
Beispiel #16
0
def test_logging():
    with TemporaryDirectory() as tmpdir:
        test_data = _get_test_data()
        file_name = 'test_log'
        output_dir = os.path.join(tmpdir, file_name)

        net = _get_test_network()
        ce_loss = loss.SoftmaxCrossEntropyLoss()
        acc = mx.gluon.metric.Accuracy()
        est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)

        est.logger.addHandler(logging.FileHandler(output_dir))

        train_metrics = est.train_metrics
        val_metrics = est.val_metrics
        logging_handler = event_handler.LoggingHandler(metrics=train_metrics)
        est.fit(test_data, event_handlers=[logging_handler], epochs=3)
        assert logging_handler.batch_index == 0
        assert logging_handler.current_epoch == 3
        assert os.path.isfile(output_dir)
        del est  # Clean up estimator and logger before deleting tmpdir
def train(args: argparse.Namespace) -> HybridBlock:
    session = boto3.session.Session()

    client = session.client(service_name="secretsmanager",
                            region_name="us-east-1")
    mlflow_secret = client.get_secret_value(SecretId=args.mlflow_secret)
    mlflowdb_conf = json.loads(mlflow_secret["SecretString"])

    converters.encoders[np.float64] = converters.escape_float
    converters.conversions = converters.encoders.copy()
    converters.conversions.update(converters.decoders)

    mlflow.set_tracking_uri(
        f"mysql+pymysql://{mlflowdb_conf['username']}:{mlflowdb_conf['password']}@{mlflowdb_conf['host']}/mlflow"
    )

    if mlflow.get_experiment_by_name(args.mlflow_experiment) is None:
        mlflow.create_experiment(args.mlflow_experiment,
                                 args.mlflow_artifacts_location)
    mlflow.set_experiment(args.mlflow_experiment)

    col_names = ["target"] + [f"kinematic_{i}" for i in range(1, 22)]

    train_df = pd.read_csv(f"{args.train_channel}/train.csv.gz",
                           header=None,
                           names=col_names)

    val_df = pd.read_csv(f"{args.validation_channel}/val.csv.gz",
                         header=None,
                         names=col_names)

    train_X = train_df.drop("target", axis=1)
    train_y = train_df["target"]
    train_dataset = ArrayDataset(train_X.to_numpy(dtype="float32"),
                                 train_y.to_numpy(dtype="float32"))
    train = DataLoader(train_dataset, batch_size=args.batch_size)

    val_X = val_df.drop("target", axis=1)
    val_y = val_df["target"]
    val_dataset = ArrayDataset(val_X.to_numpy(dtype="float32"),
                               val_y.to_numpy(dtype="float32"))
    validation = DataLoader(val_dataset, batch_size=args.batch_size)

    ctx = [gpu(i) for i in range(args.gpus)] if args.gpus > 0 else cpu()

    mlflow.gluon.autolog()

    with mlflow.start_run():
        net = HybridSequential()
        with net.name_scope():
            net.add(Dense(256))
            net.add(Dropout(.2))
            net.add(Dense(64))
            net.add(Dropout(.1))
            net.add(Dense(16))
            net.add(Dense(2))

        net.initialize(Xavier(magnitude=2.24), ctx=ctx)
        net.hybridize()

        trainer = Trainer(net.collect_params(), "sgd",
                          {"learning_rate": args.learning_rate})
        est = estimator.Estimator(net=net,
                                  loss=SoftmaxCrossEntropyLoss(),
                                  trainer=trainer,
                                  train_metrics=Accuracy(),
                                  context=ctx)
        est.fit(train, epochs=args.epochs, val_data=validation)

    return net
Beispiel #18
0
def train(train_dl, test_dl, exp, setting, tags=[]):

    num_epochs = setting['epochs']
    opt = setting['opt']

    # gpu setting
    gpu_count = setting['gpu_count']
    ctx = [mx.gpu(i) for i in range(gpu_count)] if gpu_count > 0 else mx.cpu()

    net = mlp(**setting['model_params'])

    net.initialize(init=mx.init.Xavier(), ctx=ctx, force_reinit=True)
    net.hybridize(static_alloc=True, static_shape=True)

    trainer = gluon.Trainer(net.collect_params(), opt, setting['opt_params'])

    # metrics
    train_acc = mx.metric.Accuracy()
    loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

    est = estimator.Estimator(net=net,
                              loss=loss_fn,
                              metrics=train_acc,
                              trainer=trainer,
                              context=ctx)

    # loggingの開始
    run = exp.start_logging()

    try:
        # tagをつける
        for t in tags:
            run.tag(t)

        # settingを保存
        log_dict(run, setting)

        # モデルを保存するcallback
        # クラウド上に保存するので/tmpでOK
        checkpoint_handler = CheckpointHandler(model_dir='/tmp',
                                               model_prefix='model',
                                               monitor=train_acc,
                                               save_best=True,
                                               max_checkpoints=0)

        # runを利用してAML上にlogging
        record_handler = AMLRecordHandler(run)

        # ignore warnings for nightly test on CI only
        import warnings
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            est.fit(train_data=train_dl,
                    val_data=test_dl,
                    epochs=num_epochs,
                    event_handlers=[checkpoint_handler, record_handler])

        # モデルをアップロード
        run.upload_file(name='model-best.params',
                        path_or_stream='/tmp/model-best.params')

        # statusをcompleteにする
        run.complete()

    except Exception as e:
        # statusをfailにする
        run.fail(e)
        raise ValueError('error occured: {}'.format(e))
Beispiel #19
0
def train(exp, setting, tags=[]):

    # データの取得
    # 本来は実行時にnormalizeなどの変換をしてもいいが説明のために分けている
    fashion_mnist_train, fashion_mnist_test = pd.read_pickle(
        setting['data_path'])

    batch_size = setting['batch_size']
    train_dl = gluon.data.DataLoader(fashion_mnist_train,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=4)
    test_dl = gluon.data.DataLoader(fashion_mnist_test,
                                    batch_size=batch_size,
                                    shuffle=False,
                                    num_workers=4)

    num_epochs = setting['epochs']
    opt = setting['opt']

    # gpu setting
    gpu_count = setting['gpu_count']
    ctx = [mx.gpu(i) for i in range(gpu_count)] if gpu_count > 0 else mx.cpu()

    net = mlp(**setting['model_params'])

    net.initialize(init=mx.init.Xavier(), ctx=ctx, force_reinit=True)
    net.hybridize(static_alloc=True, static_shape=True)

    trainer = gluon.Trainer(net.collect_params(), opt, setting['opt_params'])

    # metrics
    train_acc = mx.metric.Accuracy()
    loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

    est = estimator.Estimator(net=net,
                              loss=loss_fn,
                              metrics=train_acc,
                              trainer=trainer,
                              context=ctx)

    run = exp.start_logging(tags)

    try:
        setting['commit_hash'] = run.git_commit()
        run.save(setting, 'setting.json', mode='json')

        checkpoint_handler = CheckpointHandler(model_dir=str(run.path),
                                               model_prefix='model',
                                               monitor=train_acc,
                                               save_best=True,
                                               max_checkpoints=0)

        record_handler = RecordHandler(file_name='log.pkl',
                                       file_location=run.path)

        # ignore warnings for nightly test on CI only
        import warnings
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            est.fit(train_data=train_dl,
                    val_data=test_dl,
                    epochs=num_epochs,
                    event_handlers=[checkpoint_handler, record_handler])

    except Exception as e:
        run.delete()
        raise ValueError('error occured and delete run folder: {}'.format(e))
Beispiel #20
0
epochs = 2
for e in range(epochs):
    cumulative_loss = 0
    for i, (X, y) in enumerate(zip(dataset, imdb_y_train)):
        X = X.as_in_context(model_ctx)
        y = y.as_in_context(model_ctx)
        with autograd.record():
            output = model(X)
            l = loss(output, label)
        l.backward()
        opt.step(X.shape[0])
        cumulative_loss += mxnet.nd.sum(l).asscalar()
    test_accuracy = evaluate_accuracy(test_data, model)
    train_accuracy = evaluate_accuracy(train_data, model)
    print("Epoch %s. Loss: %s, Train_acc %s, Test_acc %s" %
          (e, cumulative_loss/num_examples, train_accuracy, test_accuracy))


# TODO: Explore Fit API.
from mxnet.gluon.contrib.estimator import estimator

est = estimator.Estimator(
    net=model,
    loss=loss,
    metrics=mxnet.metric.Accuracy(),
    trainer=opt,
    context=model_ctx)

est.fit(train_data=train_data_loader,
        epochs=2)
def train(batch_size, epochs, learning_rate, hidden_size, num_layers, dropout, num_gpus, training_channel, model_dir):
    from mxnet.gluon.contrib.estimator import estimator as E
    from mxnet.gluon.contrib.estimator.event_handler import CheckpointHandler
    
    logging.getLogger().setLevel(logging.DEBUG)
    checkpoints_dir = '/opt/ml/checkpoints'
    checkpoints_enabled = os.path.exists(checkpoints_dir)
    
    # Preparing datasets:
    logging.info('[### train ###] Loading data')
    train_loader = preprocessing(training_channel, batch_size)
    
    # Configuring network:
    logging.info('[### train ###] Initializing network')
    net = RULPredictionNet(hidden_size, num_layers, dropout).net
    net.hybridize()
    device = mx.gpu(0) if num_gpus > 0 else mx.cpu(0)
    net.initialize(mx.init.Xavier(), ctx=device)
    
    trainer = G.Trainer(
        params=net.collect_params(),
        optimizer='adam',
        optimizer_params={'learning_rate': learning_rate},
    )
    
    # Define the estimator, by passing to it the model, 
    # loss function, metrics, trainer object and context:
    estimator = E.Estimator(
        net=net,
        loss=G.loss.L2Loss(),
        train_metrics=[mx.metric.RMSE(), mx.metric.Loss()],
        trainer=trainer,
        context=device
    )
    
    checkpoint_handler = CustomCheckpointHandler(
        model_dir=model_dir,
        model_prefix='model',
        monitor=estimator.train_metrics[0],
        mode='min',
        save_best=True
    )
    
    # Start training the model:
    logging.info('[### train ###] Training start')
    estimator.fit(train_data=train_loader, epochs=epochs, event_handlers=[checkpoint_handler])
    logging.info('[### train ###] Training end')
    
    # Cleanup model directory before SageMaker zips it to send it back to S3:
    logging.info('[### train ###] Model directory clean up, only keeps the best model')
    model_name = 'model'
    os.remove(os.path.join(model_dir, model_name + '-best.params'))
    os.remove(os.path.join(model_dir, model_name + '-symbol.json'))
    os.rename(os.path.join(model_dir, model_name + '-custom-0000.params'), os.path.join(model_dir, model_name + '-best.params'))
    os.rename(os.path.join(model_dir, model_name + '-custom-symbol.json'), os.path.join(model_dir, model_name + '-symbol.json'))
    for files in os.listdir(model_dir):
        if (files[:len(model_name + '-epoch')] == model_name + '-epoch'):
            os.remove(os.path.join(model_dir, files))
                  
    logging.info('[### train ###] Emitting metrics')
    training_rmse = estimator.train_metrics[0].get()
    training_loss = estimator.train_metrics[1].get()
    print('training rmse: {}'.format(training_rmse[1]))
    print('training loss: {}'.format(training_loss[1]))
    logging.getLogger().setLevel(logging.WARNING)