コード例 #1
0
ファイル: train.py プロジェクト: steven0129/TF-yolact
 def __init__(self, model, optimizer):
     self.model = model
     self.optimizer = optimizer
     self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
     self.valid_loss = tf.keras.metrics.Mean('valid_loss', dtype=tf.float32)
     self.criterion = loss_yolact.YOLACTLoss()
     self.loc = tf.keras.metrics.Mean('loc_loss', dtype=tf.float32)
     self.conf = tf.keras.metrics.Mean('conf_loss', dtype=tf.float32)
     self.mask = tf.keras.metrics.Mean('mask_loss', dtype=tf.float32)
     self.seg = tf.keras.metrics.Mean('seg_loss', dtype=tf.float32)
     self.v_loc = tf.keras.metrics.Mean('vloc_loss', dtype=tf.float32)
     self.v_conf = tf.keras.metrics.Mean('vconf_loss', dtype=tf.float32)
     self.v_mask = tf.keras.metrics.Mean('vmask_loss', dtype=tf.float32)
     self.v_seg = tf.keras.metrics.Mean('vseg_loss', dtype=tf.float32)
コード例 #2
0
def main(argv):
    # set up Grappler for graph optimization
    # Ref: https://www.tensorflow.org/guide/graph_optimization
    @contextlib.contextmanager
    def options(options):
        old_opts = tf.config.optimizer.get_experimental_options()
        tf.config.optimizer.set_experimental_options(options)
        try:
            yield
        finally:
            tf.config.optimizer.set_experimental_options(old_opts)

    # -----------------------------------------------------------------
    # Creating dataloaders for training and validation
    logging.info("Creating the dataloader from: %s..." % FLAGS.tfrecord_dir)
    train_dataset = dataset_coco.prepare_dataloader(
        tfrecord_dir=FLAGS.tfrecord_dir,
        batch_size=FLAGS.batch_size,
        subset='train')

    valid_dataset = dataset_coco.prepare_dataloader(
        tfrecord_dir=FLAGS.tfrecord_dir, batch_size=1, subset='val')
    # -----------------------------------------------------------------
    # Creating the instance of the model specified.
    logging.info("Creating the model instance of YOLACT")
    model = yolact.Yolact(input_size=550,
                          fpn_channels=256,
                          feature_map_size=[69, 35, 18, 9, 5],
                          num_class=91,
                          num_mask=32,
                          aspect_ratio=[1, 0.5, 2],
                          scales=[24, 48, 96, 192, 384])

    # add weight decay
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Conv2D) or isinstance(
                layer, tf.keras.layers.Dense):
            layer.add_loss(lambda: tf.keras.regularizers.l2(FLAGS.weight_decay)
                           (layer.kernel))
        if hasattr(layer, 'bias_regularizer') and layer.use_bias:
            layer.add_loss(lambda: tf.keras.regularizers.l2(FLAGS.weight_decay)
                           (layer.bias))

    # -----------------------------------------------------------------
    # Choose the Optimizor, Loss Function, and Metrics, learning rate schedule
    lr_schedule = learning_rate_schedule.Yolact_LearningRateSchedule(
        warmup_steps=500, warmup_lr=1e-4, initial_lr=FLAGS.lr)
    logging.info("Initiate the Optimizer and Loss function...")
    optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule,
                                        momentum=FLAGS.momentum)
    criterion = loss_yolact.YOLACTLoss()
    train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
    valid_loss = tf.keras.metrics.Mean('valid_loss', dtype=tf.float32)
    loc = tf.keras.metrics.Mean('loc_loss', dtype=tf.float32)
    conf = tf.keras.metrics.Mean('conf_loss', dtype=tf.float32)
    mask = tf.keras.metrics.Mean('mask_loss', dtype=tf.float32)
    seg = tf.keras.metrics.Mean('seg_loss', dtype=tf.float32)
    v_loc = tf.keras.metrics.Mean('vloc_loss', dtype=tf.float32)
    v_conf = tf.keras.metrics.Mean('vconf_loss', dtype=tf.float32)
    v_mask = tf.keras.metrics.Mean('vmask_loss', dtype=tf.float32)
    v_seg = tf.keras.metrics.Mean('vseg_loss', dtype=tf.float32)

    # -----------------------------------------------------------------

    # Setup the TensorBoard for better visualization
    # Ref: https://www.tensorflow.org/tensorboard/get_started
    logging.info("Setup the TensorBoard...")
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    train_log_dir = './logs/gradient_tape/' + current_time + '/train'
    test_log_dir = './logs/gradient_tape/' + current_time + '/test'
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)
    test_summary_writer = tf.summary.create_file_writer(test_log_dir)

    # -----------------------------------------------------------------
    # Start the Training and Validation Process
    logging.info("Start the training process...")

    # setup checkpoints manager
    checkpoint = tf.train.Checkpoint(step=tf.Variable(1),
                                     optimizer=optimizer,
                                     model=model)
    manager = tf.train.CheckpointManager(checkpoint,
                                         directory="./checkpoints",
                                         max_to_keep=5)
    # restore from latest checkpoint and iteration
    status = checkpoint.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        logging.info("Restored from {}".format(manager.latest_checkpoint))
    else:
        logging.info("Initializing from scratch.")

    best_val = 1e10
    iterations = checkpoint.step.numpy()

    for image, labels in train_dataset:
        # check iteration and change the learning rate
        if iterations > FLAGS.train_iter:
            break

        checkpoint.step.assign_add(1)
        iterations += 1
        with options({
                'constant_folding': True,
                'layout_optimize': True,
                'loop_optimization': True,
                'arithmetic_optimization': True,
                'remapping': True
        }):
            loc_loss, conf_loss, mask_loss, seg_loss = train_step(
                model, criterion, train_loss, optimizer, image, labels)
        loc.update_state(loc_loss)
        conf.update_state(conf_loss)
        mask.update_state(mask_loss)
        seg.update_state(seg_loss)
        with train_summary_writer.as_default():
            tf.summary.scalar('Total loss',
                              train_loss.result(),
                              step=iterations)
            tf.summary.scalar('Loc loss', loc.result(), step=iterations)
            tf.summary.scalar('Conf loss', conf.result(), step=iterations)
            tf.summary.scalar('Mask loss', mask.result(), step=iterations)
            tf.summary.scalar('Seg loss', seg.result(), step=iterations)

        if iterations and iterations % FLAGS.print_interval == 0:
            logging.info(
                "Iteration {}, LR: {}, Total Loss: {}, B: {},  C: {}, M: {}, S:{} "
                .format(iterations,
                        optimizer._decayed_lr(var_dtype=tf.float32),
                        train_loss.result(), loc.result(), conf.result(),
                        mask.result(), seg.result()))

        if iterations and iterations % FLAGS.save_interval == 0:
            # save checkpoint
            save_path = manager.save()
            logging.info("Saved checkpoint for step {}: {}".format(
                int(checkpoint.step), save_path))
            # validation
            valid_iter = 0
            for valid_image, valid_labels in valid_dataset:
                if valid_iter > FLAGS.valid_iter:
                    break
                # calculate validation loss
                with options({
                        'constant_folding': True,
                        'layout_optimize': True,
                        'loop_optimization': True,
                        'arithmetic_optimization': True,
                        'remapping': True
                }):
                    valid_loc_loss, valid_conf_loss, valid_mask_loss, valid_seg_loss = valid_step(
                        model, criterion, valid_loss, valid_image,
                        valid_labels)
                v_loc.update_state(valid_loc_loss)
                v_conf.update_state(valid_conf_loss)
                v_mask.update_state(valid_mask_loss)
                v_seg.update_state(valid_seg_loss)
                valid_iter += 1

            with test_summary_writer.as_default():
                tf.summary.scalar('V Total loss',
                                  valid_loss.result(),
                                  step=iterations)
                tf.summary.scalar('V Loc loss',
                                  v_loc.result(),
                                  step=iterations)
                tf.summary.scalar('V Conf loss',
                                  v_conf.result(),
                                  step=iterations)
                tf.summary.scalar('V Mask loss',
                                  v_mask.result(),
                                  step=iterations)
                tf.summary.scalar('V Seg loss',
                                  v_seg.result(),
                                  step=iterations)

            train_template = 'Iteration {}, Train Loss: {}, Loc Loss: {},  Conf Loss: {}, Mask Loss: {}, Seg Loss: {}'
            valid_template = 'Iteration {}, Valid Loss: {}, V Loc Loss: {},  V Conf Loss: {}, V Mask Loss: {}, Seg Loss: {}'
            logging.info(
                train_template.format(iterations + 1, train_loss.result(),
                                      loc.result(), conf.result(),
                                      mask.result(), seg.result()))
            logging.info(
                valid_template.format(iterations + 1, valid_loss.result(),
                                      v_loc.result(), v_conf.result(),
                                      v_mask.result(), v_seg.result()))
            if valid_loss.result() < best_val:
                # Saving the weights:
                best_val = valid_loss.result()
                model.save_weights('./weights/weights_' +
                                   str(valid_loss.result().numpy()) + '.h5')

            # reset the metrics
            train_loss.reset_states()
            loc.reset_states()
            conf.reset_states()
            mask.reset_states()
            seg.reset_states()

            valid_loss.reset_states()
            v_loc.reset_states()
            v_conf.reset_states()
            v_mask.reset_states()
            v_seg.reset_states()
