def main(args, train_set, class_num, pre_ckpt, model_def, depth_multiplier,
         is_augmenter, image_size, output_size, batch_size, rand_seed,
         max_nrof_epochs, init_learning_rate, learning_rate_decay_factor,
         obj_weight, noobj_weight, wh_weight, obj_thresh, iou_thresh,
         vaildation_split, log_dir, is_prune, initial_sparsity, final_sparsity,
         end_epoch, frequency):
    # Build path
    log_dir = (Path(log_dir) /
               datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
               )  # type: Path
    ckpt_weights = log_dir / 'yolo_weights.h5'
    ckpt = log_dir / 'yolo_model.h5'
    if not log_dir.exists():
        log_dir.mkdir(parents=True)
    write_arguments_to_file(args, str(log_dir / 'args.txt'))

    # Build utils

    h = Helper(f'data/{train_set}_img_ann.npy', class_num,
               f'data/{train_set}_anchor.npy',
               np.reshape(np.array(image_size), (-1, 2)),
               np.reshape(np.array(output_size), (-1, 2)), vaildation_split)
    h.set_dataset(batch_size, rand_seed, is_training=(is_augmenter == 'True'))

    # Build network
    if False:
        network = eval(model_def)  # type :yolo_mobilev2
        yolo_model, yolo_model_wrapper = network(
            [image_size[0], image_size[1], 3],
            len(h.anchors[0]),
            class_num,
            alpha=depth_multiplier)
    else:
        yolo_model, yolo_model_wrapper,output_size = \
            convert.make_model(model_def, model_def + '.weights',
                               f'data/{train_set}_anchor.npy',
                               h.train_epoch_step * end_epoch,
                               initial_sparsity,
                               final_sparsity,
                               frequency)
        tf.keras.models.save_model(yolo_model,
                                   'pre_prun{model_def}',
                                   include_optimizer=False)
    if pre_ckpt != None and pre_ckpt != 'None' and pre_ckpt != '':
        if 'h5' in pre_ckpt:
            yolo_model_wrapper.load_weights(str(pre_ckpt))
            print(INFO, f' Load CKPT {str(pre_ckpt)}')
        else:
            print(ERROR, ' Pre CKPT path is unvalid')

    # prune model
    pruning_params = {
        'pruning_schedule':
        sparsity.PolynomialDecay(initial_sparsity=.50,
                                 final_sparsity=.90,
                                 begin_step=0,
                                 end_step=h.train_epoch_step * end_epoch,
                                 frequency=frequency)
    }

    train_model = yolo_model_wrapper

    train_model.compile(
        keras.optimizers.Adam(lr=init_learning_rate,
                              decay=learning_rate_decay_factor),
        loss=[
            create_loss_fn(h, obj_thresh, iou_thresh, obj_weight, noobj_weight,
                           wh_weight, layer)
            for layer in range(
                len(train_model.output) if isinstance(train_model.output, list
                                                      ) else 1)
        ],
        metrics=[
            Yolo_Precision(obj_thresh, name='p'),
            Yolo_Recall(obj_thresh, name='r')
        ])
    """ NOTE fix the dataset output shape """
    shapes = (train_model.input.shape, tuple(h.output_shapes))
    h.train_dataset = h.train_dataset.apply(assert_element_shape(shapes))
    h.test_dataset = h.test_dataset.apply(assert_element_shape(shapes))
    """ Callbacks """
    if is_prune == 'True':
        cbs = [
            sparsity.UpdatePruningStep(),
            sparsity.PruningSummaries(log_dir=str(log_dir), profile_batch=0)
        ]
    else:
        cbs = [TensorBoard(str(log_dir), update_freq='batch', profile_batch=3)]

    # Training
    try:
        train_model.fit(h.train_dataset,
                        epochs=max_nrof_epochs,
                        steps_per_epoch=h.train_epoch_step,
                        callbacks=cbs,
                        validation_data=h.test_dataset,
                        validation_steps=int(h.test_epoch_step *
                                             h.validation_split))
    except KeyboardInterrupt as e:
        pass
        train_model.summary()
    if is_prune == 'True':
        final_model = tmot.sparsity.keras.strip_pruning(train_model)
        final_model.summary()
        model_name = 'sparse1.h5'
        yolo_model = tmot.sparsity.keras.strip_pruning(yolo_model)
        tf.keras.models.save_model(yolo_model,
                                   model_name,
                                   include_optimizer=False)
        tf.keras.models.save_model(yolo_model,
                                   '{model_def}.tf',
                                   include_optimizer=False)
    else:
        keras.models.save_model(yolo_model, str(ckpt))
        print()
        print(INFO, f' Save Model as {str(ckpt)}')
