Exemplo n.º 1
0
def train():
    total_timer = Timer()
    train_timer = Timer()
    load_timer = Timer()
    max_epoch = 30
    epoch_step = int(cfg.train_num//cfg.batch_size)
    t = 1
    for epoch in range(1, max_epoch + 1):
        print('-'*25, 'epoch', epoch,'/',str(max_epoch), '-'*25)


        t_loss = 0
        ll_loss = 0
        r_loss = 0
        c_loss = 0
        
        
       
        for step in range(1, epoch_step + 1):
     
            t = t + 1
            total_timer.tic()
            load_timer.tic()
 
            images, labels, imnm, num_boxes, imsize = data.get()
            
#             load_timer.toc()
            feed_dict = {input_: images,
                         get_boxes: labels[..., 1:5][:, ::-1, :],
                         get_classes: labels[..., 0].reshape((cfg.batch_size, -1))[:, ::-1]
                        }
            if cfg.cnt_branch:
                _, g_step_, tt_loss, cl_loss, cn_loss, re_loss, lr_ = sess.run(
                    [train_op,
                     global_step,
                     total_loss,
                     cls_loss, 
                     cnt_loss, 
                     reg_loss,
                     lr], feed_dict = feed_dict)
            else:
                _, g_step_, tt_loss, cl_loss, re_loss, lr_ = sess.run(
                    [train_op,
                     global_step,
                     total_loss,
                     cls_loss, 
                     reg_loss,
                     lr], feed_dict = feed_dict)
            
            
            total_timer.toc()
            if g_step_%50 ==0:
                sys.stdout.write('\r>> ' + 'iters '+str(g_step_)+str('/')+str(epoch_step*max_epoch)+' loss '+str(tt_loss) + ' ')
                sys.stdout.flush()
                summary_str = sess.run(summary_op, feed_dict = feed_dict)
                
                train_total_summary = tf.Summary(value=[
                    tf.Summary.Value(tag="config/learning rate", simple_value=lr_),
                    tf.Summary.Value(tag="train/classification/focal_loss", simple_value=cfg.class_weight*cl_loss),
                    tf.Summary.Value(tag="train/classification/cnt_loss", simple_value=cfg.cnt_weight*cn_loss),
#                     tf.Summary.Value(tag="train/p_nm", simple_value=p_nm_),
                    tf.Summary.Value(tag="train/regress_loss", simple_value=cfg.regress_weight*re_loss),
#                     tf.Summary.Value(tag="train/clone_loss", simple_value=cfg.class_weight*cl_loss + cfg.regress_weight*re_loss + cfg.cnt_weight*cn_loss),
#                     tf.Summary.Value(tag="train/l2_loss", simple_value=l2_loss),
                    tf.Summary.Value(tag="train/total_loss", simple_value=tt_loss)
                    ])
                print('curent speed: ', total_timer.diff, 'remain time: ', total_timer.remain(g_step_, epoch_step*max_epoch))
                summary_writer.add_summary(summary_str, g_step_)
                summary_writer.add_summary(train_total_summary, g_step_)
            if g_step_%10000 == 0:
                print('saving checkpoint')
                saver.save(sess, cfg.ckecpoint_file + '/model.ckpt', g_step_)

        total_timer.toc()
        sys.stdout.write('\n')
        print('>> mean loss', t_loss)
        print('curent speed: ', total_timer.average_time, 'remain time: ', total_timer.remain(g_step_, epoch_step*max_epoch))
        
    print('saving checkpoint')
    saver.save(sess, cfg.ckecpoint_file + '/model.ckpt', g_step_)
Exemplo n.º 2
0
                        imnm[i], pred_b[j][0], pred_b[j][1], pred_b[j][2],
                        pred_b[j][3], pred_s[j], pred_l[j] + 1
                    ])
            single_gt_num = np.where(labels[i][:, 0] > 0)[0].shape[0]
            box = np.hstack((labels[i][:single_gt_num, 1:],
                             np.reshape(labels[i][:single_gt_num, 0],
                                        (-1, 1)))).tolist()
            gt_dict[imnm[i]] = box

        val_timer.toc()
        sys.stdout.write('\r>> ' + 'val_nums ' + str(val_step) + str('/') +
                         str(cfg.test_num + 1))
        sys.stdout.flush()

    print('curent val speed: ', val_timer.average_time, 'val remain time: ',
          val_timer.remain(val_step, cfg.test_num + 1))
    if cfg.cnt_branch:
        print('val mean regress loss: ', val_rloss, 'val mean class loss: ',
              val_closs, 'val mean cnt loss: ', val_cnt_loss)
    else:
        print('val mean regress loss: ', val_rloss, 'val mean class loss: ',
              val_closs)
    mean_rec = 0
    mean_prec = 0
    mAP = 0
    for classidx in range(1, cfg.class_num):  #从1到21,对应[bg,...]21个类(除bg)
        rec, prec, ap = voc_eval(gt_dict,
                                 val_pred,
                                 classidx,
                                 iou_thres=0.5,
                                 use_07_metric=False)