Exemplo n.º 1
0
def eger(cfg):
    gen = data_gen.get_batch(batch_size=cfg.batch_size)


    images, true_box, true_label = next(gen)
    print(true_label)
    loct, conft = np_utils.get_loc_conf(true_box, true_label, batch_size=cfg.batch_size)
    get_loss(images, conft, loct)
Exemplo n.º 2
0
def train():
    img = tf.placeholder(
        shape=[config.batch_size, cfg['min_dim'], cfg['min_dim'], 3],
        dtype=tf.float32)
    anchors_num = sum(
        [cfg['feature_maps'][s]**2 * cfg['aspect_num'][s] for s in range(6)])

    loc = tf.placeholder(shape=[config.batch_size, anchors_num, 4],
                         dtype=tf.float32)
    conf = tf.placeholder(shape=[config.batch_size, anchors_num],
                          dtype=tf.float32)

    pred_loc, pred_confs, vbs = model.model(img, config)

    train_tensors, sum_op = get_loss(conf, loc, pred_loc, pred_confs, config)

    gen = data_gen.get_batch(batch_size=config.batch_size,
                             image_size=cfg['min_dim'],
                             max_detect=50)
    optimizer = tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.9)
    train_op = slim.learning.create_train_op(train_tensors, optimizer)

    saver = tf.train.Saver(vbs)

    def restore(sess):
        saver.restore(sess, '/home/dsl/all_check/vgg_16.ckpt')

    sv = tf.train.Supervisor(
        logdir='/home/dsl/all_check/face_detect/vgg500_nn',
        summary_op=None,
        init_fn=restore)

    with sv.managed_session() as sess:
        for step in range(1000000000):

            images, true_box, true_label = next(gen)
            loct, conft = np_utils.get_loc_conf(true_box,
                                                true_label,
                                                batch_size=config.batch_size,
                                                cfg=cfg)
            feed_dict = {img: images, loc: loct, conf: conft}

            ls = sess.run(train_op, feed_dict=feed_dict)
            if step % 10 == 0:
                summaries = sess.run(sum_op, feed_dict=feed_dict)
                sv.summary_computed(sess, summaries)
                print(ls)
Exemplo n.º 3
0
def train(cfg):
    img = tf.placeholder(shape=[cfg.batch_size, 300, 300, 3], dtype=tf.float32)
    #boxs = tf.placeholder(shape=[batch_size, 50, 4], dtype=tf.float32)
    #label = tf.placeholder(shape=[batch_size, 50], dtype=tf.int32)
    loc = tf.placeholder(shape=[cfg.batch_size, 7512,4], dtype=tf.float32)
    conf =  tf.placeholder(shape=[cfg.batch_size, 7512], dtype=tf.float32)

    pred_loc, pred_confs, vbs = model(img)

    train_tensors,sum_op = get_loss(conf,loc,pred_loc, pred_confs)

    gen = data_gen.get_batch(batch_size=cfg.batch_size)
    optimizer = tf.train.MomentumOptimizer(learning_rate=0.001, momentum=0.9)
    train_op = slim.learning.create_train_op(train_tensors, optimizer)


    saver = tf.train.Saver(vbs)
    def restore(sess):
        saver.restore(sess,'/home/dsl/all_check/vgg_16.ckpt')

    sv = tf.train.Supervisor(logdir='/home/dsl/all_check/face_detect', summary_op=None, init_fn=restore)

    with sv.managed_session() as sess:
        for step in range(1000000000):

            images, true_box, true_label = next(gen)
            loct,conft = np_utils.get_loc_conf(true_box,true_label,batch_size=cfg.batch_size)
            feed_dict = {img: images, loc: loct,
                         conf: conft}

            ls = sess.run(train_op, feed_dict=feed_dict)
            if step%10==0:
                summaries = sess.run(sum_op,feed_dict=feed_dict)
                sv.summary_computed(sess, summaries)
                print(ls)
#train()
#tf.enable_eager_execution()
#eger()

