示例#1
0
def get_estimator(style_img_path=None,
                  data_path=None,
                  style_weight=5.0,
                  content_weight=1.0,
                  tv_weight=1e-4,
                  steps_per_epoch=None,
                  validation_steps=None,
                  model_dir=tempfile.mkdtemp()):
    train_csv, _, path = load_data(data_path, load_object=False)
    if style_img_path is None:
        style_img_path = tf.keras.utils.get_file(
            'kandinsky.jpg',
            'https://storage.googleapis.com/download.tensorflow.org/example_images/Vassily_Kandinsky%2C_1913_'
            '-_Composition_7.jpg')
    style_img = cv2.imread(style_img_path)
    assert (style_img is not None), "Invalid style reference image"
    tfr_save_dir = os.path.join(path, 'tfrecords')
    style_img = (style_img.astype(np.float32) / 127.5) / 127.5
    style_img_t = tf.convert_to_tensor(np.expand_dims(style_img, axis=0))
    writer = RecordWriter(train_data=train_csv,
                          save_dir=tfr_save_dir,
                          ops=[
                              ImageReader(inputs="image",
                                          parent_path=path,
                                          outputs="image"),
                              Resize(inputs="image",
                                     target_size=(256, 256),
                                     outputs="image")
                          ])

    pipeline = fe.Pipeline(batch_size=4,
                           data=writer,
                           ops=[Rescale(inputs="image", outputs="image")])

    model = fe.build(model_def=styleTransferNet,
                     model_name="style_transfer_net",
                     loss_name="loss",
                     optimizer=tf.keras.optimizers.Adam(1e-3))

    network = fe.Network(ops=[
        ModelOp(inputs="image", model=model, outputs="image_out"),
        ExtractVGGFeatures(inputs=lambda: style_img_t, outputs="y_style"),
        ExtractVGGFeatures(inputs="image", outputs="y_content"),
        ExtractVGGFeatures(inputs="image_out", outputs="y_pred"),
        StyleContentLoss(style_weight=style_weight,
                         content_weight=content_weight,
                         tv_weight=tv_weight,
                         inputs=('y_pred', 'y_style', 'y_content',
                                 'image_out'),
                         outputs='loss')
    ])

    estimator = fe.Estimator(network=network,
                             pipeline=pipeline,
                             epochs=2,
                             steps_per_epoch=steps_per_epoch,
                             validation_steps=validation_steps,
                             traces=ModelSaver(model_name="style_transfer_net",
                                               save_dir=model_dir))
    return estimator
def get_estimator(data_path=None, model_dir=tempfile.mkdtemp(), batch_size=2):
    #prepare dataset
    train_csv, val_csv, path = load_data(path=data_path)
    writer = fe.RecordWriter(
        save_dir=os.path.join(path, "retinanet_coco_1024"),
        train_data=train_csv,
        validation_data=val_csv,
        ops=[
            ImageReader(inputs="image", parent_path=path, outputs="image"),
            String2List(inputs=["x1", "y1", "width", "height", "obj_label"],
                        outputs=["x1", "y1", "width", "height", "obj_label"]),
            ResizeImageAndBbox(
                target_size=(1024, 1024),
                keep_ratio=True,
                inputs=["image", "x1", "y1", "width", "height"],
                outputs=["image", "x1", "y1", "width", "height"]),
            FlipImageAndBbox(inputs=[
                "image", "x1", "y1", "width", "height", "obj_label", "id"
            ],
                             outputs=[
                                 "image", "x1", "y1", "width", "height",
                                 "obj_label", "id"
                             ]),
            GenerateTarget(inputs=("obj_label", "x1", "y1", "width", "height"),
                           outputs=("cls_gt", "x1_gt", "y1_gt", "w_gt",
                                    "h_gt"))
        ],
        expand_dims=True,
        compression="GZIP",
        write_feature=[
            "image", "id", "cls_gt", "x1_gt", "y1_gt", "w_gt", "h_gt",
            "obj_label", "x1", "y1", "width", "height"
        ])
    # prepare pipeline
    pipeline = fe.Pipeline(batch_size=batch_size,
                           data=writer,
                           ops=[
                               Rescale(inputs="image", outputs="image"),
                               Pad(padded_shape=[2051],
                                   inputs=[
                                       "x1_gt", "y1_gt", "w_gt", "h_gt",
                                       "obj_label", "x1", "y1", "width",
                                       "height"
                                   ],
                                   outputs=[
                                       "x1_gt", "y1_gt", "w_gt", "h_gt",
                                       "obj_label", "x1", "y1", "width",
                                       "height"
                                   ])
                           ])
    # prepare network
    model = fe.build(model_def=lambda: RetinaNet(input_shape=(1024, 1024, 3),
                                                 num_classes=90),
                     model_name="retinanet",
                     optimizer=tf.optimizers.SGD(momentum=0.9),
                     loss_name="total_loss")
    network = fe.Network(ops=[
        ModelOp(inputs="image", model=model, outputs=["cls_pred", "loc_pred"]),
        RetinaLoss(inputs=("cls_gt", "x1_gt", "y1_gt", "w_gt", "h_gt",
                           "cls_pred", "loc_pred"),
                   outputs=("total_loss", "focal_loss", "l1_loss")),
        PredictBox(inputs=[
            "cls_pred", "loc_pred", "obj_label", "x1", "y1", "width", "height"
        ],
                   outputs=("pred", "gt"),
                   mode="eval",
                   input_shape=(1024, 1024, 3))
    ])
    # prepare estimator
    estimator = fe.Estimator(
        network=network,
        pipeline=pipeline,
        epochs=7,
        traces=[
            MeanAvgPrecision(90, (1024, 1024, 3),
                             'pred',
                             'gt',
                             output_name=("mAP", "AP50", "AP75")),
            ModelSaver(model_name="retinanet",
                       save_dir=model_dir,
                       save_best='mAP',
                       save_best_mode='max'),
            LRController(model_name="retinanet",
                         lr_schedule=MyLRSchedule(schedule_mode="step"))
        ])
    return estimator