Ejemplo n.º 1
0
def run_train():

    # output dir, etc
    out_dir = '/root/sharefolder/sdcnd/didi1/output'
    # makedirs(out_dir +'/tf')
    # makedirs(out_dir +'/check_points')
    log = Logger(out_dir+'/log.txt',mode='a')
    # log.write(unicode('aaa {}'.format('aaa')))
    #lidar data -----------------
    if 1:
        ratios=np.array([0.5,1,2], dtype=np.float32)
        scales=np.array([1,2,3],   dtype=np.float32)
        bases = make_bases(
            base_size = 16,
            ratios=ratios,
            scales=scales
        )
        num_bases = len(bases)
        stride = 8

        num_frames = 154
        # num_frames = 2
        rgbs, tops, fronts, gt_labels, gt_boxes3d, top_imgs, front_imgs, lidars = load_dummy_datas(num_frames)
        num_frames = len(rgbs)

        top_shape   = tops[0].shape
        front_shape = fronts[0].shape
        rgb_shape   = rgbs[0].shape
        top_feature_shape = (top_shape[0]//stride, top_shape[1]//stride)
        out_shape=(8,3)


        #-----------------------
        #check data
        if 0:
            fig = mlab.figure(figure=None, bgcolor=(0,0,0), fgcolor=None, engine=None, size=(1000, 500))
            draw_lidar(lidars[0], fig=fig)
            draw_gt_boxes3d(gt_boxes3d[0], fig=fig)
            mlab.show(1)
            cv2.waitKey(1)




    # set anchor boxes
    num_class = 2 #incude background
    anchors, inside_inds =  make_anchors(bases, stride, top_shape[0:2], top_feature_shape[0:2])
    inside_inds = np.arange(0,len(anchors),dtype=np.int32)  #use all  #<todo>
    print ('out_shape=%s'%str(out_shape))
    print ('num_frames=%d'%num_frames)


    #load model ####################################################################################################
    top_anchors     = tf.placeholder(shape=[None, 4], dtype=tf.int32,   name ='anchors'    )
    top_inside_inds = tf.placeholder(shape=[None   ], dtype=tf.int32,   name ='inside_inds')

    top_images   = tf.placeholder(shape=[None, 400, 400, 8 ], dtype=tf.float32, name='input_top'  )
    front_images = tf.placeholder(shape=[None, 1, 1], dtype=tf.float32, name='front')
    rgb_images   = tf.placeholder(shape=[None, 375, 1242, 3  ], dtype=tf.float32, name='rgb'  )
    top_rois     = tf.placeholder(shape=[None, 5], dtype=tf.float32,   name ='top_rois'   ) #<todo> change to int32???
    front_rois   = tf.placeholder(shape=[None, 5], dtype=tf.float32,   name ='front_rois' )
    rgb_rois     = tf.placeholder(shape=[None, 5], dtype=tf.float32,   name ='rgb_rois'   )

    top_features, top_scores, top_probs, top_deltas, proposals, proposal_scores = \
        top_feature_net(top_images, top_anchors, top_inside_inds, num_bases)

    front_features = front_feature_net(front_images)
    rgb_features   = rgb_feature_net(rgb_images)

    # import pdb; pdb.set_trace()
    fuse_scores, fuse_probs, fuse_deltas, aux_fuse_scores, aux_fuse_probs, aux_fuse_deltas = \
        fusion_net(
			( [top_features,     top_rois,     6,6,1./stride],
			  [front_features,   front_rois,   0,0,1./stride],  #disable by 0,0
			  [rgb_features,     rgb_rois,     6,6,1./stride],),
            num_class, out_shape) #<todo>  add non max suppression

    # import pdb; pdb.set_trace()


    #loss ########################################################################################################
    top_inds     = tf.placeholder(shape=[None   ], dtype=tf.int32,   name='top_ind'    )
    top_pos_inds = tf.placeholder(shape=[None   ], dtype=tf.int32,   name='top_pos_ind')
    top_labels   = tf.placeholder(shape=[None   ], dtype=tf.int32,   name='top_label'  )
    top_targets  = tf.placeholder(shape=[None, 4], dtype=tf.float32, name='top_target' )
    with tf.variable_scope('rpn-loss') as scope:
        top_cls_loss, top_reg_loss = rpn_loss(top_scores, top_deltas, top_inds, top_pos_inds, top_labels, top_targets)
    tf.summary.scalar('top_cls_loss', top_cls_loss)
    tf.summary.scalar('top_reg_loss', top_reg_loss)


    fuse_labels  = tf.placeholder(shape=[None            ], dtype=tf.int32,   name='fuse_label' )
    fuse_targets = tf.placeholder(shape=[None, 8, 3], dtype=tf.float32, name='fuse_target')
    with tf.variable_scope('rcnn-loss') as scope:
        fuse_cls_loss, fuse_reg_loss = rcnn_loss(fuse_scores, fuse_deltas, fuse_labels, fuse_targets)
    tf.summary.scalar('fuse_cls_loss', fuse_cls_loss)
    tf.summary.scalar('fuse_reg_loss', fuse_reg_loss)

    with tf.variable_scope('aux_rcnn_loss') as scope:
        with tf.variable_scope('aux_loss_1') as scope:
            aux_fuse_cls_loss_1, aux_fuse_reg_loss_1 = rcnn_loss(aux_fuse_scores[0], aux_fuse_deltas[0],
             fuse_labels, fuse_targets)
        tf.summary.scalar('aux_fuse_cls_loss_1', aux_fuse_cls_loss_1)
        tf.summary.scalar('aux_fuse_reg_loss_1', aux_fuse_reg_loss_1)
        with tf.variable_scope('aux_loss_2') as scope:
            aux_fuse_cls_loss_2, aux_fuse_reg_loss_2 = rcnn_loss(aux_fuse_scores[1], aux_fuse_deltas[1],
             fuse_labels, fuse_targets)
        tf.summary.scalar('aux_fuse_cls_loss_2', aux_fuse_cls_loss_2)
        tf.summary.scalar('aux_fuse_reg_loss_2', aux_fuse_reg_loss_1)
    #solver
    # with tf.variable_scope('l2-reg') as scope:
    #     l2 = l2_regulariser(decay=0.0005)
    # tf.summary.scalar('total_l2reg', l2)

    with tf.variable_scope('total_loss') as scope:
        total_loss = top_cls_loss+top_reg_loss+fuse_cls_loss+0.1*fuse_reg_loss \
                    + aux_fuse_cls_loss_1 + aux_fuse_reg_loss_1 + aux_fuse_cls_loss_2 + aux_fuse_reg_loss_2
    tf.summary.scalar('total_loss', total_loss)

    learning_rate = tf.placeholder(tf.float32, shape=[])
    solver = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
    #solver_step = solver.minimize(top_cls_loss+top_reg_loss+l2)
    solver_step = solver.minimize(total_loss)

    max_iter = 10000
    iter_debug=8

    # start training here  #########################################################################################
    log.write(unicode('epoch     iter    rate   |  top_cls_loss   reg_loss   |  fuse_cls_loss  reg_loss  |  \n'))
    log.write(unicode('-------------------------------------------------------------------------------------\n'))

    num_ratios=len(ratios)
    num_scales=len(scales)
    # fig, axs = plt.subplots(num_ratios,num_scales)

    sess = tf.InteractiveSession()
    # with sess.as_default():
    merged = tf.summary.merge_all()

    log_dir = out_dir+'/train'
    if tf.gfile.Exists(log_dir):
        #gotta be careful
        tf.gfile.DeleteRecursively(log_dir)
        print 'Removed files in {}'.format(log_dir)
    train_writer = tf.summary.FileWriter(log_dir, sess.graph)

    saver  = tf.train.Saver()
    tf.global_variables_initializer().run()
    # sess.run( tf.global_variables_initializer(), { IS_TRAIN_PHASE : True } )

    #option: loading pretrained model
    saver.restore(sess, '/root/sharefolder/sdcnd/didi1/output/check_points/snap.ckpt')

    batch_top_cls_loss =0
    batch_top_reg_loss =0
    batch_fuse_cls_loss=0
    batch_fuse_reg_loss=0
    for iter in range(max_iter):
        epoch=1.0*iter
        rate=0.05


        ## generate train image -------------
        idx = np.random.choice(num_frames)     #*10   #num_frames)  #0
        idx = 87
        batch_top_images    = tops[idx].reshape(1,*top_shape)
        batch_front_images  = fronts[idx].reshape(1,*front_shape)
        batch_rgb_images    = rgbs[idx].reshape(1,*rgb_shape)

        batch_gt_labels    = gt_labels[idx]
        batch_gt_boxes3d   = gt_boxes3d[idx]
        batch_gt_top_boxes = box3d_to_top_box(batch_gt_boxes3d)


		## run propsal generation ------------
        fd1={
            top_images:      batch_top_images,
            top_anchors:     anchors,
            top_inside_inds: inside_inds,

            learning_rate:   rate,
            IS_TRAIN_PHASE:  True
        }
        batch_proposals, batch_proposal_scores, batch_top_features = sess.run([proposals, proposal_scores, top_features],fd1)

        ## generate  train rois  ------------
        batch_top_inds, batch_top_pos_inds, batch_top_labels, batch_top_targets  = \
            rpn_target ( anchors, inside_inds, batch_gt_labels,  batch_gt_top_boxes)

        batch_top_rois, batch_fuse_labels, batch_fuse_targets  = \
             rcnn_target(  batch_proposals, batch_gt_labels, batch_gt_top_boxes, batch_gt_boxes3d )

        batch_rois3d	 = project_to_roi3d    (batch_top_rois)
        batch_front_rois = project_to_front_roi(batch_rois3d  )
        batch_rgb_rois   = project_to_rgb_roi  (batch_rois3d  )


        ##debug gt generation
        if False:
        # if 1 and iter%iter_debug==0:
            top_image = top_imgs[idx]
            rgb       = rgbs[idx]

            img_gt     = draw_rpn_gt(top_image, batch_gt_top_boxes, batch_gt_labels)
            img_label  = draw_rpn_labels (top_image, anchors, batch_top_inds, batch_top_labels )
            img_target = draw_rpn_targets(top_image, anchors, batch_top_pos_inds, batch_top_targets)
            #imshow('img_rpn_gt',img_gt)
            #imshow('img_rpn_label',img_label)
            #imshow('img_rpn_target',img_target)

            img_label  = draw_rcnn_labels (top_image, batch_top_rois, batch_fuse_labels )
            img_target = draw_rcnn_targets(top_image, batch_top_rois, batch_fuse_labels, batch_fuse_targets)
            #imshow('img_rcnn_label',img_label)
            imshow('img_rcnn_target',img_target)


            img_rgb_rois = draw_boxes(rgb, batch_rgb_rois[:,1:5], color=(255,0,255), thickness=1)
            imshow('img_rgb_rois',img_rgb_rois)

            cv2.waitKey(1)

        ## run classification and regression loss -----------
        fd2={
			top_images:      batch_top_images,
            top_anchors:     anchors,
            top_inside_inds: inside_inds,

            learning_rate:   rate,
            IS_TRAIN_PHASE:  True,

            top_images: batch_top_images,
            front_images: batch_front_images,
            rgb_images: batch_rgb_images,

			top_rois:   batch_top_rois,
            front_rois: batch_front_rois,
            rgb_rois:   batch_rgb_rois,

            top_inds:     batch_top_inds,
            top_pos_inds: batch_top_pos_inds,
            top_labels:   batch_top_labels,
            top_targets:  batch_top_targets,

            fuse_labels:  batch_fuse_labels,
            fuse_targets: batch_fuse_targets,
        }
        #_, batch_top_cls_loss, batch_top_reg_loss = sess.run([solver_step, top_cls_loss, top_reg_loss],fd2)

        # import pdb; pdb.set_trace()
        # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        # run_metadata = tf.RunMetadata()
        run_options = None
        run_metadata = None
        _, summary, batch_top_cls_loss, batch_top_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss = \
           sess.run([solver_step, merged, top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss],
                    feed_dict = fd2,
                    options = run_options,
                    run_metadata = run_metadata)
        # train_writer.add_run_metadata(run_metadata, 'step%03d' % iter)
        train_writer.add_summary(summary, iter)
        train_writer.flush()

        log.write(unicode('%3.1f   %d   %0.4f   |   %0.5f   %0.5f   |   %0.5f   %0.5f  \n' %\
			(epoch, iter, rate, batch_top_cls_loss, batch_top_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss)))






        #print('ok')
        # debug: ------------------------------------
        if iter%10==0:
            top_image = top_imgs[idx]
            rgb       = rgbs[idx]


            batch_fuse_probs, batch_fuse_deltas = \
                sess.run([ fuse_probs, fuse_deltas ],fd2)

            #batch_fuse_deltas=0*batch_fuse_deltas #disable 3d box prediction
            probs, boxes3d = rcnn_nms(batch_fuse_probs, batch_fuse_deltas, batch_rois3d, threshold=0.5)


            ## show rcnn(fuse) nms
            gt_2d_box = batch_gt_top_boxes
            img_rcnn     = draw_rcnn (top_image, batch_fuse_probs, batch_fuse_deltas, batch_top_rois, batch_rois3d, gt_2d_box,darker=1)
            boxes_3d = rcnn_result( batch_fuse_probs, batch_fuse_deltas,  batch_top_rois, batch_rois3d, gt_2d_box)
            img_rcnn_nms = draw_rcnn_nms(rgb, boxes3d, probs)
            imshow('img_rcnn',img_rcnn)
            imshow('img_rcnn_nms',img_rcnn_nms)
            # cv2.imwrite('result.png', img_rcnn_nms)
            cv2.waitKey(1)

        if False:
        # if iter%100==0:
        # if iter%iter_debug==0:
            top_image = top_imgs[idx]
            rgb       = rgbs[idx]

            batch_top_probs, batch_top_scores, batch_top_deltas  = \
                sess.run([ top_probs, top_scores, top_deltas ],fd2)

            batch_fuse_probs, batch_fuse_deltas = \
                sess.run([ fuse_probs, fuse_deltas ],fd2)

            #batch_fuse_deltas=0*batch_fuse_deltas #disable 3d box prediction
            probs, boxes3d = rcnn_nms(batch_fuse_probs, batch_fuse_deltas, batch_rois3d, threshold=0.5)


            ## show rpn score maps
            # import pdb; pdb.set_trace()
            fig, axs = plt.subplots(num_ratios,num_scales)
            p = batch_top_probs.reshape( 50, 50, 2*num_bases)
            for n in range(num_bases):
                r=n%num_scales
                s=n//num_scales
                pn = p[:,:,2*n+1]*255
                axs[s,r].cla()
                axs[s,r].imshow(pn, cmap='gray', vmin=0, vmax=255)
            plt.pause(0.01)

			## show rpn(top) nms
            img_rpn     = draw_rpn    (top_image, batch_top_probs, batch_top_deltas, anchors, inside_inds)
            img_rpn_nms = draw_rpn_nms(top_image, batch_proposals, batch_proposal_scores)
            #imshow('img_rpn',img_rpn)
            imshow('img_rpn_nms',img_rpn_nms)
            cv2.waitKey(1)

            ## show rcnn(fuse) nms
            img_rcnn     = draw_rcnn (top_image, batch_fuse_probs, batch_fuse_deltas, batch_top_rois, batch_rois3d,darker=1)
            img_rcnn_nms = draw_rcnn_nms(rgb, boxes3d, probs)
            imshow('img_rcnn',img_rcnn)
            imshow('img_rcnn_nms',img_rcnn_nms)
            cv2.waitKey(1)

        # save: ------------------------------------
        if iter%500==0:
            #saver.save(sess, out_dir + '/check_points/%06d.ckpt'%iter)  #iter
            saver.save(sess, out_dir + '/check_points/snap.ckpt')  #iter

    train_writer.close()
Ejemplo n.º 2
0
def run_train():

    # output dir, etc
    out_dir = '/root/share/out/didi/xxx'
    makedirs(out_dir + '/tf')
    makedirs(out_dir + '/check_points')
    log = Logger(out_dir + '/log.txt', mode='a')

    #one lidar data -----------------
    if 1:
        ratios = np.array([0.5, 1, 2], dtype=np.float32)
        scales = np.array([1, 2, 3], dtype=np.float32)
        bases = make_bases(base_size=16, ratios=ratios, scales=scales)
        num_bases = len(bases)
        stride = 8

        rgb, top, front, gt_labels, gt_boxes3d, top_image, front_image, lidar = load_dummy_data(
        )
        top_shape = top.shape
        front_shape = front.shape
        rgb_shape = rgb.shape

        top_feature_shape = (top_shape[0] // stride, top_shape[1] // stride)
        out_shape = (8, 3)

        #-----------------------
        #check data
        if 1:
            fig = mlab.figure(figure=None,
                              bgcolor=(0, 0, 0),
                              fgcolor=None,
                              engine=None,
                              size=(1000, 500))
            draw_lidar(lidar, fig=fig)
            draw_gt_boxes3d(gt_boxes3d, fig=fig)
            mlab.show(1)
            cv2.waitKey(1)

    # set anchor boxes
    num_class = 2  #incude background
    anchors, inside_inds = make_anchors(bases, stride, top_shape[0:2],
                                        top_feature_shape[0:2])
    inside_inds = np.arange(0, len(anchors), dtype=np.int32)  #use all  #<todo>
    print('out_shape=%s' % str(out_shape))

    #load model ####################################################################################################
    top_anchors = tf.placeholder(shape=[None, 4],
                                 dtype=tf.int32,
                                 name='anchors')
    top_inside_inds = tf.placeholder(shape=[None],
                                     dtype=tf.int32,
                                     name='inside_inds')

    top_images = tf.placeholder(shape=[None, *top_shape],
                                dtype=tf.float32,
                                name='top')
    front_images = tf.placeholder(shape=[None, *front_shape],
                                  dtype=tf.float32,
                                  name='front')
    rgb_images = tf.placeholder(shape=[None, *rgb_shape],
                                dtype=tf.float32,
                                name='rgb')
    top_rois = tf.placeholder(shape=[None, 5],
                              dtype=tf.float32,
                              name='top_rois')  #<todo> change to int32???
    front_rois = tf.placeholder(shape=[None, 5],
                                dtype=tf.float32,
                                name='front_rois')
    rgb_rois = tf.placeholder(shape=[None, 5],
                              dtype=tf.float32,
                              name='rgb_rois')

    top_features, top_scores, top_probs, top_deltas, proposals, proposal_scores = \
        top_lidar_feature_net(top_images, top_anchors, top_inside_inds, num_bases)

    front_features = front_feature_net(front_images)
    rgb_features = rgb_feature_net(rgb_images)

    fuse_scores, fuse_probs, fuse_deltas = \
        fusion_net(
    ( [top_features,     top_rois,     6,6,1./stride],
     [front_features,   front_rois,   0,0,1./stride],  #disable by 0,0
     [rgb_features,     rgb_rois,     6,6,1./stride],),
            num_class, out_shape) #<todo>  add non max suppression

    #loss ########################################################################################################
    top_inds = tf.placeholder(shape=[None], dtype=tf.int32, name='top_ind')
    top_pos_inds = tf.placeholder(shape=[None],
                                  dtype=tf.int32,
                                  name='top_pos_ind')
    top_labels = tf.placeholder(shape=[None], dtype=tf.int32, name='top_label')
    top_targets = tf.placeholder(shape=[None, 4],
                                 dtype=tf.float32,
                                 name='top_target')
    top_cls_loss, top_reg_loss = rpn_loss(top_scores, top_deltas, top_inds,
                                          top_pos_inds, top_labels,
                                          top_targets)

    fuse_labels = tf.placeholder(shape=[None],
                                 dtype=tf.int32,
                                 name='fuse_label')
    fuse_targets = tf.placeholder(shape=[None, *out_shape],
                                  dtype=tf.float32,
                                  name='fuse_target')
    fuse_cls_loss, fuse_reg_loss = rcnn_loss(fuse_scores, fuse_deltas,
                                             fuse_labels, fuse_targets)

    #put your solver here
    l2 = l2_regulariser(decay=0.0005)
    learning_rate = tf.placeholder(tf.float32, shape=[])
    solver = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                        momentum=0.9)
    #solver_step = solver.minimize(top_cls_loss+top_reg_loss+l2)
    solver_step = solver.minimize(top_cls_loss + top_reg_loss + fuse_cls_loss +
                                  fuse_reg_loss + l2)

    max_iter = 10000

    # start training here  #########################################################################################
    log.write(
        'epoch     iter    rate   |  top_cls_loss   reg_loss   |  fuse_cls_loss  reg_loss  |  \n'
    )
    log.write(
        '-------------------------------------------------------------------------------------\n'
    )

    num_ratios = len(ratios)
    num_scales = len(scales)
    fig, axs = plt.subplots(num_ratios, num_scales)

    sess = tf.InteractiveSession()
    with sess.as_default():
        sess.run(tf.global_variables_initializer(), {IS_TRAIN_PHASE: True})
        summary_writer = tf.summary.FileWriter(out_dir + '/tf', sess.graph)
        saver = tf.train.Saver()

        batch_top_cls_loss = 0
        batch_top_reg_loss = 0
        batch_fuse_cls_loss = 0
        batch_fuse_reg_loss = 0
        for iter in range(max_iter):
            epoch = 1.0 * iter
            rate = 0.1

            ## generate train image -------------
            batch_top_images = top.reshape(1, *top_shape)
            batch_front_images = front.reshape(1, *front_shape)
            batch_rgb_images = rgb.reshape(1, *rgb_shape)

            batch_gt_labels = gt_labels
            batch_gt_boxes3d = gt_boxes3d
            batch_gt_top_boxes = box3d_to_top_box(gt_boxes3d)

            ## run propsal generation ------------
            fd1 = {
                top_images: batch_top_images,
                top_anchors: anchors,
                top_inside_inds: inside_inds,
                learning_rate: rate,
                IS_TRAIN_PHASE: True
            }
            batch_proposals, batch_proposal_scores, batch_top_features = sess.run(
                [proposals, proposal_scores, top_features], fd1)

            ## generate  train rois  ------------
            batch_top_inds, batch_top_pos_inds, batch_top_labels, batch_top_targets  = \
                rpn_target ( anchors, inside_inds, batch_gt_labels,  batch_gt_top_boxes)

            batch_top_rois, batch_fuse_labels, batch_fuse_targets  = \
                 rcnn_target(  batch_proposals, batch_gt_labels, batch_gt_top_boxes, batch_gt_boxes3d )

            batch_rois3d = project_to_roi3d(batch_top_rois)
            batch_front_rois = project_to_front_roi(batch_rois3d)
            batch_rgb_rois = project_to_rgb_roi(batch_rois3d)

            ##debug gt generation
            if 1:
                img_gt = draw_rpn_gt(top_image, batch_gt_top_boxes,
                                     batch_gt_labels)
                img_label = draw_rpn_labels(top_image, anchors, batch_top_inds,
                                            batch_top_labels)
                img_target = draw_rpn_targets(top_image, anchors,
                                              batch_top_pos_inds,
                                              batch_top_targets)
                imshow('img_rpn_gt', img_gt)
                imshow('img_rpn_label', img_label)
                imshow('img_rpn_target', img_target)

                img_label = draw_rcnn_labels(top_image, batch_top_rois,
                                             batch_fuse_labels)
                img_target = draw_rcnn_targets(top_image, batch_top_rois,
                                               batch_fuse_labels,
                                               batch_fuse_targets)
                imshow('img_rcnn_label', img_label)
                imshow('img_rcnn_target', img_target)

                img_rgb_rois = draw_boxes(rgb,
                                          batch_rgb_rois[:, 1:5],
                                          thickness=1)
                imshow('img_rgb_rois', img_rgb_rois)

                cv2.waitKey(1)

            ## run classification and regression loss -----------
            fd2 = {
                **fd1,
                top_images: batch_top_images,
                front_images: batch_front_images,
                rgb_images: batch_rgb_images,
                top_rois: batch_top_rois,
                front_rois: batch_front_rois,
                rgb_rois: batch_rgb_rois,
                top_inds: batch_top_inds,
                top_pos_inds: batch_top_pos_inds,
                top_labels: batch_top_labels,
                top_targets: batch_top_targets,
                fuse_labels: batch_fuse_labels,
                fuse_targets: batch_fuse_targets,
            }
            #_, batch_top_cls_loss, batch_top_reg_loss = sess.run([solver_step, top_cls_loss, top_reg_loss],fd2)


            _, batch_top_cls_loss, batch_top_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss = \
               sess.run([solver_step, top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss],fd2)

            log.write('%3.1f   %d   %0.4f   |   %0.5f   %0.5f   |   %0.5f   %0.5f  \n' %\
    (epoch, iter, rate, batch_top_cls_loss, batch_top_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss))

            #print('ok')
            # debug: ------------------------------------

            if iter % 4 == 0:
                batch_top_probs, batch_top_scores, batch_top_deltas  = \
                    sess.run([ top_probs, top_scores, top_deltas ],fd2)

                batch_fuse_probs, batch_fuse_deltas = \
                    sess.run([ fuse_probs, fuse_deltas ],fd2)

                probs, boxes3d = rcnn_nms(batch_fuse_probs, batch_fuse_deltas,
                                          batch_rois3d)

                ## show rpn score maps
                p = batch_top_probs.reshape(*(top_feature_shape[0:2]),
                                            2 * num_bases)
                for n in range(num_bases):
                    r = n % num_scales
                    s = n // num_scales
                    pn = p[:, :, 2 * n + 1] * 255
                    axs[s, r].cla()
                    axs[s, r].imshow(pn, cmap='gray', vmin=0, vmax=255)
                plt.pause(0.01)

                ## show rpn(top) nms
                img_rpn = draw_rpn(top_image, batch_top_probs,
                                   batch_top_deltas, anchors, inside_inds)
                img_rpn_nms = draw_rpn_nms(top_image, batch_proposals,
                                           batch_proposal_scores)
                imshow('img_rpn', img_rpn)
                imshow('img_rpn_nms', img_rpn_nms)
                cv2.waitKey(1)

                ## show rcnn(fuse) nms
                img_rcnn = draw_rcnn(top_image, batch_fuse_probs,
                                     batch_fuse_deltas, batch_top_rois,
                                     batch_rois3d)
                img_rcnn_nms = draw_rcnn_nms(rgb, boxes3d, probs)
                imshow('img_rcnn', img_rcnn)
                imshow('img_rcnn_nms', img_rcnn_nms)
                cv2.waitKey(1)

            # save: ------------------------------------
            if iter % 500 == 0:
                #saver.save(sess, out_dir + '/check_points/%06d.ckpt'%iter)  #iter
                saver.save(sess, out_dir + '/check_points/snap.ckpt')  #iter
Ejemplo n.º 3
0
def run_train():

    # output dir, etc
    out_dir = './outputs'
    makedirs(out_dir +'/tf')
    makedirs(out_dir +'/check_points')
    log = Logger(out_dir+'/log_%s.txt'%(time.strftime('%Y-%m-%d %H:%M:%S')),mode='a')
    index=np.load(data_root+'seg/train_list.npy')
    index=sorted(index)
    index=np.array(index)
    num_frames = len(index)
    # pdb.set_trace()
    #lidar data -----------------
    if 1:
        ###generate anchor base 
        # ratios=np.array([0.4,0.6,1.7,2.4], dtype=np.float32)
        # scales=np.array([0.5,1,2,3],   dtype=np.float32)
        # bases = make_bases(
        #     base_size = 16,
        #     ratios=ratios,
        #     scales=scales
        # )
        ratios=np.array([1.7,2.4])
        scales=np.array([1.7,2.4])
        bases=np.array([[-19.5, -8, 19.5, 8],
                        [-8, -19.5, 8, 19.5],
                        [-5, -3, 5, 3],
                        [-3, -5, 3, 5]
                        ])
        # pdb.set_trace()
        num_bases = len(bases)
        stride = 4

        out_shape=(8,3)


        rgbs, tops, fronts, gt_labels, gt_boxes3d, top_imgs, front_imgs, rgbs_norm, image_index = load_dummy_datas(index[:3])
        # rgbs, tops, fronts, gt_labels, gt_boxes3d, top_imgs, front_imgs, rgbs_norm, image_index, lidars = load_dummy_datas()
        top_shape   = tops[0].shape
        front_shape = fronts[0].shape
        rgb_shape   = rgbs[0].shape
        top_feature_shape = ((top_shape[0]-1)//stride+1, (top_shape[1]-1)//stride+1)
        # pdb.set_trace()
        # set anchor boxes
        num_class = 2 #incude background
        anchors, inside_inds =  make_anchors(bases, stride, top_shape[0:2], top_feature_shape[0:2])
        # inside_inds = np.arange(0,len(anchors),dtype=np.int32)  #use all  #<todo>
        print ('out_shape=%s'%str(out_shape))
        print ('num_frames=%d'%num_frames)

        #-----------------------
        #check data
        if 0:
            fig = mlab.figure(figure=None, bgcolor=(0,0,0), fgcolor=None, engine=None, size=(1000, 500))
            draw_lidar(lidars[0], fig=fig)
            draw_gt_boxes3d(gt_boxes3d[0], fig=fig)
            mlab.show(1)
            cv2.waitKey(1)




    #load model ####################################################################################################
    top_anchors     = tf.placeholder(shape=[None, 4], dtype=tf.int32,   name ='anchors'    )
    top_inside_inds = tf.placeholder(shape=[None   ], dtype=tf.int32,   name ='inside_inds')

    top_images   = tf.placeholder(shape=[None, *top_shape  ], dtype=tf.float32, name='top'  )
    front_images = tf.placeholder(shape=[None, *front_shape], dtype=tf.float32, name='front')
    rgb_images   = tf.placeholder(shape=[None, None, None, 3 ], dtype=tf.float32, name='rgb'  )
    top_rois     = tf.placeholder(shape=[None, 5], dtype=tf.float32,   name ='top_rois'   ) #<todo> change to int32???
    front_rois   = tf.placeholder(shape=[None, 5], dtype=tf.float32,   name ='front_rois' )
    rgb_rois     = tf.placeholder(shape=[None, 5], dtype=tf.float32,   name ='rgb_rois'   )

    top_features, top_scores, top_probs, top_deltas, proposals, proposal_scores = \
        top_feature_net(top_images, top_anchors, top_inside_inds, num_bases)
    # pdb.set_trace()
    front_features = front_feature_net(front_images)
    rgb_features   = rgb_feature_net(rgb_images)

    fuse_scores, fuse_probs, fuse_deltas = \
        fusion_net(
			( [top_features,     top_rois,     6,6,1./stride],
			  [front_features,   front_rois,   0,0,1./stride],  #disable by 0,0
			  [rgb_features,     rgb_rois,     6,6,1./(2*stride)],),
            num_class, out_shape) #<todo>  add non max suppression



    #loss ########################################################################################################
    top_inds     = tf.placeholder(shape=[None   ], dtype=tf.int32,   name='top_ind'    )
    top_pos_inds = tf.placeholder(shape=[None   ], dtype=tf.int32,   name='top_pos_ind')
    top_labels   = tf.placeholder(shape=[None   ], dtype=tf.int32,   name='top_label'  )
    top_targets  = tf.placeholder(shape=[None, 4], dtype=tf.float32, name='top_target' )
    top_cls_loss, top_reg_loss = rpn_loss(2*top_scores, top_deltas, top_inds, top_pos_inds, top_labels, top_targets)

    fuse_labels  = tf.placeholder(shape=[None            ], dtype=tf.int32,   name='fuse_label' )
    fuse_targets = tf.placeholder(shape=[None, *out_shape], dtype=tf.float32, name='fuse_target')
    fuse_cls_loss, fuse_reg_loss = rcnn_loss(fuse_scores, fuse_deltas, fuse_labels, fuse_targets)
    tf.summary.scalar('rpn_cls_loss', top_cls_loss)
    tf.summary.scalar('rpn_reg_loss', top_reg_loss)
    tf.summary.scalar('rcnn_cls_loss', fuse_cls_loss)
    tf.summary.scalar('rcnn_reg_loss', fuse_reg_loss)

    #solver
    l2 = l2_regulariser(decay=0.000005)
    tf.summary.scalar('l2', l2)
    learning_rate = tf.placeholder(tf.float32, shape=[])
    solver = tf.train.AdamOptimizer(learning_rate)
    # solver = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
    #solver_step = solver.minimize(top_cls_loss+top_reg_loss+l2)
    solver_step = solver.minimize(1*top_cls_loss+1*top_reg_loss+1.5*fuse_cls_loss+2*fuse_reg_loss+l2)

    max_iter = 200000
    iter_debug=1

    # start training here  #########################################################################################
    log.write('epoch     iter    speed   rate   |  top_cls_loss   reg_loss   |  fuse_cls_loss  reg_loss  |  \n')
    log.write('-------------------------------------------------------------------------------------\n')

    num_ratios=len(ratios)
    num_scales=len(scales)
    fig, axs = plt.subplots(num_ratios,num_scales)

    merged = tf.summary.merge_all()

    sess = tf.InteractiveSession()
    train_writer = tf.summary.FileWriter( './outputs/tensorboard/Res_Vgg_up',
                                      sess.graph)
    with sess.as_default():
        sess.run( tf.global_variables_initializer(), { IS_TRAIN_PHASE : True } )
        # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        # summary_writer = tf.summary.FileWriter(out_dir+'/tf', sess.graph)
        saver  = tf.train.Saver() 
        saver.restore(sess, './outputs/check_points/snap_ResNet_vgg_up_NGT_060000.ckpt') 
        # # saver.restore(sess, './outputs/check_points/MobileNet.ckpt')  

        # var_lt_res=[v for v in tf.trainable_variables() if v.name.startswith('res')]#resnet_v1_50
        # # pdb.set_trace()
        # ## var_lt=[v for v in tf.trainable_variables() if not(v.name.startswith('fuse-block-1')) and not(v.name.startswith('fuse')) and not(v.name.startswith('fuse-input'))]

        # # # var_lt.pop(0)
        # # # var_lt.pop(0)
        # # # pdb.set_trace()
        # saver_0=tf.train.Saver(var_lt_res)        
        # # # 
        # saver_0.restore(sess, './outputs/check_points/resnet_v1_50.ckpt')
        # # pdb.set_trace()
        # top_lt=[v for v in tf.trainable_variables() if v.name.startswith('top_base')]
        # top_lt.pop(0)
        # # # top_lt.pop(0)
        # for v in top_lt:
        #     # pdb.set_trace()
        #     for v_rgb in var_lt:
        #         if v.name[9:]==v_rgb.name:
        #             print ("assign weights:%s"%v.name)
        #             v.assign(v_rgb)
        # var_lt_vgg=[v for v in tf.trainable_variables() if v.name.startswith('vgg')]
        # var_lt_vgg.pop(0)
        # saver_1=tf.train.Saver(var_lt_vgg)
        
        # # pdb.set_trace()
        # saver_1.restore(sess, './outputs/check_points/vgg_16.ckpt')

        batch_top_cls_loss =0
        batch_top_reg_loss =0
        batch_fuse_cls_loss=0
        batch_fuse_reg_loss=0
        rate=0.000005
        frame_range = np.arange(num_frames)
        idx=0
        frame=0
        for iter in range(max_iter):
            epoch=iter//num_frames+1
            # rate=0.001
            start_time=time.time()


            # generate train image -------------
            # idx = np.random.choice(num_frames)     #*10   #num_frames)  #0
            # shuffle the samples every 4*num_frames
            if iter%(num_frames*2)==0:
                idx=0
                frame=0
                count=0
                end_flag=0
                frame_range1 = np.random.permutation(num_frames)
                if np.all(frame_range1==frame_range):
                    raise Exception("Invalid level!", permutation)
                frame_range=frame_range1

            #load 500 samples every 2000 iterations
            freq=int(200)
            if idx%freq==0 :
                count+=idx
                if count%(2*freq)==0:
                    frame+=idx
                    frame_end=min(frame+freq,num_frames)
                    if frame_end==num_frames:
                        end_flag=1
                    # pdb.set_trace()
                    del rgbs, tops, fronts, gt_labels, gt_boxes3d, top_imgs, front_imgs, rgbs_norm, image_index
                    rgbs, tops, fronts, gt_labels, gt_boxes3d, top_imgs, front_imgs, rgbs_norm, image_index = load_dummy_datas(index[frame_range[frame:frame_end]])
                idx=0
            if (end_flag==1) and (idx+frame)==num_frames:
                idx=0
            print('processing image : %s'%image_index[idx])

            if (iter+1)%(10000)==0:
                rate=0.8*rate

            rgb_shape   = rgbs[idx].shape
            batch_top_images    = tops[idx].reshape(1,*top_shape)
            batch_front_images  = fronts[idx].reshape(1,*front_shape)
            batch_rgb_images    = rgbs_norm[idx].reshape(1,*rgb_shape)
            # batch_rgb_images    = rgbs[idx].reshape(1,*rgb_shape)

            top_img=tops[idx]
            # pdb.set_trace()
            inside_inds_filtered=anchor_filter(top_img[:,:,-1], anchors, inside_inds)

            # pdb.set_trace()
            batch_gt_labels    = gt_labels[idx]
            if len(batch_gt_labels)==0:
                # pdb.set_trace()
                idx=idx+1
                continue
            batch_gt_boxes3d   = gt_boxes3d[idx]
            # pdb.set_trace()
            batch_gt_top_boxes = box3d_to_top_box(batch_gt_boxes3d)




			## run propsal generation ------------
            fd1={
                top_images:      batch_top_images,
                top_anchors:     anchors,
                top_inside_inds: inside_inds_filtered,

                learning_rate:   rate,
                IS_TRAIN_PHASE:  True
            }
            batch_proposals, batch_proposal_scores, batch_top_features = sess.run([proposals, proposal_scores, top_features],fd1)
            print(batch_proposal_scores[:50])
            # pdb.set_trace()
            ## generate  train rois  ------------
            batch_top_inds, batch_top_pos_inds, batch_top_labels, batch_top_targets  = \
                rpn_target ( anchors, inside_inds_filtered, batch_gt_labels,  batch_gt_top_boxes)

            batch_top_rois, batch_fuse_labels, batch_fuse_targets  = \
                 rcnn_target(  batch_proposals, batch_gt_labels, batch_gt_top_boxes, batch_gt_boxes3d )

            batch_rois3d	 = project_to_roi3d    (batch_top_rois)
            batch_front_rois = project_to_front_roi(batch_rois3d  )
            batch_rgb_rois   = project_to_rgb_roi  (batch_rois3d  )


            # keep = np.where((batch_rgb_rois[:,1]>=-200) & (batch_rgb_rois[:,2]>=-200) & (batch_rgb_rois[:,3]<=(rgb_shape[1]+200)) & (batch_rgb_rois[:,4]<=(rgb_shape[0]+200)))[0]
            # batch_rois3d        = batch_rois3d[keep]      
            # batch_front_rois    = batch_front_rois[keep]
            # batch_rgb_rois      = batch_rgb_rois[keep]  
            # batch_proposal_scores=batch_proposal_scores[keep]
            # batch_top_rois      =batch_top_rois[keep]

            if len(batch_rois3d)==0:
                # pdb.set_trace()
                idx=idx+1
                continue




            ##debug gt generation
            if vis and iter%iter_debug==0:
                top_image = top_imgs[idx]
                rgb       = rgbs[idx]

                img_gt     = draw_rpn_gt(top_image, batch_gt_top_boxes, batch_gt_labels)
                img_label  = draw_rpn_labels (img_gt, anchors, batch_top_inds, batch_top_labels )
                img_target = draw_rpn_targets(top_image, anchors, batch_top_pos_inds, batch_top_targets)
                #imshow('img_rpn_gt',img_gt)
                imshow('img_anchor_label',img_label)
                #imshow('img_rpn_target',img_target)

                img_label  = draw_rcnn_labels (top_image, batch_top_rois, batch_fuse_labels )
                img_target = draw_rcnn_targets(top_image, batch_top_rois, batch_fuse_labels, batch_fuse_targets)
                #imshow('img_rcnn_label',img_label)
                if vis :
                    imshow('img_rcnn_target',img_target)


                img_rgb_rois = draw_boxes(rgb, batch_rgb_rois[:,1:5], color=(255,0,255), thickness=1)
                if vis :
                    imshow('img_rgb_rois',img_rgb_rois)
                    cv2.waitKey(1)

            ## run classification and regression loss -----------
            fd2={
				**fd1,

                top_images: batch_top_images,
                front_images: batch_front_images,
                rgb_images: batch_rgb_images,

				top_rois:   batch_top_rois,
                front_rois: batch_front_rois,
                rgb_rois:   batch_rgb_rois,

                top_inds:     batch_top_inds,
                top_pos_inds: batch_top_pos_inds,
                top_labels:   batch_top_labels,
                top_targets:  batch_top_targets,

                fuse_labels:  batch_fuse_labels,
                fuse_targets: batch_fuse_targets,
            }
            #_, batch_top_cls_loss, batch_top_reg_loss = sess.run([solver_step, top_cls_loss, top_reg_loss],fd2)


            _, batch_top_cls_loss, batch_top_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss = \
               sess.run([solver_step, top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss],fd2)

            speed=time.time()-start_time
            log.write('%5.1f   %5d    %0.4fs   %0.4f   |   %0.5f   %0.5f   |   %0.5f   %0.5f  \n' %\
				(epoch, iter, speed, rate, batch_top_cls_loss, batch_top_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss))



            #print('ok')
            # debug: ------------------------------------

            if vis and iter%iter_debug==0:
                top_image = top_imgs[idx]
                rgb       = rgbs[idx]

                batch_top_probs, batch_top_scores, batch_top_deltas  = \
                    sess.run([ top_probs, top_scores, top_deltas ],fd2)

                batch_fuse_probs, batch_fuse_deltas = \
                    sess.run([ fuse_probs, fuse_deltas ],fd2)

                #batch_fuse_deltas=0*batch_fuse_deltas #disable 3d box prediction
                probs, boxes3d = rcnn_nms(batch_fuse_probs, batch_fuse_deltas, batch_rois3d, threshold=0.05)


                ## show rpn score maps
                p = batch_top_probs.reshape( *(top_feature_shape[0:2]), 2*num_bases)
                for n in range(num_bases):
                    r=n%num_scales
                    s=n//num_scales
                    pn = p[:,:,2*n+1]*255
                    axs[s,r].cla()
                    if vis :
                        axs[s,r].imshow(pn, cmap='gray', vmin=0, vmax=255)
                        plt.pause(0.01)

				## show rpn(top) nms
                img_rpn     = draw_rpn    (top_image, batch_top_probs, batch_top_deltas, anchors, inside_inds)
                img_rpn_nms = draw_rpn_nms(img_gt, batch_proposals, batch_proposal_scores)
                #imshow('img_rpn',img_rpn)
                if vis :
                    imshow('img_rpn_nms',img_rpn_nms)
                    cv2.waitKey(1)

                ## show rcnn(fuse) nms
                img_rcnn     = draw_rcnn (top_image, batch_fuse_probs, batch_fuse_deltas, batch_top_rois, batch_rois3d,darker=1)
                img_rcnn_nms = draw_rcnn_nms(rgb, boxes3d, probs)
                if vis :
                    imshow('img_rcnn',img_rcnn)
                    imshow('img_rcnn_nms',img_rcnn_nms)
                    cv2.waitKey(0)
            if (iter)%10==0:
                summary = sess.run(merged,fd2)
                train_writer.add_summary(summary, iter)
            # save: ------------------------------------
            if (iter)%2000==0 and (iter!=0):
                #saver.save(sess, out_dir + '/check_points/%06d.ckpt'%iter)  #iter
                saver.save(sess, out_dir + '/check_points/snap_ResNet_vgg_NGT_%06d.ckpt'%iter)  #iter
                # saver.save(sess, out_dir + '/check_points/MobileNet.ckpt')  #iter
                # pdb.set_trace()
                pass

            idx=idx+1