#detect('/media/dsl/20d6b919-92e1-4489-b2be-a092290668e4/VOCdevkit/VOCdevkit/VOC2007/JPEGImages/000133.jpg')
#video()
Exemplo n.º 4
0
def main(opts):

    model_name = 'OCT-E2E-MLT'
    net = OctMLT(attention=True)
    print("Using {0}".format(model_name))

    learning_rate = opts.base_lr
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=opts.base_lr,
                                 weight_decay=weight_decay)
    step_start = 0
    if os.path.exists(opts.model):
        print('loading model from %s' % args.model)
        step_start, learning_rate = net_utils.load_net(args.model, net)

    if opts.cuda:
        net.cuda()

    net.train()

    data_generator = data_gen.get_batch(num_workers=opts.num_readers,
                                        input_size=opts.input_size,
                                        batch_size=opts.batch_size,
                                        train_list=opts.train_list,
                                        geo_type=opts.geo_type)

    dg_ocr = ocr_gen.get_batch(num_workers=2,
                               batch_size=opts.ocr_batch_size,
                               train_list=opts.ocr_feed_list,
                               in_train=True,
                               norm_height=norm_height,
                               rgb=True)

    train_loss = 0
    bbox_loss, seg_loss, angle_loss = 0., 0., 0.
    cnt = 0
    ctc_loss = CTCLoss()

    ctc_loss_val = 0
    box_loss_val = 0
    good_all = 0
    gt_all = 0

    best_step = step_start
    best_loss = 1000000
    best_model = net.state_dict()
    best_optimizer = optimizer.state_dict()
    best_learning_rate = learning_rate
    max_patience = 3000
    early_stop = False

    for step in range(step_start, opts.max_iters):

        # batch
        images, image_fns, score_maps, geo_maps, training_masks, gtso, lbso, gt_idxs = next(
            data_generator)
        im_data = net_utils.np_to_variable(images, is_cuda=opts.cuda).permute(
            0, 3, 1, 2)
        start = timeit.timeit()
        try:
            seg_pred, roi_pred, angle_pred, features = net(im_data)
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            continue
        end = timeit.timeit()

        # backward

        smaps_var = net_utils.np_to_variable(score_maps, is_cuda=opts.cuda)
        training_mask_var = net_utils.np_to_variable(training_masks,
                                                     is_cuda=opts.cuda)
        angle_gt = net_utils.np_to_variable(geo_maps[:, :, :, 4],
                                            is_cuda=opts.cuda)
        geo_gt = net_utils.np_to_variable(geo_maps[:, :, :, [0, 1, 2, 3]],
                                          is_cuda=opts.cuda)

        try:
            loss = net.loss(seg_pred, smaps_var, training_mask_var, angle_pred,
                            angle_gt, roi_pred, geo_gt)
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            continue

        bbox_loss += net.box_loss_value.data.cpu().numpy()
        seg_loss += net.segm_loss_value.data.cpu().numpy()
        angle_loss += net.angle_loss_value.data.cpu().numpy()

        train_loss += loss.data.cpu().numpy()
        optimizer.zero_grad()

        try:

            if step > 10000:  #this is just extra augumentation step ... in early stage just slows down training
                ctcl, gt_b_good, gt_b_all = process_boxes(images,
                                                          im_data,
                                                          seg_pred[0],
                                                          roi_pred[0],
                                                          angle_pred[0],
                                                          score_maps,
                                                          gt_idxs,
                                                          gtso,
                                                          lbso,
                                                          features,
                                                          net,
                                                          ctc_loss,
                                                          opts,
                                                          debug=opts.debug)
                ctc_loss_val += ctcl.data.cpu().numpy()[0]
                loss = loss + ctcl
                gt_all += gt_b_all
                good_all += gt_b_good

            imageso, labels, label_length = next(dg_ocr)
            im_data_ocr = net_utils.np_to_variable(imageso,
                                                   is_cuda=opts.cuda).permute(
                                                       0, 3, 1, 2)
            features = net.forward_features(im_data_ocr)
            labels_pred = net.forward_ocr(features)

            probs_sizes = torch.IntTensor(
                [(labels_pred.permute(2, 0, 1).size()[0])] *
                (labels_pred.permute(2, 0, 1).size()[1]))
            label_sizes = torch.IntTensor(
                torch.from_numpy(np.array(label_length)).int())
            labels = torch.IntTensor(torch.from_numpy(np.array(labels)).int())
            loss_ocr = ctc_loss(labels_pred.permute(2, 0,
                                                    1), labels, probs_sizes,
                                label_sizes) / im_data_ocr.size(0) * 0.5

            loss_ocr.backward()
            loss.backward()

            optimizer.step()
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            pass
        cnt += 1
        if step % disp_interval == 0:

            if opts.debug:

                segm = seg_pred[0].data.cpu()[0].numpy()
                segm = segm.squeeze(0)
                cv2.imshow('segm_map', segm)

                segm_res = cv2.resize(score_maps[0],
                                      (images.shape[2], images.shape[1]))
                mask = np.argwhere(segm_res > 0)

                x_data = im_data.data.cpu().numpy()[0]
                x_data = x_data.swapaxes(0, 2)
                x_data = x_data.swapaxes(0, 1)

                x_data += 1
                x_data *= 128
                x_data = np.asarray(x_data, dtype=np.uint8)
                x_data = x_data[:, :, ::-1]

                im_show = x_data
                try:
                    im_show[mask[:, 0], mask[:, 1], 1] = 255
                    im_show[mask[:, 0], mask[:, 1], 0] = 0
                    im_show[mask[:, 0], mask[:, 1], 2] = 0
                except:
                    pass

                cv2.imshow('img0', im_show)
                cv2.imshow('score_maps', score_maps[0] * 255)
                cv2.imshow('train_mask', training_masks[0] * 255)
                cv2.waitKey(10)

            train_loss /= cnt
            bbox_loss /= cnt
            seg_loss /= cnt
            angle_loss /= cnt
            ctc_loss_val /= cnt
            box_loss_val /= cnt

            if train_loss < best_loss:
                best_step = step
                best_model = net.state_dict()
                best_loss = train_loss
                best_learning_rate = learning_rate
                best_optimizer = optimizer.state_dict()
            if best_step - step > max_patience:
                print("Early stopped criteria achieved.")
                save_name = os.path.join(
                    opts.save_path,
                    'BEST_{}_{}.h5'.format(model_name, best_step))
                state = {
                    'step': best_step,
                    'learning_rate': best_learning_rate,
                    'state_dict': best_model,
                    'optimizer': best_optimizer
                }
                torch.save(state, save_name)
                print('save model: {}'.format(save_name))
                opts.max_iters = step
                early_stop = True
            try:
                print(
                    'epoch %d[%d], loss: %.3f, bbox_loss: %.3f, seg_loss: %.3f, ang_loss: %.3f, ctc_loss: %.3f, rec: %.5f in %.3f'
                    % (step / batch_per_epoch, step, train_loss, bbox_loss,
                       seg_loss, angle_loss, ctc_loss_val,
                       good_all / max(1, gt_all), end - start))
                print('max_memory_allocated {}'.format(
                    torch.cuda.max_memory_allocated()))
            except:
                import sys, traceback
                traceback.print_exc(file=sys.stdout)
                pass

            train_loss = 0
            bbox_loss, seg_loss, angle_loss = 0., 0., 0.
            cnt = 0
            ctc_loss_val = 0
            good_all = 0
            gt_all = 0
            box_loss_val = 0

        #if step % valid_interval == 0:
        #  validate(opts.valid_list, net)
        if step > step_start and (step % batch_per_epoch == 0):
            save_name = os.path.join(opts.save_path,
                                     '{}_{}.h5'.format(model_name, step))
            state = {
                'step': step,
                'learning_rate': learning_rate,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'max_memory_allocated': torch.cuda.max_memory_allocated()
            }
            torch.save(state, save_name)
            print('save model: {}\tmax memory: {}'.format(
                save_name, torch.cuda.max_memory_allocated()))
    if not early_stop:
        save_name = os.path.join(opts.save_path, '{}.h5'.format(model_name))
        state = {
            'step': step,
            'learning_rate': learning_rate,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        torch.save(state, save_name)
        print('save model: {}'.format(save_name))
Exemplo n.º 5
0
def main(opts):
    model_name = 'E2E-MLT'
    # net = ModelResNetSep2(attention=True)
    net = ModelResNetSep_crnn(
        attention=True,
        multi_scale=True,
        num_classes=400,
        fixed_height=norm_height,
        net='densenet',
    )
    # net = ModelResNetSep_final(attention=True)
    print("Using {0}".format(model_name))
    ctc_loss = nn.CTCLoss()
    if opts.cuda:
        net.to(device)
        ctc_loss.to(device)
    learning_rate = opts.base_lr
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=opts.base_lr,
                                 weight_decay=weight_decay)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='max', factor=0.5, patience=4, verbose=True)
    scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,
                                                  base_lr=0.0006,
                                                  max_lr=0.001,
                                                  step_size_up=3000,
                                                  cycle_momentum=False)
    step_start = 0
    if os.path.exists(opts.model):
        print('loading model from %s' % args.model)
        # net_dict = net.state_dict()
        step_start, learning_rate = net_utils.load_net(args.model, net,
                                                       optimizer)
    #     step_start, learning_rate = net_utils.load_net(args.model, net, None)
    #
    #   step_start = 0
    net_utils.adjust_learning_rate(optimizer, learning_rate)

    net.train()

    data_generator = data_gen.get_batch(num_workers=opts.num_readers,
                                        input_size=opts.input_size,
                                        batch_size=opts.batch_size,
                                        train_list=opts.train_path,
                                        geo_type=opts.geo_type,
                                        normalize=opts.normalize)

    dg_ocr = ocr_gen.get_batch(num_workers=2,
                               batch_size=opts.ocr_batch_size,
                               train_list=opts.ocr_feed_list,
                               in_train=True,
                               norm_height=norm_height,
                               rgb=True,
                               normalize=opts.normalize)

    # e2edata = E2Edataset(train_list=opts.eval_path, normalize= opts.normalize)
    # e2edataloader = torch.utils.data.DataLoader(e2edata, batch_size=opts.batch_size, shuffle=True, collate_fn=E2Ecollate
    #                                           )

    train_loss = 0
    train_loss_temp = 0
    bbox_loss, seg_loss, angle_loss = 0., 0., 0.
    cnt = 1

    # ctc_loss = CTCLoss()

    ctc_loss_val = 0
    ctc_loss_val2 = 0
    ctcl = torch.tensor([0])
    box_loss_val = 0
    good_all = 0
    gt_all = 0
    train_loss_lr = 0
    cntt = 0
    time_total = 0
    now = time.time()

    for step in range(step_start, opts.max_iters):
        # scheduler.batch_step()

        # batch
        images, image_fns, score_maps, geo_maps, training_masks, gtso, lbso, gt_idxs = next(
            data_generator)
        im_data = net_utils.np_to_variable(images, is_cuda=opts.cuda).permute(
            0, 3, 1, 2)
        start = timeit.timeit()
        # cv2.imshow('img', images)
        try:
            seg_pred, roi_pred, angle_pred, features = net(im_data)
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            continue
        end = timeit.timeit()

        # backward

        smaps_var = net_utils.np_to_variable(score_maps, is_cuda=opts.cuda)
        training_mask_var = net_utils.np_to_variable(training_masks,
                                                     is_cuda=opts.cuda)
        angle_gt = net_utils.np_to_variable(geo_maps[:, :, :, 4],
                                            is_cuda=opts.cuda)
        geo_gt = net_utils.np_to_variable(geo_maps[:, :, :, [0, 1, 2, 3]],
                                          is_cuda=opts.cuda)

        try:
            # ? loss
            loss = net.loss(seg_pred, smaps_var, training_mask_var, angle_pred,
                            angle_gt, roi_pred, geo_gt)
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            continue

            # @ loss_val
        if not (torch.isnan(loss) or torch.isinf(loss)):
            train_loss_temp += loss.data.cpu().numpy()

        optimizer.zero_grad()

        try:

            if step > 1000 or True:  # this is just extra augumentation step ... in early stage just slows down training
                ctcl, gt_b_good, gt_b_all = process_boxes(images,
                                                          im_data,
                                                          seg_pred[0],
                                                          roi_pred[0],
                                                          angle_pred[0],
                                                          score_maps,
                                                          gt_idxs,
                                                          gtso,
                                                          lbso,
                                                          features,
                                                          net,
                                                          ctc_loss,
                                                          opts,
                                                          debug=opts.debug)

                # ? loss
                loss = loss + ctcl
                gt_all += gt_b_all
                good_all += gt_b_good

            imageso, labels, label_length = next(dg_ocr)
            im_data_ocr = net_utils.np_to_variable(imageso,
                                                   is_cuda=opts.cuda).permute(
                                                       0, 3, 1, 2)
            # features = net.forward_features(im_data_ocr)
            labels_pred = net.forward_ocr(im_data_ocr)

            probs_sizes = torch.IntTensor([(labels_pred.size()[0])] *
                                          (labels_pred.size()[1])).long()
            label_sizes = torch.IntTensor(
                torch.from_numpy(np.array(label_length)).int()).long()
            labels = torch.IntTensor(torch.from_numpy(
                np.array(labels)).int()).long()
            loss_ocr = ctc_loss(labels_pred, labels, probs_sizes,
                                label_sizes) / im_data_ocr.size(0) * 0.5

            loss_ocr.backward()
            # @ loss_val
            # ctc_loss_val2 += loss_ocr.item()

            loss.backward()

            clipping_value = 0.5
            torch.nn.utils.clip_grad_norm_(net.parameters(), clipping_value)
            if opts.d1:
                print('loss_nan', torch.isnan(loss))
                print('loss_inf', torch.isinf(loss))
                print('lossocr_nan', torch.isnan(loss_ocr))
                print('lossocr_inf', torch.isinf(loss_ocr))

            if not (torch.isnan(loss) or torch.isinf(loss)
                    or torch.isnan(loss_ocr) or torch.isinf(loss_ocr)):
                bbox_loss += net.box_loss_value.data.cpu().numpy()
                seg_loss += net.segm_loss_value.data.cpu().numpy()
                angle_loss += net.angle_loss_value.data.cpu().numpy()
                train_loss += train_loss_temp
                ctc_loss_val2 += loss_ocr.item()
                ctc_loss_val += ctcl.data.cpu().numpy()[0]
                # train_loss += loss.data.cpu().numpy()[0] #net.bbox_loss.data.cpu().numpy()[0]
                optimizer.step()
                scheduler.step()
                train_loss_temp = 0
                cnt += 1

        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            pass

        if step % disp_interval == 0:

            if opts.debug:

                segm = seg_pred[0].data.cpu()[0].numpy()
                segm = segm.squeeze(0)
                cv2.imshow('segm_map', segm)

                segm_res = cv2.resize(score_maps[0],
                                      (images.shape[2], images.shape[1]))
                mask = np.argwhere(segm_res > 0)

                x_data = im_data.data.cpu().numpy()[0]
                x_data = x_data.swapaxes(0, 2)
                x_data = x_data.swapaxes(0, 1)

                if opts.normalize:
                    x_data += 1
                    x_data *= 128
                x_data = np.asarray(x_data, dtype=np.uint8)
                x_data = x_data[:, :, ::-1]

                im_show = x_data
                try:
                    im_show[mask[:, 0], mask[:, 1], 1] = 255
                    im_show[mask[:, 0], mask[:, 1], 0] = 0
                    im_show[mask[:, 0], mask[:, 1], 2] = 0
                except:
                    pass

                cv2.imshow('img0', im_show)
                cv2.imshow('score_maps', score_maps[0] * 255)
                cv2.imshow('train_mask', training_masks[0] * 255)
                cv2.waitKey(10)

            train_loss /= cnt
            bbox_loss /= cnt
            seg_loss /= cnt
            angle_loss /= cnt
            ctc_loss_val /= cnt
            ctc_loss_val2 /= cnt
            box_loss_val /= cnt
            train_loss_lr += (train_loss)

            cntt += 1
            time_now = time.time() - now
            time_total += time_now
            now = time.time()
            for param_group in optimizer.param_groups:
                learning_rate = param_group['lr']
            save_log = os.path.join(opts.save_path, 'loss.txt')

            f = open(save_log, 'a')
            f.write(
                'epoch %d[%d], lr: %f, loss: %.3f, bbox_loss: %.3f, seg_loss: %.3f, ang_loss: %.3f, ctc_loss: %.3f, rec: %.5f, lv2: %.3f, time: %.2f s, cnt: %d\n'
                % (step / batch_per_epoch, step, learning_rate, train_loss,
                   bbox_loss, seg_loss, angle_loss, ctc_loss_val,
                   good_all / max(1, gt_all), ctc_loss_val2, time_now, cnt))
            f.close()
            try:

                print(
                    'epoch %d[%d], lr: %f, loss: %.3f, bbox_loss: %.3f, seg_loss: %.3f, ang_loss: %.3f, ctc_loss: %.3f, rec: %.5f, lv2: %.3f, time: %.2f s, cnt: %d\n'
                    %
                    (step / batch_per_epoch, step, learning_rate, train_loss,
                     bbox_loss, seg_loss, angle_loss, ctc_loss_val,
                     good_all / max(1, gt_all), ctc_loss_val2, time_now, cnt))
            except:
                import sys, traceback
                traceback.print_exc(file=sys.stdout)
                pass

            train_loss = 0
            bbox_loss, seg_loss, angle_loss = 0., 0., 0.
            cnt = 0
            ctc_loss_val = 0
            ctc_loss_val2 = 0
            good_all = 0
            gt_all = 0
            box_loss_val = 0

        # if step % valid_interval == 0:
        #  validate(opts.valid_list, net)
        if step > step_start and (step % batch_per_epoch == 0):
            for param_group in optimizer.param_groups:
                learning_rate = param_group['lr']
                print('learning_rate', learning_rate)
            save_name = os.path.join(opts.save_path,
                                     '{}_{}.h5'.format(model_name, step))
            state = {
                'step': step,
                'learning_rate': learning_rate,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(state, save_name)
            #evaluate
            re_tpe2e, re_tp, re_e1, precision = evaluate_e2e_crnn(
                root=args.eval_path,
                net=net,
                norm_height=norm_height,
                name_model=save_name,
                normalize=args.normalize,
                save_dir=args.save_path)
            # CER,WER = evaluate_crnn(e2edataloader,net)

            # scheduler.step(re_tpe2e)
            f = open(save_log, 'a')
            f.write(
                'time epoch [%d]: %.2f s, loss_total: %.3f, lr:%f, re_tpe2e = %f, re_tp = %f, re_e1 = %f, precision = %f\n'
                % (step / batch_per_epoch, time_total, train_loss_lr / cntt,
                   learning_rate, re_tpe2e, re_tp, re_e1, precision))
            f.close()
            print(
                'time epoch [%d]: %.2f s, loss_total: %.3f, re_tpe2e = %f, re_tp = %f, re_e1 = %f, precision = %f'
                % (step / batch_per_epoch, time_total, train_loss_lr / cntt,
                   re_tpe2e, re_tp, re_e1, precision))
            #print('time epoch [%d]: %.2f s, loss_total: %.3f' % (step / batch_per_epoch, time_total,train_loss_lr/cntt))
            print('save model: {}'.format(save_name))
            time_total = 0
            cntt = 0
            train_loss_lr = 0
            net.train()
Exemplo n.º 6
0
start_epoch = 0  # start from epoch 0 or last epoch

arm_loss = ArmBoxLoss()
refin_anchors_model = Refin_anchors()
net = RefinDet(num_anchors=9, cls_num=21)
net.fpn.load_state_dict(
    torch.load('/home/dsl/all_check/resnet50-19c8e357.pth'), strict=False)
net.cuda()

criterion = MultiBoxLoss(num_classes=21)
optimizer = optim.SGD(net.parameters(),
                      lr=0.001,
                      momentum=0.9,
                      weight_decay=1e-4)
gen_bdd = data_gen.get_batch(batch_size=config.batch_size,
                             class_name='voc',
                             image_size=config.image_size,
                             max_detect=100)
sch = optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.9)


# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()

    train_loss = 0
    for x in range(10000):
        images, true_box, true_label = next(gen_bdd)
        try:
            loc_t, conf_t = get_loc_conf_new(true_box,
                                             true_label,