def test_horovod_initialized_correctly(self):
        estimator = Estimator.from_torch(
            model=model_creator,
            optimizer=optimizer_creator,
            loss=nn.MSELoss(),
            scheduler_creator=scheduler_creator,
            config={
                "lr": 1e-2,  # used in optimizer_creator
                "hidden_size": 1  # used in model_creator
            },
            backend="horovod",
            workers_per_node=2)

        def get_size():
            import horovod.torch as hvd
            return hvd.size()

        results = estimator.horovod_runner.run(get_size)
        assert results == [2, 2]

        def get_rank():
            import horovod.torch as hvd
            return hvd.rank()

        results = estimator.horovod_runner.run(get_rank)
        results = sorted(results)
        assert results == [0, 1]
        estimator.shutdown()
Пример #2
0
    def test_bigdl_pytorch_estimator_dataframe_fit_evaluate(self):
        class SimpleModel(nn.Module):
            def __init__(self):
                super(SimpleModel, self).__init__()
                self.fc = nn.Linear(5, 5)

            def forward(self, x):
                x = self.fc(x)
                return F.log_softmax(x, dim=1)

        model = SimpleModel()

        def loss_func(input, target):
            return nn.CrossEntropyLoss().forward(input, target.flatten().long())

        rdd = self.sc.range(0, 100)
        df = rdd.map(lambda x: ([float(x)] * 5,
                                [int(np.random.randint(0, 2,
                                                       size=()))])).toDF(["feature", "label"])

        with tempfile.TemporaryDirectory() as temp_dir_name:
            estimator = Estimator.from_torch(model=model, loss=loss_func, metrics=[Accuracy()],
                                             optimizer=SGD(learningrate_schedule=Default()),
                                             model_dir=temp_dir_name)
            estimator.fit(data=df, epochs=4, batch_size=2, validation_data=df,
                          checkpoint_trigger=EveryEpoch(),
                          feature_cols=["feature"], label_cols=["label"])
            eval_result = estimator.evaluate(df, batch_size=2,
                                             feature_cols=["feature"], label_cols=["label"])
            assert isinstance(eval_result, dict)
    def test_bigdl_pytorch_estimator_dataloader(self):
        class SimpleModel(nn.Module):
            def __init__(self):
                super(SimpleModel, self).__init__()
                self.dense1 = nn.Linear(2, 4)
                self.bn1 = torch.nn.BatchNorm1d(4)
                self.dense2 = nn.Linear(4, 1)

            def forward(self, x):
                x = self.dense1(x)
                x = self.bn1(x)
                x = torch.sigmoid(self.dense2(x))
                return x

        model = SimpleModel()

        estimator = Estimator.from_torch(model=model, loss=nn.BCELoss(),
                                         metrics=[Accuracy()],
                                         optimizer=Adam())

        inputs = torch.Tensor([[1, 2], [1, 3], [3, 2], [5, 6], [8, 9], [1, 9]])
        targets = torch.Tensor([[0], [0], [0], [1], [1], [1]])
        train_loader = torch.utils.data.DataLoader(
            TensorDataset(inputs, targets),
            batch_size=2,
        )
        val_loader = torch.utils.data.DataLoader(
            TensorDataset(inputs, targets),
            batch_size=2,
        )
        estimator.fit(data=train_loader, epochs=2, validation_data=val_loader,
                      checkpoint_trigger=EveryEpoch())
        estimator.evaluate(data=val_loader)
Пример #4
0
    def test_bigdl_pytorch_estimator_dataframe_predict(self):
        def loss_func(input, target):
            return nn.CrossEntropyLoss().forward(input, target.flatten().long())

        class IdentityNet(nn.Module):
            def __init__(self):
                super().__init__()
                # need this line to avoid optimizer raise empty variable list
                self.fc1 = nn.Linear(5, 5)

            def forward(self, input_):
                return input_

        model = IdentityNet()
        rdd = self.sc.range(0, 100)
        df = rdd.map(lambda x: ([float(x)] * 5,
                                [int(np.random.randint(0, 2,
                                                       size=()))])).toDF(["feature", "label"])

        with tempfile.TemporaryDirectory() as temp_dir_name:
            estimator = Estimator.from_torch(model=model, loss=loss_func,
                                             optimizer=SGD(learningrate_schedule=Default()),
                                             model_dir=temp_dir_name)
            result = estimator.predict(df, feature_cols=["feature"])
            expr = "sum(cast(feature <> to_array(prediction) as int)) as error"
            assert result.selectExpr(expr).first()["error"] == 0
