コード例 #1
0
    def test_forecast_tcmf_distributed(self):
        input = dict({'id': self.id, 'y': self.data})

        from bigdl.orca import init_orca_context, stop_orca_context

        init_orca_context(cores=4,
                          spark_log_level="INFO",
                          init_ray_on_spark=True,
                          object_store_memory="1g")
        self.model.fit(input, num_workers=4, **self.fit_params)

        with tempfile.TemporaryDirectory() as tempdirname:
            self.model.save(tempdirname)
            loaded_model = TCMFForecaster.load(tempdirname,
                                               is_xshards_distributed=False)
        yhat = self.model.predict(horizon=self.horizon, num_workers=4)
        yhat_loaded = loaded_model.predict(horizon=self.horizon, num_workers=4)
        yhat_id = yhat_loaded["id"]
        np.testing.assert_equal(yhat_id, self.id)
        yhat = yhat["prediction"]
        yhat_loaded = yhat_loaded["prediction"]
        assert yhat.shape == (self.num_samples, self.horizon)
        np.testing.assert_equal(yhat, yhat_loaded)

        self.model.fit_incremental({'y': self.data_new})
        yhat_incr = self.model.predict(horizon=self.horizon)
        yhat_incr = yhat_incr["prediction"]
        assert yhat_incr.shape == (self.num_samples, self.horizon)
        np.testing.assert_raises(AssertionError, np.testing.assert_array_equal,
                                 yhat, yhat_incr)

        target_value = dict({"y": self.data_new})
        assert self.model.evaluate(target_value=target_value, metric=['mse'])
        stop_orca_context()
コード例 #2
0
    def test_s2s_forecaster_xshard_input(self):
        train_data, val_data, test_data = create_data()
        print("original", train_data[0].dtype)
        init_orca_context(cores=4, memory="2g")
        from bigdl.orca.data import XShards

        def transform_to_dict(data):
            return {'x': data[0], 'y': data[1]}

        def transform_to_dict_x(data):
            return {'x': data[0]}

        train_data = XShards.partition(train_data).transform_shard(
            transform_to_dict)
        val_data = XShards.partition(val_data).transform_shard(
            transform_to_dict)
        test_data = XShards.partition(test_data).transform_shard(
            transform_to_dict_x)
        for distributed in [True, False]:
            forecaster = Seq2SeqForecaster(past_seq_len=24,
                                           future_seq_len=5,
                                           input_feature_num=1,
                                           output_feature_num=1,
                                           loss="mae",
                                           lr=0.01,
                                           distributed=distributed)
            forecaster.fit(train_data, epochs=2)
            distributed_pred = forecaster.predict(test_data)
            distributed_eval = forecaster.evaluate(val_data)
        stop_orca_context()
コード例 #3
0
def friesian_context_fixture(request):
    import os
    from bigdl.orca import OrcaContext, init_orca_context, stop_orca_context
    OrcaContext._eager_mode = True
    access_key_id = os.getenv("AWS_ACCESS_KEY_ID")
    secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY")
    if access_key_id is not None and secret_access_key is not None:
        env = {
            "AWS_ACCESS_KEY_ID": access_key_id,
            "AWS_SECRET_ACCESS_KEY": secret_access_key
        }
    else:
        env = None
    sc = init_orca_context(cores=4, spark_log_level="INFO", env=env)
    yield sc
    stop_orca_context()
コード例 #4
0
def orca_context_fixture():
    sc = init_orca_context(cores=8)

    def to_array_(v):
        return v.toArray().tolist()

    def flatten_(v):
        result = []
        for elem in v:
            result.extend(elem.toArray().tolist())
        return result

    spark = SparkSession(sc)
    spark.udf.register("to_array", to_array_, ArrayType(DoubleType()))
    spark.udf.register("flatten", flatten_, ArrayType(DoubleType()))
    yield
    stop_orca_context()