コード例 #3
0
ファイル: train.py プロジェクト: pjvazquez/yolact-1
def main(argv):
    # set up Grappler for graph optimization
    # Ref: https://www.tensorflow.org/guide/graph_optimization
    @contextlib.contextmanager
    def options(options):
        old_opts = tf.config.optimizer.get_experimental_options()
        tf.config.optimizer.set_experimental_options(options)
        try:
            yield
        finally:
            tf.config.optimizer.set_experimental_options(old_opts)

    # -----------------------------------------------------------------
    # Creating the instance of the model specified.
    logging.info("Creating the model instance of YOLACT")
    model = yolact.Yolact(
        img_h=FLAGS.img_h,
        img_w=FLAGS.img_w,
        fpn_channels=256,
        num_class=FLAGS.num_class + 1,  # adding background class
        num_mask=64,
        aspect_ratio=[float(i) for i in FLAGS.aspect_ratio],
        scales=[int(i) for i in FLAGS.scale],
        use_dcn=FLAGS.use_dcn)
    if FLAGS.model_quantization:
        logging.info("Quantization aware training")
        quantize_model = tfmot.quantization.keras.quantize_model
        model = quantize_model(model)
    # -----------------------------------------------------------------
    # Creating dataloaders for training and validation
    logging.info("Creating the dataloader from: %s..." % FLAGS.tfrecord_dir)
    train_dataset = dataset_coco.prepare_dataloader(
        img_h=FLAGS.img_h,
        img_w=FLAGS.img_w,
        feature_map_size=model.feature_map_size,
        protonet_out_size=model.protonet_out_size,
        aspect_ratio=[float(i) for i in FLAGS.aspect_ratio],
        scale=[int(i) for i in FLAGS.scale],
        tfrecord_dir=FLAGS.tfrecord_dir,
        batch_size=FLAGS.batch_size,
        subset='train')

    valid_dataset = dataset_coco.prepare_dataloader(
        img_h=FLAGS.img_h,
        img_w=FLAGS.img_w,
        feature_map_size=model.feature_map_size,
        protonet_out_size=model.protonet_out_size,
        aspect_ratio=[float(i) for i in FLAGS.aspect_ratio],
        scale=[int(i) for i in FLAGS.scale],
        tfrecord_dir=FLAGS.tfrecord_dir,
        batch_size=1,
        subset='val')

    # add weight decay
    def add_weight_decay(model, weight_decay):
        # https://github.com/keras-team/keras/issues/12053
        if (weight_decay is None) or (weight_decay == 0.0):
            return

        # recursion inside the model
        def add_decay_loss(m, factor):
            if isinstance(m, tf.keras.Model):
                for layer in m.layers:
                    add_decay_loss(layer, factor)
            else:
                for param in m.trainable_weights:
                    with tf.keras.backend.name_scope('weight_regularizer'):
                        regularizer = lambda: tf.keras.regularizers.l2(factor)(
                            param)
                        m.add_loss(regularizer)

        # weight decay and l2 regularization differs by a factor of 2
        add_decay_loss(model, weight_decay / 2.0)
        return

    add_weight_decay(model, FLAGS.weight_decay)

    # -----------------------------------------------------------------
    # Choose the Optimizor, Loss Function, and Metrics, learning rate schedule
    lr_schedule = learning_rate_schedule.Yolact_LearningRateSchedule(
        warmup_steps=5000,
        warmup_lr=1e-4,
        initial_lr=FLAGS.lr,
        total_steps=FLAGS.total_steps)
    logging.info("Initiate the Optimizer and Loss function...")
    # optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=FLAGS.momentum)
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
    criterion = loss_yolact.YOLACTLoss()
    train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
    valid_loss = tf.keras.metrics.Mean('valid_loss', dtype=tf.float32)
    loc = tf.keras.metrics.Mean('loc_loss', dtype=tf.float32)
    conf = tf.keras.metrics.Mean('conf_loss', dtype=tf.float32)
    mask = tf.keras.metrics.Mean('mask_loss', dtype=tf.float32)
    seg = tf.keras.metrics.Mean('seg_loss', dtype=tf.float32)
    v_loc = tf.keras.metrics.Mean('vloc_loss', dtype=tf.float32)
    v_conf = tf.keras.metrics.Mean('vconf_loss', dtype=tf.float32)
    v_mask = tf.keras.metrics.Mean('vmask_loss', dtype=tf.float32)
    v_seg = tf.keras.metrics.Mean('vseg_loss', dtype=tf.float32)

    # -----------------------------------------------------------------

    # Setup the TensorBoard for better visualization
    # Ref: https://www.tensorflow.org/tensorboard/get_started
    logging.info("Setup the TensorBoard...")
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    train_log_dir = './logs/gradient_tape/' + current_time + '/train'
    test_log_dir = './logs/gradient_tape/' + current_time + '/test'
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)
    test_summary_writer = tf.summary.create_file_writer(test_log_dir)

    # -----------------------------------------------------------------
    # Start the Training and Validation Process
    logging.info("Start the training process...")

    # setup checkpoints manager
    checkpoint = tf.train.Checkpoint(step=tf.Variable(1),
                                     optimizer=optimizer,
                                     model=model)
    manager = tf.train.CheckpointManager(checkpoint,
                                         directory="./checkpoints",
                                         max_to_keep=5)
    # restore from latest checkpoint and iteration
    status = checkpoint.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        logging.info("Restored from {}".format(manager.latest_checkpoint))
    else:
        logging.info("Initializing from scratch.")

    # COCO evalator for showing MAP
    coco_evaluator = coco_evaluation.CocoMaskEvaluator(
        _get_categories_list(FLAGS.label_map))

    best_val = 1e10
    iterations = checkpoint.step.numpy()

    for image, labels in train_dataset:
        # check iteration and change the learning rate
        if iterations > FLAGS.train_iter:
            break

        checkpoint.step.assign_add(1)
        iterations += 1
        with options({
                'constant_folding': True,
                'layout_optimize': True,
                'loop_optimization': True,
                'arithmetic_optimization': True,
                'remapping': True
        }):
            with tf.GradientTape() as tape:
                output = model(image, training=True)
                loc_loss, conf_loss, mask_loss, seg_loss, total_loss = criterion(
                    output, labels, FLAGS.num_class + 1)
            grads = tape.gradient(total_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            train_loss.update_state(total_loss)

        loc.update_state(loc_loss)
        conf.update_state(conf_loss)
        mask.update_state(mask_loss)
        seg.update_state(seg_loss)
        with train_summary_writer.as_default():
            tf.summary.scalar('Total loss',
                              train_loss.result(),
                              step=iterations)
            tf.summary.scalar('Loc loss', loc.result(), step=iterations)
            tf.summary.scalar('Conf loss', conf.result(), step=iterations)
            tf.summary.scalar('Mask loss', mask.result(), step=iterations)
            tf.summary.scalar('Seg loss', seg.result(), step=iterations)

        if iterations and iterations % FLAGS.print_interval == 0:
            logging.info(
                "Iteration {}, LR: {}, Total Loss: {}, B: {},  C: {}, M: {}, S:{} "
                .format(iterations,
                        optimizer._decayed_lr(var_dtype=tf.float32),
                        train_loss.result(), loc.result(), conf.result(),
                        mask.result(), seg.result()))

        if iterations and iterations % FLAGS.save_interval == 0:
            # save checkpoint
            save_path = manager.save()
            logging.info("Saved checkpoint for step {}: {}".format(
                int(checkpoint.step), save_path))
            # validation
            valid_iter = 0
            for valid_image, valid_labels in valid_dataset:
                if valid_iter > FLAGS.valid_iter:
                    break
                # calculate validation loss
                with options({
                        'constant_folding': True,
                        'layout_optimize': True,
                        'loop_optimization': True,
                        'arithmetic_optimization': True,
                        'remapping': True
                }):
                    output = model(valid_image, training=False)
                    valid_loc_loss, valid_conf_loss, valid_mask_loss, valid_seg_loss, valid_total_loss = criterion(
                        output, valid_labels, FLAGS.num_class + 1)
                    valid_loss.update_state(valid_total_loss)

                    _h = valid_image.shape[1]
                    _w = valid_image.shape[2]

                    gt_num_box = valid_labels['num_obj'][0].numpy()
                    gt_boxes = valid_labels['boxes_norm'][0][:gt_num_box]
                    gt_boxes = gt_boxes.numpy() * np.array([_h, _w, _h, _w])
                    gt_classes = valid_labels['classes'][0][:gt_num_box].numpy(
                    )
                    gt_masks = valid_labels['mask_target'][
                        0][:gt_num_box].numpy()

                    gt_masked_image = np.zeros((gt_num_box, _h, _w))
                    for _b in range(gt_num_box):
                        _mask = gt_masks[_b].astype("uint8")
                        _mask = cv2.resize(_mask, (_w, _h))
                        gt_masked_image[_b] = _mask

                    coco_evaluator.add_single_ground_truth_image_info(
                        image_id='image' + str(valid_iter),
                        groundtruth_dict={
                            standard_fields.InputDataFields.groundtruth_boxes:
                            gt_boxes,
                            standard_fields.InputDataFields.groundtruth_classes:
                            gt_classes,
                            standard_fields.InputDataFields.groundtruth_instance_masks:
                            gt_masked_image
                        })

                    det_num = output['num_detections'][0].numpy()
                    det_boxes = output['detection_boxes'][0][:det_num]
                    det_boxes = det_boxes.numpy() * np.array([_h, _w, _h, _w])
                    det_masks = output['detection_masks'][0][:det_num].numpy()
                    det_masks = (det_masks > 0.5)

                    det_scores = output['detection_scores'][0][:det_num].numpy(
                    )
                    det_classes = output['detection_classes'][
                        0][:det_num].numpy()

                    det_masked_image = np.zeros((det_num, _h, _w))
                    for _b in range(det_num):
                        _mask = det_masks[_b].astype("uint8")
                        _mask = cv2.resize(_mask, (_w, _h))
                        det_masked_image[_b] = _mask

                    coco_evaluator.add_single_detected_image_info(
                        image_id='image' + str(valid_iter),
                        detections_dict={
                            standard_fields.DetectionResultFields.detection_boxes:
                            det_boxes,
                            standard_fields.DetectionResultFields.detection_scores:
                            det_scores,
                            standard_fields.DetectionResultFields.detection_classes:
                            det_classes,
                            standard_fields.DetectionResultFields.detection_masks:
                            det_masked_image
                        })

                v_loc.update_state(valid_loc_loss)
                v_conf.update_state(valid_conf_loss)
                v_mask.update_state(valid_mask_loss)
                v_seg.update_state(valid_seg_loss)
                valid_iter += 1

            metrics = coco_evaluator.evaluate()
            coco_evaluator.clear()

            with test_summary_writer.as_default():
                tf.summary.scalar('V Total loss',
                                  valid_loss.result(),
                                  step=iterations)
                tf.summary.scalar('V Loc loss',
                                  v_loc.result(),
                                  step=iterations)
                tf.summary.scalar('V Conf loss',
                                  v_conf.result(),
                                  step=iterations)
                tf.summary.scalar('V Mask loss',
                                  v_mask.result(),
                                  step=iterations)
                tf.summary.scalar('V Seg loss',
                                  v_seg.result(),
                                  step=iterations)

            train_template = 'Iteration {}, Train Loss: {}, Loc Loss: {},  Conf Loss: {}, Mask Loss: {}, Seg Loss: {}'
            valid_template = 'Iteration {}, Valid Loss: {}, V Loc Loss: {},  V Conf Loss: {}, V Mask Loss: {}, Seg Loss: {}'
            logging.info(
                train_template.format(iterations + 1, train_loss.result(),
                                      loc.result(), conf.result(),
                                      mask.result(), seg.result()))
            logging.info(
                valid_template.format(iterations + 1, valid_loss.result(),
                                      v_loc.result(), v_conf.result(),
                                      v_mask.result(), v_seg.result()))
            if valid_loss.result() < best_val:
                best_val = valid_loss.result()
                if FLAGS.tflite_export:
                    detection_module = YOLACTModule(model, True)
                    # Getting the concrete function traces the graph and forces variables to
                    # be constructed; only after this can we save the saved model.
                    concrete_function = detection_module.inference_fn.get_concrete_function(
                        tf.TensorSpec(shape=[
                            FLAGS.batch_size, FLAGS.img_h, FLAGS.img_w, 3
                        ],
                                      dtype=tf.float32,
                                      name='input'))

                    # Export SavedModel.
                    tf.saved_model.save(detection_module,
                                        './saved_models/saved_model_' +
                                        str(valid_loss.result().numpy()),
                                        signatures=concrete_function)
                else:
                    model.save('./saved_models/saved_model_' +
                               str(valid_loss.result().numpy()))

            # reset the metrics
            train_loss.reset_states()
            loc.reset_states()
            conf.reset_states()
            mask.reset_states()
            seg.reset_states()

            valid_loss.reset_states()
            v_loc.reset_states()
            v_conf.reset_states()
            v_mask.reset_states()
            v_seg.reset_states()
コード例 #4
0
def main(argv):
    # set fixed random seed, load config files
    tf.random.set_seed(RANDOM_SEED)

    # using mix precision or not
    if MIXPRECISION:
        policy = mixed_precision.Policy('mixed_float16')
        mixed_precision.set_policy(policy)

    # get params for model
    train_iter, input_size, num_cls, lrs_schedule_params, loss_params, parser_params, model_params = get_params(
        FLAGS.name)

    # -----------------------------------------------------------------
    # set up Grappler for graph optimization
    # Ref: https://www.tensorflow.org/guide/graph_optimization
    @contextlib.contextmanager
    def options(opts):
        old_opts = tf.config.optimizer.get_experimental_options()
        tf.config.optimizer.set_experimental_options(opts)
        try:
            yield
        finally:
            tf.config.optimizer.set_experimental_options(old_opts)

    # -----------------------------------------------------------------
    # Creating the instance of the model specified.
    logging.info("Creating the model instance of YOLACT")
    model = Yolact(**model_params)

    # add weight decay
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Conv2D) or isinstance(
                layer, tf.keras.layers.Dense):
            layer.add_loss(lambda: tf.keras.regularizers.l2(FLAGS.weight_decay)
                           (layer.kernel))
        if hasattr(layer, 'bias_regularizer') and layer.use_bias:
            layer.add_loss(lambda: tf.keras.regularizers.l2(FLAGS.weight_decay)
                           (layer.bias))

    # -----------------------------------------------------------------
    # Creating dataloaders for training and validation
    logging.info("Creating the dataloader from: %s..." % FLAGS.tfrecord_dir)
    dateset = ObjectDetectionDataset(dataset_name=FLAGS.name,
                                     tfrecord_dir=os.path.join(
                                         FLAGS.tfrecord_dir, FLAGS.name),
                                     anchor_instance=model.anchor_instance,
                                     **parser_params)
    train_dataset = dateset.get_dataloader(subset='train',
                                           batch_size=FLAGS.batch_size)
    valid_dataset = dateset.get_dataloader(subset='val', batch_size=1)
    # count number of valid data for progress bar
    # Todo any better way to do it?
    num_val = 0
    for _ in valid_dataset:
        num_val += 1
    # -----------------------------------------------------------------
    # Choose the Optimizor, Loss Function, and Metrics, learning rate schedule
    lr_schedule = learning_rate_schedule.Yolact_LearningRateSchedule(
        **lrs_schedule_params)
    logging.info("Initiate the Optimizer and Loss function...")
    optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule,
                                        momentum=FLAGS.momentum)
    criterion = loss_yolact.YOLACTLoss(**loss_params)
    train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
    loc = tf.keras.metrics.Mean('loc_loss', dtype=tf.float32)
    conf = tf.keras.metrics.Mean('conf_loss', dtype=tf.float32)
    mask = tf.keras.metrics.Mean('mask_loss', dtype=tf.float32)
    seg = tf.keras.metrics.Mean('seg_loss', dtype=tf.float32)
    # -----------------------------------------------------------------

    # Setup the TensorBoard for better visualization
    # Ref: https://www.tensorflow.org/tensorboard/get_started
    logging.info("Setup the TensorBoard...")
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    train_log_dir = './logs/gradient_tape/' + current_time + '/train'
    test_log_dir = './logs/gradient_tape/' + current_time + '/test'
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)
    test_summary_writer = tf.summary.create_file_writer(test_log_dir)

    # -----------------------------------------------------------------
    # Start the Training and Validation Process
    logging.info("Start the training process...")

    # setup checkpoints manager
    checkpoint = tf.train.Checkpoint(step=tf.Variable(1),
                                     optimizer=optimizer,
                                     model=model)
    manager = tf.train.CheckpointManager(checkpoint,
                                         directory="./checkpoints",
                                         max_to_keep=5)
    # restore from latest checkpoint and iteration
    status = checkpoint.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        logging.info("Restored from {}".format(manager.latest_checkpoint))
    else:
        logging.info("Initializing from scratch.")

    best_masks_map = 0.
    iterations = checkpoint.step.numpy()

    for image, labels in train_dataset:
        # check iteration and change the learning rate
        if iterations > train_iter:
            break

        checkpoint.step.assign_add(1)
        iterations += 1
        with options({
                'constant_folding': True,
                'layout_optimize': True,
                'loop_optimization': True,
                'arithmetic_optimization': True,
                'remapping': True
        }):
            loc_loss, conf_loss, mask_loss, seg_loss = train_step(
                model, criterion, train_loss, optimizer, image, labels,
                num_cls)
        loc.update_state(loc_loss)
        conf.update_state(conf_loss)
        mask.update_state(mask_loss)
        seg.update_state(seg_loss)
        with train_summary_writer.as_default():
            tf.summary.scalar('Total loss',
                              train_loss.result(),
                              step=iterations)
            tf.summary.scalar('Loc loss', loc.result(), step=iterations)
            tf.summary.scalar('Conf loss', conf.result(), step=iterations)
            tf.summary.scalar('Mask loss', mask.result(), step=iterations)
            tf.summary.scalar('Seg loss', seg.result(), step=iterations)

        if iterations and iterations % FLAGS.print_interval == 0:
            tf.print(
                "Iteration {}, LR: {}, Total Loss: {}, B: {},  C: {}, M: {}, S:{} "
                .format(iterations,
                        optimizer._decayed_lr(var_dtype=tf.float32),
                        train_loss.result(), loc.result(), conf.result(),
                        mask.result(), seg.result()))

        if iterations and iterations % FLAGS.save_interval == 0:
            # save checkpoint
            save_path = manager.save()
            logging.info("Saved checkpoint for step {}: {}".format(
                int(checkpoint.step), save_path))

            # validation and print mAP table
            all_map = evaluate(model, valid_dataset, num_val, num_cls)
            box_map, mask_map = all_map['box']['all'], all_map['mask']['all']
            tf.print(f"box mAP:{box_map}, mask mAP:{mask_map}")

            with test_summary_writer.as_default():
                tf.summary.scalar('Box mAP', box_map, step=iterations)
                tf.summary.scalar('Mask mAP', mask_map, step=iterations)

            # Saving the weights:
            if mask_map > best_masks_map:
                best_masks_map = mask_map
                model.save_weights(
                    f'{FLAGS.weights}/weights_{FLAGS.name}_{str(best_masks_map)}.h5'
                )

            # reset the metrics
            train_loss.reset_states()
            loc.reset_states()
            conf.reset_states()
            mask.reset_states()
            seg.reset_states()