Пример #5
0
    def test_bigdl_pytorch_estimator_pandas_dataframe(self):
        class SimpleModel(nn.Module):
            def __init__(self):
                super(SimpleModel, self).__init__()
                self.fc = nn.Linear(1, 10)

            def forward(self, x):
                x = torch.unsqueeze(x, dim=1)
                x = self.fc(x)
                return F.log_softmax(x, dim=1)

        def loss_func(input, target):
            return nn.CrossEntropyLoss().forward(input,
                                                 target.flatten().long())

        model = SimpleModel()

        OrcaContext.pandas_read_backend = "pandas"
        file_path = os.path.join(resource_path,
                                 "orca/learn/simple_feature_label.csv")
        data_shard = read_csv(file_path)

        with tempfile.TemporaryDirectory() as temp_dir_name:
            estimator = Estimator.from_torch(
                model=model,
                loss=loss_func,
                metrics=[Accuracy()],
                optimizer=SGD(learningrate_schedule=Default()),
                model_dir=temp_dir_name)
            estimator.fit(data=data_shard,
                          epochs=1,
                          batch_size=4,
                          feature_cols=['feature'],
                          label_cols=['label'],
                          validation_data=data_shard,
                          checkpoint_trigger=EveryEpoch())
            estimator.evaluate(data_shard,
                               batch_size=4,
                               feature_cols=['feature'],
                               label_cols=['label'])
            est2 = Estimator.from_torch(model=model,
                                        loss=loss_func,
                                        metrics=[Accuracy()],
                                        optimizer=None)
            est2.load_orca_checkpoint(temp_dir_name)
            est2.predict(data_shard, batch_size=4, feature_cols=['feature'])
Пример #6
0
def get_estimator(workers_per_node=1, model_fn=get_model):
    estimator = Estimator.from_torch(model=model_fn,
                                     optimizer=get_optimizer,
                                     loss=nn.BCELoss(),
                                     metrics=Accuracy(),
                                     config={"lr": 1e-2},
                                     workers_per_node=workers_per_node,
                                     backend="spark")
    return estimator
    def test_save_and_restore(self):
        estimator1 = Estimator.from_torch(
            model=model_creator,
            optimizer=optimizer_creator,
            loss=nn.MSELoss(),
            scheduler_creator=scheduler_creator,
            config={
                "lr": 1e-2,  # used in optimizer_creator
                "hidden_size": 1  # used in model_creator
            },
            backend="horovod")
        with TemporaryDirectory() as tmp_path:
            estimator1.fit(train_data_creator, batch_size=4, epochs=1)
            checkpoint_path = os.path.join(tmp_path, "checkpoint")
            estimator1.save(checkpoint_path)

            model1 = estimator1.get_model()

            estimator1.shutdown()

            estimator2 = Estimator.from_torch(
                model=model_creator,
                optimizer=optimizer_creator,
                loss=nn.MSELoss(),
                scheduler_creator=scheduler_creator,
                config={
                    "lr": 1e-2,  # used in optimizer_creator
                    "hidden_size": 1  # used in model_creator
                },
                backend="horovod")
            estimator2.load(checkpoint_path)

            model2 = estimator2.get_model()

        model1_state_dict = model1.state_dict()
        model2_state_dict = model2.state_dict()

        assert set(model1_state_dict.keys()) == set(model2_state_dict.keys())

        for k in model1_state_dict:
            assert torch.equal(model1_state_dict[k], model2_state_dict[k])
        estimator2.shutdown()
Пример #8
0
def get_estimator(workers_per_node=1,
                  model_fn=get_model,
                  sync_stats=False,
                  log_level=logging.INFO,
                  loss=nn.BCELoss(),
                  optimizer=get_optimizer):
    estimator = Estimator.from_torch(model=model_fn,
                                     optimizer=optimizer,
                                     loss=loss,
                                     metrics=Accuracy(),
                                     config={"lr": 1e-2},
                                     workers_per_node=workers_per_node,
                                     backend="torch_distributed",
                                     sync_stats=sync_stats,
                                     log_level=log_level)
    return estimator