コード例 #5
0
    def test_s2s_forecaster_distributed(self):
        train_data, val_data, test_data = create_data()

        init_orca_context(cores=4, memory="2g")

        forecaster = Seq2SeqForecaster(past_seq_len=24,
                                       future_seq_len=5,
                                       input_feature_num=1,
                                       output_feature_num=1,
                                       loss="mae",
                                       lr=0.01,
                                       distributed=True)

        forecaster.fit(train_data, epochs=2)
        distributed_pred = forecaster.predict(test_data[0])
        distributed_eval = forecaster.evaluate(val_data)

        model = forecaster.get_model()
        assert isinstance(model, torch.nn.Module)

        forecaster.to_local()
        local_pred = forecaster.predict(test_data[0])
        local_eval = forecaster.evaluate(val_data)

        np.testing.assert_almost_equal(distributed_pred, local_pred, decimal=5)

        try:
            import onnx
            import onnxruntime
            local_pred_onnx = forecaster.predict_with_onnx(test_data[0])
            local_eval_onnx = forecaster.evaluate_with_onnx(val_data)
            np.testing.assert_almost_equal(distributed_pred,
                                           local_pred_onnx,
                                           decimal=5)
        except ImportError:
            pass

        model = forecaster.get_model()
        assert isinstance(model, torch.nn.Module)

        stop_orca_context()
コード例 #6
0
 def tearDown(self):
     """ teardown any state that was previously setup with a setup_method
     call.
     """
     stop_orca_context()
コード例 #7
0
    cat_sizes_dict['engaged_with_user_id'] = user_index.size()
    cat_sizes_dict['enaging_user_id'] = user_index.size()

    cross_sizes_dict = dict(
        zip(["_".join(cross_names) for cross_names in cross_cols],
            args.cross_sizes))

    cat_sizes_dict.update(cross_sizes_dict)

    count_sizes_dict = dict(zip(count_cols, [len(bins)] * len(count_cols)))
    cat_sizes_dict.update(count_sizes_dict)
    print("cat size dict: ", cat_sizes_dict)

    if not exists(os.path.join(args.output_folder, "meta")):
        makedirs(os.path.join(args.output_folder, "meta"))

    with tempfile.TemporaryDirectory() as local_path:
        with open(os.path.join(local_path, "categorical_sizes.pkl"),
                  'wb') as f:
            pickle.dump(cat_sizes_dict, f)
        put_local_file_to_remote(os.path.join(local_path,
                                              "categorical_sizes.pkl"),
                                 os.path.join(args.output_folder,
                                              "meta/categorical_sizes.pkl"),
                                 over_write=True)

    end = time()
    print("Preprocessing and save time: ", end - start)

    stop_orca_context()
コード例 #8
0
 def tearDown(self) -> None:
     from bigdl.orca import stop_orca_context
     stop_orca_context()
コード例 #9
0
ファイル: conftest.py プロジェクト: EmiCareOfCell44/BigDL
def orca_context_fixture():
    from bigdl.orca import init_orca_context, stop_orca_context
    init_orca_context(cores=8, init_ray_on_spark=True,
                      object_store_memory="1g")
    yield
    stop_orca_context()
