Пример #1
0
def main(train_set: str, outfile: str, max_iters: int, in_hw: tuple,
         out_hw: tuple, is_random: bool, is_plot: bool):
    helper = Helper('data/{}_img.list'.format(train_set),
                    'data/{}_ann.list'.format(train_set), None, None, in_hw,
                    out_hw)
    g = helper.generator(is_training=False, is_make_lable=False)
    _, true_box = next(g)
    X = true_box.copy()
    try:
        while True:
            _, true_box = next(g)
            X = np.vstack((X, true_box))
    except StopIteration as e:
        print('collotation all box')
    x = X[:, 3:]
    initial_centroids = np.vstack(
        (np.linspace(0.05, 0.3, num=5), np.linspace(0.05, 0.5, num=5)))
    initial_centroids = initial_centroids.T
    # initial_centroids = np.random.rand(5, 2)
    centroids, idx = runkMeans(x, initial_centroids, 10, is_plot)
    centroids /= np.array([helper.grid_w, helper.grid_h])
    np.savetxt(outfile, centroids, fmt='%f')
Пример #2
0
def main(pb_path, class_num, anchor_file, image_size, image_path):
    g = tf.get_default_graph()
    helper = Helper(None, None, class_num, anchor_file, image_size, (7, 10))

    test_img = helper._read_img(image_path, True)
    test_img = helper._process_img(test_img, None, is_training=False)[0]

    with tf.gfile.GFile(pb_path, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')

    inputs = g.get_tensor_by_name('Input_image:0')
    pred_label = g.get_tensor_by_name('Yolo/Final/conv2d/BiasAdd:0')
    """ reshape the model output """
    pred_label = tf.reshape(pred_label, [-1, helper.out_h, helper.out_w, len(helper.anchors), 5+class_num])

    """ split the label """
    pred_xy = pred_label[..., 0:2]
    pred_wh = pred_label[..., 2:4]
    pred_confidence = pred_label[..., 4:5]
    pred_cls = pred_label[..., 5:]

    pred_xy = tf.nn.sigmoid(pred_xy)
    pred_wh = tf.exp(pred_wh)
    pred_confidence_sigmoid = tf.nn.sigmoid(pred_confidence)
    obj_mask = pred_confidence_sigmoid[..., 0] > .7
    """ reshape box  """
    pred_xy_A, pred_wh_A = tf_xywh_to_all(pred_xy, pred_wh, helper)

    box = tf.concat([pred_xy_A, pred_wh_A], -1)

    yxyx_box = tf_center_to_corner(box)
    yxyx_box = tf.boolean_mask(yxyx_box, obj_mask)
    """ nms  """
    select = tf.image.non_max_suppression(yxyx_box,
                                          scores=tf.reshape(tf.boolean_mask(pred_confidence_sigmoid, obj_mask), (-1, )),
                                          max_output_size=30)
    vaild_box = tf.gather(yxyx_box, select)
    vaild_box = vaild_box[tf.newaxis, :, :]
    """ draw box """
    img_box = tf.image.draw_bounding_boxes(inputs, vaild_box)

    """ run """
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        # todo -----
        test_img = helper._read_img(image_path, is_resize=True)
        test_img, _ = helper._process_img(test_img, None, False)
        test_img = test_img[np.newaxis, :, :, :]
        img_box_, vaild_box_, yxyx_box_, pred_xy_, pred_wh_, pred_confidence_sigmoid_ = sess.run(
            [img_box, vaild_box, yxyx_box, pred_xy, pred_wh, pred_confidence_sigmoid], feed_dict={inputs: test_img})
    skimage.io.imshow(img_box_[0])
    skimage.io.show()
Пример #3
0
def main(args, train_set, class_num, pre_ckpt, model_def, depth_multiplier,
         is_augmenter, image_size, output_size, batch_size, rand_seed,
         max_nrof_epochs, init_learning_rate, learning_rate_decay_factor,
         obj_weight, noobj_weight, wh_weight, obj_thresh, iou_thresh,
         vaildation_split, log_dir, is_prune, initial_sparsity, final_sparsity,
         end_epoch, frequency):
    # Build path
    log_dir = (Path(log_dir) /
               datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
               )  # type: Path
    ckpt_weights = log_dir / 'yolo_weights.h5'
    ckpt = log_dir / 'yolo_model.h5'
    if not log_dir.exists():
        log_dir.mkdir(parents=True)
    write_arguments_to_file(args, str(log_dir / 'args.txt'))

    # Build utils

    h = Helper(f'data/{train_set}_img_ann.npy', class_num,
               f'data/{train_set}_anchor.npy',
               np.reshape(np.array(image_size), (-1, 2)),
               np.reshape(np.array(output_size), (-1, 2)), vaildation_split)
    h.set_dataset(batch_size, rand_seed, is_training=(is_augmenter == 'True'))

    # Build network
    if False:
        network = eval(model_def)  # type :yolo_mobilev2
        yolo_model, yolo_model_wrapper = network(
            [image_size[0], image_size[1], 3],
            len(h.anchors[0]),
            class_num,
            alpha=depth_multiplier)
    else:
        yolo_model, yolo_model_wrapper,output_size = \
            convert.make_model(model_def, model_def + '.weights',
                               f'data/{train_set}_anchor.npy',
                               h.train_epoch_step * end_epoch,
                               initial_sparsity,
                               final_sparsity,
                               frequency)
        tf.keras.models.save_model(yolo_model,
                                   'pre_prun{model_def}',
                                   include_optimizer=False)
    if pre_ckpt != None and pre_ckpt != 'None' and pre_ckpt != '':
        if 'h5' in pre_ckpt:
            yolo_model_wrapper.load_weights(str(pre_ckpt))
            print(INFO, f' Load CKPT {str(pre_ckpt)}')
        else:
            print(ERROR, ' Pre CKPT path is unvalid')

    # prune model
    pruning_params = {
        'pruning_schedule':
        sparsity.PolynomialDecay(initial_sparsity=.50,
                                 final_sparsity=.90,
                                 begin_step=0,
                                 end_step=h.train_epoch_step * end_epoch,
                                 frequency=frequency)
    }

    train_model = yolo_model_wrapper

    train_model.compile(
        keras.optimizers.Adam(lr=init_learning_rate,
                              decay=learning_rate_decay_factor),
        loss=[
            create_loss_fn(h, obj_thresh, iou_thresh, obj_weight, noobj_weight,
                           wh_weight, layer)
            for layer in range(
                len(train_model.output) if isinstance(train_model.output, list
                                                      ) else 1)
        ],
        metrics=[
            Yolo_Precision(obj_thresh, name='p'),
            Yolo_Recall(obj_thresh, name='r')
        ])
    """ NOTE fix the dataset output shape """
    shapes = (train_model.input.shape, tuple(h.output_shapes))
    h.train_dataset = h.train_dataset.apply(assert_element_shape(shapes))
    h.test_dataset = h.test_dataset.apply(assert_element_shape(shapes))
    """ Callbacks """
    if is_prune == 'True':
        cbs = [
            sparsity.UpdatePruningStep(),
            sparsity.PruningSummaries(log_dir=str(log_dir), profile_batch=0)
        ]
    else:
        cbs = [TensorBoard(str(log_dir), update_freq='batch', profile_batch=3)]

    # Training
    try:
        train_model.fit(h.train_dataset,
                        epochs=max_nrof_epochs,
                        steps_per_epoch=h.train_epoch_step,
                        callbacks=cbs,
                        validation_data=h.test_dataset,
                        validation_steps=int(h.test_epoch_step *
                                             h.validation_split))
    except KeyboardInterrupt as e:
        pass
        train_model.summary()
    if is_prune == 'True':
        final_model = tmot.sparsity.keras.strip_pruning(train_model)
        final_model.summary()
        model_name = 'sparse1.h5'
        yolo_model = tmot.sparsity.keras.strip_pruning(yolo_model)
        tf.keras.models.save_model(yolo_model,
                                   model_name,
                                   include_optimizer=False)
        tf.keras.models.save_model(yolo_model,
                                   '{model_def}.tf',
                                   include_optimizer=False)
    else:
        keras.models.save_model(yolo_model, str(ckpt))
        print()
        print(INFO, f' Save Model as {str(ckpt)}')
Пример #4
0
def runTrainingDetection(uuid,
                         datasetDir,
                         numOfClass,
                         obj_thresh=0.7,
                         iou_thresh=0.5,
                         obj_weight=1.0,
                         noobj_weight=1.0,
                         wh_weight=1.0,
                         max_nrof_epochs=50,
                         batch_size=96,
                         vaildation_split=0.2):
    config = tf.ConfigProto()
    sess = tf.Session(config=config)
    keras.backend.set_session(sess)

    datasetList = [os.path.join(datasetDir, f) for f in os.listdir(datasetDir)]
    image_list = []

    #img_ann_f = open(os.path.join(datasetDir, 'dataset_img_ann.txt'), 'w+')

    for fileName in datasetList:
        if '.jpg' in fileName:
            #        print('/home/m5stack/VTrainingService/' + fileName, file=img_ann_f)
            image_list.append('/home/m5stack/VTrainingService/' + fileName)

    #img_ann_f.close()

    image_path_list = np.array(
        image_list
    )  #np.loadtxt(os.path.join(datasetDir, 'dataset_img_ann.txt'), dtype=str)

    ann_list = list(image_path_list)
    ann_list = [re.sub(r'JPEGImages', 'labels', s) for s in ann_list]
    ann_list = [re.sub(r'.jpg', '.txt', s) for s in ann_list]

    lines = np.array([
        np.array([
            image_path_list[i],
            np.loadtxt(ann_list[i], dtype=float, ndmin=2),
            np.array(skimage.io.imread(image_path_list[i]).shape[0:2])
        ]) for i in range(len(ann_list))
    ])

    np.save(os.path.join(datasetDir, 'dataset_img_ann.npy'), lines)

    #print('dataset npu>>>', os.path.join(datasetDir, 'dataset_img_ann.npy'))

    h = Helper(os.path.join(datasetDir, 'dataset_img_ann.npy'), numOfClass,
               'voc_anchor.npy', np.reshape(np.array((224, 320)), (-1, 2)),
               np.reshape(np.array((7, 10, 14, 20)),
                          (-1, 2)), vaildation_split)

    h.set_dataset(batch_size, 6)

    network = eval('yolo_mobilev1')  # type :yolo_mobilev2
    yolo_model, train_model = network([224, 320, 3],
                                      len(h.anchors[0]),
                                      numOfClass,
                                      alpha=0.5)

    train_model.compile(
        RAdam(),
        loss=[
            create_loss_fn(h, obj_thresh, iou_thresh, obj_weight, noobj_weight,
                           wh_weight, layer)
            for layer in range(
                len(train_model.output) if isinstance(train_model.output, list
                                                      ) else 1)
        ],
        metrics=[
            Yolo_Precision(obj_thresh, name='p'),
            Yolo_Recall(obj_thresh, name='r')
        ])

    shapes = (train_model.input.shape, tuple(h.output_shapes))
    h.train_dataset = h.train_dataset.apply(assert_element_shape(shapes))
    h.test_dataset = h.test_dataset.apply(assert_element_shape(shapes))

    #print('train', h.train_dataset, '\n\r\n\rtest', h.test_dataset)

    try:
        train_model.fit(h.train_dataset,
                        epochs=max_nrof_epochs,
                        steps_per_epoch=10,
                        validation_data=h.test_dataset,
                        validation_steps=1)
    except Exception as e:
        return (-45, 'Unexpected error found during training, err:', e)

    keras.models.save_model(
        yolo_model, f'{localSSDLoc}trained_h5_file/{uuid}_mbnet5_yolov3.h5')

    converter = tf.lite.TFLiteConverter.from_keras_model_file(
        f'{localSSDLoc}trained_h5_file/{uuid}_mbnet5_yolov3.h5',
        custom_objects={
            'RAdam':
            RAdam,
            'loss_softmax_cross_entropy_with_logits_v2':
            loss_softmax_cross_entropy_with_logits_v2
        })
    tflite_model = converter.convert()
    open(f'{localSSDLoc}trained_tflite_file/{uuid}_mbnet5_yolov3_quant.tflite',
         "wb").write(tflite_model)

    subprocess.run([
        f'{nncaseLoc}/ncc',
        f'{localSSDLoc}trained_tflite_file/{uuid}_mbnet5_yolov3_quant.tflite',
        f'{localSSDLoc}trained_kmodel_file/{uuid}_mbnet5_yolov3.kmodel', '-i',
        'tflite', '-o', 'k210model', '--dataset', datasetDir
    ])

    if os.path.isfile(
            f'{localSSDLoc}trained_kmodel_file/{uuid}_mbnet5_yolov3.kmodel'):
        return (
            0, f'{localSSDLoc}trained_kmodel_file/{uuid}_mbnet5_yolov3.kmodel')
    else:
        return (-16,
                'Unexpected Error Found During generating Kendryte k210model.')
Пример #5
0
def main(ckpt_weights, image_size, output_size, model_def, class_num,
         depth_multiplier, obj_thresh, iou_thresh, train_set, test_image):
    h = Helper(None, class_num, f'data/{train_set}_anchor.npy',
               np.reshape(np.array(image_size), (-1, 2)),
               np.reshape(np.array(output_size), (-1, 2)))
    network = eval(model_def)  # type :yolo_mobilev2
    yolo_model, yolo_model_warpper = network([image_size[0], image_size[1], 3],
                                             len(h.anchors[0]),
                                             class_num,
                                             alpha=depth_multiplier)

    yolo_model_warpper.load_weights(str(ckpt_weights))
    print(INFO, f' Load CKPT {str(ckpt_weights)}')
    orig_img = h._read_img(str(test_image))
    image_shape = orig_img.shape[0:2]
    img, _ = h._process_img(orig_img,
                            true_box=None,
                            is_training=False,
                            is_resize=True)
    """ load images """
    img = tf.expand_dims(img, 0)
    y_pred = yolo_model_warpper.predict(img)
    """ box list """
    _yxyx_box = []
    _yxyx_box_scores = []
    """ preprocess label """
    for l, pred_label in enumerate(y_pred):
        """ split the label """
        pred_xy = pred_label[..., 0:2]
        pred_wh = pred_label[..., 2:4]
        pred_confidence = pred_label[..., 4:5]
        pred_cls = pred_label[..., 5:]
        # box_scores = obj_score * class_score
        box_scores = tf.sigmoid(pred_cls) * tf.sigmoid(pred_confidence)
        # obj_mask = pred_confidence_score[..., 0] > obj_thresh
        """ reshape box  """
        # NOTE tf_xywh_to_all will auto use sigmoid function
        pred_xy_A, pred_wh_A = tf_xywh_to_all(pred_xy, pred_wh, l, h)
        boxes = correct_box(pred_xy_A, pred_wh_A, image_size, image_shape)
        boxes = tf.reshape(boxes, (-1, 4))
        box_scores = tf.reshape(box_scores, (-1, class_num))
        """ append box and scores to global list """
        _yxyx_box.append(boxes)
        _yxyx_box_scores.append(box_scores)

    yxyx_box = tf.concat(_yxyx_box, axis=0)
    yxyx_box_scores = tf.concat(_yxyx_box_scores, axis=0)

    mask = yxyx_box_scores >= obj_thresh
    """ do nms for every classes"""
    _boxes = []
    _scores = []
    _classes = []
    for c in range(class_num):
        class_boxes = tf.boolean_mask(yxyx_box, mask[:, c])
        class_box_scores = tf.boolean_mask(yxyx_box_scores[:, c], mask[:, c])
        select = tf.image.non_max_suppression(class_boxes,
                                              scores=class_box_scores,
                                              max_output_size=30,
                                              iou_threshold=iou_thresh)
        class_boxes = tf.gather(class_boxes, select)
        class_box_scores = tf.gather(class_box_scores, select)
        _boxes.append(class_boxes)
        _scores.append(class_box_scores)
        _classes.append(tf.ones_like(class_box_scores) * c)

    boxes = tf.concat(_boxes, axis=0)
    classes = tf.concat(_classes, axis=0)
    scores = tf.concat(_scores, axis=0)
    """ draw box  """
    font = ImageFont.truetype(font='asset/FiraMono-Medium.otf',
                              size=tf.cast(
                                  tf.floor(3e-2 * image_shape[0] + 0.5),
                                  tf.int32).numpy())

    thickness = (image_shape[0] + image_shape[1]) // 300
    """ show result """
    if len(classes) > 0:
        pil_img = Image.fromarray(orig_img)
        print(f'[top\tleft\tbottom\tright\tscore\tclass]')
        for i, c in enumerate(classes):
            box = boxes[i]
            score = scores[i]
            label = '{:2d} {:.2f}'.format(int(c.numpy()), score.numpy())
            draw = ImageDraw.Draw(pil_img)
            label_size = draw.textsize(label, font)
            top, left, bottom, right = box
            print(
                f'[{top:.1f}\t{left:.1f}\t{bottom:.1f}\t{right:.1f}\t{score:.2f}\t{int(c):2d}]'
            )
            top = max(0, tf.cast(tf.floor(top + 0.5), tf.int32))
            left = max(0, tf.cast(tf.floor(left + 0.5), tf.int32))
            bottom = min(image_shape[0],
                         tf.cast(tf.floor(bottom + 0.5), tf.int32))
            right = min(image_shape[1], tf.cast(tf.floor(right + 0.5),
                                                tf.int32))

            if top - image_shape[0] >= 0:
                text_origin = tf.convert_to_tensor([left, top - label_size[1]])
            else:
                text_origin = tf.convert_to_tensor([left, top + 1])

            for j in range(thickness):
                draw.rectangle([left + j, top + j, right - j, bottom - j],
                               outline=h.colormap[c])
            draw.rectangle(
                [tuple(text_origin),
                 tuple(text_origin + label_size)],
                fill=h.colormap[c])
            draw.text(text_origin, label, fill=(0, 0, 0), font=font)
            del draw
        pil_img.show()
    else:
        print(NOTE, ' no boxes detected')
Пример #6
0
def main(args, train_set, class_num, train_classifier, pre_ckpt, model_def,
         is_augmenter, anchor_file, image_size, output_size, batch_size,
         rand_seed, max_nrof_epochs, init_learning_rate,
         learning_rate_decay_epochs, learning_rate_decay_factor, obj_weight,
         noobj_weight, obj_thresh, iou_thresh, log_dir):
    g = tf.get_default_graph()
    tf.set_random_seed(rand_seed)
    """ import network """
    network = eval(model_def)
    """ generate the dataset """
    # [(0.57273, 0.677385), (1.87446, 2.06253), (3.33843, 5.47434), (7.88282, 3.52778), (9.77052, 9.16828)]
    helper = Helper('data/{}_img.list'.format(train_set),
                    'data/{}_ann.list'.format(train_set), class_num,
                    anchor_file, image_size, output_size)
    helper.set_dataset(batch_size,
                       rand_seed,
                       is_training=(is_augmenter == 'True'))
    next_img, next_label = helper.get_iter()
    """ define the model """
    batch_image = tf.placeholder_with_default(
        next_img,
        shape=[None, image_size[0], image_size[1], 3],
        name='Input_image')
    batch_label = tf.placeholder_with_default(next_label,
                                              shape=[
                                                  None, output_size[0],
                                                  output_size[1],
                                                  len(helper.anchors),
                                                  5 + class_num
                                              ],
                                              name='Input_label')
    training_control = tf.placeholder_with_default(True,
                                                   shape=[],
                                                   name='training_control')
    true_label = tf.identity(batch_label)
    nets, endpoints = network(batch_image,
                              len(helper.anchors),
                              class_num,
                              phase_train=training_control)
    """ reshape the model output """
    pred_label = tf.reshape(nets, [
        -1, output_size[0], output_size[1],
        len(helper.anchors), 5 + class_num
    ],
                            name='predict')
    """ split the label """
    pred_xy = pred_label[..., 0:2]
    pred_wh = pred_label[..., 2:4]
    pred_confidence = pred_label[..., 4:5]
    pred_cls = pred_label[..., 5:]

    pred_xy = tf.nn.sigmoid(pred_xy)
    pred_wh = tf.exp(pred_wh)
    pred_confidence_sigmoid = tf.nn.sigmoid(pred_confidence)

    true_xy = true_label[..., 0:2]
    true_wh = true_label[..., 2:4]
    true_confidence = true_label[..., 4:5]
    true_cls = true_label[..., 5:]

    obj_mask = true_confidence[..., 0] > obj_thresh
    """ calc the noobj mask ~ """
    if train_classifier == 'True':
        noobj_mask = tf.logical_not(obj_mask)
    else:
        noobj_mask = calc_noobj_mask(true_xy,
                                     true_wh,
                                     pred_xy,
                                     pred_wh,
                                     obj_mask,
                                     iou_thresh=iou_thresh,
                                     helper=helper)
    """ define loss """
    xy_loss = tf.reduce_sum(
        tf.square(
            tf.boolean_mask(true_xy, obj_mask) -
            tf.boolean_mask(pred_xy, obj_mask))) / batch_size
    wh_loss = tf.reduce_sum(
        tf.square(
            tf.boolean_mask(true_wh, obj_mask) -
            tf.boolean_mask(pred_wh, obj_mask))) / batch_size
    obj_loss = obj_weight * tf.reduce_sum(
        tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.boolean_mask(true_confidence, obj_mask),
            logits=tf.boolean_mask(pred_confidence, obj_mask))) / batch_size
    noobj_loss = noobj_weight * tf.reduce_sum(
        tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.boolean_mask(true_confidence, noobj_mask),
            logits=tf.boolean_mask(pred_confidence, noobj_mask))) / batch_size
    cls_loss = tf.reduce_sum(
        tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=tf.boolean_mask(true_cls, obj_mask),
            logits=tf.boolean_mask(pred_cls, obj_mask))) / batch_size

    # xy_loss = tf.losses.mean_squared_error(tf.boolean_mask(true_xy, obj_mask), tf.boolean_mask(pred_xy, obj_mask))
    # wh_loss = tf.losses.mean_squared_error(tf.boolean_mask(true_wh, obj_mask), tf.boolean_mask(pred_wh, obj_mask))
    # obj_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.boolean_mask(true_confidence, obj_mask), logits=tf.boolean_mask(pred_confidence, obj_mask), weights=5.0)
    # noobj_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.boolean_mask(true_confidence, noobj_mask), logits=tf.boolean_mask(pred_confidence, noobj_mask), weights=.5)
    # cls_loss = tf.losses.softmax_cross_entropy(onehot_labels=tf.boolean_mask(true_cls, obj_mask), logits=tf.boolean_mask(pred_cls, obj_mask))

    if train_classifier == 'True':
        total_loss = obj_loss + noobj_loss + cls_loss
    else:
        total_loss = obj_loss + noobj_loss + cls_loss + xy_loss + wh_loss
    """ define steps """
    global_steps = tf.train.create_global_step()
    """ define learing rate """
    current_learning_rate = tf.train.exponential_decay(
        init_learning_rate,
        global_steps,
        helper.epoch_step // learning_rate_decay_epochs,
        learning_rate_decay_factor,
        staircase=False)
    """ define train_op """
    train_op = slim.learning.create_train_op(
        total_loss, tf.train.AdamOptimizer(current_learning_rate),
        global_steps)
    """ calc the accuracy """
    precision, prec_op = tf.metrics.precision_at_thresholds(
        true_confidence, pred_confidence_sigmoid, [obj_thresh])
    test_precision, test_prec_op = tf.metrics.precision_at_thresholds(
        true_confidence, pred_confidence_sigmoid, [obj_thresh])
    recall, recall_op = tf.metrics.recall_at_thresholds(
        true_confidence, pred_confidence_sigmoid, [obj_thresh])
    test_recall, test_recall_op = tf.metrics.recall_at_thresholds(
        true_confidence, pred_confidence_sigmoid, [obj_thresh])
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        """ must save the bn paramter! """
        var_list = tf.global_variables() + tf.local_variables(
        )  # list(set(tf.trainable_variables() + [g for g in tf.global_variables() if 'moving_' in g.name]))
        saver = tf.train.Saver(var_list)

        # init the model and restore the pre-train weight
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer()
                 )  # NOTE the accuracy must init local variable
        restore_ckpt(sess, var_list, pre_ckpt)
        # define the log and saver
        subdir = os.path.join(
            log_dir, datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S'))
        train_writer = tf.summary.FileWriter(subdir, graph=sess.graph)
        write_arguments_to_file(args, os.path.join(subdir, 'arguments.txt'))
        tf.summary.scalar('total_loss', total_loss)
        tf.summary.scalar('obj_loss', obj_loss)
        tf.summary.scalar('noobj_loss', noobj_loss)
        tf.summary.scalar('mse_loss', xy_loss + wh_loss)
        tf.summary.scalar('class_loss', cls_loss)
        tf.summary.scalar('leraning_rate', current_learning_rate)
        tf.summary.scalar('precision', precision[0])
        tf.summary.scalar('recall', recall[0])
        merged = tf.summary.merge_all()
        t_prec_summary = tf.summary.scalar('test_precision', test_precision[0])
        t_recall_summary = tf.summary.scalar('test_recall', test_recall[0])

        try:
            for i in range(max_nrof_epochs):
                with tqdm(total=helper.epoch_step,
                          bar_format=
                          '{n_fmt}/{total_fmt} |{bar}| {rate_fmt}{postfix}',
                          unit=' batch',
                          dynamic_ncols=True) as t:
                    for j in range(helper.epoch_step):
                        if j % 30 == 0:
                            summary1, summary2, _, _, step_cnt = sess.run(
                                [
                                    t_prec_summary, t_recall_summary,
                                    test_recall_op, test_prec_op, global_steps
                                ],
                                feed_dict={training_control: False})
                            train_writer.add_summary(summary1, step_cnt)
                            train_writer.add_summary(summary2, step_cnt)
                        else:
                            summary, _, total_l, prec, _, _, lr, step_cnt = sess.run(
                                [
                                    merged, train_op, total_loss, precision,
                                    prec_op, recall_op, current_learning_rate,
                                    global_steps
                                ])
                            t.set_postfix(loss='{:<5.3f}'.format(total_l),
                                          prec='{:<4.2f}%'.format(prec[0] *
                                                                  100),
                                          lr='{:f}'.format(lr))
                            train_writer.add_summary(summary, step_cnt)
                        t.update()
            saver.save(sess,
                       save_path=os.path.join(subdir, 'model.ckpt'),
                       global_step=global_steps)
            print('save over')
        except KeyboardInterrupt as e:
            saver.save(sess,
                       save_path=os.path.join(subdir, 'model.ckpt'),
                       global_step=global_steps)
            print('save over')