Пример #9
0
    def test_data_parallel_sgd_correctness(self):
        sc = init_nncontext()
        rdd = sc.range(0, 100).repartition(2)

        # partition 0: [(0, 0), (0, 0)]
        # partition 1: [(1, 0), (1, 0)]
        # model: y = w * x
        # loss = (wx)^2
        # dloss/dw = 2x^2*w
        # end of first iteration:
        #    partition 0 loss: 0.0
        #    partition 1 loss: 1.0
        #    avg_grad = avg([0, 0, 2, 2]) = 1
        #    weight = 1.0 - 0.5 * avg_grad = 0.5
        # end of second iteration:
        #    partition 0 loss: 0.0
        #    partition 1 loss: 0.25
        #    avg_grad = avg([0, 0, 1, 1]) = 0.5
        #    weight = 0.5 - 0.5 * avg_grad = 0.25
        df = rdd.mapPartitionsWithIndex(
            lambda idx, iter: [([float(idx)], [0.0]) for _ in iter][:2]).toDF(
                ["feature", "label"])

        def get_optimizer(model, config):
            return torch.optim.SGD(model.parameters(), lr=0.5)

        estimator = Estimator.from_torch(model=lambda config: LinearModel(),
                                         optimizer=get_optimizer,
                                         loss=torch.nn.MSELoss(),
                                         metrics=Accuracy(),
                                         config={},
                                         workers_per_node=2,
                                         backend="torch_distributed",
                                         sync_stats=False)

        stats = estimator.fit(df,
                              batch_size=4,
                              epochs=2,
                              feature_cols=["feature"],
                              label_cols=["label"],
                              reduce_results=False)

        state = estimator.get_state_dict()
        assert state['models'][0]['fc1.weight'].item() == 0.25
Пример #10
0
def train_yseq_hvd(workers_per_node, epochs, **config):
    from bigdl.orca.learn.pytorch import Estimator
    estimator = Estimator.from_torch(model=model_creator,
                                     optimizer=optimizer_creator,
                                     loss=loss_creator,
                                     workers_per_node=workers_per_node,
                                     config=config,
                                     backend="horovod")

    stats = estimator.fit(train_data_creator, epochs=epochs)
    for s in stats:
        for k, v in s.items():
            print(f"{k}: {v}")
    val_stats = estimator.evaluate(val_data_creator)
    val_loss = val_stats['val_loss']

    # retrieve the model
    yseq = estimator.get_model()
    estimator.shutdown()
    return yseq, val_loss
    def test_train(self):
        estimator = Estimator.from_torch(
            model=model_creator,
            optimizer=optimizer_creator,
            loss=nn.MSELoss(),
            scheduler_creator=scheduler_creator,
            config={
                "lr": 1e-2,  # used in optimizer_creator
                "hidden_size": 1  # used in model_creator
            },
            backend="horovod",
            workers_per_node=2)
        stats1 = estimator.fit(train_data_creator, batch_size=4, epochs=5)
        train_loss1 = stats1[-1]["train_loss"]
        validation_loss1 = estimator.evaluate(
            validation_data_creator)["val_loss"]

        stats2 = estimator.fit(train_data_creator, batch_size=4, epochs=3)
        train_loss2 = stats2[-1]["train_loss"]
        validation_loss2 = estimator.evaluate(
            validation_data_creator)["val_loss"]

        # Verify syncing weights, i.e. the two workers have the same weights after training
        import ray
        import numpy as np
        remote_workers = estimator.remote_workers
        state_dicts = ray.get(
            [worker.get_state_dict.remote() for worker in remote_workers])
        weights = [state["models"] for state in state_dicts]
        worker1_weights = weights[0][0]
        worker2_weights = weights[1][0]
        for layer in list(worker1_weights.keys()):
            assert np.allclose(worker1_weights[layer].numpy(),
                               worker2_weights[layer].numpy())

        assert train_loss2 <= train_loss1, (train_loss2, train_loss1)
        # todo this test maybe too strict, need to further check
        # assert validation_loss2 <= validation_loss1, (validation_loss2,
        #                                               validation_loss1)
        estimator.shutdown()
