Beispiel #1
0
    def log_rpn(self,step=None, scope_name=''):

        top_image = self.top_image
        subdir = self.log_subdir
        top_inds = self.batch_top_inds
        top_labels = self.batch_top_labels
        top_pos_inds = self.batch_top_pos_inds
        top_targets = self.batch_top_targets
        proposals = self.batch_proposals
        proposal_scores = self.batch_proposal_scores
        gt_top_boxes = self.batch_gt_top_boxes
        gt_labels = self.batch_gt_labels

        if gt_top_boxes is not None:
            img_gt = draw_rpn_gt(top_image, gt_top_boxes, gt_labels)
            # nud.imsave('img_rpn_gt', img_gt, subdir)
            self.summary_image(img_gt, scope_name + '/img_rpn_gt', step=step)

        if top_inds is not None:
            img_label = draw_rpn_labels(top_image, self.top_view_anchors, top_inds, top_labels)
            # nud.imsave('img_rpn_label', img_label, subdir)
            self.summary_image(img_label, scope_name+ '/img_rpn_label', step=step)

        if top_pos_inds is not None:
            img_target = draw_rpn_targets(top_image, self.top_view_anchors, top_pos_inds, top_targets)
            # nud.imsave('img_rpn_target', img_target, subdir)
            self.summary_image(img_target, scope_name+ '/img_rpn_target', step=step)

        if proposals is not None:
            rpn_proposal = draw_rpn_proposal(top_image, proposals, proposal_scores, draw_num=20)
            # nud.imsave('img_rpn_proposal', rpn_proposal, subdir)
            self.summary_image(rpn_proposal, scope_name + '/img_rpn_proposal',step=step)
Beispiel #2
0
    def log_rpn(self,step=None, scope_name=''):

        top_image = self.top_image
        subdir = self.log_subdir
        top_inds = self.batch_top_inds
        top_labels = self.batch_top_labels
        top_pos_inds = self.batch_top_pos_inds
        top_targets = self.batch_top_targets
        proposals = self.batch_proposals
        proposal_scores = self.batch_proposal_scores
        gt_top_boxes = self.batch_gt_top_boxes
        gt_labels = self.batch_gt_labels

        if gt_top_boxes is not None:
            img_gt = draw_rpn_gt(top_image, gt_top_boxes, gt_labels)
            # nud.imsave('img_rpn_gt', img_gt, subdir)
            self.summary_image(img_gt, scope_name + '/img_rpn_gt', step=step)

        if top_inds is not None:
            img_label = draw_rpn_labels(top_image, self.top_view_anchors, top_inds, top_labels)
            # nud.imsave('img_rpn_label', img_label, subdir)
            self.summary_image(img_label, scope_name+ '/img_rpn_label', step=step)

        if top_pos_inds is not None:
            img_target = draw_rpn_targets(top_image, self.top_view_anchors, top_pos_inds, top_targets)
            # nud.imsave('img_rpn_target', img_target, subdir)
            self.summary_image(img_target, scope_name+ '/img_rpn_target', step=step)

        if proposals is not None:
            rpn_proposal = draw_rpn_proposal(top_image, proposals, proposal_scores, draw_num=20)
            # nud.imsave('img_rpn_proposal', rpn_proposal, subdir)
            self.summary_image(rpn_proposal, scope_name + '/img_rpn_proposal',step=step)
Beispiel #3
0
    def log_rpn(self,
                step=None,
                scope_name='',
                loss=None,
                tensor_board=True,
                draw_rpn_target=False):

        top_image = self.top_image
        subdir = self.log_subdir
        top_inds = self.batch_top_inds
        top_labels = self.batch_top_labels
        top_pos_inds = self.batch_top_pos_inds
        top_targets = self.batch_top_targets
        proposals = self.batch_proposals
        proposal_scores = self.batch_proposal_scores
        gt_top_boxes = self.batch_gt_top_boxes
        gt_labels = self.batch_gt_labels

        total_img = None
        if gt_top_boxes is not None:
            total_img = draw_rpn_gt(top_image, gt_top_boxes, gt_labels)
        # nud.imsave('img_rpn_gt', img_gt, subdir)

        if draw_rpn_target:
            img_label = draw_rpn_labels(top_image, self.top_view_anchors,
                                        top_inds, top_labels)
            # nud.imsave('img_rpn_label', img_label, subdir)
            total_img = np.concatenate(
                (total_img,
                 img_label), 1) if total_img is not None else img_label
            img_target = draw_rpn_targets(top_image, self.top_view_anchors,
                                          top_pos_inds, top_targets)
            # nud.imsave('img_rpn_target', img_target, subdir)
            total_img = np.concatenate((total_img, img_target), 1)

        if proposals is not None:
            rpn_proposal = draw_rpn_proposal(top_image, proposals,
                                             proposal_scores)
            if loss != None:
                text = 'loss c: %6f r: %6f' % loss
                font = cv2.FONT_HERSHEY_SIMPLEX
                text_pos = (0, 25)
                cv2.putText(rpn_proposal, text, text_pos, font, 0.5,
                            (5, 255, 100), 0, cv2.LINE_AA)
            if total_img is not None:
                total_img = np.concatenate((total_img, rpn_proposal), 1)
            else:
                total_img = rpn_proposal
            # print('\nproposal_scores= {}\n'.format(proposal_scores))
            # nud.imsave('img_rpn_proposal', rpn_proposal, subdir)
            if tensor_board:
                self.summary_image(total_img,
                                   scope_name + '/top_view',
                                   step=step)
        return total_img
Beispiel #4
0
    def log_rpn(self, step=None, scope_name='', loss=None, tensor_board=True, draw_rpn_target=False):

        top_image = self.top_image
        subdir = self.log_subdir
        top_inds = self.batch_top_inds
        top_labels = self.batch_top_labels
        top_pos_inds = self.batch_top_pos_inds
        top_targets = self.batch_top_targets
        proposals = self.batch_proposals
        proposal_scores = self.batch_proposal_scores
        gt_top_boxes = self.batch_gt_top_boxes
        gt_labels = self.batch_gt_labels
        
        total_img = None
        if gt_top_boxes is not None:
            total_img = draw_rpn_gt(top_image, gt_top_boxes, gt_labels)
        # nud.imsave('img_rpn_gt', img_gt, subdir)

        if draw_rpn_target:
            img_label = draw_rpn_labels(top_image, self.top_view_anchors, top_inds, top_labels)
            # nud.imsave('img_rpn_label', img_label, subdir)
            total_img = np.concatenate((total_img, img_label), 1) if total_img is not None else img_label
            img_target = draw_rpn_targets(top_image, self.top_view_anchors, top_pos_inds, top_targets)
            # nud.imsave('img_rpn_target', img_target, subdir)
            total_img = np.concatenate((total_img, img_target), 1)

        if proposals is not None:
            rpn_proposal = draw_rpn_proposal(top_image, proposals, proposal_scores)
            if loss != None:
                text = 'loss c: %6f r: %6f' % loss
                font = cv2.FONT_HERSHEY_SIMPLEX
                text_pos = (0, 25)
                cv2.putText(rpn_proposal, text, text_pos, font, 0.5, (5, 255, 100), 0, cv2.LINE_AA)
            if total_img is not None:
                total_img = np.concatenate((total_img, rpn_proposal), 1)
            else:
                total_img = rpn_proposal
            # print('\nproposal_scores= {}\n'.format(proposal_scores))
            # nud.imsave('img_rpn_proposal', rpn_proposal, subdir)
            if tensor_board: self.summary_image(total_img, scope_name + '/top_view', step=step)
        return total_img