def runTrainingDetection(uuid,
                         datasetDir,
                         numOfClass,
                         obj_thresh=0.7,
                         iou_thresh=0.5,
                         obj_weight=1.0,
                         noobj_weight=1.0,
                         wh_weight=1.0,
                         max_nrof_epochs=50,
                         batch_size=96,
                         vaildation_split=0.2):
    config = tf.ConfigProto()
    sess = tf.Session(config=config)
    keras.backend.set_session(sess)

    datasetList = [os.path.join(datasetDir, f) for f in os.listdir(datasetDir)]
    image_list = []

    #img_ann_f = open(os.path.join(datasetDir, 'dataset_img_ann.txt'), 'w+')

    for fileName in datasetList:
        if '.jpg' in fileName:
            #        print('/home/m5stack/VTrainingService/' + fileName, file=img_ann_f)
            image_list.append('/home/m5stack/VTrainingService/' + fileName)

    #img_ann_f.close()

    image_path_list = np.array(
        image_list
    )  #np.loadtxt(os.path.join(datasetDir, 'dataset_img_ann.txt'), dtype=str)

    ann_list = list(image_path_list)
    ann_list = [re.sub(r'JPEGImages', 'labels', s) for s in ann_list]
    ann_list = [re.sub(r'.jpg', '.txt', s) for s in ann_list]

    lines = np.array([
        np.array([
            image_path_list[i],
            np.loadtxt(ann_list[i], dtype=float, ndmin=2),
            np.array(skimage.io.imread(image_path_list[i]).shape[0:2])
        ]) for i in range(len(ann_list))
    ])

    np.save(os.path.join(datasetDir, 'dataset_img_ann.npy'), lines)

    #print('dataset npu>>>', os.path.join(datasetDir, 'dataset_img_ann.npy'))

    h = Helper(os.path.join(datasetDir, 'dataset_img_ann.npy'), numOfClass,
               'voc_anchor.npy', np.reshape(np.array((224, 320)), (-1, 2)),
               np.reshape(np.array((7, 10, 14, 20)),
                          (-1, 2)), vaildation_split)

    h.set_dataset(batch_size, 6)

    network = eval('yolo_mobilev1')  # type :yolo_mobilev2
    yolo_model, train_model = network([224, 320, 3],
                                      len(h.anchors[0]),
                                      numOfClass,
                                      alpha=0.5)

    train_model.compile(
        RAdam(),
        loss=[
            create_loss_fn(h, obj_thresh, iou_thresh, obj_weight, noobj_weight,
                           wh_weight, layer)
            for layer in range(
                len(train_model.output) if isinstance(train_model.output, list
                                                      ) else 1)
        ],
        metrics=[
            Yolo_Precision(obj_thresh, name='p'),
            Yolo_Recall(obj_thresh, name='r')
        ])

    shapes = (train_model.input.shape, tuple(h.output_shapes))
    h.train_dataset = h.train_dataset.apply(assert_element_shape(shapes))
    h.test_dataset = h.test_dataset.apply(assert_element_shape(shapes))

    #print('train', h.train_dataset, '\n\r\n\rtest', h.test_dataset)

    try:
        train_model.fit(h.train_dataset,
                        epochs=max_nrof_epochs,
                        steps_per_epoch=10,
                        validation_data=h.test_dataset,
                        validation_steps=1)
    except Exception as e:
        return (-45, 'Unexpected error found during training, err:', e)

    keras.models.save_model(
        yolo_model, f'{localSSDLoc}trained_h5_file/{uuid}_mbnet5_yolov3.h5')

    converter = tf.lite.TFLiteConverter.from_keras_model_file(
        f'{localSSDLoc}trained_h5_file/{uuid}_mbnet5_yolov3.h5',
        custom_objects={
            'RAdam':
            RAdam,
            'loss_softmax_cross_entropy_with_logits_v2':
            loss_softmax_cross_entropy_with_logits_v2
        })
    tflite_model = converter.convert()
    open(f'{localSSDLoc}trained_tflite_file/{uuid}_mbnet5_yolov3_quant.tflite',
         "wb").write(tflite_model)

    subprocess.run([
        f'{nncaseLoc}/ncc',
        f'{localSSDLoc}trained_tflite_file/{uuid}_mbnet5_yolov3_quant.tflite',
        f'{localSSDLoc}trained_kmodel_file/{uuid}_mbnet5_yolov3.kmodel', '-i',
        'tflite', '-o', 'k210model', '--dataset', datasetDir
    ])

    if os.path.isfile(
            f'{localSSDLoc}trained_kmodel_file/{uuid}_mbnet5_yolov3.kmodel'):
        return (
            0, f'{localSSDLoc}trained_kmodel_file/{uuid}_mbnet5_yolov3.kmodel')
    else:
        return (-16,
                'Unexpected Error Found During generating Kendryte k210model.')