Пример #12
0
def train_example(workers_per_node):
    estimator = Estimator.from_torch(
        model=model_creator,
        optimizer=optimizer_creator,
        loss=nn.MSELoss(),
        scheduler_creator=scheduler_creator,
        workers_per_node=workers_per_node,
        config={
            "lr": 1e-2,  # used in optimizer_creator
            "hidden_size": 1  # used in model_creator
        }, backend="horovod")

    # train 5 epochs
    stats = estimator.fit(train_data_creator, batch_size=4, epochs=5)
    print("train stats: {}".format(stats))
    val_stats = estimator.evaluate(validation_data_creator)
    print("validation stats: {}".format(val_stats))

    # retrieve the model
    model = estimator.get_model()
    print("trained weight: % .2f, bias: % .2f" % (
        model.weight.item(), model.bias.item()))
    def test_bigdl_pytorch_estimator_dataloader_creator(self):
        class SimpleModel(nn.Module):
            def __init__(self, momentum):
                super(SimpleModel, self).__init__()
                self.dense1 = nn.Linear(2, 4)
                self.bn1 = torch.nn.BatchNorm1d(4, momentum=momentum)
                self.dense2 = nn.Linear(4, 1)

            def forward(self, x):
                x = self.dense1(x)
                x = self.bn1(x)
                x = torch.sigmoid(self.dense2(x))
                return x

        def model_creator(config):
            model = SimpleModel(momentum=config.get("momentum", 0.1))
            return model

        estimator = Estimator.from_torch(model=model_creator, loss=nn.BCELoss(),
                                         metrics=[Accuracy()],
                                         optimizer=Adam(),
                                         config={"momentum": 0.9})

        def get_dataloader(config, batch_size):
            inputs = torch.Tensor([[1, 2], [1, 3], [3, 2], [5, 6], [8, 9], [1, 9]])
            targets = torch.Tensor([[0], [0], [0], [1], [1], [1]])
            data_loader = torch.utils.data.DataLoader(
                TensorDataset(inputs, targets),
                batch_size=batch_size,
                num_workers=config.get("threads", 1)
            )
            return data_loader

        estimator.fit(data=get_dataloader, epochs=2, batch_size=2, validation_data=get_dataloader,
                      checkpoint_trigger=EveryEpoch())
        estimator.evaluate(data=get_dataloader, batch_size=2)
        model = estimator.get_model()
        assert isinstance(model, nn.Module)
Пример #14
0
    def test_bigdl_pytorch_estimator_save_and_load(self):
        class Network(nn.Module):
            def __init__(self):
                super(Network, self).__init__()

                self.fc1 = nn.Linear(28 * 28, 500)
                self.fc2 = nn.Linear(500, 10)

            def forward(self, x):
                x = x.view(-1, 28 * 28)
                x = F.relu(self.fc1(x))
                x = self.fc2(x)
                return F.log_softmax(x, dim=1)

        model = Network()
        model.train()
        criterion = nn.NLLLoss()
        adam = torch.optim.Adam(model.parameters(), 0.001)

        dir = "/tmp/dataset/"
        batch_size = 320

        images = torch.randn(1000 * 28 * 28,
                             dtype=torch.float32).view(1000, 1, 28, 28)
        labels = torch.randint(0, 10, (1000, ), dtype=torch.long)

        dataset = torch.utils.data.TensorDataset(images, labels)
        train_loader = torch.utils.data.DataLoader(dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True)
        test_loader = torch.utils.data.DataLoader(dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False)

        # epoch 1
        est = Estimator.from_torch(model=model,
                                   optimizer=adam,
                                   loss=criterion,
                                   metrics=[Accuracy()])

        est.fit(data=train_loader,
                epochs=1,
                validation_data=test_loader,
                checkpoint_trigger=EveryEpoch())
        paras1 = list(est.get_model().named_parameters())
        est.save("model_epoch_1")

        # epoch 2
        est.fit(data=train_loader,
                epochs=2,
                validation_data=test_loader,
                checkpoint_trigger=EveryEpoch())
        paras2 = list(est.get_model().named_parameters())
        est.load("model_epoch_1")
        paras3 = list(est.get_model().named_parameters())

        load_success = 0
        for i in range(len(paras2)):
            name2, para2 = paras2[i]
            name3, para3 = paras3[i]
            if not torch.all(torch.eq(para2, para3)):
                load_success = 1
                break
        if not load_success:
            raise Exception(
                "Load failed. Parameters did not change after loading.")

        for i in range(len(paras1)):
            name1, para1 = paras1[i]
            name3, para3 = paras3[i]
            if not torch.all(torch.eq(para1, para3)):
                raise Exception("After reloading the model," + name1 +
                                "does not match.")
        print("pass")
