예제 #1
0
def main(args, train_set, class_num, pre_ckpt, model_def, depth_multiplier,
         is_augmenter, image_size, output_size, batch_size, rand_seed,
         max_nrof_epochs, init_learning_rate, learning_rate_decay_factor,
         obj_weight, noobj_weight, wh_weight, obj_thresh, iou_thresh,
         vaildation_split, log_dir, is_prune, initial_sparsity, final_sparsity,
         end_epoch, frequency):
    # Build path
    log_dir = (Path(log_dir) /
               datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
               )  # type: Path
    ckpt_weights = log_dir / 'yolo_weights.h5'
    ckpt = log_dir / 'yolo_model.h5'
    if not log_dir.exists():
        log_dir.mkdir(parents=True)
    write_arguments_to_file(args, str(log_dir / 'args.txt'))

    # Build utils
    h = Helper(f'data/{train_set}_img_ann.npy', class_num,
               f'data/{train_set}_anchor.npy',
               np.reshape(np.array(image_size), (-1, 2)),
               np.reshape(np.array(output_size), (-1, 2)), vaildation_split)
    h.set_dataset(batch_size, rand_seed, is_training=(is_augmenter == 'True'))

    # Build network
    network = eval(model_def)  # type :yolo_mobilev2
    yolo_model, yolo_model_warpper = network([image_size[0], image_size[1], 3],
                                             len(h.anchors[0]),
                                             class_num,
                                             alpha=depth_multiplier)

    if pre_ckpt != None and pre_ckpt != 'None' and pre_ckpt != '':
        if 'h5' in pre_ckpt:
            yolo_model_warpper.load_weights(str(pre_ckpt))
            print(INFO, f' Load CKPT {str(pre_ckpt)}')
        else:
            print(ERROR, ' Pre CKPT path is unvalid')

    # prune model
    pruning_params = {
        'pruning_schedule':
        sparsity.PolynomialDecay(initial_sparsity=initial_sparsity,
                                 final_sparsity=final_sparsity,
                                 begin_step=0,
                                 end_step=h.train_epoch_step * end_epoch,
                                 frequency=frequency)
    }

    if is_prune == 'True':
        train_model = sparsity.prune_low_magnitude(yolo_model_warpper,
                                                   **pruning_params)
    else:
        train_model = yolo_model_warpper

    train_model.compile(
        keras.optimizers.Adam(lr=init_learning_rate,
                              decay=learning_rate_decay_factor),
        loss=[
            create_loss_fn(h, obj_thresh, iou_thresh, obj_weight, noobj_weight,
                           wh_weight, layer)
            for layer in range(
                len(train_model.output) if isinstance(train_model.output, list
                                                      ) else 1)
        ],
        metrics=[
            Yolo_Precision(obj_thresh, name='p'),
            Yolo_Recall(obj_thresh, name='r')
        ])
    """ NOTE fix the dataset output shape """
    shapes = (train_model.input.shape, tuple(h.output_shapes))
    h.train_dataset = h.train_dataset.apply(assert_element_shape(shapes))
    h.test_dataset = h.test_dataset.apply(assert_element_shape(shapes))
    """ Callbacks """
    if is_prune == 'True':
        cbs = [
            sparsity.UpdatePruningStep(),
            sparsity.PruningSummaries(log_dir=str(log_dir), profile_batch=0)
        ]
    else:
        cbs = [TensorBoard(str(log_dir), update_freq='batch', profile_batch=3)]

    # Training
    try:
        train_model.fit(h.train_dataset,
                        epochs=max_nrof_epochs,
                        steps_per_epoch=h.train_epoch_step,
                        callbacks=cbs,
                        validation_data=h.test_dataset,
                        validation_steps=int(h.test_epoch_step *
                                             h.validation_split))
    except KeyboardInterrupt as e:
        pass

    if is_prune == 'True':
        final_model = sparsity.strip_pruning(train_model)
        prune_ckpt = log_dir / 'yolo_prune_model.h5'
        keras.models.save_model(yolo_model,
                                str(prune_ckpt),
                                include_optimizer=False)
        print()
        print(INFO, f' Save Pruned Model as {str(prune_ckpt)}')
    else:
        keras.models.save_model(yolo_model, str(ckpt))
        print()
        print(INFO, f' Save Model as {str(ckpt)}')