コード例 #10
0
def main():
    anchors = yolo_anchors
    anchor_masks = yolo_anchor_masks

    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir",
                        dest="data_dir",
                        help="Required. The path where data locates.")
    parser.add_argument(
        "--output_data",
        dest="output_data",
        default=tempfile.mkdtemp(),
        help="Required. The path where voc parquet data locates.")
    parser.add_argument("--data_year",
                        dest="data_year",
                        default="2009",
                        help="Required. The voc data date.")
    parser.add_argument("--split_name_train",
                        dest="split_name_train",
                        default="train",
                        help="Required. Split name.")
    parser.add_argument("--split_name_test",
                        dest="split_name_test",
                        default="val",
                        help="Required. Split name.")
    parser.add_argument("--names",
                        dest="names",
                        help="Required. The path where class names locates.")
    parser.add_argument("--weights",
                        dest="weights",
                        default="./checkpoints/yolov3.weights",
                        help="Required. The path where weights locates.")
    parser.add_argument("--checkpoint",
                        dest="checkpoint",
                        default="./checkpoints/yolov3.tf",
                        help="Required. The path where checkpoint locates.")
    parser.add_argument(
        "--checkpoint_folder",
        dest="checkpoint_folder",
        default="./checkpoints",
        help="Required. The path where saved checkpoint locates.")
    parser.add_argument("--epochs",
                        dest="epochs",
                        type=int,
                        default=2,
                        help="Required. epochs.")
    parser.add_argument("--batch_size",
                        dest="batch_size",
                        type=int,
                        default=16,
                        help="Required. epochs.")
    parser.add_argument(
        "--cluster_mode",
        dest="cluster_mode",
        default="local",
        help="Required. Run on local/yarn/k8s/spark-submit mode.")
    parser.add_argument("--class_num",
                        dest="class_num",
                        type=int,
                        default=20,
                        help="Required. class num.")
    parser.add_argument(
        "--worker_num",
        type=int,
        default=1,
        help="The number of slave nodes to be used in the cluster."
        "You can change it depending on your own cluster setting.")
    parser.add_argument(
        "--cores",
        type=int,
        default=4,
        help="The number of cpu cores you want to use on each node. "
        "You can change it depending on your own cluster setting.")
    parser.add_argument(
        "--memory",
        type=str,
        default="20g",
        help="The memory you want to use on each node. "
        "You can change it depending on your own cluster setting.")
    parser.add_argument(
        "--object_store_memory",
        type=str,
        default="10g",
        help="The memory you want to use on each node. "
        "You can change it depending on your own cluster setting.")
    parser.add_argument("--enable_numa_binding",
                        dest="enable_numa_binding",
                        default=False,
                        help="enable_numa_binding")
    parser.add_argument('--k8s_master',
                        type=str,
                        default="",
                        help="The k8s master. "
                        "It should be k8s://https://<k8s-apiserver-host>: "
                        "<k8s-apiserver-port>.")
    parser.add_argument("--container_image",
                        type=str,
                        default="",
                        help="The runtime k8s image. ")
    parser.add_argument('--k8s_driver_host',
                        type=str,
                        default="",
                        help="The k8s driver localhost.")
    parser.add_argument('--k8s_driver_port',
                        type=str,
                        default="",
                        help="The k8s driver port.")
    parser.add_argument('--nfs_mount_path',
                        type=str,
                        default="",
                        help="nfs mount path")

    options = parser.parse_args()

    if options.cluster_mode == "local":
        init_orca_context(cluster_mode="local",
                          cores=options.cores,
                          num_nodes=options.worker_num,
                          memory=options.memory,
                          init_ray_on_spark=True,
                          object_store_memory=options.object_store_memory)
    elif options.cluster_mode == "k8s":
        init_orca_context(
            cluster_mode="k8s",
            master=options.k8s_master,
            container_image=options.container_image,
            init_ray_on_spark=True,
            enable_numa_binding=options.enable_numa_binding,
            num_nodes=options.worker_num,
            cores=options.cores,
            memory=options.memory,
            object_store_memory=options.object_store_memory,
            conf={
                "spark.driver.host":
                options.driver_host,
                "spark.driver.port":
                options.driver_port,
                "spark.kubernetes.executor.volumes.persistentVolumeClaim."
                "nfsvolumeclaim.options.claimName":
                "nfsvolumeclaim",
                "spark.kubernetes.executor.volumes.persistentVolumeClaim."
                "nfsvolumeclaim.mount.path":
                options.nfs_mount_path,
                "spark.kubernetes.driver.volumes.persistentVolumeClaim."
                "nfsvolumeclaim.options.claimName":
                "nfsvolumeclaim",
                "spark.kubernetes.driver.volumes.persistentVolumeClaim."
                "nfsvolumeclaim.mount.path":
                options.nfs_mount_path
            })
    elif options.cluster_mode == "yarn":
        init_orca_context(cluster_mode="yarn-client",
                          cores=options.cores,
                          num_nodes=options.worker_num,
                          memory=options.memory,
                          init_ray_on_spark=True,
                          enable_numa_binding=options.enable_numa_binding,
                          object_store_memory=options.object_store_memory)
    elif options.cluster_mode == "spark-submit":
        init_orca_context(cluster_mode="spark-submit")
    # convert yolov3 weights
    yolo = YoloV3(classes=80)
    load_darknet_weights(yolo, options.weights)
    yolo.save_weights(options.checkpoint)

    def model_creator(config):
        model = YoloV3(DEFAULT_IMAGE_SIZE,
                       training=True,
                       classes=options.class_num)
        anchors = yolo_anchors
        anchor_masks = yolo_anchor_masks

        model_pretrained = YoloV3(DEFAULT_IMAGE_SIZE,
                                  training=True,
                                  classes=80)
        model_pretrained.load_weights(options.checkpoint)

        model.get_layer('yolo_darknet').set_weights(
            model_pretrained.get_layer('yolo_darknet').get_weights())
        freeze_all(model.get_layer('yolo_darknet'))

        optimizer = tf.keras.optimizers.Adam(lr=1e-3)
        loss = [
            YoloLoss(anchors[mask], classes=options.class_num)
            for mask in anchor_masks
        ]
        model.compile(optimizer=optimizer, loss=loss, run_eagerly=False)
        return model

    # prepare data
    class_map = {
        name: idx
        for idx, name in enumerate(open(options.names).read().splitlines())
    }
    dataset_path = os.path.join(options.data_dir, "VOCdevkit")
    voc_train_path = os.path.join(options.output_data, "train_dataset")
    voc_val_path = os.path.join(options.output_data, "val_dataset")

    write_parquet(format="voc",
                  voc_root_path=dataset_path,
                  output_path="file://" + voc_train_path,
                  splits_names=[(options.data_year, options.split_name_train)],
                  classes=class_map)
    write_parquet(format="voc",
                  voc_root_path=dataset_path,
                  output_path="file://" + voc_val_path,
                  splits_names=[(options.data_year, options.split_name_test)],
                  classes=class_map)

    output_types = {
        "image": tf.string,
        "label": tf.float32,
        "image_id": tf.string
    }
    output_shapes = {"image": (), "label": (None, 5), "image_id": ()}

    def train_data_creator(config, batch_size):
        train_dataset = read_parquet(format="tf_dataset",
                                     path=voc_train_path,
                                     output_types=output_types,
                                     output_shapes=output_shapes)
        train_dataset = train_dataset.map(
            lambda data_dict: (data_dict["image"], data_dict["label"]))
        train_dataset = train_dataset.map(parse_data_train)
        train_dataset = train_dataset.shuffle(buffer_size=512)
        train_dataset = train_dataset.batch(batch_size)
        train_dataset = train_dataset.map(lambda x, y: (
            transform_images(x, DEFAULT_IMAGE_SIZE),
            transform_targets(y, anchors, anchor_masks, DEFAULT_IMAGE_SIZE)))
        train_dataset = train_dataset.prefetch(
            buffer_size=tf.data.experimental.AUTOTUNE)
        return train_dataset

    def val_data_creator(config, batch_size):
        val_dataset = read_parquet(format="tf_dataset",
                                   path=voc_val_path,
                                   output_types=output_types,
                                   output_shapes=output_shapes)
        val_dataset = val_dataset.map(lambda data_dict:
                                      (data_dict["image"], data_dict["label"]))
        val_dataset = val_dataset.map(parse_data_train)
        val_dataset = val_dataset.batch(batch_size)
        val_dataset = val_dataset.map(lambda x, y: (
            transform_images(x, DEFAULT_IMAGE_SIZE),
            transform_targets(y, anchors, anchor_masks, DEFAULT_IMAGE_SIZE)))
        return val_dataset

    callbacks = [
        ReduceLROnPlateau(verbose=1),
        EarlyStopping(patience=3, verbose=1),
        ModelCheckpoint(options.checkpoint_folder + '/yolov3_train_{epoch}.tf',
                        verbose=1,
                        save_weights_only=True),
        TensorBoard(log_dir='logs')
    ]

    trainer = Estimator.from_keras(model_creator=model_creator)

    trainer.fit(train_data_creator,
                epochs=options.epochs,
                batch_size=options.batch_size,
                steps_per_epoch=3473 // options.batch_size,
                callbacks=callbacks,
                validation_data=val_data_creator,
                validation_steps=3581 // options.batch_size)
    stop_orca_context()
コード例 #11
0
 def teardown_method(self, method):
     stop_orca_context()
コード例 #12
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch Tensorboard Example')
    parser.add_argument('--cluster_mode', type=str, default="local",
                        help='The cluster mode, such as local, yarn, spark-submit or k8s.')
    parser.add_argument('--backend', type=str, default="bigdl",
                        help='The backend of PyTorch Estimator; '
                             'bigdl, torch_distributed and spark are supported.')
    parser.add_argument('--batch_size', type=int, default=64, help='The training batch size')
    parser.add_argument('--epochs', type=int, default=2, help='The number of epochs to train for')
    args = parser.parse_args()

    if args.cluster_mode == "local":
        init_orca_context()
    elif args.cluster_mode == "yarn":
        init_orca_context(cluster_mode=args.cluster_mode, cores=4, num_nodes=2)
    elif args.cluster_mode == "spark-submit":
        init_orca_context(cluster_mode=args.cluster_mode)

    tensorboard_dir = "runs"
    writer = SummaryWriter(tensorboard_dir + '/fashion_mnist_experiment_1')
    # constant for classes
    classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

    # plot some random training images
    dataiter = iter(train_data_creator(config={}, batch_size=4))
    images, labels = dataiter.next()

    # create grid of images
    img_grid = torchvision.utils.make_grid(images)

    # show images
    matplotlib_imshow(img_grid, one_channel=True)

    # write to tensorboard
    writer.add_image('four_fashion_mnist_images', img_grid)

    # inspect the model using tensorboard
    writer.add_graph(model_creator(config={}), images)
    writer.close()

    # training loss vs. epochs
    criterion = nn.CrossEntropyLoss()
    batch_size = args.batch_size
    epochs = args.epochs
    if args.backend == "bigdl":
        train_loader = train_data_creator(config={}, batch_size=batch_size)
        test_loader = validation_data_creator(config={}, batch_size=batch_size)

        net = model_creator(config={})
        optimizer = optimizer_creator(model=net, config={"lr": 0.001})
        orca_estimator = Estimator.from_torch(model=net,
                                              optimizer=optimizer,
                                              loss=criterion,
                                              metrics=[Accuracy()],
                                              backend="bigdl")

        orca_estimator.set_tensorboard(tensorboard_dir, "bigdl")

        orca_estimator.fit(data=train_loader, epochs=epochs, validation_data=test_loader,
                           checkpoint_trigger=EveryEpoch())

        res = orca_estimator.evaluate(data=test_loader)
        print("Accuracy of the network on the test images: %s" % res)
    elif args.backend in ["torch_distributed", "spark"]:
        orca_estimator = Estimator.from_torch(model=model_creator,
                                              optimizer=optimizer_creator,
                                              loss=criterion,
                                              metrics=[Accuracy()],
                                              backend=args.backend)
        stats = orca_estimator.fit(train_data_creator, epochs=epochs, batch_size=batch_size)

        for stat in stats:
            writer.add_scalar("training_loss", stat['train_loss'], stat['epoch'])
        print("Train stats: {}".format(stats))
        val_stats = orca_estimator.evaluate(validation_data_creator, batch_size=batch_size)
        print("Validation stats: {}".format(val_stats))
        orca_estimator.shutdown()
    else:
        raise NotImplementedError("Only bigdl and torch_distributed are supported "
                                  "as the backend, but got {}".format(args.backend))

    stop_orca_context()
コード例 #13
0
def main(cluster_mode, max_epoch, file_path, batch_size, platform,
         non_interactive):
    import matplotlib
    if not non_interactive and platform == "mac":
        matplotlib.use('qt5agg')

    if cluster_mode == "local":
        init_orca_context(cluster_mode="local", cores=4, memory="3g")
    elif cluster_mode == "yarn":
        init_orca_context(cluster_mode="yarn-client",
                          num_nodes=2,
                          cores=2,
                          driver_memory="3g")
    elif cluster_mode == "spark-submit":
        init_orca_context(cluster_mode="spark-submit")
    load_data(file_path)
    img_dir = os.path.join(file_path, "train")
    label_dir = os.path.join(file_path, "train_masks")

    # Here we only take the first 1000 files for simplicity
    df_train = pd.read_csv(os.path.join(file_path, 'train_masks.csv'))
    ids_train = df_train['img'].map(lambda s: s.split('.')[0])
    ids_train = ids_train[:1000]

    x_train_filenames = []
    y_train_filenames = []
    for img_id in ids_train:
        x_train_filenames.append(os.path.join(img_dir,
                                              "{}.jpg".format(img_id)))
        y_train_filenames.append(
            os.path.join(label_dir, "{}_mask.gif".format(img_id)))

    x_train_filenames, x_val_filenames, y_train_filenames, y_val_filenames = \
        train_test_split(x_train_filenames, y_train_filenames, test_size=0.2, random_state=42)

    def load_and_process_image(path):
        array = mpimg.imread(path)
        result = np.array(Image.fromarray(array).resize(size=(128, 128)))
        result = result.astype(float)
        result /= 255.0
        return result

    def load_and_process_image_label(path):
        array = mpimg.imread(path)
        result = np.array(Image.fromarray(array).resize(size=(128, 128)))
        result = np.expand_dims(result[:, :, 1], axis=-1)
        result = result.astype(float)
        result /= 255.0
        return result

    train_images = np.stack(
        [load_and_process_image(filepath) for filepath in x_train_filenames])
    train_label_images = np.stack([
        load_and_process_image_label(filepath)
        for filepath in y_train_filenames
    ])
    val_images = np.stack(
        [load_and_process_image(filepath) for filepath in x_val_filenames])
    val_label_images = np.stack([
        load_and_process_image_label(filepath) for filepath in y_val_filenames
    ])
    train_shards = XShards.partition({
        "x": train_images,
        "y": train_label_images
    })
    val_shards = XShards.partition({"x": val_images, "y": val_label_images})

    # Build the U-Net model
    def conv_block(input_tensor, num_filters):
        encoder = layers.Conv2D(num_filters, (3, 3),
                                padding='same')(input_tensor)
        encoder = layers.Activation('relu')(encoder)
        encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(encoder)
        encoder = layers.Activation('relu')(encoder)
        return encoder

    def encoder_block(input_tensor, num_filters):
        encoder = conv_block(input_tensor, num_filters)
        encoder_pool = layers.MaxPooling2D((2, 2), strides=(2, 2))(encoder)

        return encoder_pool, encoder

    def decoder_block(input_tensor, concat_tensor, num_filters):
        decoder = layers.Conv2DTranspose(num_filters, (2, 2),
                                         strides=(2, 2),
                                         padding='same')(input_tensor)
        decoder = layers.concatenate([concat_tensor, decoder], axis=-1)
        decoder = layers.Activation('relu')(decoder)
        decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
        decoder = layers.Activation('relu')(decoder)
        decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
        decoder = layers.Activation('relu')(decoder)
        return decoder

    inputs = layers.Input(shape=(128, 128, 3))  # 128
    encoder0_pool, encoder0 = encoder_block(inputs, 16)  # 64
    encoder1_pool, encoder1 = encoder_block(encoder0_pool, 32)  # 32
    encoder2_pool, encoder2 = encoder_block(encoder1_pool, 64)  # 16
    encoder3_pool, encoder3 = encoder_block(encoder2_pool, 128)  # 8
    center = conv_block(encoder3_pool, 256)  # center
    decoder3 = decoder_block(center, encoder3, 128)  # 16
    decoder2 = decoder_block(decoder3, encoder2, 64)  # 32
    decoder1 = decoder_block(decoder2, encoder1, 32)  # 64
    decoder0 = decoder_block(decoder1, encoder0, 16)  # 128
    outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(decoder0)

    net = models.Model(inputs=[inputs], outputs=[outputs])

    # Define custom metrics
    def dice_coeff(y_true, y_pred):
        smooth = 1.
        # Flatten
        y_true_f = tf.reshape(y_true, [-1])
        y_pred_f = tf.reshape(y_pred, [-1])
        intersection = tf.reduce_sum(y_true_f * y_pred_f)
        score = (2. * intersection + smooth) / \
                (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
        return score

    # Define custom loss function
    def dice_loss(y_true, y_pred):
        loss = 1 - dice_coeff(y_true, y_pred)
        return loss

    def bce_dice_loss(y_true, y_pred):
        loss = losses.binary_crossentropy(y_true, y_pred) + dice_loss(
            y_true, y_pred)
        return loss

    # compile model
    net.compile(optimizer=tf.keras.optimizers.Adam(2e-3), loss=bce_dice_loss)
    print(net.summary())

    # create an estimator from keras model
    est = Estimator.from_keras(keras_model=net)
    # fit with estimator
    est.fit(data=train_shards, batch_size=batch_size, epochs=max_epoch)
    # evaluate with estimator
    result = est.evaluate(val_shards)
    print(result)
    # predict with estimator
    val_shards.cache()
    val_image_shards = val_shards.transform_shard(
        lambda val_dict: {"x": val_dict["x"]})
    pred_shards = est.predict(data=val_image_shards, batch_size=batch_size)
    pred = pred_shards.collect()[0]["prediction"]
    val_image_label = val_shards.collect()[0]
    val_image = val_image_label["x"]
    val_label = val_image_label["y"]
    if not non_interactive:
        # visualize 5 predicted results
        plt.figure(figsize=(10, 20))
        for i in range(5):
            img = val_image[i]
            label = val_label[i]
            predicted_label = pred[i]

            plt.subplot(5, 3, 3 * i + 1)
            plt.imshow(img)
            plt.title("Input image")

            plt.subplot(5, 3, 3 * i + 2)
            plt.imshow(label[:, :, 0], cmap='gray')
            plt.title("Actual Mask")
            plt.subplot(5, 3, 3 * i + 3)
            plt.imshow(predicted_label, cmap='gray')
            plt.title("Predicted Mask")
        plt.suptitle("Examples of Input Image, Label, and Prediction")

        plt.show()

    stop_orca_context()