Пример #15
0
    def test_bigdl_pytorch_estimator_shard(self):
        class SimpleModel(nn.Module):
            def __init__(self):
                super(SimpleModel, self).__init__()
                self.fc = nn.Linear(2, 2)

            def forward(self, x):
                x = self.fc(x)
                return F.log_softmax(x, dim=1)

        model = SimpleModel()

        def loss_func(input, target):
            return nn.CrossEntropyLoss().forward(input,
                                                 target.flatten().long())

        def transform(df):
            result = {
                "x":
                np.stack([df['user'].to_numpy(), df['item'].to_numpy()],
                         axis=1),
                "y":
                df['label'].to_numpy()
            }
            return result

        def transform_del_y(d):
            result = {"x": d["x"]}
            return result

        OrcaContext.pandas_read_backend = "pandas"
        file_path = os.path.join(resource_path, "orca/learn/ncf.csv")
        data_shard = read_csv(file_path)
        data_shard = data_shard.transform_shard(transform)

        with tempfile.TemporaryDirectory() as temp_dir_name:
            estimator = Estimator.from_torch(
                model=model,
                loss=loss_func,
                metrics=[Accuracy()],
                optimizer=SGD(learningrate_schedule=Default()),
                model_dir=temp_dir_name)
            estimator.fit(data=data_shard,
                          epochs=4,
                          batch_size=2,
                          validation_data=data_shard,
                          checkpoint_trigger=EveryEpoch())
            state_dict1 = estimator.get_model().state_dict()

            estimator.evaluate(data_shard, batch_size=2)
            est2 = Estimator.from_torch(model=model,
                                        loss=loss_func,
                                        metrics=[Accuracy()],
                                        optimizer=None)
            est2.load_orca_checkpoint(temp_dir_name)
            state_dict2 = est2.get_model().state_dict()

            for name in state_dict1:
                para1 = state_dict1[name]
                para2 = state_dict2[name]
                assert torch.all(torch.eq(para1, para2)), "After reloading the model, " \
                                                          "%r does not match" % name

            est2.fit(data=data_shard,
                     epochs=8,
                     batch_size=2,
                     validation_data=data_shard,
                     checkpoint_trigger=EveryEpoch())
            est2.evaluate(data_shard, batch_size=2)
            pred_result = est2.predict(data_shard)
            pred_c = pred_result.collect()
            assert (pred_result, SparkXShards)
            pred_shard = data_shard.transform_shard(transform_del_y)
            pred_result2 = est2.predict(pred_shard)
            pred_c_2 = pred_result2.collect()
            assert (pred_c[0]["prediction"] == pred_c_2[0]["prediction"]).all()
Пример #16
0
def optim_creator(model, config):
    return optim.Adam(model.parameters(), lr=config.get("lr", 0.01))


criterion = nn.MSELoss()
model_dir = opt.data_dir + "/models"

if opt.backend == "bigdl":
    model = model_creator(config={
        "upscale_factor": opt.upscale_factor,
        "seed": opt.seed
    })
    optimizer = optim_creator(model, config={"lr": opt.lr})
    estimator = Estimator.from_torch(model=model,
                                     optimizer=optimizer,
                                     loss=criterion,
                                     metrics=[MSE()],
                                     model_dir=model_dir,
                                     backend="bigdl")

    train_loader = train_data_creator(config={
        "upscale_factor": opt.upscale_factor,
        "threads": opt.threads
    },
                                      batch_size=opt.batch_size)
    test_loader = validation_data_creator(config={
        "upscale_factor": opt.upscale_factor,
        "threads": opt.threads
    },
                                          batch_size=opt.batch_size)

    estimator.fit(data=train_loader,
Пример #17
0
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))

