def get_estimator(style_img_path=None, data_path=None, style_weight=5.0, content_weight=1.0, tv_weight=1e-4, steps_per_epoch=None, validation_steps=None, model_dir=tempfile.mkdtemp()): train_csv, _, path = load_data(data_path, load_object=False) if style_img_path is None: style_img_path = tf.keras.utils.get_file( 'kandinsky.jpg', 'https://storage.googleapis.com/download.tensorflow.org/example_images/Vassily_Kandinsky%2C_1913_' '-_Composition_7.jpg') style_img = cv2.imread(style_img_path) assert (style_img is not None), "Invalid style reference image" tfr_save_dir = os.path.join(path, 'tfrecords') style_img = (style_img.astype(np.float32) / 127.5) / 127.5 style_img_t = tf.convert_to_tensor(np.expand_dims(style_img, axis=0)) writer = RecordWriter(train_data=train_csv, save_dir=tfr_save_dir, ops=[ ImageReader(inputs="image", parent_path=path, outputs="image"), Resize(inputs="image", target_size=(256, 256), outputs="image") ]) pipeline = fe.Pipeline(batch_size=4, data=writer, ops=[Rescale(inputs="image", outputs="image")]) model = fe.build(model_def=styleTransferNet, model_name="style_transfer_net", loss_name="loss", optimizer=tf.keras.optimizers.Adam(1e-3)) network = fe.Network(ops=[ ModelOp(inputs="image", model=model, outputs="image_out"), ExtractVGGFeatures(inputs=lambda: style_img_t, outputs="y_style"), ExtractVGGFeatures(inputs="image", outputs="y_content"), ExtractVGGFeatures(inputs="image_out", outputs="y_pred"), StyleContentLoss(style_weight=style_weight, content_weight=content_weight, tv_weight=tv_weight, inputs=('y_pred', 'y_style', 'y_content', 'image_out'), outputs='loss') ]) estimator = fe.Estimator(network=network, pipeline=pipeline, epochs=2, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, traces=ModelSaver(model_name="style_transfer_net", save_dir=model_dir)) return estimator
def get_estimator(data_path=None, model_dir=tempfile.mkdtemp(), batch_size=2): #prepare dataset train_csv, val_csv, path = load_data(path=data_path) writer = fe.RecordWriter( save_dir=os.path.join(path, "retinanet_coco_1024"), train_data=train_csv, validation_data=val_csv, ops=[ ImageReader(inputs="image", parent_path=path, outputs="image"), String2List(inputs=["x1", "y1", "width", "height", "obj_label"], outputs=["x1", "y1", "width", "height", "obj_label"]), ResizeImageAndBbox( target_size=(1024, 1024), keep_ratio=True, inputs=["image", "x1", "y1", "width", "height"], outputs=["image", "x1", "y1", "width", "height"]), FlipImageAndBbox(inputs=[ "image", "x1", "y1", "width", "height", "obj_label", "id" ], outputs=[ "image", "x1", "y1", "width", "height", "obj_label", "id" ]), GenerateTarget(inputs=("obj_label", "x1", "y1", "width", "height"), outputs=("cls_gt", "x1_gt", "y1_gt", "w_gt", "h_gt")) ], expand_dims=True, compression="GZIP", write_feature=[ "image", "id", "cls_gt", "x1_gt", "y1_gt", "w_gt", "h_gt", "obj_label", "x1", "y1", "width", "height" ]) # prepare pipeline pipeline = fe.Pipeline(batch_size=batch_size, data=writer, ops=[ Rescale(inputs="image", outputs="image"), Pad(padded_shape=[2051], inputs=[ "x1_gt", "y1_gt", "w_gt", "h_gt", "obj_label", "x1", "y1", "width", "height" ], outputs=[ "x1_gt", "y1_gt", "w_gt", "h_gt", "obj_label", "x1", "y1", "width", "height" ]) ]) # prepare network model = fe.build(model_def=lambda: RetinaNet(input_shape=(1024, 1024, 3), num_classes=90), model_name="retinanet", optimizer=tf.optimizers.SGD(momentum=0.9), loss_name="total_loss") network = fe.Network(ops=[ ModelOp(inputs="image", model=model, outputs=["cls_pred", "loc_pred"]), RetinaLoss(inputs=("cls_gt", "x1_gt", "y1_gt", "w_gt", "h_gt", "cls_pred", "loc_pred"), outputs=("total_loss", "focal_loss", "l1_loss")), PredictBox(inputs=[ "cls_pred", "loc_pred", "obj_label", "x1", "y1", "width", "height" ], outputs=("pred", "gt"), mode="eval", input_shape=(1024, 1024, 3)) ]) # prepare estimator estimator = fe.Estimator( network=network, pipeline=pipeline, epochs=7, traces=[ MeanAvgPrecision(90, (1024, 1024, 3), 'pred', 'gt', output_name=("mAP", "AP50", "AP75")), ModelSaver(model_name="retinanet", save_dir=model_dir, save_best='mAP', save_best_mode='max'), LRController(model_name="retinanet", lr_schedule=MyLRSchedule(schedule_mode="step")) ]) return estimator