def train(FLAGS):
    """Train yolov3 with different backbone
    """
    prune = FLAGS['prune']
    opt = FLAGS['opt']
    backbone = FLAGS['backbone']
    log_dir = FLAGS['log_directory'] or os.path.join(
        'logs',
        str(backbone).split('.')[1].lower() + str(datetime.date.today()))
    if tf.io.gfile.exists(log_dir) is not True:
        tf.io.gfile.mkdir(log_dir)
    batch_size = FLAGS['batch_size']
    train_dataset_glob = FLAGS['train_dataset']
    val_dataset_glob = FLAGS['val_dataset']
    test_dataset_glob = FLAGS['test_dataset']
    freeze = FLAGS['freeze']
    freeze_step = FLAGS['epochs'][0]
    train_step = FLAGS['epochs'][1]

    if opt == OPT.DEBUG:
        tf.config.experimental_run_functions_eagerly(True)
        tf.debugging.set_log_device_placement(True)
        tf.get_logger().setLevel(tf.logging.DEBUG)
    elif opt == OPT.XLA:
        config = tf.ConfigProto()
        config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
        sess = tf.Session(config=config)
        tf.keras.backend.set_session(sess)

    class_names = get_classes(FLAGS['classes_path'])
    num_classes = len(class_names)
    anchors = get_anchors(FLAGS['anchors_path'])
    input_shape = FLAGS['input_size']  # multiple of 32, hw
    model_path = FLAGS['model']
    if model_path and model_path.endswith('.h5') is not True:
        model_path = tf.train.latest_checkpoint(model_path)
    lr = FLAGS['learning_rate']
    tpu_address = FLAGS['tpu_address']
    if tpu_address is not None:
        cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=tpu_address)
        tf.config.experimental_connect_to_host(cluster_resolver.master())
        tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    else:
        strategy = tf.distribute.MirroredStrategy(devices=FLAGS['gpus'])

    batch_size = batch_size * strategy.num_replicas_in_sync

    train_dataset_builder = Dataset(train_dataset_glob, batch_size, anchors,
                                    num_classes, input_shape)
    train_dataset, train_num = train_dataset_builder.build()
    val_dataset_builder = Dataset(val_dataset_glob,
                                  batch_size,
                                  anchors,
                                  num_classes,
                                  input_shape,
                                  mode=DATASET_MODE.VALIDATE)
    val_dataset, val_num = val_dataset_builder.build()
    map_callback = MAPCallback(test_dataset_glob, input_shape, anchors,
                               class_names)
    logging = tf.keras.callbacks.TensorBoard(write_graph=False,
                                             log_dir=log_dir,
                                             write_images=True)
    checkpoint = tf.keras.callbacks.ModelCheckpoint(os.path.join(
        log_dir, 'ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5'),
                                                    monitor='val_loss',
                                                    save_weights_only=False,
                                                    save_best_only=False,
                                                    period=1)
    cos_lr = tf.keras.callbacks.LearningRateScheduler(
        lambda epoch, _: tf.keras.experimental.CosineDecay(lr[1], train_step)
        (epoch - freeze_step).numpy(), 1)
    early_stopping = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        min_delta=0,
        patience=(freeze_step + train_step) // 10,
        verbose=0)
    if tf.version.VERSION.startswith('1.'):
        loss = [
            lambda y_true, yolo_output: YoloLoss(
                y_true, yolo_output, 0, anchors, print_loss=True)
        ]
    else:
        loss = [
            YoloLoss(idx, anchors, print_loss=False)
            for idx in range(len(anchors) // 3)
        ]

    with strategy.scope():
        #factory = ModelFactory(tf.keras.layers.Input(shape=(*input_shape, 3)),
        #                       weights_path=model_path)
        factory = ModelFactory(tf.keras.layers.Input(shape=(*input_shape, 3)))
        if backbone == BACKBONE.MOBILENETV2:
            model = factory.build(mobilenetv2_yolo_body,
                                  20,
                                  len(anchors) // 1,
                                  num_classes,
                                  alpha=1.0)
        elif backbone == BACKBONE.DARKNET53:
            model = factory.build(darknet_yolo_body, 185,
                                  len(anchors) // 3, num_classes)
        elif backbone == BACKBONE.EFFICIENTNET:
            FLAGS['model_name'] = 'efficientnet-b4'
            model = factory.build(
                efficientnet_yolo_body,
                20,  # todo
                FLAGS['model_name'],
                len(anchors) // 2,
                batch_norm_momentum=0.9,
                batch_norm_epsilon=1e-3,
                num_classes=num_classes,
                drop_connect_rate=0.2,
                data_format="channels_first")

    if prune:
        from tensorflow_model_optimization.python.core.api.sparsity import keras as sparsity
        end_step = np.ceil(1.0 * train_num / batch_size).astype(
            np.int32) * train_step
        new_pruning_params = {
            'pruning_schedule':
            sparsity.PolynomialDecay(initial_sparsity=0.5,
                                     final_sparsity=0.9,
                                     begin_step=0,
                                     end_step=end_step,
                                     frequency=1000)
        }
        pruned_model = sparsity.prune_low_magnitude(model,
                                                    **new_pruning_params)
        pruned_model.compile(optimizer=tf.keras.optimizers.Adam(lr[0],
                                                                epsilon=1e-8),
                             loss=loss)
        pruned_model.fit(train_dataset,
                         epochs=train_step,
                         initial_epoch=0,
                         steps_per_epoch=max(1, train_num // batch_size),
                         callbacks=[
                             checkpoint, cos_lr, logging, map_callback,
                             early_stopping
                         ],
                         validation_data=val_dataset,
                         validation_steps=max(1, val_num // batch_size))
        model = sparsity.strip_pruning(pruned_model)
        model.save_weights(
            os.path.join(
                log_dir,
                str(backbone).split('.')[1].lower() +
                '_trained_weights_pruned.h5'))
        with zipfile.ZipFile(os.path.join(
                log_dir,
                str(backbone).split('.')[1].lower() +
                '_trained_weights_pruned.h5.zip'),
                             'w',
                             compression=zipfile.ZIP_DEFLATED) as f:
            f.write(
                os.path.join(
                    log_dir,
                    str(backbone).split('.')[1].lower() +
                    '_trained_weights_pruned.h5'))
        return

    # Train with frozen layers first, to get a stable loss.
    # Adjust num epochs to your dataset. This step is enough to obtain a not bad model.
    if freeze is True:
        with strategy.scope():
            model.compile(optimizer=tf.keras.optimizers.Adam(lr[0],
                                                             epsilon=1e-8),
                          loss=loss)
        model.fit(train_dataset,
                  epochs=freeze_step,
                  initial_epoch=0,
                  steps_per_epoch=max(1, train_num // batch_size),
                  callbacks=[logging, checkpoint],
                  validation_data=val_dataset,
                  validation_steps=max(1, val_num // batch_size))
        model.save_weights(
            os.path.join(
                log_dir,
                str(backbone).split('.')[1].lower() +
                '_trained_weights_stage_1.h5'))
    # Unfreeze and continue training, to fine-tune.
    # Train longer if the result is not good.
    else:
        #if 1:
        for i in range(len(model.layers)):
            model.layers[i].trainable = True
        with strategy.scope():
            model.compile(optimizer=tf.keras.optimizers.Adam(lr[1],
                                                             epsilon=1e-8),
                          loss=loss)  # recompile to apply the change
        print('Unfreeze all of the layers.')
        model.fit(
            train_dataset,
            epochs=train_step + freeze_step,
            initial_epoch=freeze_step,
            steps_per_epoch=max(1, train_num // batch_size),
            callbacks=[
                checkpoint,
                cos_lr,
                logging,
                early_stopping  #map_callback
            ],
            validation_data=val_dataset,
            validation_steps=max(1, val_num // batch_size))
        model.save_weights(
            os.path.join(
                log_dir,
                str(backbone).split('.')[1].lower() +
                '_trained_weights_final.h5'))