dataiter = iter(test_loader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))

if args.backend == "bigdl":
    net = model_creator(config={})
    optimizer = optim_creator(model=net, config={"lr": 0.001})
    orca_estimator = Estimator.from_torch(model=net,
                                          optimizer=optimizer,
                                          loss=criterion,
                                          metrics=[Accuracy()],
                                          backend="bigdl")

    orca_estimator.fit(data=train_loader, epochs=args.epochs, validation_data=test_loader,
                       checkpoint_trigger=EveryEpoch())

    res = orca_estimator.evaluate(data=test_loader)
    print("Accuracy of the network on the test images: %s" % res)
elif args.backend in ["torch_distributed", "spark"]:
    orca_estimator = Estimator.from_torch(model=model_creator,
                                          optimizer=optim_creator,
                                          loss=criterion,
                                          metrics=[Accuracy()],
                                          backend=args.backend,
                                          config={"lr": 0.001,
Пример #18
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch Tensorboard Example')
    parser.add_argument('--cluster_mode', type=str, default="local",
                        help='The cluster mode, such as local, yarn, spark-submit or k8s.')
    parser.add_argument('--backend', type=str, default="bigdl",
                        help='The backend of PyTorch Estimator; '
                             'bigdl, torch_distributed and spark are supported.')
    parser.add_argument('--batch_size', type=int, default=64, help='The training batch size')
    parser.add_argument('--epochs', type=int, default=2, help='The number of epochs to train for')
    args = parser.parse_args()

    if args.cluster_mode == "local":
        init_orca_context()
    elif args.cluster_mode == "yarn":
        init_orca_context(cluster_mode=args.cluster_mode, cores=4, num_nodes=2)
    elif args.cluster_mode == "spark-submit":
        init_orca_context(cluster_mode=args.cluster_mode)

    tensorboard_dir = "runs"
    writer = SummaryWriter(tensorboard_dir + '/fashion_mnist_experiment_1')
    # constant for classes
    classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

    # plot some random training images
    dataiter = iter(train_data_creator(config={}, batch_size=4))
    images, labels = dataiter.next()

    # create grid of images
    img_grid = torchvision.utils.make_grid(images)

    # show images
    matplotlib_imshow(img_grid, one_channel=True)

    # write to tensorboard
    writer.add_image('four_fashion_mnist_images', img_grid)

    # inspect the model using tensorboard
    writer.add_graph(model_creator(config={}), images)
    writer.close()

    # training loss vs. epochs
    criterion = nn.CrossEntropyLoss()
    batch_size = args.batch_size
    epochs = args.epochs
    if args.backend == "bigdl":
        train_loader = train_data_creator(config={}, batch_size=batch_size)
        test_loader = validation_data_creator(config={}, batch_size=batch_size)

        net = model_creator(config={})
        optimizer = optimizer_creator(model=net, config={"lr": 0.001})
        orca_estimator = Estimator.from_torch(model=net,
                                              optimizer=optimizer,
                                              loss=criterion,
                                              metrics=[Accuracy()],
                                              backend="bigdl")

        orca_estimator.set_tensorboard(tensorboard_dir, "bigdl")

        orca_estimator.fit(data=train_loader, epochs=epochs, validation_data=test_loader,
                           checkpoint_trigger=EveryEpoch())

        res = orca_estimator.evaluate(data=test_loader)
        print("Accuracy of the network on the test images: %s" % res)
    elif args.backend in ["torch_distributed", "spark"]:
        orca_estimator = Estimator.from_torch(model=model_creator,
                                              optimizer=optimizer_creator,
                                              loss=criterion,
                                              metrics=[Accuracy()],
                                              backend=args.backend)
        stats = orca_estimator.fit(train_data_creator, epochs=epochs, batch_size=batch_size)

        for stat in stats:
            writer.add_scalar("training_loss", stat['train_loss'], stat['epoch'])
        print("Train stats: {}".format(stats))
        val_stats = orca_estimator.evaluate(validation_data_creator, batch_size=batch_size)
        print("Validation stats: {}".format(val_stats))
        orca_estimator.shutdown()
    else:
        raise NotImplementedError("Only bigdl and torch_distributed are supported "
                                  "as the backend, but got {}".format(args.backend))

    stop_orca_context()