Beispiel #5
0
def run_train():

    # output dir, etc
    out_dir = '/home/dongwoo/Project/MV3D/data/out'
    makedirs(out_dir + '/tf')
    makedirs(out_dir + '/check_points')
    log = Logger(out_dir + '/log.txt', mode='a')

    #lidar data -----------------
    if 1:
        ratios = np.array([0.5, 1, 2], dtype=np.float32)
        scales = np.array([1, 2, 3], dtype=np.float32)
        bases = make_bases(base_size=16, ratios=ratios, scales=scales)
        num_bases = len(bases)
        stride = 8

        rgbs, tops, fronts, gt_labels, gt_boxes3d, top_imgs, front_imgs, lidars = load_dummy_datas(
        )
        num_frames = len(rgbs)

        top_shape = tops[0].shape
        front_shape = fronts[0].shape
        rgb_shape = rgbs[0].shape
        top_feature_shape = (top_shape[0] // stride, top_shape[1] // stride)
        out_shape = (8, 3)

        #-----------------------
        #check data
        if 0:
            fig = mlab.figure(figure=None,
                              bgcolor=(0, 0, 0),
                              fgcolor=None,
                              engine=None,
                              size=(1000, 500))
            draw_lidar(lidars[0], fig=fig)
            draw_gt_boxes3d(gt_boxes3d[0], fig=fig)
            mlab.show(1)
            cv2.waitKey(1)

    # set anchor boxes

    num_class = 2  #incude background
    anchors, inside_inds = make_anchors(bases, stride, top_shape[0:2],
                                        top_feature_shape[0:2])
    inside_inds = np.arange(0, len(anchors), dtype=np.int32)  #use all  #<todo>
    print('out_shape=%s' % str(out_shape))
    print('num_frames=%d' % num_frames)

    #load model ####################################################################################################
    top_anchors = tf.placeholder(shape=[None, 4],
                                 dtype=tf.int32,
                                 name='anchors')
    top_inside_inds = tf.placeholder(shape=[None],
                                     dtype=tf.int32,
                                     name='inside_inds')

    top_images = tf.placeholder(shape=[None, *top_shape],
                                dtype=tf.float32,
                                name='top')
    front_images = tf.placeholder(shape=[None, *front_shape],
                                  dtype=tf.float32,
                                  name='front')
    rgb_images = tf.placeholder(shape=[None, *rgb_shape],
                                dtype=tf.float32,
                                name='rgb')
    top_rois = tf.placeholder(shape=[None, 5],
                              dtype=tf.float32,
                              name='top_rois')  #<todo> change to int32???
    front_rois = tf.placeholder(shape=[None, 5],
                                dtype=tf.float32,
                                name='front_rois')
    rgb_rois = tf.placeholder(shape=[None, 5],
                              dtype=tf.float32,
                              name='rgb_rois')

    top_features, top_scores, top_probs, top_deltas, proposals, proposal_scores = \
        top_feature_net(top_images, top_anchors, top_inside_inds, num_bases)

    front_features = front_feature_net(front_images)
    rgb_features = rgb_feature_net(rgb_images)

    fuse_scores, fuse_probs, fuse_deltas = \
        fusion_net(
    ( [top_features,     top_rois,     6,6,1./stride],
     [front_features,   front_rois,   0,0,1./stride],  #disable by 0,0
     [rgb_features,     rgb_rois,     6,6,1./stride],),
            num_class, out_shape) #<todo>  add non max suppression

    #loss ########################################################################################################
    top_inds = tf.placeholder(shape=[None], dtype=tf.int32, name='top_ind')
    top_pos_inds = tf.placeholder(shape=[None],
                                  dtype=tf.int32,
                                  name='top_pos_ind')
    top_labels = tf.placeholder(shape=[None], dtype=tf.int32, name='top_label')
    top_targets = tf.placeholder(shape=[None, 4],
                                 dtype=tf.float32,
                                 name='top_target')
    top_cls_loss, top_reg_loss = rpn_loss(top_scores, top_deltas, top_inds,
                                          top_pos_inds, top_labels,
                                          top_targets)

    fuse_labels = tf.placeholder(shape=[None],
                                 dtype=tf.int32,
                                 name='fuse_label')
    fuse_targets = tf.placeholder(shape=[None, *out_shape],
                                  dtype=tf.float32,
                                  name='fuse_target')
    fuse_cls_loss, fuse_reg_loss = rcnn_loss(fuse_scores, fuse_deltas,
                                             fuse_labels, fuse_targets)

    #solver
    l2 = l2_regulariser(decay=0.001)
    learning_rate = tf.placeholder(tf.float32, shape=[])
    #solver = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
    solver = tf.train.GradientDescentOptimizer(learning_rate=learning_rate,
                                               use_locking=False,
                                               name='GradientDescent')
    #solver_step = solver.minimize(top_cls_loss+top_reg_loss+l2)
    solver_step = solver.minimize(top_cls_loss + top_reg_loss + fuse_cls_loss +
                                  0.1 * fuse_reg_loss + l2)

    max_iter = 10000
    iter_debug = 8

    # start training here  #########################################################################################
    log.write(
        'epoch     iter    rate   |  top_cls_loss   reg_loss   |  fuse_cls_loss  reg_loss  |  \n'
    )
    log.write(
        '-------------------------------------------------------------------------------------\n'
    )

    num_ratios = len(ratios)
    num_scales = len(scales)
    fig, axs = plt.subplots(num_ratios, num_scales)

    sess = tf.InteractiveSession()
    with sess.as_default():
        sess.run(tf.global_variables_initializer(), {IS_TRAIN_PHASE: True})
        summary_writer = tf.summary.FileWriter(out_dir + '/tf', sess.graph)
        saver = tf.train.Saver()

        batch_top_cls_loss = 0
        batch_top_reg_loss = 0
        batch_fuse_cls_loss = 0
        batch_fuse_reg_loss = 0
        iter = 0
        while iter < max_iter:
            #for iter in range(max_iter):
            epoch = 1.0 * iter
            rate = 0.05

            ## generate train image -------------
            idx = np.random.choice(num_frames)  #*10   #num_frames)  #0
            #print (idx)
            batch_top_images = tops[idx].reshape(1, *top_shape)
            batch_front_images = fronts[idx].reshape(1, *front_shape)
            batch_rgb_images = rgbs[idx].reshape(1, *rgb_shape)

            batch_gt_labels = gt_labels[idx]
            batch_gt_boxes3d = gt_boxes3d[idx]
            batch_gt_top_boxes = box3d_to_top_box(batch_gt_boxes3d)

            if len(batch_gt_labels) == 0:
                continue

## run propsal generation ------------
            fd1 = {
                top_images: batch_top_images,
                top_anchors: anchors,
                top_inside_inds: inside_inds,
                learning_rate: rate,
                IS_TRAIN_PHASE: True
            }
            batch_proposals, batch_proposal_scores, batch_top_features = sess.run(
                [proposals, proposal_scores, top_features], fd1)

            ## generate  train rois  ------------
            #print (anchors)
            #print (inside_inds)
            #print (batch_gt_labels)
            #print (batch_gt_top_boxes)
            batch_top_inds, batch_top_pos_inds, batch_top_labels, batch_top_targets  = \
                rpn_target ( anchors, inside_inds, batch_gt_labels,  batch_gt_top_boxes)

            batch_top_rois, batch_fuse_labels, batch_fuse_targets  = \
                 rcnn_target(  batch_proposals, batch_gt_labels, batch_gt_top_boxes, batch_gt_boxes3d )

            batch_rois3d = project_to_roi3d(batch_top_rois)
            batch_front_rois = project_to_front_roi(batch_rois3d)
            batch_rgb_rois = project_to_rgb_roi(batch_rois3d)

            ##debug gt generation
            if 1 and iter % iter_debug == 1:
                top_image = top_imgs[idx]
                rgb = rgbs[idx]

                img_gt = draw_rpn_gt(top_image, batch_gt_top_boxes,
                                     batch_gt_labels)
                img_label = draw_rpn_labels(top_image, anchors, batch_top_inds,
                                            batch_top_labels)
                img_target = draw_rpn_targets(top_image, anchors,
                                              batch_top_pos_inds,
                                              batch_top_targets)
                #imshow('img_rpn_gt',img_gt)
                #imshow('img_rpn_label',img_label)
                #imshow('img_rpn_target',img_target)

                img_label = draw_rcnn_labels(top_image, batch_top_rois,
                                             batch_fuse_labels)
                img_target = draw_rcnn_targets(top_image, batch_top_rois,
                                               batch_fuse_labels,
                                               batch_fuse_targets)
                #imshow('img_rcnn_label',img_label)
                #imshow('img_rcnn_target',img_target)

                img_rgb_rois = draw_boxes(rgb,
                                          batch_rgb_rois[:, 1:5],
                                          color=(255, 0, 255),
                                          thickness=1)
                #imshow('img_rgb_rois',img_rgb_rois)

                #cv2.waitKey(1)

            ## run classification and regression loss -----------
            fd2 = {
                **fd1,
                top_images: batch_top_images,
                front_images: batch_front_images,
                rgb_images: batch_rgb_images,
                top_rois: batch_top_rois,
                front_rois: batch_front_rois,
                rgb_rois: batch_rgb_rois,
                top_inds: batch_top_inds,
                top_pos_inds: batch_top_pos_inds,
                top_labels: batch_top_labels,
                top_targets: batch_top_targets,
                fuse_labels: batch_fuse_labels,
                fuse_targets: batch_fuse_targets,
            }
            #_, batch_top_cls_loss, batch_top_reg_loss = sess.run([solver_step, top_cls_loss, top_reg_loss],fd2)


            _, batch_top_cls_loss, batch_top_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss = \
               sess.run([solver_step, top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss],fd2)

            log.write('%3.1f   %d   %0.4f   |   %0.5f   %0.5f   |   %0.5f   %0.5f  \n' %\
    (epoch, iter, rate, batch_top_cls_loss, batch_top_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss))

            # debug: ------------------------------------

            if iter % iter_debug == 0:
                top_image = top_imgs[idx]
                rgb = rgbs[idx]

                batch_top_probs, batch_top_scores, batch_top_deltas  = \
                    sess.run([ top_probs, top_scores, top_deltas ],fd2)

                batch_fuse_probs, batch_fuse_deltas = \
                    sess.run([ fuse_probs, fuse_deltas ],fd2)

                #batch_fuse_deltas=0*batch_fuse_deltas #disable 3d box prediction
                probs, boxes3d = rcnn_nms(batch_fuse_probs,
                                          batch_fuse_deltas,
                                          batch_rois3d,
                                          threshold=0.5)

                ## show rpn score maps
                p = batch_top_probs.reshape(*(top_feature_shape[0:2]),
                                            2 * num_bases)
                for n in range(num_bases):
                    r = n % num_scales
                    s = n // num_scales
                    pn = p[:, :, 2 * n + 1] * 255
                    axs[s, r].cla()
                    axs[s, r].imshow(pn, cmap='gray', vmin=0, vmax=255)
                plt.pause(0.01)

                ## show rpn(top) nms
                #img_rpn     = draw_rpn    (top_image, batch_top_probs, batch_top_deltas, anchors, inside_inds)
                img_rpn_nms = draw_rpn_nms(top_image, batch_proposals,
                                           batch_proposal_scores)
                #imshow('img_rpn',img_rpn)
                #imshow('img_rpn_nms',img_rpn_nms)
                #cv2.waitKey(1)

                ## show rcnn(fuse) nms
                #img_rcnn     = draw_rcnn (top_image, batch_fuse_probs, batch_fuse_deltas, batch_top_rois, batch_rois3d,darker=1)
                #img_rcnn_nms = draw_rcnn_nms(rgb, boxes3d, probs)
                #imshow('img_rcnn',img_rcnn)
                #imshow('img_rcnn_nms',img_rcnn_nms)
                #cv2.waitKey(1)

            # save: ------------------------------------
            if iter % 500 == 0:
                #saver.save(sess, out_dir + '/check_points/%06d.ckpt'%iter)  #iter
                saver.save(sess, out_dir + '/check_points/snap.ckpt')  #iter

            iter = iter + 1
Beispiel #6
0
def run_test():

    # output dir, etc
    out_dir = './outputs'
    makedirs(out_dir + '/tf')
    makedirs(out_dir + '/check_points')
    log = Logger(out_dir + '/log_%s.txt' %
                 (time.strftime('%Y-%m-%d %H:%M:%S')),
                 mode='a')

    #lidar data -----------------
    if 1:
        # ratios=np.array([0.5,1,2], dtype=np.float32)
        # scales=np.array([1,2,3,4,5,6],   dtype=np.float32)
        # bases = make_bases(
        #     base_size = 16,
        #     ratios=ratios,
        #     scales=scales
        # )
        ratios = np.array([1.7, 2.4])
        scales = np.array([1.7, 2.4])
        bases = np.array([[-19.5, -8, 19.5, 8], [-8, -19.5, 8, 19.5],
                          [-5, -3, 5, 3], [-3, -5, 3, 5]])
        num_bases = len(bases)
        stride = 4

        rgbs, tops, fronts, gt_labels, gt_boxes3d, top_imgs, front_imgs, lidars, rgbs_norm0 = load_dummy_datas(
        )
        num_frames = len(rgbs)

        top_shape = tops[0].shape
        front_shape = fronts[0].shape
        rgb_shape = rgbs[0].shape
        top_feature_shape = ((top_shape[0] - 1) // stride + 1,
                             (top_shape[1] - 1) // stride + 1)
        out_shape = (8, 3)

        #-----------------------
        #check data
        if 0:
            fig = mlab.figure(figure=None,
                              bgcolor=(0, 0, 0),
                              fgcolor=None,
                              engine=None,
                              size=(1000, 500))
            draw_lidar(lidars[0], fig=fig)
            draw_gt_boxes3d(gt_boxes3d[0], fig=fig)
            mlab.show(1)
            cv2.waitKey(0)

    # set anchor boxes
    num_class = 2  #incude background
    anchors, inside_inds = make_anchors(bases, stride, top_shape[0:2],
                                        top_feature_shape[0:2])
    # inside_inds = np.arange(0,len(anchors),dtype=np.int32)  #use all  #<todo>
    print('out_shape=%s' % str(out_shape))
    print('num_frames=%d' % num_frames)

    #load model ####################################################################################################
    top_anchors = tf.placeholder(shape=[None, 4],
                                 dtype=tf.int32,
                                 name='anchors')
    top_inside_inds = tf.placeholder(shape=[None],
                                     dtype=tf.int32,
                                     name='inside_inds')

    top_images = tf.placeholder(shape=[None, *top_shape],
                                dtype=tf.float32,
                                name='top')
    front_images = tf.placeholder(shape=[None, *front_shape],
                                  dtype=tf.float32,
                                  name='front')
    rgb_images = tf.placeholder(shape=[None, None, None, 3],
                                dtype=tf.float32,
                                name='rgb')
    top_rois = tf.placeholder(shape=[None, 5],
                              dtype=tf.float32,
                              name='top_rois')  #<todo> change to int32???
    front_rois = tf.placeholder(shape=[None, 5],
                                dtype=tf.float32,
                                name='front_rois')
    rgb_rois = tf.placeholder(shape=[None, 5],
                              dtype=tf.float32,
                              name='rgb_rois')

    top_features, top_scores, top_probs, top_deltas, proposals, proposal_scores = \
        top_feature_net(top_images, top_anchors, top_inside_inds, num_bases)

    front_features = front_feature_net(front_images)
    rgb_features = rgb_feature_net(rgb_images)

    fuse_scores, fuse_probs, fuse_deltas = \
        fusion_net(
            ( [top_features,     top_rois,     6,6,1./stride],
              [front_features,   front_rois,   0,0,1./stride],  #disable by 0,0
              [rgb_features,     rgb_rois,     6,6,1./(2*stride)],),
            num_class, out_shape) #<todo>  add non max suppression

    num_ratios = len(ratios)
    num_scales = len(scales)
    fig, axs = plt.subplots(num_ratios, num_scales)
    mfig = mlab.figure(figure=None,
                       bgcolor=(0, 0, 0),
                       fgcolor=None,
                       engine=None,
                       size=(500, 500))

    sess = tf.InteractiveSession()
    with sess.as_default():
        sess.run(tf.global_variables_initializer(), {IS_TRAIN_PHASE: True})
        # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        summary_writer = tf.summary.FileWriter(out_dir + '/tf', sess.graph)
        saver = tf.train.Saver()

        saver.restore(
            sess, './outputs/check_points/snap_ResNet_vgg_up_NGT_060000.ckpt')

        batch_top_cls_loss = 0
        batch_top_reg_loss = 0
        batch_fuse_cls_loss = 0
        batch_fuse_reg_loss = 0

        for iter in range(num_frames):
            # epoch=1.0*iter
            # rate=0.001
            # start_time=time.time()

            ## generate train image -------------
            # idx = np.random.choice(num_frames)     #*10   #num_frames)  #0
            frame_range = np.arange(num_frames)
            idx = frame_range[iter % num_frames]  #*10   #num_frames)  #0
            rgb_shape = rgbs[idx].shape
            # top_img=top_imgs[idx]

            batch_top_images = tops[idx].reshape(1, *top_shape)
            batch_front_images = fronts[idx].reshape(1, *front_shape)
            batch_rgb_images = rgbs_norm0[idx].reshape(1, *rgb_shape)

            batch_gt_labels = gt_labels[idx]
            batch_gt_boxes3d = gt_boxes3d[idx]
            # pdb.set_trace()
            batch_gt_top_boxes = box3d_to_top_box(batch_gt_boxes3d)

            inside_inds_filtered = anchor_filter(batch_top_images[0, :, :, -1],
                                                 anchors, inside_inds)

            ## run propsal generation ------------
            fd1 = {
                top_images: batch_top_images,
                top_anchors: anchors,
                top_inside_inds: inside_inds_filtered,
                IS_TRAIN_PHASE: True
            }
            batch_proposals, batch_proposal_scores, batch_top_features = sess.run(
                [proposals, proposal_scores, top_features], fd1)
            print(batch_proposal_scores[:10])
            ## generate  train rois  ------------
            batch_top_rois = batch_proposals
            # pdb.set_trace()
            batch_rois3d = project_to_roi3d(batch_top_rois)
            batch_front_rois = project_to_front_roi(batch_rois3d)
            batch_rgb_rois = project_to_rgb_roi(batch_rois3d)
            # pdb.set_trace()
            keep = np.where((batch_rgb_rois[:, 1] >= -200)
                            & (batch_rgb_rois[:, 2] >= -200)
                            & (batch_rgb_rois[:, 3] <= (rgb_shape[1] + 200)) &
                            (batch_rgb_rois[:, 4] <= (rgb_shape[0] + 200)))[0]
            batch_rois3d = batch_rois3d[keep]
            batch_front_rois = batch_front_rois[keep]
            batch_rgb_rois = batch_rgb_rois[keep]
            batch_proposal_scores = batch_proposal_scores[keep]
            batch_top_rois = batch_top_rois[keep]

            ## run classification and regression  -----------

            fd2 = {
                **fd1,
                top_images: batch_top_images,
                front_images: batch_front_images,
                rgb_images: batch_rgb_images,
                top_rois: batch_top_rois,
                front_rois: batch_front_rois,
                rgb_rois: batch_rgb_rois,
            }
            batch_top_probs, batch_top_deltas = sess.run(
                [top_probs, top_deltas], fd2)
            batch_fuse_probs, batch_fuse_deltas = sess.run(
                [fuse_probs, fuse_deltas], fd2)
            # pdb.set_trace()

            probs, boxes3d = rcnn_nms(batch_fuse_probs,
                                      batch_fuse_deltas,
                                      batch_rois3d,
                                      threshold=0.05)

            # pdb.set_trace()
            # debug: ------------------------------------
            if is_show == 1:
                top_image = top_imgs[idx]
                surround_image = fronts[idx]
                lidar = lidars[idx]
                rgb = rgbs[idx]

                batch_top_probs, batch_top_scores, batch_top_deltas  = \
                    sess.run([ top_probs, top_scores, top_deltas ],fd2)
                batch_fuse_probs, batch_fuse_deltas = \
                    sess.run([ fuse_probs, fuse_deltas ],fd2)
                ## show on lidar
                mlab.clf(mfig)
                # draw_didi_lidar(mfig, lidar, is_grid=1, is_axis=1)
                draw_lidar(lidar, fig=mfig)
                if len(boxes3d) != 0:
                    # draw_didi_boxes3d(mfig, boxes3d)
                    draw_target_boxes3d(boxes3d, fig=mfig)
                    draw_gt_boxes3d(batch_gt_boxes3d, fig=mfig)
                # azimuth,elevation,distance,focalpoint = MM_PER_VIEW1
                # mlab.view(azimuth,elevation,distance,focalpoint)
                mlab.show(1)
                # cv2.waitKey(0)
                # mlab.close()

                ## show rpn score maps
                p = batch_top_probs.reshape(*(top_feature_shape[0:2]),
                                            2 * num_bases)
                # for n in range(num_bases):

                #     pn = p[:,:,2*n+1]*255
                #     if num_scales==1 or num_ratios==1:
                #         axs[n].cla()
                #         axs[n].imshow(pn, cmap='gray', vmin=0, vmax=255)
                #     else:
                #         r=n%num_scales
                #         s=n//num_scales
                #         axs[r,s].cla()
                #         axs[r,s].imshow(pn, cmap='gray', vmin=0, vmax=255)
                plt.pause(0.01)
                # pdb.set_trace()
                img_gt = draw_rpn_gt(top_image, batch_gt_top_boxes,
                                     batch_gt_labels)
                img_rpn_nms = draw_rpn_nms(img_gt, batch_proposals,
                                           batch_proposal_scores)
                imshow('img_rpn_nms', img_rpn_nms)
                cv2.waitKey(1)
                # imshow('img_rpn_gt',img_gt)

                rgb1 = draw_rcnn_nms(rgb, boxes3d, probs)
                # projections=box3d_to_rgb_projections(batch_gt_boxes3d)
                # img_rcnn_nms = draw_rgb_projections(rgb1, projections, color=(0,0,255), thickness=1)

                # pdb.set_trace()
                # rgb_boxes=project_to_rgb_roi(boxes3d)
                rgb_boxes = batch_rgb_rois
                img_rgb_2d_detection = draw_boxes(rgb,
                                                  rgb_boxes[:, 1:5],
                                                  color=(255, 0, 255),
                                                  thickness=1)

                imshow('draw_rcnn_nms', rgb1)
                # imshow('img_rgb_2d_detection',img_rgb_2d_detection)
                cv2.waitKey(0)
Beispiel #7
0
def run_train():

    # output dir, etc
    out_dir = '/root/sharefolder/sdcnd/didi1/output'
    # makedirs(out_dir +'/tf')
    # makedirs(out_dir +'/check_points')
    log = Logger(out_dir+'/log.txt',mode='a')
    # log.write(unicode('aaa {}'.format('aaa')))
    #lidar data -----------------
    if 1:
        ratios=np.array([0.5,1,2], dtype=np.float32)
        scales=np.array([1,2,3],   dtype=np.float32)
        bases = make_bases(
            base_size = 16,
            ratios=ratios,
            scales=scales
        )
        num_bases = len(bases)
        stride = 8

        num_frames = 154
        # num_frames = 2
        rgbs, tops, fronts, gt_labels, gt_boxes3d, top_imgs, front_imgs, lidars = load_dummy_datas(num_frames)
        num_frames = len(rgbs)

        top_shape   = tops[0].shape
        front_shape = fronts[0].shape
        rgb_shape   = rgbs[0].shape
        top_feature_shape = (top_shape[0]//stride, top_shape[1]//stride)
        out_shape=(8,3)


        #-----------------------
        #check data
        if 0:
            fig = mlab.figure(figure=None, bgcolor=(0,0,0), fgcolor=None, engine=None, size=(1000, 500))
            draw_lidar(lidars[0], fig=fig)
            draw_gt_boxes3d(gt_boxes3d[0], fig=fig)
            mlab.show(1)
            cv2.waitKey(1)




    # set anchor boxes
    num_class = 2 #incude background
    anchors, inside_inds =  make_anchors(bases, stride, top_shape[0:2], top_feature_shape[0:2])
    inside_inds = np.arange(0,len(anchors),dtype=np.int32)  #use all  #<todo>
    print ('out_shape=%s'%str(out_shape))
    print ('num_frames=%d'%num_frames)


    #load model ####################################################################################################
    top_anchors     = tf.placeholder(shape=[None, 4], dtype=tf.int32,   name ='anchors'    )
    top_inside_inds = tf.placeholder(shape=[None   ], dtype=tf.int32,   name ='inside_inds')

    top_images   = tf.placeholder(shape=[None, 400, 400, 8 ], dtype=tf.float32, name='input_top'  )
    front_images = tf.placeholder(shape=[None, 1, 1], dtype=tf.float32, name='front')
    rgb_images   = tf.placeholder(shape=[None, 375, 1242, 3  ], dtype=tf.float32, name='rgb'  )
    top_rois     = tf.placeholder(shape=[None, 5], dtype=tf.float32,   name ='top_rois'   ) #<todo> change to int32???
    front_rois   = tf.placeholder(shape=[None, 5], dtype=tf.float32,   name ='front_rois' )
    rgb_rois     = tf.placeholder(shape=[None, 5], dtype=tf.float32,   name ='rgb_rois'   )

    top_features, top_scores, top_probs, top_deltas, proposals, proposal_scores = \
        top_feature_net(top_images, top_anchors, top_inside_inds, num_bases)

    front_features = front_feature_net(front_images)
    rgb_features   = rgb_feature_net(rgb_images)

    # import pdb; pdb.set_trace()
    fuse_scores, fuse_probs, fuse_deltas, aux_fuse_scores, aux_fuse_probs, aux_fuse_deltas = \
        fusion_net(
			( [top_features,     top_rois,     6,6,1./stride],
			  [front_features,   front_rois,   0,0,1./stride],  #disable by 0,0
			  [rgb_features,     rgb_rois,     6,6,1./stride],),
            num_class, out_shape) #<todo>  add non max suppression

    # import pdb; pdb.set_trace()


    #loss ########################################################################################################
    top_inds     = tf.placeholder(shape=[None   ], dtype=tf.int32,   name='top_ind'    )
    top_pos_inds = tf.placeholder(shape=[None   ], dtype=tf.int32,   name='top_pos_ind')
    top_labels   = tf.placeholder(shape=[None   ], dtype=tf.int32,   name='top_label'  )
    top_targets  = tf.placeholder(shape=[None, 4], dtype=tf.float32, name='top_target' )
    with tf.variable_scope('rpn-loss') as scope:
        top_cls_loss, top_reg_loss = rpn_loss(top_scores, top_deltas, top_inds, top_pos_inds, top_labels, top_targets)
    tf.summary.scalar('top_cls_loss', top_cls_loss)
    tf.summary.scalar('top_reg_loss', top_reg_loss)


    fuse_labels  = tf.placeholder(shape=[None            ], dtype=tf.int32,   name='fuse_label' )
    fuse_targets = tf.placeholder(shape=[None, 8, 3], dtype=tf.float32, name='fuse_target')
    with tf.variable_scope('rcnn-loss') as scope:
        fuse_cls_loss, fuse_reg_loss = rcnn_loss(fuse_scores, fuse_deltas, fuse_labels, fuse_targets)
    tf.summary.scalar('fuse_cls_loss', fuse_cls_loss)
    tf.summary.scalar('fuse_reg_loss', fuse_reg_loss)

    with tf.variable_scope('aux_rcnn_loss') as scope:
        with tf.variable_scope('aux_loss_1') as scope:
            aux_fuse_cls_loss_1, aux_fuse_reg_loss_1 = rcnn_loss(aux_fuse_scores[0], aux_fuse_deltas[0],
             fuse_labels, fuse_targets)
        tf.summary.scalar('aux_fuse_cls_loss_1', aux_fuse_cls_loss_1)
        tf.summary.scalar('aux_fuse_reg_loss_1', aux_fuse_reg_loss_1)
        with tf.variable_scope('aux_loss_2') as scope:
            aux_fuse_cls_loss_2, aux_fuse_reg_loss_2 = rcnn_loss(aux_fuse_scores[1], aux_fuse_deltas[1],
             fuse_labels, fuse_targets)
        tf.summary.scalar('aux_fuse_cls_loss_2', aux_fuse_cls_loss_2)
        tf.summary.scalar('aux_fuse_reg_loss_2', aux_fuse_reg_loss_1)
    #solver
    # with tf.variable_scope('l2-reg') as scope:
    #     l2 = l2_regulariser(decay=0.0005)
    # tf.summary.scalar('total_l2reg', l2)

    with tf.variable_scope('total_loss') as scope:
        total_loss = top_cls_loss+top_reg_loss+fuse_cls_loss+0.1*fuse_reg_loss \
                    + aux_fuse_cls_loss_1 + aux_fuse_reg_loss_1 + aux_fuse_cls_loss_2 + aux_fuse_reg_loss_2
    tf.summary.scalar('total_loss', total_loss)

    learning_rate = tf.placeholder(tf.float32, shape=[])
    solver = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
    #solver_step = solver.minimize(top_cls_loss+top_reg_loss+l2)
    solver_step = solver.minimize(total_loss)

    max_iter = 10000
    iter_debug=8

    # start training here  #########################################################################################
    log.write(unicode('epoch     iter    rate   |  top_cls_loss   reg_loss   |  fuse_cls_loss  reg_loss  |  \n'))
    log.write(unicode('-------------------------------------------------------------------------------------\n'))

    num_ratios=len(ratios)
    num_scales=len(scales)
    # fig, axs = plt.subplots(num_ratios,num_scales)

    sess = tf.InteractiveSession()
    # with sess.as_default():
    merged = tf.summary.merge_all()

    log_dir = out_dir+'/train'
    if tf.gfile.Exists(log_dir):
        #gotta be careful
        tf.gfile.DeleteRecursively(log_dir)
        print 'Removed files in {}'.format(log_dir)
    train_writer = tf.summary.FileWriter(log_dir, sess.graph)

    saver  = tf.train.Saver()
    tf.global_variables_initializer().run()
    # sess.run( tf.global_variables_initializer(), { IS_TRAIN_PHASE : True } )

    #option: loading pretrained model
    saver.restore(sess, '/root/sharefolder/sdcnd/didi1/output/check_points/snap.ckpt')

    batch_top_cls_loss =0
    batch_top_reg_loss =0
    batch_fuse_cls_loss=0
    batch_fuse_reg_loss=0
    for iter in range(max_iter):
        epoch=1.0*iter
        rate=0.05


        ## generate train image -------------
        idx = np.random.choice(num_frames)     #*10   #num_frames)  #0
        idx = 87
        batch_top_images    = tops[idx].reshape(1,*top_shape)
        batch_front_images  = fronts[idx].reshape(1,*front_shape)
        batch_rgb_images    = rgbs[idx].reshape(1,*rgb_shape)

        batch_gt_labels    = gt_labels[idx]
        batch_gt_boxes3d   = gt_boxes3d[idx]
        batch_gt_top_boxes = box3d_to_top_box(batch_gt_boxes3d)


		## run propsal generation ------------
        fd1={
            top_images:      batch_top_images,
            top_anchors:     anchors,
            top_inside_inds: inside_inds,

            learning_rate:   rate,
            IS_TRAIN_PHASE:  True
        }
        batch_proposals, batch_proposal_scores, batch_top_features = sess.run([proposals, proposal_scores, top_features],fd1)

        ## generate  train rois  ------------
        batch_top_inds, batch_top_pos_inds, batch_top_labels, batch_top_targets  = \
            rpn_target ( anchors, inside_inds, batch_gt_labels,  batch_gt_top_boxes)

        batch_top_rois, batch_fuse_labels, batch_fuse_targets  = \
             rcnn_target(  batch_proposals, batch_gt_labels, batch_gt_top_boxes, batch_gt_boxes3d )

        batch_rois3d	 = project_to_roi3d    (batch_top_rois)
        batch_front_rois = project_to_front_roi(batch_rois3d  )
        batch_rgb_rois   = project_to_rgb_roi  (batch_rois3d  )


        ##debug gt generation
        if False:
        # if 1 and iter%iter_debug==0:
            top_image = top_imgs[idx]
            rgb       = rgbs[idx]

            img_gt     = draw_rpn_gt(top_image, batch_gt_top_boxes, batch_gt_labels)
            img_label  = draw_rpn_labels (top_image, anchors, batch_top_inds, batch_top_labels )
            img_target = draw_rpn_targets(top_image, anchors, batch_top_pos_inds, batch_top_targets)
            #imshow('img_rpn_gt',img_gt)
            #imshow('img_rpn_label',img_label)
            #imshow('img_rpn_target',img_target)

            img_label  = draw_rcnn_labels (top_image, batch_top_rois, batch_fuse_labels )
            img_target = draw_rcnn_targets(top_image, batch_top_rois, batch_fuse_labels, batch_fuse_targets)
            #imshow('img_rcnn_label',img_label)
            imshow('img_rcnn_target',img_target)


            img_rgb_rois = draw_boxes(rgb, batch_rgb_rois[:,1:5], color=(255,0,255), thickness=1)
            imshow('img_rgb_rois',img_rgb_rois)

            cv2.waitKey(1)

        ## run classification and regression loss -----------
        fd2={
			top_images:      batch_top_images,
            top_anchors:     anchors,
            top_inside_inds: inside_inds,

            learning_rate:   rate,
            IS_TRAIN_PHASE:  True,

            top_images: batch_top_images,
            front_images: batch_front_images,
            rgb_images: batch_rgb_images,

			top_rois:   batch_top_rois,
            front_rois: batch_front_rois,
            rgb_rois:   batch_rgb_rois,

            top_inds:     batch_top_inds,
            top_pos_inds: batch_top_pos_inds,
            top_labels:   batch_top_labels,
            top_targets:  batch_top_targets,

            fuse_labels:  batch_fuse_labels,
            fuse_targets: batch_fuse_targets,
        }
        #_, batch_top_cls_loss, batch_top_reg_loss = sess.run([solver_step, top_cls_loss, top_reg_loss],fd2)

        # import pdb; pdb.set_trace()
        # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        # run_metadata = tf.RunMetadata()
        run_options = None
        run_metadata = None
        _, summary, batch_top_cls_loss, batch_top_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss = \
           sess.run([solver_step, merged, top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss],
                    feed_dict = fd2,
                    options = run_options,
                    run_metadata = run_metadata)
        # train_writer.add_run_metadata(run_metadata, 'step%03d' % iter)
        train_writer.add_summary(summary, iter)
        train_writer.flush()

        log.write(unicode('%3.1f   %d   %0.4f   |   %0.5f   %0.5f   |   %0.5f   %0.5f  \n' %\
			(epoch, iter, rate, batch_top_cls_loss, batch_top_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss)))






        #print('ok')
        # debug: ------------------------------------
        if iter%10==0:
            top_image = top_imgs[idx]
            rgb       = rgbs[idx]


            batch_fuse_probs, batch_fuse_deltas = \
                sess.run([ fuse_probs, fuse_deltas ],fd2)

            #batch_fuse_deltas=0*batch_fuse_deltas #disable 3d box prediction
            probs, boxes3d = rcnn_nms(batch_fuse_probs, batch_fuse_deltas, batch_rois3d, threshold=0.5)


            ## show rcnn(fuse) nms
            gt_2d_box = batch_gt_top_boxes
            img_rcnn     = draw_rcnn (top_image, batch_fuse_probs, batch_fuse_deltas, batch_top_rois, batch_rois3d, gt_2d_box,darker=1)
            boxes_3d = rcnn_result( batch_fuse_probs, batch_fuse_deltas,  batch_top_rois, batch_rois3d, gt_2d_box)
            img_rcnn_nms = draw_rcnn_nms(rgb, boxes3d, probs)
            imshow('img_rcnn',img_rcnn)
            imshow('img_rcnn_nms',img_rcnn_nms)
            # cv2.imwrite('result.png', img_rcnn_nms)
            cv2.waitKey(1)

        if False:
        # if iter%100==0:
        # if iter%iter_debug==0:
            top_image = top_imgs[idx]
            rgb       = rgbs[idx]

            batch_top_probs, batch_top_scores, batch_top_deltas  = \
                sess.run([ top_probs, top_scores, top_deltas ],fd2)

            batch_fuse_probs, batch_fuse_deltas = \
                sess.run([ fuse_probs, fuse_deltas ],fd2)

            #batch_fuse_deltas=0*batch_fuse_deltas #disable 3d box prediction
            probs, boxes3d = rcnn_nms(batch_fuse_probs, batch_fuse_deltas, batch_rois3d, threshold=0.5)


            ## show rpn score maps
            # import pdb; pdb.set_trace()
            fig, axs = plt.subplots(num_ratios,num_scales)
            p = batch_top_probs.reshape( 50, 50, 2*num_bases)
            for n in range(num_bases):
                r=n%num_scales
                s=n//num_scales
                pn = p[:,:,2*n+1]*255
                axs[s,r].cla()
                axs[s,r].imshow(pn, cmap='gray', vmin=0, vmax=255)
            plt.pause(0.01)

			## show rpn(top) nms
            img_rpn     = draw_rpn    (top_image, batch_top_probs, batch_top_deltas, anchors, inside_inds)
            img_rpn_nms = draw_rpn_nms(top_image, batch_proposals, batch_proposal_scores)
            #imshow('img_rpn',img_rpn)
            imshow('img_rpn_nms',img_rpn_nms)
            cv2.waitKey(1)

            ## show rcnn(fuse) nms
            img_rcnn     = draw_rcnn (top_image, batch_fuse_probs, batch_fuse_deltas, batch_top_rois, batch_rois3d,darker=1)
            img_rcnn_nms = draw_rcnn_nms(rgb, boxes3d, probs)
            imshow('img_rcnn',img_rcnn)
            imshow('img_rcnn_nms',img_rcnn_nms)
            cv2.waitKey(1)

        # save: ------------------------------------
        if iter%500==0:
            #saver.save(sess, out_dir + '/check_points/%06d.ckpt'%iter)  #iter
            saver.save(sess, out_dir + '/check_points/snap.ckpt')  #iter

    train_writer.close()
def run_train():

    # output dir, etc
    out_dir = './outputs'
    makedirs(out_dir +'/tf')
    makedirs(out_dir +'/check_points')
    log = Logger(out_dir+'/log_%s.txt'%(time.strftime('%Y-%m-%d %H:%M:%S')),mode='a')
    index=np.load(data_root+'seg/train_list.npy')
    index=sorted(index)
    index=np.array(index)
    num_frames = len(index)
    # pdb.set_trace()
    #lidar data -----------------
    if 1:
        ###generate anchor base 
        # ratios=np.array([0.4,0.6,1.7,2.4], dtype=np.float32)
        # scales=np.array([0.5,1,2,3],   dtype=np.float32)
        # bases = make_bases(
        #     base_size = 16,
        #     ratios=ratios,
        #     scales=scales
        # )
        ratios=np.array([1.7,2.4])
        scales=np.array([1.7,2.4])
        bases=np.array([[-19.5, -8, 19.5, 8],
                        [-8, -19.5, 8, 19.5],
                        [-5, -3, 5, 3],
                        [-3, -5, 3, 5]
                        ])
        # pdb.set_trace()
        num_bases = len(bases)
        stride = 4

        out_shape=(8,3)


        rgbs, tops, fronts, gt_labels, gt_boxes3d, top_imgs, front_imgs, rgbs_norm, image_index = load_dummy_datas(index[:3])
        # rgbs, tops, fronts, gt_labels, gt_boxes3d, top_imgs, front_imgs, rgbs_norm, image_index, lidars = load_dummy_datas()
        top_shape   = tops[0].shape
        front_shape = fronts[0].shape
        rgb_shape   = rgbs[0].shape
        top_feature_shape = ((top_shape[0]-1)//stride+1, (top_shape[1]-1)//stride+1)
        # pdb.set_trace()
        # set anchor boxes
        num_class = 2 #incude background
        anchors, inside_inds =  make_anchors(bases, stride, top_shape[0:2], top_feature_shape[0:2])
        # inside_inds = np.arange(0,len(anchors),dtype=np.int32)  #use all  #<todo>
        print ('out_shape=%s'%str(out_shape))
        print ('num_frames=%d'%num_frames)

        #-----------------------
        #check data
        if 0:
            fig = mlab.figure(figure=None, bgcolor=(0,0,0), fgcolor=None, engine=None, size=(1000, 500))
            draw_lidar(lidars[0], fig=fig)
            draw_gt_boxes3d(gt_boxes3d[0], fig=fig)
            mlab.show(1)
            cv2.waitKey(1)




    #load model ####################################################################################################
    top_anchors     = tf.placeholder(shape=[None, 4], dtype=tf.int32,   name ='anchors'    )
    top_inside_inds = tf.placeholder(shape=[None   ], dtype=tf.int32,   name ='inside_inds')

    top_images   = tf.placeholder(shape=[None, *top_shape  ], dtype=tf.float32, name='top'  )
    front_images = tf.placeholder(shape=[None, *front_shape], dtype=tf.float32, name='front')
    rgb_images   = tf.placeholder(shape=[None, None, None, 3 ], dtype=tf.float32, name='rgb'  )
    top_rois     = tf.placeholder(shape=[None, 5], dtype=tf.float32,   name ='top_rois'   ) #<todo> change to int32???
    front_rois   = tf.placeholder(shape=[None, 5], dtype=tf.float32,   name ='front_rois' )
    rgb_rois     = tf.placeholder(shape=[None, 5], dtype=tf.float32,   name ='rgb_rois'   )

    top_features, top_scores, top_probs, top_deltas, proposals, proposal_scores = \
        top_feature_net(top_images, top_anchors, top_inside_inds, num_bases)
    # pdb.set_trace()
    front_features = front_feature_net(front_images)
    rgb_features   = rgb_feature_net(rgb_images)

    fuse_scores, fuse_probs, fuse_deltas = \
        fusion_net(
			( [top_features,     top_rois,     6,6,1./stride],
			  [front_features,   front_rois,   0,0,1./stride],  #disable by 0,0
			  [rgb_features,     rgb_rois,     6,6,1./(2*stride)],),
            num_class, out_shape) #<todo>  add non max suppression



    #loss ########################################################################################################
    top_inds     = tf.placeholder(shape=[None   ], dtype=tf.int32,   name='top_ind'    )
    top_pos_inds = tf.placeholder(shape=[None   ], dtype=tf.int32,   name='top_pos_ind')
    top_labels   = tf.placeholder(shape=[None   ], dtype=tf.int32,   name='top_label'  )
    top_targets  = tf.placeholder(shape=[None, 4], dtype=tf.float32, name='top_target' )
    top_cls_loss, top_reg_loss = rpn_loss(2*top_scores, top_deltas, top_inds, top_pos_inds, top_labels, top_targets)

    fuse_labels  = tf.placeholder(shape=[None            ], dtype=tf.int32,   name='fuse_label' )
    fuse_targets = tf.placeholder(shape=[None, *out_shape], dtype=tf.float32, name='fuse_target')
    fuse_cls_loss, fuse_reg_loss = rcnn_loss(fuse_scores, fuse_deltas, fuse_labels, fuse_targets)
    tf.summary.scalar('rpn_cls_loss', top_cls_loss)
    tf.summary.scalar('rpn_reg_loss', top_reg_loss)
    tf.summary.scalar('rcnn_cls_loss', fuse_cls_loss)
    tf.summary.scalar('rcnn_reg_loss', fuse_reg_loss)

    #solver
    l2 = l2_regulariser(decay=0.000005)
    tf.summary.scalar('l2', l2)
    learning_rate = tf.placeholder(tf.float32, shape=[])
    solver = tf.train.AdamOptimizer(learning_rate)
    # solver = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
    #solver_step = solver.minimize(top_cls_loss+top_reg_loss+l2)
    solver_step = solver.minimize(1*top_cls_loss+1*top_reg_loss+1.5*fuse_cls_loss+2*fuse_reg_loss+l2)

    max_iter = 200000
    iter_debug=1

    # start training here  #########################################################################################
    log.write('epoch     iter    speed   rate   |  top_cls_loss   reg_loss   |  fuse_cls_loss  reg_loss  |  \n')
    log.write('-------------------------------------------------------------------------------------\n')

    num_ratios=len(ratios)
    num_scales=len(scales)
    fig, axs = plt.subplots(num_ratios,num_scales)

    merged = tf.summary.merge_all()

    sess = tf.InteractiveSession()
    train_writer = tf.summary.FileWriter( './outputs/tensorboard/Res_Vgg_up',
                                      sess.graph)
    with sess.as_default():
        sess.run( tf.global_variables_initializer(), { IS_TRAIN_PHASE : True } )
        # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        # summary_writer = tf.summary.FileWriter(out_dir+'/tf', sess.graph)
        saver  = tf.train.Saver() 
        saver.restore(sess, './outputs/check_points/snap_ResNet_vgg_up_NGT_060000.ckpt') 
        # # saver.restore(sess, './outputs/check_points/MobileNet.ckpt')  

        # var_lt_res=[v for v in tf.trainable_variables() if v.name.startswith('res')]#resnet_v1_50
        # # pdb.set_trace()
        # ## var_lt=[v for v in tf.trainable_variables() if not(v.name.startswith('fuse-block-1')) and not(v.name.startswith('fuse')) and not(v.name.startswith('fuse-input'))]

        # # # var_lt.pop(0)
        # # # var_lt.pop(0)
        # # # pdb.set_trace()
        # saver_0=tf.train.Saver(var_lt_res)        
        # # # 
        # saver_0.restore(sess, './outputs/check_points/resnet_v1_50.ckpt')
        # # pdb.set_trace()
        # top_lt=[v for v in tf.trainable_variables() if v.name.startswith('top_base')]
        # top_lt.pop(0)
        # # # top_lt.pop(0)
        # for v in top_lt:
        #     # pdb.set_trace()
        #     for v_rgb in var_lt:
        #         if v.name[9:]==v_rgb.name:
        #             print ("assign weights:%s"%v.name)
        #             v.assign(v_rgb)
        # var_lt_vgg=[v for v in tf.trainable_variables() if v.name.startswith('vgg')]
        # var_lt_vgg.pop(0)
        # saver_1=tf.train.Saver(var_lt_vgg)
        
        # # pdb.set_trace()
        # saver_1.restore(sess, './outputs/check_points/vgg_16.ckpt')

        batch_top_cls_loss =0
        batch_top_reg_loss =0
        batch_fuse_cls_loss=0
        batch_fuse_reg_loss=0
        rate=0.000005
        frame_range = np.arange(num_frames)
        idx=0
        frame=0
        for iter in range(max_iter):
            epoch=iter//num_frames+1
            # rate=0.001
            start_time=time.time()


            # generate train image -------------
            # idx = np.random.choice(num_frames)     #*10   #num_frames)  #0
            # shuffle the samples every 4*num_frames
            if iter%(num_frames*2)==0:
                idx=0
                frame=0
                count=0
                end_flag=0
                frame_range1 = np.random.permutation(num_frames)
                if np.all(frame_range1==frame_range):
                    raise Exception("Invalid level!", permutation)
                frame_range=frame_range1

            #load 500 samples every 2000 iterations
            freq=int(200)
            if idx%freq==0 :
                count+=idx
                if count%(2*freq)==0:
                    frame+=idx
                    frame_end=min(frame+freq,num_frames)
                    if frame_end==num_frames:
                        end_flag=1
                    # pdb.set_trace()
                    del rgbs, tops, fronts, gt_labels, gt_boxes3d, top_imgs, front_imgs, rgbs_norm, image_index
                    rgbs, tops, fronts, gt_labels, gt_boxes3d, top_imgs, front_imgs, rgbs_norm, image_index = load_dummy_datas(index[frame_range[frame:frame_end]])
                idx=0
            if (end_flag==1) and (idx+frame)==num_frames:
                idx=0
            print('processing image : %s'%image_index[idx])

            if (iter+1)%(10000)==0:
                rate=0.8*rate

            rgb_shape   = rgbs[idx].shape
            batch_top_images    = tops[idx].reshape(1,*top_shape)
            batch_front_images  = fronts[idx].reshape(1,*front_shape)
            batch_rgb_images    = rgbs_norm[idx].reshape(1,*rgb_shape)
            # batch_rgb_images    = rgbs[idx].reshape(1,*rgb_shape)

            top_img=tops[idx]
            # pdb.set_trace()
            inside_inds_filtered=anchor_filter(top_img[:,:,-1], anchors, inside_inds)

            # pdb.set_trace()
            batch_gt_labels    = gt_labels[idx]
            if len(batch_gt_labels)==0:
                # pdb.set_trace()
                idx=idx+1
                continue
            batch_gt_boxes3d   = gt_boxes3d[idx]
            # pdb.set_trace()
            batch_gt_top_boxes = box3d_to_top_box(batch_gt_boxes3d)




			## run propsal generation ------------
            fd1={
                top_images:      batch_top_images,
                top_anchors:     anchors,
                top_inside_inds: inside_inds_filtered,

                learning_rate:   rate,
                IS_TRAIN_PHASE:  True
            }
            batch_proposals, batch_proposal_scores, batch_top_features = sess.run([proposals, proposal_scores, top_features],fd1)
            print(batch_proposal_scores[:50])
            # pdb.set_trace()
            ## generate  train rois  ------------
            batch_top_inds, batch_top_pos_inds, batch_top_labels, batch_top_targets  = \
                rpn_target ( anchors, inside_inds_filtered, batch_gt_labels,  batch_gt_top_boxes)

            batch_top_rois, batch_fuse_labels, batch_fuse_targets  = \
                 rcnn_target(  batch_proposals, batch_gt_labels, batch_gt_top_boxes, batch_gt_boxes3d )

            batch_rois3d	 = project_to_roi3d    (batch_top_rois)
            batch_front_rois = project_to_front_roi(batch_rois3d  )
            batch_rgb_rois   = project_to_rgb_roi  (batch_rois3d  )


            # keep = np.where((batch_rgb_rois[:,1]>=-200) & (batch_rgb_rois[:,2]>=-200) & (batch_rgb_rois[:,3]<=(rgb_shape[1]+200)) & (batch_rgb_rois[:,4]<=(rgb_shape[0]+200)))[0]
            # batch_rois3d        = batch_rois3d[keep]      
            # batch_front_rois    = batch_front_rois[keep]
            # batch_rgb_rois      = batch_rgb_rois[keep]  
            # batch_proposal_scores=batch_proposal_scores[keep]
            # batch_top_rois      =batch_top_rois[keep]

            if len(batch_rois3d)==0:
                # pdb.set_trace()
                idx=idx+1
                continue




            ##debug gt generation
            if vis and iter%iter_debug==0:
                top_image = top_imgs[idx]
                rgb       = rgbs[idx]

                img_gt     = draw_rpn_gt(top_image, batch_gt_top_boxes, batch_gt_labels)
                img_label  = draw_rpn_labels (img_gt, anchors, batch_top_inds, batch_top_labels )
                img_target = draw_rpn_targets(top_image, anchors, batch_top_pos_inds, batch_top_targets)
                #imshow('img_rpn_gt',img_gt)
                imshow('img_anchor_label',img_label)
                #imshow('img_rpn_target',img_target)

                img_label  = draw_rcnn_labels (top_image, batch_top_rois, batch_fuse_labels )
                img_target = draw_rcnn_targets(top_image, batch_top_rois, batch_fuse_labels, batch_fuse_targets)
                #imshow('img_rcnn_label',img_label)
                if vis :
                    imshow('img_rcnn_target',img_target)


                img_rgb_rois = draw_boxes(rgb, batch_rgb_rois[:,1:5], color=(255,0,255), thickness=1)
                if vis :
                    imshow('img_rgb_rois',img_rgb_rois)
                    cv2.waitKey(1)

            ## run classification and regression loss -----------
            fd2={
				**fd1,

                top_images: batch_top_images,
                front_images: batch_front_images,
                rgb_images: batch_rgb_images,

				top_rois:   batch_top_rois,
                front_rois: batch_front_rois,
                rgb_rois:   batch_rgb_rois,

                top_inds:     batch_top_inds,
                top_pos_inds: batch_top_pos_inds,
                top_labels:   batch_top_labels,
                top_targets:  batch_top_targets,

                fuse_labels:  batch_fuse_labels,
                fuse_targets: batch_fuse_targets,
            }
            #_, batch_top_cls_loss, batch_top_reg_loss = sess.run([solver_step, top_cls_loss, top_reg_loss],fd2)


            _, batch_top_cls_loss, batch_top_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss = \
               sess.run([solver_step, top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss],fd2)

            speed=time.time()-start_time
            log.write('%5.1f   %5d    %0.4fs   %0.4f   |   %0.5f   %0.5f   |   %0.5f   %0.5f  \n' %\
				(epoch, iter, speed, rate, batch_top_cls_loss, batch_top_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss))



            #print('ok')
            # debug: ------------------------------------

            if vis and iter%iter_debug==0:
                top_image = top_imgs[idx]
                rgb       = rgbs[idx]

                batch_top_probs, batch_top_scores, batch_top_deltas  = \
                    sess.run([ top_probs, top_scores, top_deltas ],fd2)

                batch_fuse_probs, batch_fuse_deltas = \
                    sess.run([ fuse_probs, fuse_deltas ],fd2)

                #batch_fuse_deltas=0*batch_fuse_deltas #disable 3d box prediction
                probs, boxes3d = rcnn_nms(batch_fuse_probs, batch_fuse_deltas, batch_rois3d, threshold=0.05)


                ## show rpn score maps
                p = batch_top_probs.reshape( *(top_feature_shape[0:2]), 2*num_bases)
                for n in range(num_bases):
                    r=n%num_scales
                    s=n//num_scales
                    pn = p[:,:,2*n+1]*255
                    axs[s,r].cla()
                    if vis :
                        axs[s,r].imshow(pn, cmap='gray', vmin=0, vmax=255)
                        plt.pause(0.01)

				## show rpn(top) nms
                img_rpn     = draw_rpn    (top_image, batch_top_probs, batch_top_deltas, anchors, inside_inds)
                img_rpn_nms = draw_rpn_nms(img_gt, batch_proposals, batch_proposal_scores)
                #imshow('img_rpn',img_rpn)
                if vis :
                    imshow('img_rpn_nms',img_rpn_nms)
                    cv2.waitKey(1)

                ## show rcnn(fuse) nms
                img_rcnn     = draw_rcnn (top_image, batch_fuse_probs, batch_fuse_deltas, batch_top_rois, batch_rois3d,darker=1)
                img_rcnn_nms = draw_rcnn_nms(rgb, boxes3d, probs)
                if vis :
                    imshow('img_rcnn',img_rcnn)
                    imshow('img_rcnn_nms',img_rcnn_nms)
                    cv2.waitKey(0)
            if (iter)%10==0:
                summary = sess.run(merged,fd2)
                train_writer.add_summary(summary, iter)
            # save: ------------------------------------
            if (iter)%2000==0 and (iter!=0):
                #saver.save(sess, out_dir + '/check_points/%06d.ckpt'%iter)  #iter
                saver.save(sess, out_dir + '/check_points/snap_ResNet_vgg_NGT_%06d.ckpt'%iter)  #iter
                # saver.save(sess, out_dir + '/check_points/MobileNet.ckpt')  #iter
                # pdb.set_trace()
                pass

            idx=idx+1
Beispiel #9
0
def run_train():

    # output dir, etc
    out_dir = '/root/share/out/didi/xxx'
    makedirs(out_dir + '/tf')
    log = Logger(out_dir + '/log.txt', mode='a')

    #one lidar data -----------------
    if 1:
        ratios = np.array([0.5, 1, 2], dtype=np.float32)
        scales = np.array([1, 2, 3], dtype=np.float32)
        bases = make_bases(base_size=16, ratios=ratios, scales=scales)
        num_bases = len(bases)
        stride = 8

        rgb, top, top_image, lidar, gt_labels, gt_boxes3d, gt_top_boxes = load_dummy_data(
        )
        top_shape = top.shape
        top_feature_shape = (top_shape[0] // stride, top_shape[1] // stride)

        rgb_shape = rgb.shape
        out_shape = (8, 3)

        #-----------------------
        #check data
        if 0:
            fig = mlab.figure(figure=None,
                              bgcolor=(0, 0, 0),
                              fgcolor=None,
                              engine=None,
                              size=(1000, 500))
            draw_lidar(lidar, fig=fig)
            draw_gt_boxes3d(gt_boxes3d, fig=fig)
            mlab.show(1)

            draw_gt_boxes(top_image, gt_top_boxes)
            draw_projected_gt_boxes3d(rgb, gt_boxes3d)

            #imshow('top_image',top_image)
            #imshow('rgb',rgb)
            cv2.waitKey(1)

    #one dummy data -----------------
    if 0:
        ratios = [0.5, 1, 2]
        scales = 2**np.arange(3, 6)
        bases = make_bases(base_size=16, ratios=ratios, scales=scales)
        num_bases = len(bases)
        stride = 8

        rgb, top, top_image, lidar, gt_labels, gt_boxes3d, gt_top_boxes = load_dummy_data1(
        )
        top_shape = top.shape
        top_feature_shape = (54, 72
                             )  #(top_shape[0]//stride, top_shape[1]//stride)

        rgb_shape = rgb.shape
        out_shape = (4, )

        # img_gt =draw_gt_boxes(top_image, gt_top_boxes)
        # imshow('img_gt',img_gt)
        # cv2.waitKey(1)

    # set anchor boxes
    dim = np.prod(out_shape)
    num_class = 2  #incude background
    anchors, inside_inds = make_anchors(bases, stride, top_shape[0:2],
                                        top_feature_shape[0:2])
    inside_inds = np.arange(0, len(anchors), dtype=np.int32)  #use all
    print('dim=%d' % dim)

    #load model ##############
    top_images = tf.placeholder(shape=[None, *top_shape],
                                dtype=tf.float32,
                                name='top')
    top_anchors = tf.placeholder(shape=[None, 4],
                                 dtype=tf.int32,
                                 name='anchors')
    top_inside_inds = tf.placeholder(shape=[None],
                                     dtype=tf.int32,
                                     name='inside_inds')

    top_features, top_scores, top_probs, top_deltas, top_rois1, top_roi_scores1 = \
        top_lidar_feature_net(top_images, top_anchors, top_inside_inds, num_bases)

    rgb_images = tf.placeholder(shape=[None, *rgb_shape],
                                dtype=tf.float32,
                                name='rgb')
    rgb_features = rgb_feature_net(rgb_images)

    top_rois = tf.placeholder(shape=[None, 5],
                              dtype=tf.float32,
                              name='top_rois')  #<todo> change to int32???
    rgb_rois = tf.placeholder(shape=[None, 5],
                              dtype=tf.float32,
                              name='rgb_rois')
    fuse_scores, fuse_probs, fuse_deltas = \
        fusion_net(
            (top_features,   rgb_features,),
            (top_rois,       rgb_rois,),
            ([6,6,1./stride],[6,6,1./stride],),
            num_class, out_shape)

    #loss ####################
    top_inds = tf.placeholder(shape=[None], dtype=tf.int32, name='top_ind')
    top_pos_inds = tf.placeholder(shape=[None],
                                  dtype=tf.int32,
                                  name='top_pos_ind')
    top_labels = tf.placeholder(shape=[None], dtype=tf.int32, name='top_label')
    top_targets = tf.placeholder(shape=[None, 4],
                                 dtype=tf.float32,
                                 name='top_target')
    top_cls_loss, top_reg_loss = rpn_loss(top_scores, top_deltas, top_inds,
                                          top_pos_inds, top_labels,
                                          top_targets)

    fuse_labels = tf.placeholder(shape=[None],
                                 dtype=tf.int32,
                                 name='fuse_label')
    fuse_targets = tf.placeholder(shape=[None, *out_shape],
                                  dtype=tf.float32,
                                  name='fuse_target')
    fuse_cls_loss, fuse_reg_loss = rcnn_loss(fuse_scores, fuse_deltas,
                                             fuse_labels, fuse_targets)

    #put your solver here
    l2 = l2_regulariser(decay=0.0005)
    learning_rate = tf.placeholder(tf.float32, shape=[])
    solver = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                        momentum=0.9)
    #solver_step = solver.minimize(top_cls_loss+top_reg_loss+l2)
    solver_step = solver.minimize(top_cls_loss + top_reg_loss + fuse_cls_loss +
                                  fuse_reg_loss + l2)

    max_iter = 10000

    # start training here ------------------------------------------------
    log.write('epoch        iter      rate     |  train_mse   valid_mse  |\n')
    log.write(
        '----------------------------------------------------------------------------\n'
    )

    num_ratios = len(ratios)
    num_scales = len(scales)
    fig, axs = plt.subplots(num_ratios, num_scales)

    sess = tf.InteractiveSession()
    with sess.as_default():
        sess.run(tf.global_variables_initializer(), {IS_TRAIN_PHASE: True})
        summary_writer = tf.summary.FileWriter(out_dir + '/tf', sess.graph)
        rate = 0.1

        batch_top_cls_loss = 0
        batch_top_reg_loss = 0
        batch_fuse_cls_loss = 0
        batch_fuse_reg_loss = 0
        for iter in range(max_iter):

            #random sample train data
            batch_top_images = top.reshape(1, *top_shape)
            batch_top_gt_labels = gt_labels
            batch_top_gt_boxes = gt_top_boxes

            batch_rgb_images = rgb.reshape(1, *rgb_shape)

            batch_fuse_gt_labels = gt_labels
            batch_fuse_gt_boxes = gt_top_boxes
            batch_fuse_gt_boxes3d = gt_boxes3d

            ##-------------------------------
            fd = {
                top_images: batch_top_images,
                top_anchors: anchors,
                top_inside_inds: inside_inds,
                learning_rate: rate,
                IS_TRAIN_PHASE: True
            }
            batch_top_rois1, batch_top_roi_scores1, batch_top_features = sess.run(
                [top_rois1, top_roi_scores1, top_features], fd)

            ## generate ground truth
            batch_top_inds, batch_top_pos_inds, batch_top_labels, batch_top_targets  = \
                rpn_target ( anchors, inside_inds, batch_top_gt_labels,  batch_top_gt_boxes)

            batch_top_rois, batch_fuse_labels, batch_fuse_targets  = \
                 rcnn_target(  batch_top_rois1, batch_fuse_gt_labels, batch_fuse_gt_boxes, batch_fuse_gt_boxes3d )

            #project to rgb roi -------------------------------------------------
            batch_rgb_rois = batch_top_rois.copy()
            num = len(batch_top_rois)
            for n in range(num):
                box3d = box_to_box3d(batch_top_rois[n, 1:5].reshape(
                    1, 4)).reshape(8, 3)
                qs = make_projected_box3d(box3d)

                minx = np.min(qs[:, 0])
                maxx = np.max(qs[:, 0])
                miny = np.min(qs[:, 1])
                maxy = np.max(qs[:, 1])
                batch_rgb_rois[n, 1:5] = minx, miny, maxx, maxy

            darken = 0.7
            img_rgb_roi = rgb.copy() * darken
            for n in range(num):
                b = batch_rgb_rois[n, 1:5]
                cv2.rectangle(img_rgb_roi, (b[0], b[1]), (b[2], b[3]),
                              (0, 255, 255), 1)

            imshow('img_rgb_roi', img_rgb_roi)
            #--------------------------------------------------------------------

            ##debug
            if 1:
                img_gt = draw_rpn_gt(top_image, batch_top_gt_boxes,
                                     batch_top_gt_labels)
                img_label = draw_rpn_labels(top_image, anchors, batch_top_inds,
                                            batch_top_labels)
                img_target = draw_rpn_targets(top_image, anchors,
                                              batch_top_pos_inds,
                                              batch_top_targets)
                imshow('img_rpn_gt', img_gt)
                imshow('img_rpn_label', img_label)
                imshow('img_rpn_target', img_target)

                img_label = draw_rcnn_labels(top_image, batch_top_rois,
                                             batch_fuse_labels)
                img_target = draw_rcnn_targets(top_image, batch_top_rois,
                                               batch_fuse_labels,
                                               batch_fuse_targets)
                imshow('img_rcnn_label', img_label)
                imshow('img_rcnn_target', img_target)
                cv2.waitKey(1)

            #---------------------------------------------------
            fd = {
                top_images: batch_top_images,
                top_anchors: anchors,
                top_inside_inds: inside_inds,
                top_inds: batch_top_inds,
                top_pos_inds: batch_top_pos_inds,
                top_labels: batch_top_labels,
                top_targets: batch_top_targets,
                top_rois: batch_top_rois,
                #front_rois1: batch_front_rois,
                rgb_images: batch_rgb_images,
                rgb_rois: batch_rgb_rois,
                fuse_labels: batch_fuse_labels,
                fuse_targets: batch_fuse_targets,
                learning_rate: rate,
                IS_TRAIN_PHASE: True
            }
            #_, batch_top_cls_loss, batch_top_reg_loss = sess.run([solver_step, top_cls_loss, top_reg_loss],fd)


            _, batch_top_cls_loss, batch_top_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss = \
               sess.run([solver_step, top_cls_loss, top_reg_loss, fuse_cls_loss, fuse_reg_loss],fd)

            #print('ok')
            # debug: ------------------------------------

            if iter % 4 == 0:
                batch_top_probs, batch_top_scores, batch_top_deltas  = \
                    sess.run([ top_probs, top_scores, top_deltas ],fd)

                batch_fuse_probs, batch_fuse_deltas = \
                    sess.run([ fuse_probs, fuse_deltas ],fd)

                probs, boxes3d, priors, priors3d, deltas = rcnn_nms(
                    batch_fuse_probs, batch_fuse_deltas, batch_top_rois)

                ## show rpn score maps
                p = batch_top_probs.reshape(*(top_feature_shape[0:2]),
                                            2 * num_bases)
                for n in range(num_bases):
                    r = n % num_scales
                    s = n // num_scales
                    pn = p[:, :, 2 * n + 1] * 255
                    axs[s, r].cla()
                    axs[s, r].imshow(pn, cmap='gray', vmin=0, vmax=255)
                plt.pause(0.01)

                img_rpn = draw_rpn(top_image, batch_top_probs,
                                   batch_top_deltas, anchors, inside_inds)
                img_rpn_nms = draw_rpn_nms(
                    top_image, batch_top_rois1,
                    batch_top_roi_scores1)  # estimat after non-max
                imshow('img_rpn', img_rpn)
                imshow('img_rpn_nms', img_rpn_nms)
                cv2.waitKey(1)

                #draw rcnn results --------------------------------
                img_rcnn = draw_rcnn(top_image, batch_fuse_probs,
                                     batch_fuse_deltas, batch_top_rois)
                draw_projected_gt_boxes3d(rgb,
                                          boxes3d,
                                          color=(255, 255, 255),
                                          thickness=1)

                imshow('img_rcnn', img_rcnn)
                cv2.waitKey(1)

            # debug: ------------------------------------

            log.write('%d   | %0.5f   %0.5f  %0.5f   %0.5f : \n' %
                      (iter, batch_top_cls_loss, batch_top_reg_loss,
                       batch_fuse_cls_loss, batch_fuse_reg_loss))
def run_test():

    # output dir, etc
    out_dir = './outputs'
    makedirs(out_dir + '/tf')
    makedirs(out_dir + '/check_points')
    log = Logger(out_dir + '/log/log_%s.txt' %
                 (time.strftime('%Y-%m-%d %H:%M:%S')),
                 mode='a')

    # index=np.load(train_data_root+'/train.npy')
    index_file = open(train_data_root + '/val.txt')
    # index_file=open(train_data_root+'/train.txt')
    index = [int(i.strip()) for i in index_file]
    index_file.close()

    index = sorted(index)

    print('len(index):%d' % len(index))
    num_frames = len(index)
    #lidar data -----------------
    if 1:
        ratios_rgb = np.array([0.5, 1, 2], dtype=np.float32)
        scales_rgb = np.array([0.5, 1, 2, 4, 5], dtype=np.float32)
        bases_rgb = make_bases(base_size=48,
                               ratios=ratios_rgb,
                               scales=scales_rgb)

        num_bases_rgb = len(bases_rgb)
        stride = 8

        rgbs, gt_labels, gt_3dTo2Ds, gt_boxes2d, rgbs_norm, image_index = load_dummy_datas(
            index[10])
        # num_frames = len(rgbs)

        rgb_shape = rgbs[0].shape
        # pdb.set_trace()
        rgb_feature_shape = ((rgb_shape[0] - 1) // stride + 1,
                             (rgb_shape[1] - 1) // stride + 1)
        out_shape = (2, 2)

        #-----------------------
        #check data
        if 0:
            fig = mlab.figure(figure=None,
                              bgcolor=(0, 0, 0),
                              fgcolor=None,
                              engine=None,
                              size=(1000, 1000))
            draw_lidar(lidars[0], fig=fig)
            draw_gt_boxes3d(gt_boxes3d[0], fig=fig)
            mlab.show(1)
            cv2.waitKey(0)

    # set anchor boxes
    num_class = 2  #incude background
    anchors_rgb, inside_inds_rgb = make_anchors(bases_rgb, stride,
                                                rgb_shape[0:2],
                                                rgb_feature_shape[0:2])
    print('out_shape=%s' % str(out_shape))
    print('num_frames=%d' % num_frames)

    #load model ####################################################################################################
    rgb_anchors = tf.placeholder(shape=[None, 4],
                                 dtype=tf.int32,
                                 name='anchors_rgb')
    rgb_inside_inds = tf.placeholder(shape=[None],
                                     dtype=tf.int32,
                                     name='inside_inds_rgb')

    rgb_images = tf.placeholder(shape=[None, None, None, 3],
                                dtype=tf.float32,
                                name='rgb')
    rgb_rois = tf.placeholder(shape=[None, 5],
                              dtype=tf.float32,
                              name='rgb_rois')

    rgb_features, rgb_scores, rgb_probs, rgb_deltas= \
        top_feature_net(rgb_images, rgb_anchors, rgb_inside_inds, num_bases_rgb)

    fuse_scores, fuse_probs, fuse_deltas, fuse_deltas_3dTo2D = \
        fusion_net(
            ( [rgb_features,     rgb_rois,     7,7,1./(1*stride)],),num_class, out_shape) #<todo>  add non max suppression

    sess = tf.InteractiveSession()
    with sess.as_default():
        sess.run(tf.global_variables_initializer(), {IS_TRAIN_PHASE: True})
        # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        summary_writer = tf.summary.FileWriter(out_dir + '/tf', sess.graph)
        saver = tf.train.Saver()
        # saver.restore(sess, './outputs/check_points/snap_2D_pretrain.ckpt')
        saver.restore(
            sess,
            './outputs/check_points/snap_2dTo3D__data_augmentation090000.ckpt')
        #
        # # pdb.set_trace()
        # var_lt_res=[v for v in tf.global_variables() if not v.name.startswith('fuse/3D')]
        # saver_0=tf.train.Saver(var_lt_res)
        # saver_0.restore(sess, './outputs/check_points/snap_2D_pretrain.ckpt')

        batch_top_cls_loss = 0
        batch_top_reg_loss = 0
        batch_fuse_cls_loss = 0
        batch_fuse_reg_loss = 0

        for iter in range(num_frames):
            start_time = time.time()
            # iter=iter+20
            print('Processing Img: %d  %s' % (iter, index[iter]))
            rgbs, gt_labels, gt_3dTo2Ds, gt_boxes2d, rgbs_norm, image_index = load_dummy_datas(
                index[iter])
            idx = 0

            rgb_shape = rgbs[idx].shape
            # top_img=top_imgs[idx]

            batch_rgb_images = rgbs_norm[idx].reshape(1, *rgb_shape)

            batch_gt_labels = gt_labels[idx]
            batch_gt_3dTo2Ds = gt_3dTo2Ds[idx]
            batch_gt_boxes2d = gt_boxes2d[idx]
            if len(batch_gt_labels) == 0:
                # idx=idx+1
                # pdb.set_trace()
                cv2.waitKey(0)
                # continue
            ## run propsal generation ------------
            fd1 = {
                rgb_images: batch_rgb_images,
                rgb_anchors: anchors_rgb,
                rgb_inside_inds: inside_inds_rgb,
                IS_TRAIN_PHASE: False
            }
            batch_rgb_probs, batch_deltas, batch_rgb_features = sess.run(
                [rgb_probs, rgb_deltas, rgb_features], fd1)

            rgb_feature_shape = ((rgb_shape[0] - 1) // stride + 1,
                                 (rgb_shape[1] - 1) // stride + 1)
            anchors_rgb, inside_inds_rgb = make_anchors(
                bases_rgb, stride, rgb_shape[0:2], rgb_feature_shape[0:2])
            # pdb.set_trace()
            nms_pre_topn_ = 2000
            nms_post_topn_ = 300
            img_scale = 1
            rpn_nms = rpn_nms_generator(stride,
                                        rgb_shape[1],
                                        rgb_shape[0],
                                        img_scale,
                                        nms_thresh=0.7,
                                        min_size=stride,
                                        nms_pre_topn=nms_pre_topn_,
                                        nms_post_topn=nms_post_topn_)
            batch_proposals, batch_proposal_scores = rpn_nms(
                batch_rgb_probs, batch_deltas, anchors_rgb, inside_inds_rgb)

            ## run classification and regression  -----------

            fd2 = {
                **fd1,
                rgb_images: batch_rgb_images,
                rgb_rois: batch_proposals,
            }
            # batch_top_probs,  batch_top_deltas  =  sess.run([ top_probs,  top_deltas  ],fd2)
            batch_fuse_probs, batch_fuse_deltas, batch_fuse_deltas_3dTo2D = sess.run(
                [fuse_probs, fuse_deltas, fuse_deltas_3dTo2D], fd2)

            probs, boxes2d, projections = rcnn_nms_2d(batch_fuse_probs,
                                                      batch_fuse_deltas,
                                                      batch_proposals,
                                                      batch_fuse_deltas_3dTo2D,
                                                      threshold=0.3)
            # print('nums of boxes3d : %d'%len(boxes3d))
            # generat_test_reslut(probs, boxes3d, rgb_shape, int(index[iter]), boxes2d)
            speed = time.time() - start_time
            print('speed: %0.4fs' % speed)
            # pdb.set_trace()
            # debug: ------------------------------------
            if is_show == 1:

                rgb = rgbs[idx]

                img_rpn_nms = draw_rpn_nms(rgb, batch_proposals,
                                           batch_proposal_scores)
                img_gt = draw_rpn_gt(rgb, batch_gt_boxes2d, batch_gt_labels)
                imshow('img_gt', img_gt)
                imshow('img_rpn_nms', img_rpn_nms)

                img_rcnn_nms = draw_rgb_projections(rgb,
                                                    projections,
                                                    color=(0, 0, 255),
                                                    thickness=1)
                img_rgb_2d_detection = draw_boxes(rgb,
                                                  boxes2d,
                                                  color=(255, 0, 255),
                                                  thickness=1)
                imshow('draw_rcnn_nms', img_rcnn_nms)
                imshow('img_rgb_2d_detection', img_rgb_2d_detection)
                # cv2.imwrite(out_dir+'/demo_result_train_set'+'/rgb_%05d.png'%index[iter],img_rcnn_nms)

                cv2.waitKey(0)
def run_train():
    CFG.KEEPPROBS = 0.5
    # output dir  for tensorboard, checkpoints and log
    out_dir = CFG.PATH.TRAIN.OUTPUT
    makedirs(out_dir + '/tf')
    makedirs(out_dir + '/check_points')
    makedirs(out_dir + '/log')
    log = Logger(out_dir + '/log/log_%s.txt' %
                 (time.strftime('%Y-%m-%d %H:%M:%S')),
                 mode='a')

    index = np.load(CFG.PATH.TRAIN.TARGET + '/train.npy')
    index = sorted(index)
    index = np.array(index)
    num_frames = len(index)

    #lidar data -----------------
    if 1:
        ###generate anchor base
        ratios_rgb = np.array([0.5, 1, 2], dtype=np.float32)
        scales_rgb = np.array([0.5, 1, 2, 4, 5], dtype=np.float32)
        bases_rgb = make_bases(base_size=48,
                               ratios=ratios_rgb,
                               scales=scales_rgb)

        num_bases_rgb = len(bases_rgb)
        stride = 8
        out_shape = (2, 2)

        rgbs, gt_labels, gt_3dTo2Ds, gt_boxes2d, rgbs_norm, image_index = load_dummy_datas(
            index[10:13])

        rgb_shape = rgbs[0].shape

        # set anchor boxes
        num_class = 2  #incude background
    #load model ####################################################################################################

    rgb_anchors = tf.placeholder(shape=[None, 4],
                                 dtype=tf.int32,
                                 name='anchors_rgb')
    rgb_inside_inds = tf.placeholder(shape=[None],
                                     dtype=tf.int32,
                                     name='inside_inds_rgb')

    rgb_images = tf.placeholder(shape=[None, None, None, 3],
                                dtype=tf.float32,
                                name='rgb')
    rgb_rois = tf.placeholder(shape=[None, 5],
                              dtype=tf.float32,
                              name='rgb_rois')

    rgb_features, rgb_scores, rgb_probs, rgb_deltas= \
        top_feature_net(rgb_images, rgb_anchors, rgb_inside_inds, num_bases_rgb)

    fuse_scores, fuse_probs, fuse_deltas, fuse_deltas_3dTo2D = \
        fusion_net(
            ( [rgb_features,     rgb_rois,     7,7,1./(1*stride)],),num_class, out_shape) #<todo>  add non max suppression

    #loss ########################################################################################################
    rgb_inds = tf.placeholder(shape=[None], dtype=tf.int32, name='rgb_ind')
    rgb_pos_inds = tf.placeholder(shape=[None],
                                  dtype=tf.int32,
                                  name='rgb_pos_ind')
    rgb_labels = tf.placeholder(shape=[None], dtype=tf.int32, name='rgb_label')
    rgb_targets = tf.placeholder(shape=[None, 4],
                                 dtype=tf.float32,
                                 name='rgb_target')
    rgb_cls_loss, rgb_reg_loss = rpn_loss(2 * rgb_scores, rgb_deltas, rgb_inds,
                                          rgb_pos_inds, rgb_labels,
                                          rgb_targets)

    fuse_labels = tf.placeholder(shape=[None],
                                 dtype=tf.int32,
                                 name='fuse_label')
    fuse_targets = tf.placeholder(shape=[None, 4],
                                  dtype=tf.float32,
                                  name='fuse_target')
    fuse_targets_3dTo2Ds = tf.placeholder(shape=[None, 16],
                                          dtype=tf.float32,
                                          name='fuse_target')

    fuse_cls_loss, fuse_reg_loss, fuse_reg_loss_3dTo2D = rcnn_loss(
        fuse_scores, fuse_deltas, fuse_labels, fuse_targets,
        fuse_deltas_3dTo2D, fuse_targets_3dTo2Ds)

    tf.summary.scalar('rcnn_cls_loss', fuse_cls_loss)
    tf.summary.scalar('rcnn_reg_loss', fuse_reg_loss)
    tf.summary.scalar('rcnn_reg_loss_3dTo2D', fuse_reg_loss_3dTo2D)
    tf.summary.scalar('rpn_cls_loss', rgb_cls_loss)
    tf.summary.scalar('rpn_reg_loss', rgb_reg_loss)

    #solver
    l2 = l2_regulariser(decay=CFG.TRAIN.WEIGHT_DECAY)
    tf.summary.scalar('l2', l2)
    learning_rate = tf.placeholder(tf.float32, shape=[])
    solver = tf.train.AdamOptimizer(learning_rate)
    solver_step = solver.minimize(2 * rgb_cls_loss + 1 * rgb_reg_loss +
                                  2 * fuse_cls_loss + 1 * fuse_reg_loss +
                                  0.5 * fuse_reg_loss_3dTo2D + l2)
    # 2*rgb_cls_loss+1*rgb_reg_loss+2*fuse_cls_loss+1*fuse_reg_loss+

    max_iter = 200000
    iter_debug = 1

    # start training here  #########################################################################################
    log.write(
        'epoch     iter    speed   rate   |  top_cls_loss   reg_loss   |  fuse_cls_loss  reg_loss  |  \n'
    )
    log.write(
        '-------------------------------------------------------------------------------------\n'
    )

    merged = tf.summary.merge_all()

    sess = tf.InteractiveSession()
    train_writer = tf.summary.FileWriter(
        './outputs/tensorboard/V_2dTo3d_finetune', sess.graph)
    with sess.as_default():
        sess.run(tf.global_variables_initializer(), {IS_TRAIN_PHASE: True})
        saver = tf.train.Saver()

        ##Initialize network ##

        ##Initialize network from relatively welll trained model
        # saver.restore(sess, './outputs/check_points/snap_2dTo3D__data_augmentation090000trainval.ckpt')

        #Initialize network from 2D Bounding Box  pretrained model whict trained on Last14000 datasets of  UISEE
        # var_lt_res=[v for v in tf.global_variables() if  not v.name.startswith('fuse/3D')]
        # saver_0=tf.train.Saver(var_lt_res)
        # saver_0.restore(sess, './outputs/check_points/snap_2D_pretrain.ckpt')

        ##Initialize network from ResNet50
        # var_lt_res=[v for v in tf.trainable_variables() if v.name.startswith('resnet_v1')]#resnet_v1_50
        # saver_0=tf.train.Saver(var_lt_res)
        # saver_0.restore(sess, './outputs/check_points/resnet_v1_50.ckpt')

        batch_top_cls_loss = 0
        batch_top_reg_loss = 0
        batch_fuse_cls_loss = 0
        batch_fuse_reg_loss = 0
        frame_range = np.arange(num_frames)
        idx = 0
        frame = 0

        for iter in range(max_iter):
            epoch = iter // num_frames + 1
            # rate=0.001
            start_time = time.time()
            if iter % (num_frames * 1) == 0:
                idx = 0
                frame = 0
                count = 0
                end_flag = 0
                frame_range1 = np.random.permutation(num_frames)
                if np.all(frame_range1 == frame_range):
                    raise Exception("Invalid level!", permutation)
                frame_range = frame_range1

            #load 500 samples every 2000 iterations
            freq = int(10)
            if idx % freq == 0:
                count += idx
                if count % (1 * freq) == 0:
                    frame += idx
                    frame_end = min(frame + freq, num_frames)
                    if frame_end == num_frames:
                        end_flag = 1
                    rgbs, gt_labels, gt_3dTo2Ds, gt_boxes2d, rgbs_norm, image_index = load_dummy_datas(
                        index[frame_range[frame:frame_end]])
                idx = 0
            if (end_flag == 1) and (idx + frame) == num_frames:
                idx = 0
            print('processing image : %s' % image_index[idx])

            if (iter + 1) % (CFG.TRAIN.LEARNING_RATE_DECAY_STEP) == 0:
                CFG.TRAIN.LEARNING_RATE = CFG.TRAIN.LEARNING_RATE_DECAY_SCALE * CFG.TRAIN.LEARNING_RATE

            rgb_shape = rgbs[idx].shape
            batch_rgb_images = rgbs_norm[idx].reshape(1, *rgb_shape)
            batch_gt_labels = gt_labels[idx]
            batch_gt_3dTo2Ds = gt_3dTo2Ds[idx]
            batch_gt_boxes2d = gt_boxes2d[idx]

            if len(batch_gt_labels) == 0:
                idx = idx + 1
                continue

            rgb_feature_shape = ((rgb_shape[0] - 1) // stride + 1,
                                 (rgb_shape[1] - 1) // stride + 1)
            anchors_rgb, inside_inds_rgb = make_anchors(
                bases_rgb, stride, rgb_shape[0:2], rgb_feature_shape[0:2])

            ## run propsal generation ------------
            fd1 = {
                rgb_images: batch_rgb_images,
                rgb_anchors: anchors_rgb,
                rgb_inside_inds: inside_inds_rgb,
                learning_rate: CFG.TRAIN.LEARNING_RATE,
                IS_TRAIN_PHASE: True,
            }
            batch_rgb_probs, batch_deltas, batch_rgb_features = sess.run(
                [rgb_probs, rgb_deltas, rgb_features], fd1)

            rpn_nms = rpn_nms_generator(
                stride,
                rgb_shape[1],
                rgb_shape[0],
                img_scale=1,
                nms_thresh=0.7,
                min_size=stride,
                nms_pre_topn=CFG.TRAIN.RPN_NMS_PRE_TOPN,
                nms_post_topn=CFG.TRAIN.RPN_NMS_POST_TOPN)
            batch_proposals, batch_proposal_scores = rpn_nms(
                batch_rgb_probs, batch_deltas, anchors_rgb, inside_inds_rgb)

            ## generate  train rois  ------------
            batch_rgb_inds, batch_rgb_pos_inds, batch_rgb_labels, batch_rgb_targets  = \
                rpn_target ( anchors_rgb, inside_inds_rgb, batch_gt_labels,  batch_gt_boxes2d)

            batch_rgb_rois, batch_fuse_labels, batch_fuse_targets2d, batch_fuse_targets_3dTo2Ds = rcnn_target(
                batch_proposals, batch_gt_labels, batch_gt_boxes2d,
                batch_gt_3dTo2Ds, rgb_shape[1], rgb_shape[0])

            print('nums of rcnn batch: %d' % len(batch_rgb_rois))
            ##debug gt generation
            if CFG.TRAIN.VISUALIZATION and iter % iter_debug == 0:
                rgb = rgbs[idx]

                img_gt = draw_rpn_gt(rgb, batch_gt_boxes2d, batch_gt_labels)
                rgb_label = draw_rpn_labels(img_gt, anchors_rgb,
                                            batch_rgb_inds, batch_rgb_labels)
                rgb_target = draw_rpn_targets(rgb, anchors_rgb,
                                              batch_rgb_pos_inds,
                                              batch_rgb_targets)
                #imshow('img_rpn_gt',img_gt)
                imshow('img_rgb_label', rgb_label)
                imshow('img_rpn_target', rgb_target)

                img_label = draw_rcnn_labels(rgb, batch_rgb_rois,
                                             batch_fuse_labels)
                img_target = draw_rcnn_targets(rgb, batch_rgb_rois,
                                               batch_fuse_labels,
                                               batch_fuse_targets2d)
                imshow('img_rcnn_label', img_label)
                imshow('img_rcnn_target', img_target)

                img_rgb_rois = draw_boxes(rgb,
                                          batch_rgb_rois[:, 1:5],
                                          color=(255, 0, 255),
                                          thickness=1)
                imshow('img_rgb_rois', img_rgb_rois)

                projections = box_transform_3dTo2D_inv(
                    batch_rgb_rois[:, 1:], batch_fuse_targets_3dTo2Ds)
                img_rcnn_3dTo2D = draw_rgb_projections(rgb,
                                                       projections,
                                                       color=(0, 0, 255),
                                                       thickness=1)
                imshow('img_rcnn_3dTo2D', img_rcnn_3dTo2D)
                # plt.pause(0.5)
                # cv2.waitKey(500)
                cv2.waitKey(0)

            ## run classification and regression loss -----------
            fd2 = {
                **fd1, rgb_images: batch_rgb_images,
                rgb_rois: batch_rgb_rois,
                rgb_inds: batch_rgb_inds,
                rgb_pos_inds: batch_rgb_pos_inds,
                rgb_labels: batch_rgb_labels,
                rgb_targets: batch_rgb_targets,
                fuse_labels: batch_fuse_labels,
                fuse_targets: batch_fuse_targets2d,
                fuse_targets_3dTo2Ds: batch_fuse_targets_3dTo2Ds
            }

            _, rcnn_probs, batch_rgb_cls_loss, batch_rgb_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss, batch_fuse_reg_loss_dTo2D = \
               sess.run([solver_step, fuse_probs, rgb_cls_loss, rgb_reg_loss, fuse_cls_loss, fuse_reg_loss, fuse_reg_loss_3dTo2D],fd2)

            speed = time.time() - start_time
            log.write('%5.1f   %5d    %0.4fs   %0.6f   |   %0.5f   %0.5f   |   %0.5f   %0.5f  |%0.5f   \n' %\
                (epoch, iter, speed, CFG.TRAIN.LEARNING_RATE, batch_rgb_cls_loss, batch_rgb_reg_loss, batch_fuse_cls_loss, batch_fuse_reg_loss, batch_fuse_reg_loss_dTo2D))

            if (iter) % 10 == 0:
                summary = sess.run(merged, fd2)
                train_writer.add_summary(summary, iter)
            # save: ------------------------------------

            if (iter) % 5000 == 0 and (iter != 0):
                saver.save(sess, out_dir + '/check_points/' +
                           CFG.PATH.TRAIN.CHECKPOINT_NAME +
                           '%06d.ckpt' % iter)  #iter
                # saver_rgb.save(sess, out_dir + '/check_points/pretrained_Res_rgb_model%06d.ckpt'%iter)
                # saver_top.save(sess, out_dir + '/check_points/pretrained_Res_top_model%06d.ckpt'%iter)
                # pdb.set_trace()

            idx = idx + 1