def test_input(self):
     rotate = Rotate(image_in='x')
     output = rotate.forward(data=self.single_input, state={})
     with self.subTest('Check output type'):
         self.assertEqual(type(output), list)
     with self.subTest('Check output image shape'):
         self.assertEqual(output[0].shape, self.single_output_shape)
 def test_input_image_and_mask(self):
     rotate = Rotate(image_in='x', mask_in='x_mask')
     output = rotate.forward(data=self.input_image_and_mask, state={})
     with self.subTest('Check output type'):
         self.assertEqual(type(output), list)
     with self.subTest('Check output image shape'):
         self.assertEqual(output[0].shape, self.image_and_mask_output_shape)
     with self.subTest('Check output mask shape'):
         self.assertEqual(output[1].shape, self.image_and_mask_output_shape)
def get_estimator(epochs=20,
                  batch_size=4,
                  train_steps_per_epoch=None,
                  eval_steps_per_epoch=None,
                  save_dir=tempfile.mkdtemp(),
                  log_steps=20,
                  data_dir=None):
    # step 1
    csv = montgomery.load_data(root_dir=data_dir)
    pipeline = fe.Pipeline(
        train_data=csv,
        eval_data=csv.split(0.2),
        batch_size=batch_size,
        ops=[
            ReadImage(inputs="image",
                      parent_path=csv.parent_path,
                      outputs="image",
                      color_flag='gray'),
            ReadImage(inputs="mask_left",
                      parent_path=csv.parent_path,
                      outputs="mask_left",
                      color_flag='gray',
                      mode='!infer'),
            ReadImage(inputs="mask_right",
                      parent_path=csv.parent_path,
                      outputs="mask_right",
                      color_flag='gray',
                      mode='!infer'),
            CombineLeftRightMask(inputs=("mask_left", "mask_right"),
                                 outputs="mask",
                                 mode='!infer'),
            Resize(image_in="image", width=512, height=512),
            Resize(image_in="mask", width=512, height=512, mode='!infer'),
            Sometimes(numpy_op=HorizontalFlip(
                image_in="image", mask_in="mask", mode='train')),
            Sometimes(numpy_op=Rotate(image_in="image",
                                      mask_in="mask",
                                      limit=(-10, 10),
                                      border_mode=cv2.BORDER_CONSTANT,
                                      mode='train')),
            Minmax(inputs="image", outputs="image"),
            Minmax(inputs="mask", outputs="mask", mode='!infer')
        ])

    # step 2
    model = fe.build(
        model_fn=lambda: UNet(input_size=(512, 512, 1)),
        optimizer_fn=lambda: tf.keras.optimizers.Adam(learning_rate=0.0001),
        model_name="lung_segmentation")
    network = fe.Network(ops=[
        ModelOp(inputs="image", model=model, outputs="pred_segment"),
        CrossEntropy(
            inputs=("pred_segment", "mask"), outputs="loss", form="binary"),
        UpdateOp(model=model, loss_name="loss")
    ])

    # step 3
    traces = [
        Dice(true_key="mask", pred_key="pred_segment"),
        BestModelSaver(model=model,
                       save_dir=save_dir,
                       metric='Dice',
                       save_best_mode='max')
    ]
    estimator = fe.Estimator(network=network,
                             pipeline=pipeline,
                             epochs=epochs,
                             log_steps=log_steps,
                             traces=traces,
                             train_steps_per_epoch=train_steps_per_epoch,
                             eval_steps_per_epoch=eval_steps_per_epoch)

    return estimator