Пример #1
0
def run_train():

    # output dir, etc
    out_dir = '/home/dongwoo/Project/MV3D/data/out'
    makedirs(out_dir + '/tf')
    makedirs(out_dir + '/check_points')
    log = Logger(out_dir + '/log.txt', mode='a')

    #lidar data -----------------
    if 1:
        ratios = np.array([0.5, 1, 2], dtype=np.float32)
        scales = np.array([1, 2, 3], dtype=np.float32)
        bases = make_bases(base_size=16, ratios=ratios, scales=scales)
        num_bases = len(bases)
        stride = 8

        rgbs, tops, fronts, gt_labels, gt_boxes3d, top_imgs, front_imgs, lidars = load_dummy_datas(
        )
        num_frames = len(rgbs)

        top_shape = tops[0].shape
        front_shape = fronts[0].shape
        rgb_shape = rgbs[0].shape
        top_feature_shape = (top_shape[0] // stride, top_shape[1] // stride)
        out_shape = (8, 3)

        #-----------------------
        #check data
        if 0:
            fig = mlab.figure(figure=None,
                              bgcolor=(0, 0, 0),
                              fgcolor=None,
                              engine=None,
                              size=(1000, 500))
            draw_lidar(lidars[0], fig=fig)
            draw_gt_boxes3d(gt_boxes3d[0], fig=fig)
            mlab.show(1)
            cv2.waitKey(1)

    # set anchor boxes

    num_class = 2  #incude background
    anchors, inside_inds = make_anchors(bases, stride, top_shape[0:2],
                                        top_feature_shape[0:2])
    inside_inds = np.arange(0, len(anchors), dtype=np.int32)  #use all  #<todo>
    print('out_shape=%s' % str(out_shape))
    print('num_frames=%d' % num_frames)

    #load model ####################################################################################################
    top_anchors = tf.placeholder(shape=[None, 4],
                                 dtype=tf.int32,
                                 name='anchors')
    top_inside_inds = tf.placeholder(shape=[None],
                                     dtype=tf.int32,
                                     name='inside_inds')

    top_images = tf.placeholder(shape=[None, *top_shape],
                                dtype=tf.float32,
                                name='top')
    front_images = tf.placeholder(shape=[None, *front_shape],
                                  dtype=tf.float32,
                                  name='front')
    rgb_images = tf.placeholder(shape=[None, *rgb_shape],
                                dtype=tf.float32,
                                name='rgb')
    top_rois = tf.placeholder(shape=[None, 5],
                              dtype=tf.float32,
                              name='top_rois')  #<todo> change to int32???
    front_rois = tf.placeholder(shape=[None, 5],
                                dtype=tf.float32,
                                name='front_rois')
    rgb_rois = tf.placeholder(shape=[None, 5],
                              dtype=tf.float32,
                              name='rgb_rois')

    top_features, top_scores, top_probs, top_deltas, proposals, proposal_scores = \
        top_feature_net(top_images, top_anchors, top_inside_inds, num_bases)

    front_features = front_feature_net(front_images)
    rgb_features = rgb_feature_net(rgb_images)

    fuse_scores, fuse_probs, fuse_deltas = \
        fusion_net(
    ( [top_features,     top_rois,     6,6,1./stride],
     [front_features,   front_rois,   0,0,1./stride],  #disable by 0,0
     [rgb_features,     rgb_rois,     6,6,1./stride],),
            num_class, out_shape) #<todo>  add non max suppression

    #loss ########################################################################################################
    top_inds = tf.placeholder(shape=[None], dtype=tf.int32, name='top_ind')
    top_pos_inds = tf.placeholder(shape=[None],
                                  dtype=tf.int32,
                                  name='top_pos_ind')
    top_labels = tf.placeholder(shape=[None], dtype=tf.int32, name='top_label')
    top_targets = tf.placeholder(shape=[None, 4],
                                 dtype=tf.float32,
                                 name='top_target')
    top_cls_loss, top_reg_loss = rpn_loss(top_scores, top_deltas, top_inds,
                                          top_pos_inds, top_labels,
                                          top_targets)

    fuse_labels = tf.placeholder(shape=[None],
                                 dtype=tf.int32,
                                 name='fuse_label')
    fuse_targets = tf.placeholder(shape=[None, *out_shape],
                                  dtype=tf.float32,
                                  name='fuse_target')
    fuse_cls_loss, fuse_reg_loss = rcnn_loss(fuse_scores, fuse_deltas,
                                             fuse_labels, fuse_targets)

    #solver
    l2 = l2_regulariser(decay=0.001)
    learning_rate = tf.placeholder(tf.float32, shape=[])
    #solver = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
    solver = tf.train.GradientDescentOptimizer(learning_rate=learning_rate,
                                               use_locking=False,
                                               name='GradientDescent')
    #solver_step = solver.minimize(top_cls_loss+top_reg_loss+l2)
    solver_step = solver.minimize(top_cls_loss + top_reg_loss + fuse_cls_loss +
                                  0.1 * fuse_reg_loss + l2)

    max_iter = 10000
    iter_debug = 8

    # start training here  #########################################################################################
    log.write(
        'epoch     iter    rate   |  top_cls_loss   reg_loss   |  fuse_cls_loss  reg_loss  |  \n'
    )
    log.write(
        '-------------------------------------------------------------------------------------\n'
    )

    num_ratios = len(ratios)
    num_scales = len(scales)
    fig, axs = plt.subplots(num_ratios, num_scales)

    sess = tf.InteractiveSession()
    with sess.as_default():
        sess.run(tf.global_variables_initializer(), {IS_TRAIN_PHASE: True})
        summary_writer = tf.summary.FileWriter(out_dir + '/tf', sess.graph)
        saver = tf.train.Saver()

        batch_top_cls_loss = 0
        batch_top_reg_loss = 0
        batch_fuse_cls_loss = 0
        batch_fuse_reg_loss = 0
        iter = 0
        while iter < max_iter:
            #for iter in range(max_iter):
            epoch = 1.0 * iter
            rate = 0.05

            ## generate train image -------------
            idx = np.random.choice(num_frames)  #*10   #num_frames)  #0
            #print (idx)
            batch_top_images = tops[idx].reshape(1, *top_shape)
            batch_front_images = fronts[idx].reshape(1, *front_shape)
            batch_rgb_images = rgbs[idx].reshape(1, *rgb_shape)

            batch_gt_labels = gt_labels[idx]
            batch_gt_boxes3d = gt_boxes3d[idx]
            batch_gt_top_boxes = box3d_to_top_box(batch_gt_boxes3d)

            if len(batch_gt_labels) == 0:
                continue

## run propsal generation ------------
            fd1 = {
                top_images: batch_top_images,
                top_anchors: anchors,
                top_inside_inds: inside_inds,
                learning_rate: rate,
                IS_TRAIN_PHASE: True
            }
            batch_proposals, batch_proposal_scores, batch_top_features = sess.run(
                [proposals, proposal_scores, top_features], fd1)

            ## generate  train rois  ------------
            #print (anchors)
            #print (inside_inds)
            #print (batch_gt_labels)
            #print (batch_gt_top_boxes)
            batch_top_inds, batch_top_pos_inds, batch_top_labels, batch_top_targets  = \
                rpn_target ( anchors, inside_inds, batch_gt_labels,  batch_gt_top_boxes)

            batch_top_rois, batch_fuse_labels, batch_fuse_targets  = \
                 rcnn_target(  batch_proposals, batch_gt_labels, batch_gt_top_boxes, batch_gt_boxes3d )

            batch_rois3d = project_to_roi3d(batch_top_rois)
            batch_front_rois = project_to_front_roi(batch_rois3d)
            batch_rgb_rois = project_to_rgb_roi(batch_rois3d)

            ##debug gt generation
            if 1 and iter % iter_debug == 1:
                top_image = top_imgs[idx]
                rgb = rgbs[idx]

                img_gt = draw_rpn_gt(top_image, batch_gt_top_boxes,
                                     batch_gt_labels)
                img_label = draw_rpn_labels(top_image, anchors, batch_top_inds,
                                            batch_top_labels)
                img_target = draw_rpn_targets(top_image, anchors,
                                              batch_top_pos_inds,
                                              batch_top_targets)
                #imshow('img_rpn_gt',img_gt)
                #imshow('img_rpn_label',img_label)
                #imshow('img_rpn_target',img_target)

                img_label = draw_rcnn_labels(top_image, batch_top_rois,
                                             batch_fuse_labels)
                img_target = draw_rcnn_targets(top_image, batch_top_rois,
                                               batch_fuse_labels,
                                               batch_fuse_targets)
                #imshow('img_rcnn_label',img_label)
                #imshow('img_rcnn_target',img_target)

                img_rgb_rois = draw_boxes(rgb,
                                          batch_rgb_rois[:, 1:5],
                                          color=(255, 0, 255),
                                          thickness=1)
                #imshow('img_rgb_rois',img_rgb_rois)

                #cv2.waitKey(1)

            ## run classification and regression loss -----------
            fd2 = {
                **fd1,
                top_images: batch_top_images,
                front_images: batch_front_images,
                rgb_images: batch_rgb_images,
                top_rois: batch_top_rois,
                front_rois: batch_front_rois,
                rgb_rois: batch_rgb_rois,
                top_inds: batch_top_inds,
                top_pos_inds: batch_top_pos_inds,
                top_labels: batch_top_labels,
                top_targets: batch_top_targets,
                fuse_labels: batch_fuse_labels,
                fuse_targets: batch_fuse_targets,
            }
            #_, batch_top_cls_loss, batch_top_reg_loss = sess.run([solver_step, top_cls_loss, top_reg_loss],fd2)


            _, batch_top_cls_loss, batch_top_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss = \
               sess.run([solver_step, top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss],fd2)

            log.write('%3.1f   %d   %0.4f   |   %0.5f   %0.5f   |   %0.5f   %0.5f  \n' %\
    (epoch, iter, rate, batch_top_cls_loss, batch_top_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss))

            # debug: ------------------------------------

            if iter % iter_debug == 0:
                top_image = top_imgs[idx]
                rgb = rgbs[idx]

                batch_top_probs, batch_top_scores, batch_top_deltas  = \
                    sess.run([ top_probs, top_scores, top_deltas ],fd2)

                batch_fuse_probs, batch_fuse_deltas = \
                    sess.run([ fuse_probs, fuse_deltas ],fd2)

                #batch_fuse_deltas=0*batch_fuse_deltas #disable 3d box prediction
                probs, boxes3d = rcnn_nms(batch_fuse_probs,
                                          batch_fuse_deltas,
                                          batch_rois3d,
                                          threshold=0.5)

                ## show rpn score maps
                p = batch_top_probs.reshape(*(top_feature_shape[0:2]),
                                            2 * num_bases)
                for n in range(num_bases):
                    r = n % num_scales
                    s = n // num_scales
                    pn = p[:, :, 2 * n + 1] * 255
                    axs[s, r].cla()
                    axs[s, r].imshow(pn, cmap='gray', vmin=0, vmax=255)
                plt.pause(0.01)

                ## show rpn(top) nms
                #img_rpn     = draw_rpn    (top_image, batch_top_probs, batch_top_deltas, anchors, inside_inds)
                img_rpn_nms = draw_rpn_nms(top_image, batch_proposals,
                                           batch_proposal_scores)
                #imshow('img_rpn',img_rpn)
                #imshow('img_rpn_nms',img_rpn_nms)
                #cv2.waitKey(1)

                ## show rcnn(fuse) nms
                #img_rcnn     = draw_rcnn (top_image, batch_fuse_probs, batch_fuse_deltas, batch_top_rois, batch_rois3d,darker=1)
                #img_rcnn_nms = draw_rcnn_nms(rgb, boxes3d, probs)
                #imshow('img_rcnn',img_rcnn)
                #imshow('img_rcnn_nms',img_rcnn_nms)
                #cv2.waitKey(1)

            # save: ------------------------------------
            if iter % 500 == 0:
                #saver.save(sess, out_dir + '/check_points/%06d.ckpt'%iter)  #iter
                saver.save(sess, out_dir + '/check_points/snap.ckpt')  #iter

            iter = iter + 1
Пример #2
0
def run_train():

    # output dir, etc
    out_dir = '/root/sharefolder/sdcnd/didi1/output'
    # makedirs(out_dir +'/tf')
    # makedirs(out_dir +'/check_points')
    log = Logger(out_dir+'/log.txt',mode='a')
    # log.write(unicode('aaa {}'.format('aaa')))
    #lidar data -----------------
    if 1:
        ratios=np.array([0.5,1,2], dtype=np.float32)
        scales=np.array([1,2,3],   dtype=np.float32)
        bases = make_bases(
            base_size = 16,
            ratios=ratios,
            scales=scales
        )
        num_bases = len(bases)
        stride = 8

        num_frames = 154
        # num_frames = 2
        rgbs, tops, fronts, gt_labels, gt_boxes3d, top_imgs, front_imgs, lidars = load_dummy_datas(num_frames)
        num_frames = len(rgbs)

        top_shape   = tops[0].shape
        front_shape = fronts[0].shape
        rgb_shape   = rgbs[0].shape
        top_feature_shape = (top_shape[0]//stride, top_shape[1]//stride)
        out_shape=(8,3)


        #-----------------------
        #check data
        if 0:
            fig = mlab.figure(figure=None, bgcolor=(0,0,0), fgcolor=None, engine=None, size=(1000, 500))
            draw_lidar(lidars[0], fig=fig)
            draw_gt_boxes3d(gt_boxes3d[0], fig=fig)
            mlab.show(1)
            cv2.waitKey(1)




    # set anchor boxes
    num_class = 2 #incude background
    anchors, inside_inds =  make_anchors(bases, stride, top_shape[0:2], top_feature_shape[0:2])
    inside_inds = np.arange(0,len(anchors),dtype=np.int32)  #use all  #<todo>
    print ('out_shape=%s'%str(out_shape))
    print ('num_frames=%d'%num_frames)


    #load model ####################################################################################################
    top_anchors     = tf.placeholder(shape=[None, 4], dtype=tf.int32,   name ='anchors'    )
    top_inside_inds = tf.placeholder(shape=[None   ], dtype=tf.int32,   name ='inside_inds')

    top_images   = tf.placeholder(shape=[None, 400, 400, 8 ], dtype=tf.float32, name='input_top'  )
    front_images = tf.placeholder(shape=[None, 1, 1], dtype=tf.float32, name='front')
    rgb_images   = tf.placeholder(shape=[None, 375, 1242, 3  ], dtype=tf.float32, name='rgb'  )
    top_rois     = tf.placeholder(shape=[None, 5], dtype=tf.float32,   name ='top_rois'   ) #<todo> change to int32???
    front_rois   = tf.placeholder(shape=[None, 5], dtype=tf.float32,   name ='front_rois' )
    rgb_rois     = tf.placeholder(shape=[None, 5], dtype=tf.float32,   name ='rgb_rois'   )

    top_features, top_scores, top_probs, top_deltas, proposals, proposal_scores = \
        top_feature_net(top_images, top_anchors, top_inside_inds, num_bases)

    front_features = front_feature_net(front_images)
    rgb_features   = rgb_feature_net(rgb_images)

    # import pdb; pdb.set_trace()
    fuse_scores, fuse_probs, fuse_deltas, aux_fuse_scores, aux_fuse_probs, aux_fuse_deltas = \
        fusion_net(
			( [top_features,     top_rois,     6,6,1./stride],
			  [front_features,   front_rois,   0,0,1./stride],  #disable by 0,0
			  [rgb_features,     rgb_rois,     6,6,1./stride],),
            num_class, out_shape) #<todo>  add non max suppression

    # import pdb; pdb.set_trace()


    #loss ########################################################################################################
    top_inds     = tf.placeholder(shape=[None   ], dtype=tf.int32,   name='top_ind'    )
    top_pos_inds = tf.placeholder(shape=[None   ], dtype=tf.int32,   name='top_pos_ind')
    top_labels   = tf.placeholder(shape=[None   ], dtype=tf.int32,   name='top_label'  )
    top_targets  = tf.placeholder(shape=[None, 4], dtype=tf.float32, name='top_target' )
    with tf.variable_scope('rpn-loss') as scope:
        top_cls_loss, top_reg_loss = rpn_loss(top_scores, top_deltas, top_inds, top_pos_inds, top_labels, top_targets)
    tf.summary.scalar('top_cls_loss', top_cls_loss)
    tf.summary.scalar('top_reg_loss', top_reg_loss)


    fuse_labels  = tf.placeholder(shape=[None            ], dtype=tf.int32,   name='fuse_label' )
    fuse_targets = tf.placeholder(shape=[None, 8, 3], dtype=tf.float32, name='fuse_target')
    with tf.variable_scope('rcnn-loss') as scope:
        fuse_cls_loss, fuse_reg_loss = rcnn_loss(fuse_scores, fuse_deltas, fuse_labels, fuse_targets)
    tf.summary.scalar('fuse_cls_loss', fuse_cls_loss)
    tf.summary.scalar('fuse_reg_loss', fuse_reg_loss)

    with tf.variable_scope('aux_rcnn_loss') as scope:
        with tf.variable_scope('aux_loss_1') as scope:
            aux_fuse_cls_loss_1, aux_fuse_reg_loss_1 = rcnn_loss(aux_fuse_scores[0], aux_fuse_deltas[0],
             fuse_labels, fuse_targets)
        tf.summary.scalar('aux_fuse_cls_loss_1', aux_fuse_cls_loss_1)
        tf.summary.scalar('aux_fuse_reg_loss_1', aux_fuse_reg_loss_1)
        with tf.variable_scope('aux_loss_2') as scope:
            aux_fuse_cls_loss_2, aux_fuse_reg_loss_2 = rcnn_loss(aux_fuse_scores[1], aux_fuse_deltas[1],
             fuse_labels, fuse_targets)
        tf.summary.scalar('aux_fuse_cls_loss_2', aux_fuse_cls_loss_2)
        tf.summary.scalar('aux_fuse_reg_loss_2', aux_fuse_reg_loss_1)
    #solver
    # with tf.variable_scope('l2-reg') as scope:
    #     l2 = l2_regulariser(decay=0.0005)
    # tf.summary.scalar('total_l2reg', l2)

    with tf.variable_scope('total_loss') as scope:
        total_loss = top_cls_loss+top_reg_loss+fuse_cls_loss+0.1*fuse_reg_loss \
                    + aux_fuse_cls_loss_1 + aux_fuse_reg_loss_1 + aux_fuse_cls_loss_2 + aux_fuse_reg_loss_2
    tf.summary.scalar('total_loss', total_loss)

    learning_rate = tf.placeholder(tf.float32, shape=[])
    solver = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
    #solver_step = solver.minimize(top_cls_loss+top_reg_loss+l2)
    solver_step = solver.minimize(total_loss)

    max_iter = 10000
    iter_debug=8

    # start training here  #########################################################################################
    log.write(unicode('epoch     iter    rate   |  top_cls_loss   reg_loss   |  fuse_cls_loss  reg_loss  |  \n'))
    log.write(unicode('-------------------------------------------------------------------------------------\n'))

    num_ratios=len(ratios)
    num_scales=len(scales)
    # fig, axs = plt.subplots(num_ratios,num_scales)

    sess = tf.InteractiveSession()
    # with sess.as_default():
    merged = tf.summary.merge_all()

    log_dir = out_dir+'/train'
    if tf.gfile.Exists(log_dir):
        #gotta be careful
        tf.gfile.DeleteRecursively(log_dir)
        print 'Removed files in {}'.format(log_dir)
    train_writer = tf.summary.FileWriter(log_dir, sess.graph)

    saver  = tf.train.Saver()
    tf.global_variables_initializer().run()
    # sess.run( tf.global_variables_initializer(), { IS_TRAIN_PHASE : True } )

    #option: loading pretrained model
    saver.restore(sess, '/root/sharefolder/sdcnd/didi1/output/check_points/snap.ckpt')

    batch_top_cls_loss =0
    batch_top_reg_loss =0
    batch_fuse_cls_loss=0
    batch_fuse_reg_loss=0
    for iter in range(max_iter):
        epoch=1.0*iter
        rate=0.05


        ## generate train image -------------
        idx = np.random.choice(num_frames)     #*10   #num_frames)  #0
        idx = 87
        batch_top_images    = tops[idx].reshape(1,*top_shape)
        batch_front_images  = fronts[idx].reshape(1,*front_shape)
        batch_rgb_images    = rgbs[idx].reshape(1,*rgb_shape)

        batch_gt_labels    = gt_labels[idx]
        batch_gt_boxes3d   = gt_boxes3d[idx]
        batch_gt_top_boxes = box3d_to_top_box(batch_gt_boxes3d)


		## run propsal generation ------------
        fd1={
            top_images:      batch_top_images,
            top_anchors:     anchors,
            top_inside_inds: inside_inds,

            learning_rate:   rate,
            IS_TRAIN_PHASE:  True
        }
        batch_proposals, batch_proposal_scores, batch_top_features = sess.run([proposals, proposal_scores, top_features],fd1)

        ## generate  train rois  ------------
        batch_top_inds, batch_top_pos_inds, batch_top_labels, batch_top_targets  = \
            rpn_target ( anchors, inside_inds, batch_gt_labels,  batch_gt_top_boxes)

        batch_top_rois, batch_fuse_labels, batch_fuse_targets  = \
             rcnn_target(  batch_proposals, batch_gt_labels, batch_gt_top_boxes, batch_gt_boxes3d )

        batch_rois3d	 = project_to_roi3d    (batch_top_rois)
        batch_front_rois = project_to_front_roi(batch_rois3d  )
        batch_rgb_rois   = project_to_rgb_roi  (batch_rois3d  )


        ##debug gt generation
        if False:
        # if 1 and iter%iter_debug==0:
            top_image = top_imgs[idx]
            rgb       = rgbs[idx]

            img_gt     = draw_rpn_gt(top_image, batch_gt_top_boxes, batch_gt_labels)
            img_label  = draw_rpn_labels (top_image, anchors, batch_top_inds, batch_top_labels )
            img_target = draw_rpn_targets(top_image, anchors, batch_top_pos_inds, batch_top_targets)
            #imshow('img_rpn_gt',img_gt)
            #imshow('img_rpn_label',img_label)
            #imshow('img_rpn_target',img_target)

            img_label  = draw_rcnn_labels (top_image, batch_top_rois, batch_fuse_labels )
            img_target = draw_rcnn_targets(top_image, batch_top_rois, batch_fuse_labels, batch_fuse_targets)
            #imshow('img_rcnn_label',img_label)
            imshow('img_rcnn_target',img_target)


            img_rgb_rois = draw_boxes(rgb, batch_rgb_rois[:,1:5], color=(255,0,255), thickness=1)
            imshow('img_rgb_rois',img_rgb_rois)

            cv2.waitKey(1)

        ## run classification and regression loss -----------
        fd2={
			top_images:      batch_top_images,
            top_anchors:     anchors,
            top_inside_inds: inside_inds,

            learning_rate:   rate,
            IS_TRAIN_PHASE:  True,

            top_images: batch_top_images,
            front_images: batch_front_images,
            rgb_images: batch_rgb_images,

			top_rois:   batch_top_rois,
            front_rois: batch_front_rois,
            rgb_rois:   batch_rgb_rois,

            top_inds:     batch_top_inds,
            top_pos_inds: batch_top_pos_inds,
            top_labels:   batch_top_labels,
            top_targets:  batch_top_targets,

            fuse_labels:  batch_fuse_labels,
            fuse_targets: batch_fuse_targets,
        }
        #_, batch_top_cls_loss, batch_top_reg_loss = sess.run([solver_step, top_cls_loss, top_reg_loss],fd2)

        # import pdb; pdb.set_trace()
        # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        # run_metadata = tf.RunMetadata()
        run_options = None
        run_metadata = None
        _, summary, batch_top_cls_loss, batch_top_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss = \
           sess.run([solver_step, merged, top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss],
                    feed_dict = fd2,
                    options = run_options,
                    run_metadata = run_metadata)
        # train_writer.add_run_metadata(run_metadata, 'step%03d' % iter)
        train_writer.add_summary(summary, iter)
        train_writer.flush()

        log.write(unicode('%3.1f   %d   %0.4f   |   %0.5f   %0.5f   |   %0.5f   %0.5f  \n' %\
			(epoch, iter, rate, batch_top_cls_loss, batch_top_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss)))






        #print('ok')
        # debug: ------------------------------------
        if iter%10==0:
            top_image = top_imgs[idx]
            rgb       = rgbs[idx]


            batch_fuse_probs, batch_fuse_deltas = \
                sess.run([ fuse_probs, fuse_deltas ],fd2)

            #batch_fuse_deltas=0*batch_fuse_deltas #disable 3d box prediction
            probs, boxes3d = rcnn_nms(batch_fuse_probs, batch_fuse_deltas, batch_rois3d, threshold=0.5)


            ## show rcnn(fuse) nms
            gt_2d_box = batch_gt_top_boxes
            img_rcnn     = draw_rcnn (top_image, batch_fuse_probs, batch_fuse_deltas, batch_top_rois, batch_rois3d, gt_2d_box,darker=1)
            boxes_3d = rcnn_result( batch_fuse_probs, batch_fuse_deltas,  batch_top_rois, batch_rois3d, gt_2d_box)
            img_rcnn_nms = draw_rcnn_nms(rgb, boxes3d, probs)
            imshow('img_rcnn',img_rcnn)
            imshow('img_rcnn_nms',img_rcnn_nms)
            # cv2.imwrite('result.png', img_rcnn_nms)
            cv2.waitKey(1)

        if False:
        # if iter%100==0:
        # if iter%iter_debug==0:
            top_image = top_imgs[idx]
            rgb       = rgbs[idx]

            batch_top_probs, batch_top_scores, batch_top_deltas  = \
                sess.run([ top_probs, top_scores, top_deltas ],fd2)

            batch_fuse_probs, batch_fuse_deltas = \
                sess.run([ fuse_probs, fuse_deltas ],fd2)

            #batch_fuse_deltas=0*batch_fuse_deltas #disable 3d box prediction
            probs, boxes3d = rcnn_nms(batch_fuse_probs, batch_fuse_deltas, batch_rois3d, threshold=0.5)


            ## show rpn score maps
            # import pdb; pdb.set_trace()
            fig, axs = plt.subplots(num_ratios,num_scales)
            p = batch_top_probs.reshape( 50, 50, 2*num_bases)
            for n in range(num_bases):
                r=n%num_scales
                s=n//num_scales
                pn = p[:,:,2*n+1]*255
                axs[s,r].cla()
                axs[s,r].imshow(pn, cmap='gray', vmin=0, vmax=255)
            plt.pause(0.01)

			## show rpn(top) nms
            img_rpn     = draw_rpn    (top_image, batch_top_probs, batch_top_deltas, anchors, inside_inds)
            img_rpn_nms = draw_rpn_nms(top_image, batch_proposals, batch_proposal_scores)
            #imshow('img_rpn',img_rpn)
            imshow('img_rpn_nms',img_rpn_nms)
            cv2.waitKey(1)

            ## show rcnn(fuse) nms
            img_rcnn     = draw_rcnn (top_image, batch_fuse_probs, batch_fuse_deltas, batch_top_rois, batch_rois3d,darker=1)
            img_rcnn_nms = draw_rcnn_nms(rgb, boxes3d, probs)
            imshow('img_rcnn',img_rcnn)
            imshow('img_rcnn_nms',img_rcnn_nms)
            cv2.waitKey(1)

        # save: ------------------------------------
        if iter%500==0:
            #saver.save(sess, out_dir + '/check_points/%06d.ckpt'%iter)  #iter
            saver.save(sess, out_dir + '/check_points/snap.ckpt')  #iter

    train_writer.close()
Пример #3
0
    def fit_iteration(self,
                      is_validation=False,
                      summary_it=False,
                      summary_runmeta=False,
                      log_image=False,
                      summary_iou=False,
                      iou_statistic=False):

        data_set = self.validation_set if is_validation else self.train_set
        self.default_summary_writer = \
            self.val_summary_writer if is_validation else self.train_summary_writer

        self.step_name = 'validation' if is_validation else 'training'

        # load dataset
        self.batch_rgb_images, self.batch_top_view, self.batch_front_view, \
        self.batch_gt_labels, self.batch_gt_boxes3d, self.frame_id, self.calib = \
            data_set.load()

        # print(self.raw_img.files_path_mapping[self.frame_id])

        # resize rgb images
        self.batch_rgb_images = np.array([cv2.resize(x, (self.rgb_shape[1], self.rgb_shape[0])) \
                                          for x in self.batch_rgb_images])

        # fit_iterate log init
        if log_image:
            self.time_str = strftime("%Y_%m_%d_%H_%M", localtime())
            self.frame_info = data_set.get_frame_info()
            self.log_subdir = self.step_name + '/' + self.time_str
            self.top_image = data.draw_top_image(self.batch_top_view[0])
            # self.top_image = self.top_image_padding(top_image)

        net = self.net
        sess = self.sess

        # put tensorboard inside
        top_cls_loss = net['top_cls_loss']
        top_reg_loss = net['top_reg_loss']
        fuse_cls_loss = net['fuse_cls_loss']
        fuse_reg_loss = net['fuse_reg_loss']

        self.batch_gt_top_boxes = box.box3d_to_top_box(
            self.batch_gt_boxes3d[0])

        ## run propsal generation
        fd1 = {
            net['top_view']: self.batch_top_view,
            net['top_anchors']: self.top_view_anchors,
            net['top_inside_inds']: self.anchors_inside_inds,
            blocks.IS_TRAIN_PHASE: True,
            K.learning_phase(): 1
        }
        self.batch_proposals, self.batch_proposal_scores, self.batch_top_features = \
            sess.run([net['train_proposals'], net['train_proposal_scores'], net['top_features']], fd1)

        ## generate  train rois  for RPN
        self.batch_top_inds, self.batch_top_pos_inds, self.batch_top_labels, self.batch_top_targets = \
            rpn_target(self.top_view_anchors, self.anchors_inside_inds, self.batch_gt_labels[0],
                       self.batch_gt_top_boxes)


        self.batch_top_rois, self.batch_fuse_labels, self.batch_fuse_targets = \
           rcnn_target(self.batch_proposals, self.batch_gt_labels[0],
                         self.batch_gt_top_boxes, self.batch_gt_boxes3d[0])

        # print(self.anchors_details())
        # print(self.rpn_poposal_details())

        self.batch_rois3d = project_to_roi3d(self.batch_top_rois)
        self.batch_front_rois = project_to_front_roi(self.batch_rois3d)
        self.batch_rgb_rois = project_to_rgb_roi(self.batch_rois3d,
                                                 self.calib.velo_to_rgb)

        ## run classification and regression loss -----------
        fd2 = {
            **fd1,
            net['top_view']: self.batch_top_view,
            net['front_view']: self.batch_front_view,
            net['rgb_images']: self.batch_rgb_images,
            net['top_rois']: self.batch_top_rois,
            net['front_rois']: self.batch_front_rois,
            net['rgb_rois']: self.batch_rgb_rois,
            net['top_inds']: self.batch_top_inds,
            net['top_pos_inds']: self.batch_top_pos_inds,
            net['top_labels']: self.batch_top_labels,
            net['top_targets']: self.batch_top_targets,
            net['fuse_labels']: self.batch_fuse_labels,
            net['fuse_targets']: self.batch_fuse_targets,
        }
        """
        if self.debug_mode:
            print('\n\nstart debug mode\n\n')
            debug_sess=tf_debug.LocalCLIDebugWrapperSession(sess)
            debug_sess.add_tensor_filter('has_inf_or_nan', tf_debug.has_inf_or_nan)
            t_cls_loss, t_reg_loss, f_cls_loss, f_reg_loss = \
                debug_sess.run([top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss], fd2)
        """

        if summary_it:
            run_options = None
            run_metadata = None

            if is_validation:
                t_cls_loss, t_reg_loss, f_cls_loss, f_reg_loss, tb_sum_val = \
                    sess.run([top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss, self.summ], fd2)

                self.val_summary_writer.add_summary(tb_sum_val,
                                                    self.n_global_step)
                # print('added validation  summary ')
            else:
                if summary_runmeta:
                    run_options = tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE)
                    run_metadata = tf.RunMetadata()

                _, t_cls_loss, t_reg_loss, f_cls_loss, f_reg_loss, tb_sum_val = \
                    sess.run([self.solver_step, top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss,
                              self.summ], feed_dict=fd2, options=run_options, run_metadata=run_metadata)

                self.train_summary_writer.add_summary(tb_sum_val,
                                                      self.n_global_step)
                # print('added training  summary ')

                if summary_runmeta:
                    self.train_summary_writer.add_run_metadata(
                        run_metadata, 'step%d' % self.n_global_step)
                    # print('added runtime metadata.')

        else:
            if is_validation:
                t_cls_loss, t_reg_loss, f_cls_loss, f_reg_loss = \
                    sess.run([top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss], fd2)
            else:

                _, t_cls_loss, t_reg_loss, f_cls_loss, f_reg_loss = \
                    sess.run([self.solver_step, top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss],
                             feed_dict=fd2)

        if log_image:
            # t0 = time.time()
            step_name = 'validation' if is_validation else 'train'
            scope_name = '%s_iter_%06d' % (
                step_name, self.n_global_step -
                (self.n_global_step % self.iter_debug))

            boxes3d, lables = self.predict(self.batch_top_view, self.batch_front_view, \
                                           self.batch_rgb_images, self.calib.velo_to_rgb)

            if 0 and set(self.train_target) != set(mv3d_net.top_view_rpn_name):
                # get iou
                iou = -1
                inds = np.where(self.batch_gt_labels[0] != 0)
                try:
                    iou = box.boxes3d_score_iou(self.batch_gt_boxes3d[0][inds],
                                                boxes3d)
                    if iou_statistic: self.iou_store.append(iou)
                    if summary_iou:
                        iou_aver = sum(self.iou_store) / len(self.iou_store)
                        self.iou_store = []
                        tag = os.path.join('IOU')
                        self.summary_scalar(value=iou_aver,
                                            tag=tag,
                                            step=self.n_global_step)
                        self.log_msg.write('\n %s iou average: %.5f\n' %
                                           (self.step_name, iou_aver))
                except ValueError:
                    # print("waring :", sys.exc_info()[0])
                    pass

                #set scope name
                if iou == -1:
                    scope_name = os.path.join(scope_name,
                                              'iou_error'.format(range(5, 8)))
                else:
                    for iou_range in self.log_iou_range:
                        if int(iou * 100) in iou_range:
                            scope_name = os.path.join(
                                scope_name, 'iou_{}'.format(iou_range))

                # print('Summary log image, scope name: {}'.format(scope_name))

                self.log_fusion_net_target(self.batch_rgb_images[0],
                                           scope_name=scope_name)
            log_info_str = 'frame info: ' + self.frame_info + '\n'
            log_info_str += self.anchors_details()
            log_info_str += self.rpn_poposal_details()
            log_info_str += '\n'
            self.log_info(self.log_subdir, log_info_str)

            self.predict_log(self.log_subdir,
                             self.calib.velo_to_rgb,
                             log_rpn=True,
                             step=self.n_global_step,
                             scope_name=scope_name,
                             loss=(f_cls_loss, f_reg_loss),
                             frame_tag=self.frame_id,
                             is_train_mode=True)

            # self.log_msg.write('Image log  summary use time : {}\n'.format(time.time() - t0))

        return t_cls_loss, t_reg_loss, f_cls_loss, f_reg_loss
Пример #4
0
def run_train():

    # output dir, etc
    out_dir = './outputs'
    makedirs(out_dir +'/tf')
    makedirs(out_dir +'/check_points')
    log = Logger(out_dir+'/log_%s.txt'%(time.strftime('%Y-%m-%d %H:%M:%S')),mode='a')
    index=np.load(data_root+'seg/train_list.npy')
    index=sorted(index)
    index=np.array(index)
    num_frames = len(index)
    # pdb.set_trace()
    #lidar data -----------------
    if 1:
        ###generate anchor base 
        # ratios=np.array([0.4,0.6,1.7,2.4], dtype=np.float32)
        # scales=np.array([0.5,1,2,3],   dtype=np.float32)
        # bases = make_bases(
        #     base_size = 16,
        #     ratios=ratios,
        #     scales=scales
        # )
        ratios=np.array([1.7,2.4])
        scales=np.array([1.7,2.4])
        bases=np.array([[-19.5, -8, 19.5, 8],
                        [-8, -19.5, 8, 19.5],
                        [-5, -3, 5, 3],
                        [-3, -5, 3, 5]
                        ])
        # pdb.set_trace()
        num_bases = len(bases)
        stride = 4

        out_shape=(8,3)


        rgbs, tops, fronts, gt_labels, gt_boxes3d, top_imgs, front_imgs, rgbs_norm, image_index = load_dummy_datas(index[:3])
        # rgbs, tops, fronts, gt_labels, gt_boxes3d, top_imgs, front_imgs, rgbs_norm, image_index, lidars = load_dummy_datas()
        top_shape   = tops[0].shape
        front_shape = fronts[0].shape
        rgb_shape   = rgbs[0].shape
        top_feature_shape = ((top_shape[0]-1)//stride+1, (top_shape[1]-1)//stride+1)
        # pdb.set_trace()
        # set anchor boxes
        num_class = 2 #incude background
        anchors, inside_inds =  make_anchors(bases, stride, top_shape[0:2], top_feature_shape[0:2])
        # inside_inds = np.arange(0,len(anchors),dtype=np.int32)  #use all  #<todo>
        print ('out_shape=%s'%str(out_shape))
        print ('num_frames=%d'%num_frames)

        #-----------------------
        #check data
        if 0:
            fig = mlab.figure(figure=None, bgcolor=(0,0,0), fgcolor=None, engine=None, size=(1000, 500))
            draw_lidar(lidars[0], fig=fig)
            draw_gt_boxes3d(gt_boxes3d[0], fig=fig)
            mlab.show(1)
            cv2.waitKey(1)




    #load model ####################################################################################################
    top_anchors     = tf.placeholder(shape=[None, 4], dtype=tf.int32,   name ='anchors'    )
    top_inside_inds = tf.placeholder(shape=[None   ], dtype=tf.int32,   name ='inside_inds')

    top_images   = tf.placeholder(shape=[None, *top_shape  ], dtype=tf.float32, name='top'  )
    front_images = tf.placeholder(shape=[None, *front_shape], dtype=tf.float32, name='front')
    rgb_images   = tf.placeholder(shape=[None, None, None, 3 ], dtype=tf.float32, name='rgb'  )
    top_rois     = tf.placeholder(shape=[None, 5], dtype=tf.float32,   name ='top_rois'   ) #<todo> change to int32???
    front_rois   = tf.placeholder(shape=[None, 5], dtype=tf.float32,   name ='front_rois' )
    rgb_rois     = tf.placeholder(shape=[None, 5], dtype=tf.float32,   name ='rgb_rois'   )

    top_features, top_scores, top_probs, top_deltas, proposals, proposal_scores = \
        top_feature_net(top_images, top_anchors, top_inside_inds, num_bases)
    # pdb.set_trace()
    front_features = front_feature_net(front_images)
    rgb_features   = rgb_feature_net(rgb_images)

    fuse_scores, fuse_probs, fuse_deltas = \
        fusion_net(
			( [top_features,     top_rois,     6,6,1./stride],
			  [front_features,   front_rois,   0,0,1./stride],  #disable by 0,0
			  [rgb_features,     rgb_rois,     6,6,1./(2*stride)],),
            num_class, out_shape) #<todo>  add non max suppression



    #loss ########################################################################################################
    top_inds     = tf.placeholder(shape=[None   ], dtype=tf.int32,   name='top_ind'    )
    top_pos_inds = tf.placeholder(shape=[None   ], dtype=tf.int32,   name='top_pos_ind')
    top_labels   = tf.placeholder(shape=[None   ], dtype=tf.int32,   name='top_label'  )
    top_targets  = tf.placeholder(shape=[None, 4], dtype=tf.float32, name='top_target' )
    top_cls_loss, top_reg_loss = rpn_loss(2*top_scores, top_deltas, top_inds, top_pos_inds, top_labels, top_targets)

    fuse_labels  = tf.placeholder(shape=[None            ], dtype=tf.int32,   name='fuse_label' )
    fuse_targets = tf.placeholder(shape=[None, *out_shape], dtype=tf.float32, name='fuse_target')
    fuse_cls_loss, fuse_reg_loss = rcnn_loss(fuse_scores, fuse_deltas, fuse_labels, fuse_targets)
    tf.summary.scalar('rpn_cls_loss', top_cls_loss)
    tf.summary.scalar('rpn_reg_loss', top_reg_loss)
    tf.summary.scalar('rcnn_cls_loss', fuse_cls_loss)
    tf.summary.scalar('rcnn_reg_loss', fuse_reg_loss)

    #solver
    l2 = l2_regulariser(decay=0.000005)
    tf.summary.scalar('l2', l2)
    learning_rate = tf.placeholder(tf.float32, shape=[])
    solver = tf.train.AdamOptimizer(learning_rate)
    # solver = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
    #solver_step = solver.minimize(top_cls_loss+top_reg_loss+l2)
    solver_step = solver.minimize(1*top_cls_loss+1*top_reg_loss+1.5*fuse_cls_loss+2*fuse_reg_loss+l2)

    max_iter = 200000
    iter_debug=1

    # start training here  #########################################################################################
    log.write('epoch     iter    speed   rate   |  top_cls_loss   reg_loss   |  fuse_cls_loss  reg_loss  |  \n')
    log.write('-------------------------------------------------------------------------------------\n')

    num_ratios=len(ratios)
    num_scales=len(scales)
    fig, axs = plt.subplots(num_ratios,num_scales)

    merged = tf.summary.merge_all()

    sess = tf.InteractiveSession()
    train_writer = tf.summary.FileWriter( './outputs/tensorboard/Res_Vgg_up',
                                      sess.graph)
    with sess.as_default():
        sess.run( tf.global_variables_initializer(), { IS_TRAIN_PHASE : True } )
        # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        # summary_writer = tf.summary.FileWriter(out_dir+'/tf', sess.graph)
        saver  = tf.train.Saver() 
        saver.restore(sess, './outputs/check_points/snap_ResNet_vgg_up_NGT_060000.ckpt') 
        # # saver.restore(sess, './outputs/check_points/MobileNet.ckpt')  

        # var_lt_res=[v for v in tf.trainable_variables() if v.name.startswith('res')]#resnet_v1_50
        # # pdb.set_trace()
        # ## var_lt=[v for v in tf.trainable_variables() if not(v.name.startswith('fuse-block-1')) and not(v.name.startswith('fuse')) and not(v.name.startswith('fuse-input'))]

        # # # var_lt.pop(0)
        # # # var_lt.pop(0)
        # # # pdb.set_trace()
        # saver_0=tf.train.Saver(var_lt_res)        
        # # # 
        # saver_0.restore(sess, './outputs/check_points/resnet_v1_50.ckpt')
        # # pdb.set_trace()
        # top_lt=[v for v in tf.trainable_variables() if v.name.startswith('top_base')]
        # top_lt.pop(0)
        # # # top_lt.pop(0)
        # for v in top_lt:
        #     # pdb.set_trace()
        #     for v_rgb in var_lt:
        #         if v.name[9:]==v_rgb.name:
        #             print ("assign weights:%s"%v.name)
        #             v.assign(v_rgb)
        # var_lt_vgg=[v for v in tf.trainable_variables() if v.name.startswith('vgg')]
        # var_lt_vgg.pop(0)
        # saver_1=tf.train.Saver(var_lt_vgg)
        
        # # pdb.set_trace()
        # saver_1.restore(sess, './outputs/check_points/vgg_16.ckpt')

        batch_top_cls_loss =0
        batch_top_reg_loss =0
        batch_fuse_cls_loss=0
        batch_fuse_reg_loss=0
        rate=0.000005
        frame_range = np.arange(num_frames)
        idx=0
        frame=0
        for iter in range(max_iter):
            epoch=iter//num_frames+1
            # rate=0.001
            start_time=time.time()


            # generate train image -------------
            # idx = np.random.choice(num_frames)     #*10   #num_frames)  #0
            # shuffle the samples every 4*num_frames
            if iter%(num_frames*2)==0:
                idx=0
                frame=0
                count=0
                end_flag=0
                frame_range1 = np.random.permutation(num_frames)
                if np.all(frame_range1==frame_range):
                    raise Exception("Invalid level!", permutation)
                frame_range=frame_range1

            #load 500 samples every 2000 iterations
            freq=int(200)
            if idx%freq==0 :
                count+=idx
                if count%(2*freq)==0:
                    frame+=idx
                    frame_end=min(frame+freq,num_frames)
                    if frame_end==num_frames:
                        end_flag=1
                    # pdb.set_trace()
                    del rgbs, tops, fronts, gt_labels, gt_boxes3d, top_imgs, front_imgs, rgbs_norm, image_index
                    rgbs, tops, fronts, gt_labels, gt_boxes3d, top_imgs, front_imgs, rgbs_norm, image_index = load_dummy_datas(index[frame_range[frame:frame_end]])
                idx=0
            if (end_flag==1) and (idx+frame)==num_frames:
                idx=0
            print('processing image : %s'%image_index[idx])

            if (iter+1)%(10000)==0:
                rate=0.8*rate

            rgb_shape   = rgbs[idx].shape
            batch_top_images    = tops[idx].reshape(1,*top_shape)
            batch_front_images  = fronts[idx].reshape(1,*front_shape)
            batch_rgb_images    = rgbs_norm[idx].reshape(1,*rgb_shape)
            # batch_rgb_images    = rgbs[idx].reshape(1,*rgb_shape)

            top_img=tops[idx]
            # pdb.set_trace()
            inside_inds_filtered=anchor_filter(top_img[:,:,-1], anchors, inside_inds)

            # pdb.set_trace()
            batch_gt_labels    = gt_labels[idx]
            if len(batch_gt_labels)==0:
                # pdb.set_trace()
                idx=idx+1
                continue
            batch_gt_boxes3d   = gt_boxes3d[idx]
            # pdb.set_trace()
            batch_gt_top_boxes = box3d_to_top_box(batch_gt_boxes3d)




			## run propsal generation ------------
            fd1={
                top_images:      batch_top_images,
                top_anchors:     anchors,
                top_inside_inds: inside_inds_filtered,

                learning_rate:   rate,
                IS_TRAIN_PHASE:  True
            }
            batch_proposals, batch_proposal_scores, batch_top_features = sess.run([proposals, proposal_scores, top_features],fd1)
            print(batch_proposal_scores[:50])
            # pdb.set_trace()
            ## generate  train rois  ------------
            batch_top_inds, batch_top_pos_inds, batch_top_labels, batch_top_targets  = \
                rpn_target ( anchors, inside_inds_filtered, batch_gt_labels,  batch_gt_top_boxes)

            batch_top_rois, batch_fuse_labels, batch_fuse_targets  = \
                 rcnn_target(  batch_proposals, batch_gt_labels, batch_gt_top_boxes, batch_gt_boxes3d )

            batch_rois3d	 = project_to_roi3d    (batch_top_rois)
            batch_front_rois = project_to_front_roi(batch_rois3d  )
            batch_rgb_rois   = project_to_rgb_roi  (batch_rois3d  )


            # keep = np.where((batch_rgb_rois[:,1]>=-200) & (batch_rgb_rois[:,2]>=-200) & (batch_rgb_rois[:,3]<=(rgb_shape[1]+200)) & (batch_rgb_rois[:,4]<=(rgb_shape[0]+200)))[0]
            # batch_rois3d        = batch_rois3d[keep]      
            # batch_front_rois    = batch_front_rois[keep]
            # batch_rgb_rois      = batch_rgb_rois[keep]  
            # batch_proposal_scores=batch_proposal_scores[keep]
            # batch_top_rois      =batch_top_rois[keep]

            if len(batch_rois3d)==0:
                # pdb.set_trace()
                idx=idx+1
                continue




            ##debug gt generation
            if vis and iter%iter_debug==0:
                top_image = top_imgs[idx]
                rgb       = rgbs[idx]

                img_gt     = draw_rpn_gt(top_image, batch_gt_top_boxes, batch_gt_labels)
                img_label  = draw_rpn_labels (img_gt, anchors, batch_top_inds, batch_top_labels )
                img_target = draw_rpn_targets(top_image, anchors, batch_top_pos_inds, batch_top_targets)
                #imshow('img_rpn_gt',img_gt)
                imshow('img_anchor_label',img_label)
                #imshow('img_rpn_target',img_target)

                img_label  = draw_rcnn_labels (top_image, batch_top_rois, batch_fuse_labels )
                img_target = draw_rcnn_targets(top_image, batch_top_rois, batch_fuse_labels, batch_fuse_targets)
                #imshow('img_rcnn_label',img_label)
                if vis :
                    imshow('img_rcnn_target',img_target)


                img_rgb_rois = draw_boxes(rgb, batch_rgb_rois[:,1:5], color=(255,0,255), thickness=1)
                if vis :
                    imshow('img_rgb_rois',img_rgb_rois)
                    cv2.waitKey(1)

            ## run classification and regression loss -----------
            fd2={
				**fd1,

                top_images: batch_top_images,
                front_images: batch_front_images,
                rgb_images: batch_rgb_images,

				top_rois:   batch_top_rois,
                front_rois: batch_front_rois,
                rgb_rois:   batch_rgb_rois,

                top_inds:     batch_top_inds,
                top_pos_inds: batch_top_pos_inds,
                top_labels:   batch_top_labels,
                top_targets:  batch_top_targets,

                fuse_labels:  batch_fuse_labels,
                fuse_targets: batch_fuse_targets,
            }
            #_, batch_top_cls_loss, batch_top_reg_loss = sess.run([solver_step, top_cls_loss, top_reg_loss],fd2)


            _, batch_top_cls_loss, batch_top_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss = \
               sess.run([solver_step, top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss],fd2)

            speed=time.time()-start_time
            log.write('%5.1f   %5d    %0.4fs   %0.4f   |   %0.5f   %0.5f   |   %0.5f   %0.5f  \n' %\
				(epoch, iter, speed, rate, batch_top_cls_loss, batch_top_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss))



            #print('ok')
            # debug: ------------------------------------

            if vis and iter%iter_debug==0:
                top_image = top_imgs[idx]
                rgb       = rgbs[idx]

                batch_top_probs, batch_top_scores, batch_top_deltas  = \
                    sess.run([ top_probs, top_scores, top_deltas ],fd2)

                batch_fuse_probs, batch_fuse_deltas = \
                    sess.run([ fuse_probs, fuse_deltas ],fd2)

                #batch_fuse_deltas=0*batch_fuse_deltas #disable 3d box prediction
                probs, boxes3d = rcnn_nms(batch_fuse_probs, batch_fuse_deltas, batch_rois3d, threshold=0.05)


                ## show rpn score maps
                p = batch_top_probs.reshape( *(top_feature_shape[0:2]), 2*num_bases)
                for n in range(num_bases):
                    r=n%num_scales
                    s=n//num_scales
                    pn = p[:,:,2*n+1]*255
                    axs[s,r].cla()
                    if vis :
                        axs[s,r].imshow(pn, cmap='gray', vmin=0, vmax=255)
                        plt.pause(0.01)

				## show rpn(top) nms
                img_rpn     = draw_rpn    (top_image, batch_top_probs, batch_top_deltas, anchors, inside_inds)
                img_rpn_nms = draw_rpn_nms(img_gt, batch_proposals, batch_proposal_scores)
                #imshow('img_rpn',img_rpn)
                if vis :
                    imshow('img_rpn_nms',img_rpn_nms)
                    cv2.waitKey(1)

                ## show rcnn(fuse) nms
                img_rcnn     = draw_rcnn (top_image, batch_fuse_probs, batch_fuse_deltas, batch_top_rois, batch_rois3d,darker=1)
                img_rcnn_nms = draw_rcnn_nms(rgb, boxes3d, probs)
                if vis :
                    imshow('img_rcnn',img_rcnn)
                    imshow('img_rcnn_nms',img_rcnn_nms)
                    cv2.waitKey(0)
            if (iter)%10==0:
                summary = sess.run(merged,fd2)
                train_writer.add_summary(summary, iter)
            # save: ------------------------------------
            if (iter)%2000==0 and (iter!=0):
                #saver.save(sess, out_dir + '/check_points/%06d.ckpt'%iter)  #iter
                saver.save(sess, out_dir + '/check_points/snap_ResNet_vgg_NGT_%06d.ckpt'%iter)  #iter
                # saver.save(sess, out_dir + '/check_points/MobileNet.ckpt')  #iter
                # pdb.set_trace()
                pass

            idx=idx+1
Пример #5
0
def run_train():

    # output dir, etc
    out_dir = '/root/share/out/didi/xxx'
    makedirs(out_dir + '/tf')
    log = Logger(out_dir + '/log.txt', mode='a')

    #one lidar data -----------------
    if 1:
        ratios = np.array([0.5, 1, 2], dtype=np.float32)
        scales = np.array([1, 2, 3], dtype=np.float32)
        bases = make_bases(base_size=16, ratios=ratios, scales=scales)
        num_bases = len(bases)
        stride = 8

        rgb, top, top_image, lidar, gt_labels, gt_boxes3d, gt_top_boxes = load_dummy_data(
        )
        top_shape = top.shape
        top_feature_shape = (top_shape[0] // stride, top_shape[1] // stride)

        rgb_shape = rgb.shape
        out_shape = (8, 3)

        #-----------------------
        #check data
        if 0:
            fig = mlab.figure(figure=None,
                              bgcolor=(0, 0, 0),
                              fgcolor=None,
                              engine=None,
                              size=(1000, 500))
            draw_lidar(lidar, fig=fig)
            draw_gt_boxes3d(gt_boxes3d, fig=fig)
            mlab.show(1)

            draw_gt_boxes(top_image, gt_top_boxes)
            draw_projected_gt_boxes3d(rgb, gt_boxes3d)

            #imshow('top_image',top_image)
            #imshow('rgb',rgb)
            cv2.waitKey(1)

    #one dummy data -----------------
    if 0:
        ratios = [0.5, 1, 2]
        scales = 2**np.arange(3, 6)
        bases = make_bases(base_size=16, ratios=ratios, scales=scales)
        num_bases = len(bases)
        stride = 8

        rgb, top, top_image, lidar, gt_labels, gt_boxes3d, gt_top_boxes = load_dummy_data1(
        )
        top_shape = top.shape
        top_feature_shape = (54, 72
                             )  #(top_shape[0]//stride, top_shape[1]//stride)

        rgb_shape = rgb.shape
        out_shape = (4, )

        # img_gt =draw_gt_boxes(top_image, gt_top_boxes)
        # imshow('img_gt',img_gt)
        # cv2.waitKey(1)

    # set anchor boxes
    dim = np.prod(out_shape)
    num_class = 2  #incude background
    anchors, inside_inds = make_anchors(bases, stride, top_shape[0:2],
                                        top_feature_shape[0:2])
    inside_inds = np.arange(0, len(anchors), dtype=np.int32)  #use all
    print('dim=%d' % dim)

    #load model ##############
    top_images = tf.placeholder(shape=[None, *top_shape],
                                dtype=tf.float32,
                                name='top')
    top_anchors = tf.placeholder(shape=[None, 4],
                                 dtype=tf.int32,
                                 name='anchors')
    top_inside_inds = tf.placeholder(shape=[None],
                                     dtype=tf.int32,
                                     name='inside_inds')

    top_features, top_scores, top_probs, top_deltas, top_rois1, top_roi_scores1 = \
        top_lidar_feature_net(top_images, top_anchors, top_inside_inds, num_bases)

    rgb_images = tf.placeholder(shape=[None, *rgb_shape],
                                dtype=tf.float32,
                                name='rgb')
    rgb_features = rgb_feature_net(rgb_images)

    top_rois = tf.placeholder(shape=[None, 5],
                              dtype=tf.float32,
                              name='top_rois')  #<todo> change to int32???
    rgb_rois = tf.placeholder(shape=[None, 5],
                              dtype=tf.float32,
                              name='rgb_rois')
    fuse_scores, fuse_probs, fuse_deltas = \
        fusion_net(
            (top_features,   rgb_features,),
            (top_rois,       rgb_rois,),
            ([6,6,1./stride],[6,6,1./stride],),
            num_class, out_shape)

    #loss ####################
    top_inds = tf.placeholder(shape=[None], dtype=tf.int32, name='top_ind')
    top_pos_inds = tf.placeholder(shape=[None],
                                  dtype=tf.int32,
                                  name='top_pos_ind')
    top_labels = tf.placeholder(shape=[None], dtype=tf.int32, name='top_label')
    top_targets = tf.placeholder(shape=[None, 4],
                                 dtype=tf.float32,
                                 name='top_target')
    top_cls_loss, top_reg_loss = rpn_loss(top_scores, top_deltas, top_inds,
                                          top_pos_inds, top_labels,
                                          top_targets)

    fuse_labels = tf.placeholder(shape=[None],
                                 dtype=tf.int32,
                                 name='fuse_label')
    fuse_targets = tf.placeholder(shape=[None, *out_shape],
                                  dtype=tf.float32,
                                  name='fuse_target')
    fuse_cls_loss, fuse_reg_loss = rcnn_loss(fuse_scores, fuse_deltas,
                                             fuse_labels, fuse_targets)

    #put your solver here
    l2 = l2_regulariser(decay=0.0005)
    learning_rate = tf.placeholder(tf.float32, shape=[])
    solver = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                        momentum=0.9)
    #solver_step = solver.minimize(top_cls_loss+top_reg_loss+l2)
    solver_step = solver.minimize(top_cls_loss + top_reg_loss + fuse_cls_loss +
                                  fuse_reg_loss + l2)

    max_iter = 10000

    # start training here ------------------------------------------------
    log.write('epoch        iter      rate     |  train_mse   valid_mse  |\n')
    log.write(
        '----------------------------------------------------------------------------\n'
    )

    num_ratios = len(ratios)
    num_scales = len(scales)
    fig, axs = plt.subplots(num_ratios, num_scales)

    sess = tf.InteractiveSession()
    with sess.as_default():
        sess.run(tf.global_variables_initializer(), {IS_TRAIN_PHASE: True})
        summary_writer = tf.summary.FileWriter(out_dir + '/tf', sess.graph)
        rate = 0.1

        batch_top_cls_loss = 0
        batch_top_reg_loss = 0
        batch_fuse_cls_loss = 0
        batch_fuse_reg_loss = 0
        for iter in range(max_iter):

            #random sample train data
            batch_top_images = top.reshape(1, *top_shape)
            batch_top_gt_labels = gt_labels
            batch_top_gt_boxes = gt_top_boxes

            batch_rgb_images = rgb.reshape(1, *rgb_shape)

            batch_fuse_gt_labels = gt_labels
            batch_fuse_gt_boxes = gt_top_boxes
            batch_fuse_gt_boxes3d = gt_boxes3d

            ##-------------------------------
            fd = {
                top_images: batch_top_images,
                top_anchors: anchors,
                top_inside_inds: inside_inds,
                learning_rate: rate,
                IS_TRAIN_PHASE: True
            }
            batch_top_rois1, batch_top_roi_scores1, batch_top_features = sess.run(
                [top_rois1, top_roi_scores1, top_features], fd)

            ## generate ground truth
            batch_top_inds, batch_top_pos_inds, batch_top_labels, batch_top_targets  = \
                rpn_target ( anchors, inside_inds, batch_top_gt_labels,  batch_top_gt_boxes)

            batch_top_rois, batch_fuse_labels, batch_fuse_targets  = \
                 rcnn_target(  batch_top_rois1, batch_fuse_gt_labels, batch_fuse_gt_boxes, batch_fuse_gt_boxes3d )

            #project to rgb roi -------------------------------------------------
            batch_rgb_rois = batch_top_rois.copy()
            num = len(batch_top_rois)
            for n in range(num):
                box3d = box_to_box3d(batch_top_rois[n, 1:5].reshape(
                    1, 4)).reshape(8, 3)
                qs = make_projected_box3d(box3d)

                minx = np.min(qs[:, 0])
                maxx = np.max(qs[:, 0])
                miny = np.min(qs[:, 1])
                maxy = np.max(qs[:, 1])
                batch_rgb_rois[n, 1:5] = minx, miny, maxx, maxy

            darken = 0.7
            img_rgb_roi = rgb.copy() * darken
            for n in range(num):
                b = batch_rgb_rois[n, 1:5]
                cv2.rectangle(img_rgb_roi, (b[0], b[1]), (b[2], b[3]),
                              (0, 255, 255), 1)

            imshow('img_rgb_roi', img_rgb_roi)
            #--------------------------------------------------------------------

            ##debug
            if 1:
                img_gt = draw_rpn_gt(top_image, batch_top_gt_boxes,
                                     batch_top_gt_labels)
                img_label = draw_rpn_labels(top_image, anchors, batch_top_inds,
                                            batch_top_labels)
                img_target = draw_rpn_targets(top_image, anchors,
                                              batch_top_pos_inds,
                                              batch_top_targets)
                imshow('img_rpn_gt', img_gt)
                imshow('img_rpn_label', img_label)
                imshow('img_rpn_target', img_target)

                img_label = draw_rcnn_labels(top_image, batch_top_rois,
                                             batch_fuse_labels)
                img_target = draw_rcnn_targets(top_image, batch_top_rois,
                                               batch_fuse_labels,
                                               batch_fuse_targets)
                imshow('img_rcnn_label', img_label)
                imshow('img_rcnn_target', img_target)
                cv2.waitKey(1)

            #---------------------------------------------------
            fd = {
                top_images: batch_top_images,
                top_anchors: anchors,
                top_inside_inds: inside_inds,
                top_inds: batch_top_inds,
                top_pos_inds: batch_top_pos_inds,
                top_labels: batch_top_labels,
                top_targets: batch_top_targets,
                top_rois: batch_top_rois,
                #front_rois1: batch_front_rois,
                rgb_images: batch_rgb_images,
                rgb_rois: batch_rgb_rois,
                fuse_labels: batch_fuse_labels,
                fuse_targets: batch_fuse_targets,
                learning_rate: rate,
                IS_TRAIN_PHASE: True
            }
            #_, batch_top_cls_loss, batch_top_reg_loss = sess.run([solver_step, top_cls_loss, top_reg_loss],fd)


            _, batch_top_cls_loss, batch_top_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss = \
               sess.run([solver_step, top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss],fd)

            #print('ok')
            # debug: ------------------------------------

            if iter % 4 == 0:
                batch_top_probs, batch_top_scores, batch_top_deltas  = \
                    sess.run([ top_probs, top_scores, top_deltas ],fd)

                batch_fuse_probs, batch_fuse_deltas = \
                    sess.run([ fuse_probs, fuse_deltas ],fd)

                probs, boxes3d, priors, priors3d, deltas = rcnn_nms(
                    batch_fuse_probs, batch_fuse_deltas, batch_top_rois)

                ## show rpn score maps
                p = batch_top_probs.reshape(*(top_feature_shape[0:2]),
                                            2 * num_bases)
                for n in range(num_bases):
                    r = n % num_scales
                    s = n // num_scales
                    pn = p[:, :, 2 * n + 1] * 255
                    axs[s, r].cla()
                    axs[s, r].imshow(pn, cmap='gray', vmin=0, vmax=255)
                plt.pause(0.01)

                img_rpn = draw_rpn(top_image, batch_top_probs,
                                   batch_top_deltas, anchors, inside_inds)
                img_rpn_nms = draw_rpn_nms(
                    top_image, batch_top_rois1,
                    batch_top_roi_scores1)  # estimat after non-max
                imshow('img_rpn', img_rpn)
                imshow('img_rpn_nms', img_rpn_nms)
                cv2.waitKey(1)

                #draw rcnn results --------------------------------
                img_rcnn = draw_rcnn(top_image, batch_fuse_probs,
                                     batch_fuse_deltas, batch_top_rois)
                draw_projected_gt_boxes3d(rgb,
                                          boxes3d,
                                          color=(255, 255, 255),
                                          thickness=1)

                imshow('img_rcnn', img_rcnn)
                cv2.waitKey(1)

            # debug: ------------------------------------

            log.write('%d   | %0.5f   %0.5f  %0.5f   %0.5f : \n' %
                      (iter, batch_top_cls_loss, batch_top_reg_loss,
                       batch_fuse_cls_loss, batch_fuse_reg_loss))
Пример #6
0
    def fit_iteration(self, batch_rgb_images, batch_top_view, batch_front_view,
                      batch_gt_labels, batch_gt_boxes3d, frame_id, is_validation =False,
                      summary_it=False, summary_runmeta=False, log=False):

        net = self.net
        sess = self.sess

        # put tensorboard inside
        top_cls_loss = net['top_cls_loss']
        top_reg_loss = net['top_reg_loss']
        fuse_cls_loss = net['fuse_cls_loss']
        fuse_reg_loss = net['fuse_reg_loss']


        self.batch_gt_top_boxes = data.box3d_to_top_box(batch_gt_boxes3d[0])

        ## run propsal generation
        fd1 = {
            net['top_view']: batch_top_view,
            net['top_anchors']: self.top_view_anchors,
            net['top_inside_inds']: self.anchors_inside_inds,

            blocks.IS_TRAIN_PHASE: True,
            K.learning_phase(): 1
        }
        self.batch_proposals, self.batch_proposal_scores, self.batch_top_features = \
            sess.run([net['proposals'], net['proposal_scores'], net['top_features']], fd1)

        ## generate  train rois  for RPN
        self.batch_top_inds, self.batch_top_pos_inds, self.batch_top_labels, self.batch_top_targets = \
            rpn_target(self.top_view_anchors, self.anchors_inside_inds, batch_gt_labels[0],
                       self.batch_gt_top_boxes)
        if log:
            step_name = 'validation' if is_validation else  'train'
            scope_name = '%s_iter_%06d' % (step_name, self.n_global_step - (self.n_global_step % self.iter_debug))
            self.log_rpn(step=self.n_global_step, scope_name=scope_name)


        self.batch_top_rois, self.batch_fuse_labels, self.batch_fuse_targets = \
            rcnn_target(self.batch_proposals, batch_gt_labels[0], self.batch_gt_top_boxes, batch_gt_boxes3d[0])

        self.batch_rois3d = project_to_roi3d(self.batch_top_rois)
        self.batch_front_rois = project_to_front_roi(self.batch_rois3d)
        self.batch_rgb_rois = project_to_rgb_roi(self.batch_rois3d)

        if log: self.log_fusion_net_target(batch_rgb_images[0], scope_name=scope_name)
        if log:
            log_info_str = 'frame info: ' + self.frame_info + '\n'
            log_info_str += self.anchors_details()
            log_info_str += self.rpn_poposal_details()
            self.log_info(self.log_subdir, log_info_str)

        ## run classification and regression loss -----------
        fd2 = {
            **fd1,

            net['top_view']: batch_top_view,
            net['front_view']: batch_front_view,
            net['rgb_images']: batch_rgb_images,

            net['top_rois']: self.batch_top_rois,
            net['front_rois']: self.batch_front_rois,
            net['rgb_rois']: self.batch_rgb_rois,

            net['top_inds']: self.batch_top_inds,
            net['top_pos_inds']: self.batch_top_pos_inds,
            net['top_labels']: self.batch_top_labels,
            net['top_targets']: self.batch_top_targets,

            net['fuse_labels']: self.batch_fuse_labels,
            net['fuse_targets']: self.batch_fuse_targets,
        }

        if self.debug_mode:
            print('\n\nstart debug mode\n\n')
            debug_sess=tf_debug.LocalCLIDebugWrapperSession(sess)
            t_cls_loss, t_reg_loss, f_cls_loss, f_reg_loss = \
                debug_sess.run([top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss], fd2)


        if summary_it:
            run_options = None
            run_metadata = None

            if is_validation:
                t_cls_loss, t_reg_loss, f_cls_loss, f_reg_loss, tb_sum_val = \
                    sess.run([top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss, self.summ], fd2)
                self.val_summary_writer.add_summary(tb_sum_val, self.n_global_step)
                print('added validation  summary ')
            else:
                if summary_runmeta:
                    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                    run_metadata = tf.RunMetadata()

                _, t_cls_loss, t_reg_loss, f_cls_loss, f_reg_loss, tb_sum_val = \
                    sess.run([self.solver_step, top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss,
                              self.summ], feed_dict=fd2, options=run_options, run_metadata=run_metadata)
                self.train_summary_writer.add_summary(tb_sum_val, self.n_global_step)
                print('added training  summary ')

                if summary_runmeta:
                    self.train_summary_writer.add_run_metadata(run_metadata, 'step%d' % self.n_global_step)
                    print('added runtime metadata.')

        else:
            if is_validation:
                t_cls_loss, t_reg_loss, f_cls_loss, f_reg_loss = \
                    sess.run([top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss], fd2)
            else:

                _, t_cls_loss, t_reg_loss, f_cls_loss, f_reg_loss = \
                    sess.run([self.solver_step, top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss],
                             feed_dict=fd2)
        if log: self.log_prediction(batch_top_view, batch_front_view, batch_rgb_images,
                                    step=self.n_global_step, scope_name=scope_name)

        return t_cls_loss, t_reg_loss, f_cls_loss, f_reg_loss
Пример #7
0
    def fit_iteration(self,
                      batch_rgb_images,
                      batch_top_view,
                      batch_front_view,
                      batch_gt_labels,
                      batch_gt_boxes3d,
                      frame_id,
                      is_validation=False,
                      summary_it=False,
                      summary_runmeta=False,
                      log=False):

        net = self.net
        sess = self.sess

        # put tensorboard inside
        top_cls_loss = net['top_cls_loss']
        top_reg_loss = net['top_reg_loss']
        fuse_cls_loss = net['fuse_cls_loss']
        fuse_reg_loss = net['fuse_reg_loss']

        self.batch_gt_top_boxes = data.box3d_to_top_box(batch_gt_boxes3d[0])

        ## run propsal generation
        fd1 = {
            net['top_view']: batch_top_view,
            net['top_anchors']: self.top_view_anchors,
            net['top_inside_inds']: self.anchors_inside_inds,
            blocks.IS_TRAIN_PHASE: True,
            K.learning_phase(): 1
        }
        self.batch_proposals, self.batch_proposal_scores, self.batch_top_features = \
            sess.run([net['proposals'], net['proposal_scores'], net['top_features']], fd1)

        ## generate  train rois  for RPN
        self.batch_top_inds, self.batch_top_pos_inds, self.batch_top_labels, self.batch_top_targets = \
            rpn_target(self.top_view_anchors, self.anchors_inside_inds, batch_gt_labels[0],
                       self.batch_gt_top_boxes)
        if log:
            step_name = 'validation' if is_validation else 'train'
            scope_name = '%s_iter_%06d' % (
                step_name, self.n_global_step -
                (self.n_global_step % self.iter_debug))
            self.log_rpn(step=self.n_global_step, scope_name=scope_name)


        self.batch_top_rois, self.batch_fuse_labels, self.batch_fuse_targets = \
            rcnn_target(self.batch_proposals, batch_gt_labels[0], self.batch_gt_top_boxes, batch_gt_boxes3d[0])

        self.batch_rois3d = project_to_roi3d(self.batch_top_rois)
        self.batch_front_rois = project_to_front_roi(self.batch_rois3d)
        self.batch_rgb_rois = project_to_rgb_roi(self.batch_rois3d)

        if log:
            self.log_fusion_net_target(batch_rgb_images[0],
                                       scope_name=scope_name)
        if log:
            log_info_str = 'frame info: ' + self.frame_info + '\n'
            log_info_str += self.anchors_details()
            log_info_str += self.rpn_poposal_details()
            self.log_info(self.log_subdir, log_info_str)

        ## run classification and regression loss -----------
        fd2 = {
            **fd1,
            net['top_view']: batch_top_view,
            net['front_view']: batch_front_view,
            net['rgb_images']: batch_rgb_images,
            net['top_rois']: self.batch_top_rois,
            net['front_rois']: self.batch_front_rois,
            net['rgb_rois']: self.batch_rgb_rois,
            net['top_inds']: self.batch_top_inds,
            net['top_pos_inds']: self.batch_top_pos_inds,
            net['top_labels']: self.batch_top_labels,
            net['top_targets']: self.batch_top_targets,
            net['fuse_labels']: self.batch_fuse_labels,
            net['fuse_targets']: self.batch_fuse_targets,
        }

        if self.debug_mode:
            print('\n\nstart debug mode\n\n')
            debug_sess = tf_debug.LocalCLIDebugWrapperSession(sess)
            t_cls_loss, t_reg_loss, f_cls_loss, f_reg_loss = \
                debug_sess.run([top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss], fd2)

        if summary_it:
            run_options = None
            run_metadata = None

            if is_validation:
                t_cls_loss, t_reg_loss, f_cls_loss, f_reg_loss, tb_sum_val = \
                    sess.run([top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss, self.summ], fd2)
                self.val_summary_writer.add_summary(tb_sum_val,
                                                    self.n_global_step)
                print('added validation  summary ')
            else:
                if summary_runmeta:
                    run_options = tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE)
                    run_metadata = tf.RunMetadata()

                _, t_cls_loss, t_reg_loss, f_cls_loss, f_reg_loss, tb_sum_val = \
                    sess.run([self.solver_step, top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss,
                              self.summ], feed_dict=fd2, options=run_options, run_metadata=run_metadata)
                self.train_summary_writer.add_summary(tb_sum_val,
                                                      self.n_global_step)
                print('added training  summary ')

                if summary_runmeta:
                    self.train_summary_writer.add_run_metadata(
                        run_metadata, 'step%d' % self.n_global_step)
                    print('added runtime metadata.')

        else:
            if is_validation:
                t_cls_loss, t_reg_loss, f_cls_loss, f_reg_loss = \
                    sess.run([top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss], fd2)
            else:

                _, t_cls_loss, t_reg_loss, f_cls_loss, f_reg_loss = \
                    sess.run([self.solver_step, top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss],
                             feed_dict=fd2)
        if log:
            self.log_prediction(batch_top_view,
                                batch_front_view,
                                batch_rgb_images,
                                step=self.n_global_step,
                                scope_name=scope_name)

        return t_cls_loss, t_reg_loss, f_cls_loss, f_reg_loss
def run_train():
    CFG.KEEPPROBS = 0.5
    # output dir  for tensorboard, checkpoints and log
    out_dir = CFG.PATH.TRAIN.OUTPUT
    makedirs(out_dir + '/tf')
    makedirs(out_dir + '/check_points')
    makedirs(out_dir + '/log')
    log = Logger(out_dir + '/log/log_%s.txt' %
                 (time.strftime('%Y-%m-%d %H:%M:%S')),
                 mode='a')

    index = np.load(CFG.PATH.TRAIN.TARGET + '/train.npy')
    index = sorted(index)
    index = np.array(index)
    num_frames = len(index)

    #lidar data -----------------
    if 1:
        ###generate anchor base
        ratios_rgb = np.array([0.5, 1, 2], dtype=np.float32)
        scales_rgb = np.array([0.5, 1, 2, 4, 5], dtype=np.float32)
        bases_rgb = make_bases(base_size=48,
                               ratios=ratios_rgb,
                               scales=scales_rgb)

        num_bases_rgb = len(bases_rgb)
        stride = 8
        out_shape = (2, 2)

        rgbs, gt_labels, gt_3dTo2Ds, gt_boxes2d, rgbs_norm, image_index = load_dummy_datas(
            index[10:13])

        rgb_shape = rgbs[0].shape

        # set anchor boxes
        num_class = 2  #incude background
    #load model ####################################################################################################

    rgb_anchors = tf.placeholder(shape=[None, 4],
                                 dtype=tf.int32,
                                 name='anchors_rgb')
    rgb_inside_inds = tf.placeholder(shape=[None],
                                     dtype=tf.int32,
                                     name='inside_inds_rgb')

    rgb_images = tf.placeholder(shape=[None, None, None, 3],
                                dtype=tf.float32,
                                name='rgb')
    rgb_rois = tf.placeholder(shape=[None, 5],
                              dtype=tf.float32,
                              name='rgb_rois')

    rgb_features, rgb_scores, rgb_probs, rgb_deltas= \
        top_feature_net(rgb_images, rgb_anchors, rgb_inside_inds, num_bases_rgb)

    fuse_scores, fuse_probs, fuse_deltas, fuse_deltas_3dTo2D = \
        fusion_net(
            ( [rgb_features,     rgb_rois,     7,7,1./(1*stride)],),num_class, out_shape) #<todo>  add non max suppression

    #loss ########################################################################################################
    rgb_inds = tf.placeholder(shape=[None], dtype=tf.int32, name='rgb_ind')
    rgb_pos_inds = tf.placeholder(shape=[None],
                                  dtype=tf.int32,
                                  name='rgb_pos_ind')
    rgb_labels = tf.placeholder(shape=[None], dtype=tf.int32, name='rgb_label')
    rgb_targets = tf.placeholder(shape=[None, 4],
                                 dtype=tf.float32,
                                 name='rgb_target')
    rgb_cls_loss, rgb_reg_loss = rpn_loss(2 * rgb_scores, rgb_deltas, rgb_inds,
                                          rgb_pos_inds, rgb_labels,
                                          rgb_targets)

    fuse_labels = tf.placeholder(shape=[None],
                                 dtype=tf.int32,
                                 name='fuse_label')
    fuse_targets = tf.placeholder(shape=[None, 4],
                                  dtype=tf.float32,
                                  name='fuse_target')
    fuse_targets_3dTo2Ds = tf.placeholder(shape=[None, 16],
                                          dtype=tf.float32,
                                          name='fuse_target')

    fuse_cls_loss, fuse_reg_loss, fuse_reg_loss_3dTo2D = rcnn_loss(
        fuse_scores, fuse_deltas, fuse_labels, fuse_targets,
        fuse_deltas_3dTo2D, fuse_targets_3dTo2Ds)

    tf.summary.scalar('rcnn_cls_loss', fuse_cls_loss)
    tf.summary.scalar('rcnn_reg_loss', fuse_reg_loss)
    tf.summary.scalar('rcnn_reg_loss_3dTo2D', fuse_reg_loss_3dTo2D)
    tf.summary.scalar('rpn_cls_loss', rgb_cls_loss)
    tf.summary.scalar('rpn_reg_loss', rgb_reg_loss)

    #solver
    l2 = l2_regulariser(decay=CFG.TRAIN.WEIGHT_DECAY)
    tf.summary.scalar('l2', l2)
    learning_rate = tf.placeholder(tf.float32, shape=[])
    solver = tf.train.AdamOptimizer(learning_rate)
    solver_step = solver.minimize(2 * rgb_cls_loss + 1 * rgb_reg_loss +
                                  2 * fuse_cls_loss + 1 * fuse_reg_loss +
                                  0.5 * fuse_reg_loss_3dTo2D + l2)
    # 2*rgb_cls_loss+1*rgb_reg_loss+2*fuse_cls_loss+1*fuse_reg_loss+

    max_iter = 200000
    iter_debug = 1

    # start training here  #########################################################################################
    log.write(
        'epoch     iter    speed   rate   |  top_cls_loss   reg_loss   |  fuse_cls_loss  reg_loss  |  \n'
    )
    log.write(
        '-------------------------------------------------------------------------------------\n'
    )

    merged = tf.summary.merge_all()

    sess = tf.InteractiveSession()
    train_writer = tf.summary.FileWriter(
        './outputs/tensorboard/V_2dTo3d_finetune', sess.graph)
    with sess.as_default():
        sess.run(tf.global_variables_initializer(), {IS_TRAIN_PHASE: True})
        saver = tf.train.Saver()

        ##Initialize network ##

        ##Initialize network from relatively welll trained model
        # saver.restore(sess, './outputs/check_points/snap_2dTo3D__data_augmentation090000trainval.ckpt')

        #Initialize network from 2D Bounding Box  pretrained model whict trained on Last14000 datasets of  UISEE
        # var_lt_res=[v for v in tf.global_variables() if  not v.name.startswith('fuse/3D')]
        # saver_0=tf.train.Saver(var_lt_res)
        # saver_0.restore(sess, './outputs/check_points/snap_2D_pretrain.ckpt')

        ##Initialize network from ResNet50
        # var_lt_res=[v for v in tf.trainable_variables() if v.name.startswith('resnet_v1')]#resnet_v1_50
        # saver_0=tf.train.Saver(var_lt_res)
        # saver_0.restore(sess, './outputs/check_points/resnet_v1_50.ckpt')

        batch_top_cls_loss = 0
        batch_top_reg_loss = 0
        batch_fuse_cls_loss = 0
        batch_fuse_reg_loss = 0
        frame_range = np.arange(num_frames)
        idx = 0
        frame = 0

        for iter in range(max_iter):
            epoch = iter // num_frames + 1
            # rate=0.001
            start_time = time.time()
            if iter % (num_frames * 1) == 0:
                idx = 0
                frame = 0
                count = 0
                end_flag = 0
                frame_range1 = np.random.permutation(num_frames)
                if np.all(frame_range1 == frame_range):
                    raise Exception("Invalid level!", permutation)
                frame_range = frame_range1

            #load 500 samples every 2000 iterations
            freq = int(10)
            if idx % freq == 0:
                count += idx
                if count % (1 * freq) == 0:
                    frame += idx
                    frame_end = min(frame + freq, num_frames)
                    if frame_end == num_frames:
                        end_flag = 1
                    rgbs, gt_labels, gt_3dTo2Ds, gt_boxes2d, rgbs_norm, image_index = load_dummy_datas(
                        index[frame_range[frame:frame_end]])
                idx = 0
            if (end_flag == 1) and (idx + frame) == num_frames:
                idx = 0
            print('processing image : %s' % image_index[idx])

            if (iter + 1) % (CFG.TRAIN.LEARNING_RATE_DECAY_STEP) == 0:
                CFG.TRAIN.LEARNING_RATE = CFG.TRAIN.LEARNING_RATE_DECAY_SCALE * CFG.TRAIN.LEARNING_RATE

            rgb_shape = rgbs[idx].shape
            batch_rgb_images = rgbs_norm[idx].reshape(1, *rgb_shape)
            batch_gt_labels = gt_labels[idx]
            batch_gt_3dTo2Ds = gt_3dTo2Ds[idx]
            batch_gt_boxes2d = gt_boxes2d[idx]

            if len(batch_gt_labels) == 0:
                idx = idx + 1
                continue

            rgb_feature_shape = ((rgb_shape[0] - 1) // stride + 1,
                                 (rgb_shape[1] - 1) // stride + 1)
            anchors_rgb, inside_inds_rgb = make_anchors(
                bases_rgb, stride, rgb_shape[0:2], rgb_feature_shape[0:2])

            ## run propsal generation ------------
            fd1 = {
                rgb_images: batch_rgb_images,
                rgb_anchors: anchors_rgb,
                rgb_inside_inds: inside_inds_rgb,
                learning_rate: CFG.TRAIN.LEARNING_RATE,
                IS_TRAIN_PHASE: True,
            }
            batch_rgb_probs, batch_deltas, batch_rgb_features = sess.run(
                [rgb_probs, rgb_deltas, rgb_features], fd1)

            rpn_nms = rpn_nms_generator(
                stride,
                rgb_shape[1],
                rgb_shape[0],
                img_scale=1,
                nms_thresh=0.7,
                min_size=stride,
                nms_pre_topn=CFG.TRAIN.RPN_NMS_PRE_TOPN,
                nms_post_topn=CFG.TRAIN.RPN_NMS_POST_TOPN)
            batch_proposals, batch_proposal_scores = rpn_nms(
                batch_rgb_probs, batch_deltas, anchors_rgb, inside_inds_rgb)

            ## generate  train rois  ------------
            batch_rgb_inds, batch_rgb_pos_inds, batch_rgb_labels, batch_rgb_targets  = \
                rpn_target ( anchors_rgb, inside_inds_rgb, batch_gt_labels,  batch_gt_boxes2d)

            batch_rgb_rois, batch_fuse_labels, batch_fuse_targets2d, batch_fuse_targets_3dTo2Ds = rcnn_target(
                batch_proposals, batch_gt_labels, batch_gt_boxes2d,
                batch_gt_3dTo2Ds, rgb_shape[1], rgb_shape[0])

            print('nums of rcnn batch: %d' % len(batch_rgb_rois))
            ##debug gt generation
            if CFG.TRAIN.VISUALIZATION and iter % iter_debug == 0:
                rgb = rgbs[idx]

                img_gt = draw_rpn_gt(rgb, batch_gt_boxes2d, batch_gt_labels)
                rgb_label = draw_rpn_labels(img_gt, anchors_rgb,
                                            batch_rgb_inds, batch_rgb_labels)
                rgb_target = draw_rpn_targets(rgb, anchors_rgb,
                                              batch_rgb_pos_inds,
                                              batch_rgb_targets)
                #imshow('img_rpn_gt',img_gt)
                imshow('img_rgb_label', rgb_label)
                imshow('img_rpn_target', rgb_target)

                img_label = draw_rcnn_labels(rgb, batch_rgb_rois,
                                             batch_fuse_labels)
                img_target = draw_rcnn_targets(rgb, batch_rgb_rois,
                                               batch_fuse_labels,
                                               batch_fuse_targets2d)
                imshow('img_rcnn_label', img_label)
                imshow('img_rcnn_target', img_target)

                img_rgb_rois = draw_boxes(rgb,
                                          batch_rgb_rois[:, 1:5],
                                          color=(255, 0, 255),
                                          thickness=1)
                imshow('img_rgb_rois', img_rgb_rois)

                projections = box_transform_3dTo2D_inv(
                    batch_rgb_rois[:, 1:], batch_fuse_targets_3dTo2Ds)
                img_rcnn_3dTo2D = draw_rgb_projections(rgb,
                                                       projections,
                                                       color=(0, 0, 255),
                                                       thickness=1)
                imshow('img_rcnn_3dTo2D', img_rcnn_3dTo2D)
                # plt.pause(0.5)
                # cv2.waitKey(500)
                cv2.waitKey(0)

            ## run classification and regression loss -----------
            fd2 = {
                **fd1, rgb_images: batch_rgb_images,
                rgb_rois: batch_rgb_rois,
                rgb_inds: batch_rgb_inds,
                rgb_pos_inds: batch_rgb_pos_inds,
                rgb_labels: batch_rgb_labels,
                rgb_targets: batch_rgb_targets,
                fuse_labels: batch_fuse_labels,
                fuse_targets: batch_fuse_targets2d,
                fuse_targets_3dTo2Ds: batch_fuse_targets_3dTo2Ds
            }

            _, rcnn_probs, batch_rgb_cls_loss, batch_rgb_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss, batch_fuse_reg_loss_dTo2D = \
               sess.run([solver_step, fuse_probs, rgb_cls_loss, rgb_reg_loss, fuse_cls_loss, fuse_reg_loss, fuse_reg_loss_3dTo2D],fd2)

            speed = time.time() - start_time
            log.write('%5.1f   %5d    %0.4fs   %0.6f   |   %0.5f   %0.5f   |   %0.5f   %0.5f  |%0.5f   \n' %\
                (epoch, iter, speed, CFG.TRAIN.LEARNING_RATE, batch_rgb_cls_loss, batch_rgb_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss, batch_fuse_reg_loss_dTo2D))

            if (iter) % 10 == 0:
                summary = sess.run(merged, fd2)
                train_writer.add_summary(summary, iter)
            # save: ------------------------------------

            if (iter) % 5000 == 0 and (iter != 0):
                saver.save(sess, out_dir + '/check_points/' +
                           CFG.PATH.TRAIN.CHECKPOINT_NAME +
                           '%06d.ckpt' % iter)  #iter
                # saver_rgb.save(sess, out_dir + '/check_points/pretrained_Res_rgb_model%06d.ckpt'%iter)
                # saver_top.save(sess, out_dir + '/check_points/pretrained_Res_top_model%06d.ckpt'%iter)
                # pdb.set_trace()

            idx = idx + 1