コード例 #1
0
 def vis_one_im(self):
     if cfgs.anno:
         #im_ = pred_visualize(self.vis_image.copy(), self.vis_pred).astype(np.uint8)
         im_ = pred_visualize(self.vis_image.copy(),
                              self.vis_anno).astype(np.uint8)
         utils.save_image(im_,
                          self.re_save_dir_im,
                          name='inp_' + self.filename + '.jpg')
     if cfgs.fit_ellip:
         im_ellip = fit_ellipse_findContours(
             self.vis_image.copy(),
             np.expand_dims(self.vis_pred, axis=2).astype(np.uint8))
         utils.save_image(im_ellip,
                          self.re_save_dir_ellip,
                          name='ellip_' + self.filename + '.jpg')
     if cfgs.heatmap:
         heat_map = density_heatmap(self.vis_pred_prob[:, :, 1])
         utils.save_image(heat_map,
                          self.re_save_dir_heat,
                          name='heat_' + self.filename + '.jpg')
     if cfgs.trans_heat and cfgs.heatmap:
         trans_heat_map = translucent_heatmap(
             self.vis_image.copy(),
             heat_map.astype(np.uint8).copy())
         utils.save_image(trans_heat_map,
                          self.re_save_dir_transheat,
                          name='trans_heat_' + self.filenaem + '.jpg')
コード例 #2
0
ファイル: CaculateAccurary.py プロジェクト: gepu0221/FCN
def caculate_ellip_accu_once(im,
                             filename,
                             pred,
                             pred_pro,
                             gt_ellip,
                             if_valid=False):
    #gt_ellipse [(x,y), w, h]
    pts = []
    _, p, hierarchy = cv2.findContours(pred, cv2.RETR_TREE,
                                       cv2.CHAIN_APPROX_SIMPLE)

    for i in range(len(p)):
        for j in range(len(p[i])):
            pts.append(p[i][j])
    pts_ = np.array(pts)
    if pts_.shape[0] > 5:
        ellipse_info = cv2.fitEllipse(pts_)
        pred_ellip = np.array([
            ellipse_info[0][0], ellipse_info[0][1], ellipse_info[1][0],
            ellipse_info[1][1]
        ])
        ellipse_info = (tuple(
            np.array([ellipse_info[0][0], ellipse_info[0][1]])),
                        tuple(
                            np.array([ellipse_info[1][0],
                                      ellipse_info[1][1]])), 0)
    else:
        pred_ellip = np.array([0, 0, 0, 0])
        ellipse_info = (tuple(np.array([0, 0])), tuple(np.array([0, 0])), 0)

    sz_ = im.shape
    loss = np.sum(np.power(
        (np.array(gt_ellip) - pred_ellip), 2)) / (sz_[0] * sz_[1])

    #save worse result
    if if_valid:
        error_path = cfgs.error_path + '_valid'
    else:
        error_path = cfgs.error_path

    if loss > cfgs.loss_thresh:
        im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
        cv2.ellipse(im, ellipse_info, (0, 255, 0), 1)
        gt_ellip_info = (tuple(np.array([gt_ellip[0], gt_ellip[1]])),
                         tuple(np.array([gt_ellip[2], gt_ellip[3]])), 0)
        cv2.ellipse(im, gt_ellip_info, (0, 0, 255), 1)
        path_ = os.path.join(
            error_path,
            filename.strip().decode('utf-8') + '_' + str(int(loss)) + '.bmp')
        cv2.imwrite(path_, im)

        #heatmap
        heat_map = density_heatmap(pred_pro[:, :, 1])
        cv2.imwrite(
            os.path.join(error_path,
                         filename.strip().decode('utf-8') + '_heatseq_.bmp'),
            heat_map)

    return loss
コード例 #3
0
    def view_one_valid(self, fn, pred_anno, pred_pro, im, step):
        path_ = os.path.join(cfgs.view_path, 'valid')
        if cfgs.test_view:
            filename = fn.strip().decode('utf-8')
        else:
            filename = str(step)+'_'+fn.strip().decode('utf-8')
        pred_anno_im = (pred_anno*127).astype(np.uint8)
        cv2.imwrite(os.path.join(path_, filename+'_anno.bmp'), pred_anno_im)
        cv2.imwrite(os.path.join(path_, filename+'_im.bmp'), im[:,:,0])
        heatmap1 = density_heatmap(pred_pro[:,:,1])
        cv2.imwrite(os.path.join(path_, filename+'_heat1.bmp'), heatmap1)
        heatmap2 = density_heatmap(pred_pro[:,:,2])
        cv2.imwrite(os.path.join(path_, filename+'_heat2.bmp'), heatmap2)

        if cfgs.view_seq:
            for i in range(cfgs.seq_num):
                im_ = im[:,:,self.cur_channel-1+i]
                cv2.imwrite(os.path.join(path_, filename+'seq_'+str(i+1)+'.bmp'), im_)
コード例 #4
0
ファイル: train_fusion_sub.py プロジェクト: gepu0221/FCN
    def view_one_valid(self, fn, pred_anno, pred_pro, sm_pred_pro,
                       soft_pred_pro, im, step):
        path_ = os.path.join('image', cfgs.view_path, 'valid')
        filename = str(step) + '_' + fn.strip().decode('utf-8')
        pred_anno_im = (pred_anno * 255).astype(np.uint8)
        cv2.imwrite(os.path.join(path_, filename + '_anno.bmp'), pred_anno_im)
        cv2.imwrite(os.path.join(path_, filename + '_im.bmp'), im[:, :, 0])
        heatmap = density_heatmap(pred_pro[:, :, 1])
        cv2.imwrite(os.path.join(path_, filename + '_heat.bmp'), heatmap)
        heatmap_sm = density_heatmap(sm_pred_pro[:, :, 1])
        cv2.imwrite(os.path.join(path_, filename + '_sm_heat.bmp'), heatmap_sm)
        heatmap_soft = density_heatmap(soft_pred_pro[:, :, 1])
        cv2.imwrite(os.path.join(path_, filename + '_soft_heat.bmp'),
                    heatmap_soft)

        if cfgs.view_seq:
            for i in range(cfgs.seq_num):
                im_ = im[:, :, self.cur_channel - 1 + i]
                cv2.imwrite(
                    os.path.join(path_,
                                 filename + 'seq_' + str(i + 1) + '.bmp'), im_)
コード例 #5
0
    def im_mask_view_one(self, fn, pred_anno, pred_pro, im, step):
        path_ = os.path.join(cfgs.view_path, 'train')

        if cfgs.test_view:
            filename = fn.strip().decode('utf-8')
        else:
            filename = str(step) + '_' + fn.strip().decode('utf-8')

        pred_anno_im = cv2.cvtColor((pred_anno).astype(np.uint8),
                                    cv2.COLOR_RGB2BGR)
        cv2.imwrite(os.path.join(path_, filename + '.bmp'), pred_anno_im)
        #cv2.imwrite(os.path.join(path_, filename+'_im.bmp'), im[:,:,0])
        heatmap = density_heatmap(pred_pro[:, :, 1])
        cv2.imwrite(os.path.join(path_, filename + '_heat.bmp'), heatmap)
        if cfgs.view_seq:
            for i in range(cfgs.seq_num):
                im_ = im[:, :, self.cur_channel - 1 + i]
                cv2.imwrite(
                    os.path.join(path_,
                                 filename + 'seq_' + str(i + 1) + '.bmp'), im_)
コード例 #6
0
ファイル: FCN_pred_soft.py プロジェクト: gepu0221/FCN
def main(argv=None):
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image = tf.placeholder(tf.float32,
                           shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3],
                           name="input_image")
    annotation = tf.placeholder(tf.int32,
                                shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1],
                                name="annotation")
    soft_annotation = tf.placeholder(tf.float32,
                                     shape=[None, IMAGE_SIZE, IMAGE_SIZE, 2],
                                     name="soft_annotation")
    pred_annotation_value, pred_annotation, logits, pred_prob = inference(
        image, keep_probability)
    tf.summary.image("input_image", image, max_outputs=2)
    tf.summary.image("ground_truth",
                     tf.cast(annotation, tf.uint8),
                     max_outputs=2)
    tf.summary.image("pred_annotation",
                     tf.cast(pred_annotation, tf.uint8),
                     max_outputs=2)
    #logits:the last layer of conv net
    #labels:the ground truth
    loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits,
        labels=tf.squeeze(annotation, squeeze_dims=[3]),
        name="entropy")))
    tf.summary.scalar("entropy", loss)

    trainable_var = tf.trainable_variables()
    if FLAGS.debug:
        for var in trainable_var:
            utils.add_to_regularization_and_summary(var)
    train_op = train(loss, trainable_var)

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()

    #Create a file to write logs.
    #filename='logs'+ FLAGS.mode + str(datetime.datetime.now()) + '.txt'
    filename = "logs_%s%s.txt" % (FLAGS.mode, datetime.datetime.now())
    path_ = os.path.join(FLAGS.logs_dir, filename)
    logs_file = open(path_, 'w')
    logs_file.write("The logs file is created at %s\n" %
                    datetime.datetime.now())
    logs_file.write("The mode is %s\n" % (FLAGS.mode))
    logs_file.write(
        "The train data batch size is %d and the validation batch size is %d.\n"
        % (FLAGS.batch_size, FLAGS.v_batch_size))
    logs_file.write("The train data is %s.\n" % (FLAGS.data_dir))
    logs_file.write("The data size is %d and the MAX_ITERATION is %d.\n" %
                    (IMAGE_SIZE, MAX_ITERATION))
    logs_file.write("The model is ---%s---.\n" % FLAGS.logs_dir)

    print("Setting up image reader...")
    logs_file.write("Setting up image reader...\n")
    train_records, valid_records = scene_parsing.my_read_dataset(
        FLAGS.data_dir)
    print('number of train_records', len(train_records))
    print('number of valid_records', len(valid_records))
    logs_file.write('number of train_records %d\n' % len(train_records))
    logs_file.write('number of valid_records %d\n' % len(valid_records))

    print("Setting up dataset reader")
    image_options = {'resize': True, 'resize_size': IMAGE_SIZE}
    if FLAGS.mode == 'check_training':
        train_dataset_reader = dataset_soft.BatchDatset(
            train_records, image_options)
    validation_dataset_reader = dataset.BatchDatset(valid_records,
                                                    image_options)

    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
    #if not train,restore the model trained before
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    if FLAGS.mode == "accurary":
        count = 0
        if_con = True
        accu_iou_t = 0
        accu_pixel_t = 0

        while if_con:
            count = count + 1
            valid_images, valid_annotations, valid_filenames, if_con, start, end = validation_dataset_reader.next_batch_valid(
                FLAGS.v_batch_size)
            valid_loss, pred_anno = sess.run(
                [loss, pred_annotation],
                feed_dict={
                    image: valid_images,
                    annotation: valid_annotations,
                    keep_probability: 1.0
                })
            accu_iou, accu_pixel = accu.caculate_accurary(
                pred_anno, valid_annotations)
            print("Ture %d ---> the data from %d to %d" % (count, start, end))
            print("%s ---> Validation_pixel_accuary: %g" %
                  (datetime.datetime.now(), accu_pixel))
            print("%s ---> Validation_iou_accuary: %g" %
                  (datetime.datetime.now(), accu_iou))
            #Output logs.
            logs_file.write("Ture %d ---> the data from %d to %d\n" %
                            (count, start, end))
            logs_file.write("%s ---> Validation_pixel_accuary: %g\n" %
                            (datetime.datetime.now(), accu_pixel))
            logs_file.write("%s ---> Validation_iou_accuary: %g\n" %
                            (datetime.datetime.now(), accu_iou))

            accu_iou_t = accu_iou_t + accu_iou
            accu_pixel_t = accu_pixel_t + accu_pixel
        print("%s ---> Total validation_pixel_accuary: %g" %
              (datetime.datetime.now(), accu_pixel_t / count))
        print("%s ---> Total validation_iou_accuary: %g" %
              (datetime.datetime.now(), accu_iou_t / count))
        #Output logs
        logs_file.write("%s ---> Total validation_pixel_accurary: %g\n" %
                        (datetime.datetime.now(), accu_pixel_t / count))
        logs_file.write("%s ---> Total validation_iou_accurary: %g\n" %
                        (datetime.datetime.now(), accu_iou_t / count))

    elif FLAGS.mode == "all_visualize":

        re_save_dir = "%s%s" % (FLAGS.result_dir, datetime.datetime.now())
        logs_file.write("The result is save at file'%s'.\n" % re_save_dir)
        logs_file.write("The number of part visualization is %d.\n" %
                        FLAGS.v_batch_size)

        #Check the result path if exists.
        if not os.path.exists(re_save_dir):
            print("The path '%s' is not found." % re_save_dir)
            print("Create now ...")
            os.makedirs(re_save_dir)
            print("Create '%s' successfully." % re_save_dir)
            logs_file.write("Create '%s' successfully.\n" % re_save_dir)
        re_save_dir_anno = os.path.join(re_save_dir, 'anno')
        re_save_dir_pred = os.path.join(re_save_dir, 'pred')
        re_save_dir_train_heat = os.path.join(re_save_dir, 'train_heat')
        re_save_dir_heat = os.path.join(re_save_dir, 'heatmap')
        re_save_dir_ellip = os.path.join(re_save_dir, 'ellip')
        re_save_dir_transheat = os.path.join(re_save_dir, 'transheat')
        if not os.path.exists(re_save_dir_anno):
            os.makedirs(re_save_dir_anno)
        if not os.path.exists(re_save_dir_pred):
            os.makedirs(re_save_dir_pred)
        if not os.path.exists(re_save_dir_train_heat):
            os.makedirs(re_save_dir_train_heat)
        if not os.path.exists(re_save_dir_heat):
            os.makedirs(re_save_dir_heat)
        if not os.path.exists(re_save_dir_ellip):
            os.makedirs(re_save_dir_ellip)
        if not os.path.exists(re_save_dir_transheat):
            os.makedirs(re_save_dir_transheat)
        count = 0
        if_con = True
        accu_iou_t = 0
        accu_pixel_t = 0

        while if_con:
            count = count + 1
            valid_images, valid_annotations, valid_filename, if_con, start, end = validation_dataset_reader.next_batch_valid(
                FLAGS.v_batch_size)
            pred_value, pred, logits_, pred_prob_ = sess.run(
                [pred_annotation_value, pred_annotation, logits, pred_prob],
                feed_dict={
                    image: valid_images,
                    annotation: valid_annotations,
                    keep_probability: 1.0
                })
            #valid_annotations = np.squeeze(valid_annotations, axis=3)
            pred = np.squeeze(pred, axis=3)
            pred_value = np.squeeze(pred_value, axis=3)
            pred_ellip = np.argmax(pred_prob_ + [0.85, 0], axis=3)
            #label_predict_pixel
            for itr in range(len(pred)):
                filename = valid_filename[itr]['filename']
                if FLAGS.anno == 'T':
                    valid_images_anno = anno_visualize(
                        valid_images[itr].copy(), valid_annotations[itr, :, :,
                                                                    1])
                    utils.save_image(valid_images_anno.astype(np.uint8),
                                     re_save_dir_anno,
                                     name="anno_" + filename)
                if FLAGS.pred == 'T':
                    valid_images_pred = pred_visualize(
                        valid_images[itr].copy(),
                        np.expand_dims(pred_ellip[itr], axis=2))
                    utils.save_image(valid_images_pred.astype(np.uint8),
                                     re_save_dir_pred,
                                     name="pred_" + filename)
                if FLAGS.train_heat == 'T':
                    heat_map = density_heatmap(
                        valid_annotations[itr, :, :, 1] / FLAGS.normal)
                    utils.save_image(heat_map.astype(np.uint8),
                                     re_save_dir_train_heat,
                                     name="trainheat_" + filename)

                if FLAGS.fit_ellip == 'T':
                    valid_images_ellip = fit_ellipse_findContours(
                        valid_images[itr].copy(),
                        np.expand_dims(pred_ellip[itr],
                                       axis=2).astype(np.uint8))
                    utils.save_image(valid_images_ellip.astype(np.uint8),
                                     re_save_dir_ellip,
                                     name="ellip_" + filename)
                if FLAGS.heatmap == 'T':
                    heat_map = density_heatmap(pred_prob_[itr, :, :, 1])
                    utils.save_image(heat_map.astype(np.uint8),
                                     re_save_dir_heat,
                                     name="heat_" + filename)
                if FLAGS.trans_heat == 'T':
                    trans_heat_map = translucent_heatmap(
                        valid_images[itr],
                        heat_map.astype(np.uint8).copy())
                    utils.save_image(trans_heat_map,
                                     re_save_dir_transheat,
                                     name="trans_heat_" + filename)

    elif FLAGS.mode == 'check_training':
        re_save_dir = "%s%s" % (FLAGS.result_dir, datetime.datetime.now())
        logs_file.write("The result is save at file'%s'.\n" % re_save_dir)
        logs_file.write("The number of part visualization is %d.\n" %
                        FLAGS.v_batch_size)

        #Check the result path if exists.
        if not os.path.exists(re_save_dir):
            print("The path '%s' is not found." % re_save_dir)
            print("Create now ...")
            os.makedirs(re_save_dir)
            print("Create '%s' successfully." % re_save_dir)
            logs_file.write("Create '%s' successfully.\n" % re_save_dir)
        re_save_dir_anno = os.path.join(re_save_dir, 'anno')
        re_save_dir_pred = os.path.join(re_save_dir, 'pred')
        re_save_dir_train_heat = os.path.join(re_save_dir, 'train_heat')
        re_save_dir_heat = os.path.join(re_save_dir, 'heatmap')
        re_save_dir_ellip = os.path.join(re_save_dir, 'ellip')
        re_save_dir_transheat = os.path.join(re_save_dir, 'transheat')
        if not os.path.exists(re_save_dir_anno):
            os.makedirs(re_save_dir_anno)
        if not os.path.exists(re_save_dir_pred):
            os.makedirs(re_save_dir_pred)
        if not os.path.exists(re_save_dir_train_heat):
            os.makedirs(re_save_dir_train_heat)
        if not os.path.exists(re_save_dir_heat):
            os.makedirs(re_save_dir_heat)
        if not os.path.exists(re_save_dir_ellip):
            os.makedirs(re_save_dir_ellip)
        if not os.path.exists(re_save_dir_transheat):
            os.makedirs(re_save_dir_transheat)
        count = 0
        if_con = True
        accu_iou_t = 0
        accu_pixel_t = 0

        while if_con:
            count = count + 1
            valid_images, valid_annotations, valid_filename, if_con, start, end = train_dataset_reader.next_batch_valid(
                FLAGS.v_batch_size)
            pred_value, pred, logits_, pred_prob_ = sess.run(
                [pred_annotation_value, pred_annotation, logits, pred_prob],
                feed_dict={
                    image: valid_images,
                    soft_annotation: valid_annotations,
                    keep_probability: 1.0
                })
            #valid_annotations = np.squeeze(valid_annotations, axis=3)
            pred = np.squeeze(pred, axis=3)
            pred_value = np.squeeze(pred_value, axis=3)

            #label_predict_pixel
            for itr in range(len(pred)):
                filename = valid_filename[itr]['filename']
                if FLAGS.anno == 'T':
                    valid_images_anno = anno_visualize(
                        valid_images[itr].copy(), valid_annotations[itr, :, :,
                                                                    1])
                    utils.save_image(valid_images_anno.astype(np.uint8),
                                     re_save_dir_anno,
                                     name="anno_" + filename)
                if FLAGS.pred == 'T':
                    valid_images_pred = soft_pred_visualize(
                        valid_images[itr].copy(), pred[itr])
                    utils.save_image(valid_images_pred.astype(np.uint8),
                                     re_save_dir_pred,
                                     name="pred_" + filename)
                if FLAGS.train_heat == 'T':
                    heat_map = density_heatmap(
                        valid_annotations[itr, :, :, 1] / FLAGS.normal)
                    utils.save_image(heat_map.astype(np.uint8),
                                     re_save_dir_train_heat,
                                     name="trainheat_" + filename)

                if FLAGS.fit_ellip == 'T':
                    valid_images_ellip = fit_ellipse_findContours(
                        valid_images[itr].copy(),
                        np.expand_dims(pred[itr], axis=2).astype(np.uint8))
                    utils.save_image(valid_images_ellip.astype(np.uint8),
                                     re_save_dir_ellip,
                                     name="ellip_" + filename)
                if FLAGS.heatmap == 'T':
                    heat_map = density_heatmap(pred_prob_[itr, :, :, 1])
                    utils.save_image(heat_map.astype(np.uint8),
                                     re_save_dir_heat,
                                     name="heat_" + filename)
                if FLAGS.trans_heat == 'T':
                    trans_heat_map = translucent_heatmap(
                        valid_images[itr],
                        heat_map.astype(np.uint8).copy())
                    utils.save_image(trans_heat_map,
                                     re_save_dir_transheat,
                                     name="trans_heat_" + filename)

    logs_file.close()
    if FLAGS.mode == "check_training" or FLAGS.mode == "all_visualize":
        result_logs_file = os.path.join(re_save_dir, filename)
        shutil.copyfile(path_, result_logs_file)
コード例 #7
0
ファイル: CaculateAccurary_filp.py プロジェクト: gepu0221/FCN
    def caculate_ellip_accu_once(self,
                                 im,
                                 filename,
                                 pred,
                                 pred_pro,
                                 gt_ellip,
                                 if_valid=False):
        #gt_ellipse [(x,y), w, h]
        fn = filename.strip().decode('utf-8')
        pts = []
        _, p, hierarchy = cv2.findContours(pred, cv2.RETR_TREE,
                                           cv2.CHAIN_APPROX_SIMPLE)
        #tmp
        c_im_show = np.zeros((sz[0], sz[1], 3))
        cv2.drawContours(c_im_show, p, -1, (0, 255, 0), 1)

        for i in range(len(p)):
            for j in range(len(p[i])):
                pts.append(p[i][j])
        pts_ = np.array(pts)
        if pts_.shape[0] > 5:
            ellipse_info = cv2.fitEllipse(pts_)
            pred_ellip = np.array([
                ellipse_info[0][0], ellipse_info[0][1], ellipse_info[1][0],
                ellipse_info[1][1]
            ])
            ellipse_info = (tuple(
                np.array([
                    ellipse_info[0][0], ellipse_info[0][1]
                ])), tuple(np.array([ellipse_info[1][0],
                                     ellipse_info[1][1]])), 0)
            loss = self.haus_loss(pred_ellip, gt_ellip)
        else:
            pred_ellip = np.array([0, 0, 0, 0])
            ellipse_info = (tuple(np.array([0, 0])), tuple(np.array([0,
                                                                     0])), 0)
            loss = self.ellip_loss(pred_ellip, gt_ellip)

        #save worse result
        if if_valid:
            error_path = cfgs.error_path + '_valid'
        else:
            error_path = cfgs.error_path

        if fn in self.shelter_map:
            im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
            cv2.ellipse(im, ellipse_info, (0, 255, 0), 1)
            gt_ellip_info = (tuple(np.array([gt_ellip[0], gt_ellip[1]])),
                             tuple(np.array([gt_ellip[2], gt_ellip[3]])), 0)
            cv2.ellipse(im, gt_ellip_info, (0, 0, 255), 1)
            path_ = os.path.join(
                error_path,
                filename.strip().decode('utf-8') + '_' + str(int(loss)) +
                '.bmp')
            cv2.imwrite(path_, im)

            #heatmap
            heat_map = density_heatmap(pred_pro[:, :, 1])
            cv2.imwrite(
                os.path.join(
                    error_path,
                    filename.strip().decode('utf-8') + '_heatseq_.bmp'),
                heat_map)

            #tmp
            contours_path_ = os.path.join(
                error_path,
                filename.strip().decode('utf-8') + '_' + str(int(loss)) +
                '_contour.bmp')
            cv2.imwrite(contours_path_, c_im_show)

        return loss
コード例 #8
0
ファイル: train_seq_new.py プロジェクト: gepu0221/FCN
class FCNNet(object):

    def __init__(self, mode, max_epochs, batch_size, n_classes, train_records, valid_records, im_sz, init_lr, keep_prob):
        self.max_epochs = max_epochs
        self.keep_prob = keep_prob
        self.batch_size = batch_size
        self.NUM_OF_CLASSESS = n_classes
        self.IMAGE_SIZE = im_sz
        self.graph = tf.get_default_graph()
        self.lr = tf.placeholder(dtype=tf.float32, name='learning_rate')
        self.learning_rate = float(init_lr)
        self.mode = mode
        self.logs_dir = cfgs.logs_dir
        self.current_itr_var = tf.Variable(0, dtype=tf.int32, name='current_itr', trainable=False)
        self.cur_epoch = tf.Variable(1, dtype=tf.int32, name='cur_epoch', trainable=False)
        self.seq_num = cfgs.seq_num
        
        self.train_records = train_records
        self.valid_records = valid_records
        
        vis = True if self.mode == 'all_visualize' else False
        self.image_options = {'resize': True, 'resize_size': IMAGE_SIZE, 'visualize': vis, 'annotation':True}

        self.images = tf.placeholder(tf.float32, shape=[None, self.IMAGE_SIZE, self.IMAGE_SIZE, self.seq_num+3], name="input_image")
        self.annotations = tf.placeholder(tf.int32, shape=[None, self.IMAGE_SIZE, self.IMAGE_SIZE, 1], name="annotation")


        if self.mode == 'visualize' or 'video_vis':
            self.result_dir = cfgs.result_dir
        self.at = cfgs.at
        self.gamma = cfgs.gamma

        
    #1. get data
    def get_data(self):
        with tf.device('/cpu:0'):
            if self.mode == 'train':
                self.train_dataset_reader = dataset.BatchDatset(self.train_records, self.image_options)
            self.validation_dataset_reader = dataset.BatchDatset(self.valid_records, self.image_options)

    def get_data_vis(self):
        with tf.device('/cpu:0'):
           self.vis_dataset_reader = dataset.BatchDatset(self.valid_records, self.image_options)

    def get_data_video(self):
        with tf.device('/cpu:0'):
           self.video_dataset_reader = dataset.BatchDatset(self.valid_records, self.image_options)



    #2. net
    def vgg_net(self, weights, image):
        layers = (
            #'conv1_1', 'relu1_1',
            'conv1_2', 'relu1_2', 'pool1',

            'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',

            'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
            'relu3_3', 'conv3_4', 'relu3_4', 'pool3',

            'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
            'relu4_3', 'conv4_4', 'relu4_4', 'pool4',

            'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
            'relu5_3', 'conv5_4', 'relu5_4'
        )

        net = {}
        current = image
        for i, name in enumerate(layers):
            kind = name[:4]
            if kind == 'conv':
                kernels, bias = weights[i+2][0][0][0][0]
                # matconvnet: weights are [width, height, in_channels, out_channels]
                # tensorflow: weights are [height, width, in_channels, out_channels]
                kernels = utils.get_variable(np.transpose(kernels, (1, 0, 2, 3)), name=name + "_w")
                bias = utils.get_variable(bias.reshape(-1), name=name + "_b")
                current = utils.conv2d_basic(current, kernels, bias)
            elif kind == 'relu':
                current = tf.nn.relu(current, name=name)
                if cfgs.debug:
                    utils.add_activation_summary(current)
            elif kind == 'pool':
                current = utils.avg_pool_2x2(current)
            net[name] = current

        return net


    def inference(self, image, keep_prob):
        """
        Semantic segmentation network definition
        :param image: input image. Should have values in range 0-255
        :param keep_prob:
        :return:
        """
        print("setting up vgg initialized conv layers ...")
        model_data = utils.get_model_data(cfgs.model_dir, MODEL_URL)

        mean = model_data['normalization'][0][0][0]
        mean_pixel = np.mean(mean, axis=(0, 1))

        weights = np.squeeze(model_data['layers'])

        #processed_image = utils.process_image(image, mean_pixel)

        with tf.variable_scope("inference"):
            W1 = utils.weight_variable([3, 3, 7, 64], name="W1")
            b1 = utils.bias_variable([64], name="b1")
            conv1 = utils.conv2d_basic(image, W1, b1)
            relu1 = tf.nn.relu(conv1, name='relu1')
            
            #pretrain
            image_net = self.vgg_net(weights, relu1)
            
            conv_final_layer = image_net["conv5_3"]

            pool5 = utils.max_pool_2x2(conv_final_layer)

            W6 = utils.weight_variable([7, 7, 512, 4096], name="W6")
            b6 = utils.bias_variable([4096], name="b6")
            conv6 = utils.conv2d_basic(pool5, W6, b6)
            relu6 = tf.nn.relu(conv6, name="relu6")
            if cfgs.debug:
                utils.add_activation_summary(relu6)
            relu_dropout6 = tf.nn.dropout(relu6, keep_prob=keep_prob)

            W7 = utils.weight_variable([1, 1, 4096, 4096], name="W7")
            b7 = utils.bias_variable([4096], name="b7")
            conv7 = utils.conv2d_basic(relu_dropout6, W7, b7)
            relu7 = tf.nn.relu(conv7, name="relu7")
            if cfgs.debug:
                utils.add_activation_summary(relu7)
            relu_dropout7 = tf.nn.dropout(relu7, keep_prob=keep_prob)

            W8 = utils.weight_variable([1, 1, 4096, NUM_OF_CLASSESS], name="W8")
            b8 = utils.bias_variable([NUM_OF_CLASSESS], name="b8")
            conv8 = utils.conv2d_basic(relu_dropout7, W8, b8)
            # annotation_pred1 = tf.argmax(conv8, dimension=3, name="prediction1")

            # now to upscale to actual image size
            deconv_shape1 = image_net["pool4"].get_shape()
            W_t1 = utils.weight_variable([4, 4, deconv_shape1[3].value, NUM_OF_CLASSESS], name="W_t1")
            b_t1 = utils.bias_variable([deconv_shape1[3].value], name="b_t1")
            conv_t1 = utils.conv2d_transpose_strided(conv8, W_t1, b_t1, output_shape=tf.shape(image_net["pool4"]))
            fuse_1 = tf.add(conv_t1, image_net["pool4"], name="fuse_1")

            deconv_shape2 = image_net["pool3"].get_shape()
            W_t2 = utils.weight_variable([4, 4, deconv_shape2[3].value, deconv_shape1[3].value], name="W_t2")
            b_t2 = utils.bias_variable([deconv_shape2[3].value], name="b_t2")
            conv_t2 = utils.conv2d_transpose_strided(fuse_1, W_t2, b_t2, output_shape=tf.shape(image_net["pool3"]))
            fuse_2 = tf.add(conv_t2, image_net["pool3"], name="fuse_2")

            shape = tf.shape(image)
            deconv_shape3 = tf.stack([shape[0], shape[1], shape[2], NUM_OF_CLASSESS])
            W_t3 = utils.weight_variable([16, 16, NUM_OF_CLASSESS, deconv_shape2[3].value], name="W_t3")
            b_t3 = utils.bias_variable([NUM_OF_CLASSESS], name="b_t3")
            conv_t3 = utils.conv2d_transpose_strided(fuse_2, W_t3, b_t3, output_shape=deconv_shape3, stride=8)

            annotation_pred_value = tf.cast(tf.subtract(tf.reduce_max(conv_t3,3),tf.reduce_min(conv_t3,3)),tf.int32)
            #annotation_pred_value = tf.argmax(conv_t3, dimension=3, name="prediction")
            annotation_pred = tf.argmax(conv_t3, dimension=3, name="prediction")

        self.pred_annotation = tf.expand_dims(annotation_pred, dim=3)
        self.logits = conv_t3



    #3. optmizer
    def train_optimizer(self):
        optimizer = tf.train.AdamOptimizer(self.lr)
        self.var_list = tf.trainable_variables()
        #self.train_op = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss)
        self.grads = optimizer.compute_gradients(self.loss, var_list=self.var_list)
        if cfgs.debug:
            # print(len(var_list))
            for grad, var in self.grads:
                utils.add_gradient_summary(grad, var)

        self.train_op = optimizer.apply_gradients(self.grads)

    
    #4. loss
    def loss(self):
        self.loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits,
                                                                          labels=tf.squeeze(self.annotations, squeeze_dims=[3]),
                                                                          name="entropy")))
        
        #self.pred_annotation = tf.expand_dims(tf.argmax(self.pro, dimension=3, name='pred'), dim=3)

        #focal loss
        a_w = (1 - 2*self.at) * tf.cast(tf.squeeze(self.annotations, squeeze_dims=[3]), tf.float32) + self.at
        self.pro = tf.nn.softmax(self.logits)
      
        loss_weight = tf.pow(1-tf.reduce_sum(self.pro * tf.one_hot(tf.squeeze(self.annotations, squeeze_dims=[3]), self.NUM_OF_CLASSESS), 3), self.gamma)
     
    
        self.focal_loss = tf.reduce_mean(loss_weight * a_w * tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits,
                                                                                       labels=tf.squeeze(self.annotations, squeeze_dims=[3]),
                                                                                       name="entropy"))

        
    #5. evaluation
    def calculate_acc(self, pred_anno, anno):
        with tf.name_scope('accu'):
            self.accu_iou, self.accu = accu.caculate_accurary(pred_anno, anno)
    
    #6. summary
    def summary(self):
        with tf.name_scope('summary'):
            tf.summary.scalar('train_loss', self.loss)
            #tf.summary.scalar('accu', self.accu)
            #tf.summary.scalar('iou_accu', self.accu_iou)
            tf.summary.scalar('learning_rate', self.lr)
            self.summary_op = tf.summary.merge_all()
    
    #7. graph build
    def build(self):
        #bulid the graph
        self.inference(self.images, self.keep_prob)
        self.loss()
        self.train_optimizer()
        self.summary()

    def build_vis(self):
        #bulid the visualize graph
        self.get_data_vis()
        self.inference(self.images, self.keep_prob)
        self.loss()

    def build_video(self):
        #build the video graph
        self.get_data_video()
        self.inference(self.images, self.keep_prob)
        self.loss()
    


    #8. update lr
    def try_update_lr(self):
        try:
            with open(cfgs.learning_rate_path) as f:
                lr_ = float(f.readline().split('\n')[0])
                if self.learning_rate != lr_:
                    self.learning_rate = lr_
                    print('learning rate change from to %g' % self.learning_rate)
        except:
            pass

    #9. Model recover
    def recover_model(self, sess):
        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state(self.logs_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            print('Model restore finished')
        
        return saver
    
    #10. train or valid once 
    def create_re_dir(self):
        self.re_save_dir="%s%s" % (self.result_dir, datetime.datetime.now())

        self.re_save_dir_im = os.path.join(self.re_save_dir, 'images')
        self.re_save_dir_heat = os.path.join(self.re_save_dir, 'heatmap')
        self.re_save_dir_ellip = os.path.join(self.re_save_dir, 'ellip')
        self.re_save_dir_transheat = os.path.join(self.re_save_dir, 'transheat')
        if not os.path.exists(self.re_save_dir_im):
            os.makedirs(self.re_save_dir_im)
        if not os.path.exists(self.re_save_dir_heat):
            os.makedirs(self.re_save_dir_heat)
        if not os.path.exists(self.re_save_dir_ellip):
            os.makedirs(self.re_save_dir_ellip)
        if not os.path.exists(self.re_save_dir_transheat):
            os.makedirs(self.re_save_dir_transheat)

    def vis_one_im(self):
        if cfgs.anno:
            im_ = pred_visualize(self.vis_image.copy(), self.vis_anno).astype(np.uint8)
           #im_ = pred_visualize(self.vis_image.copy(), self.vis_pred).astype(np.uint8)
           utils.save_image(im_, self.re_save_dir_im, name='inp_' + self.filename + '.jpg')
       if cfgs.fit_ellip:
           im_ellip = fit_ellipse_findContours(self.vis_image.copy(), np.expand_dims(self.vis_pred, axis=2).astype(np.uint8))
           utils.save_image(im_ellip, self.re_save_dir_ellip, name='ellip_' + self.filename + '.jpg')
       if cfgs.heatmap:
           heat_map = density_heatmap(self.vis_pred_prob[:, :, 1])
           utils.save_image(heat_map, self.re_save_dir_heat, name='heat_' + self.filename + '.jpg')
       if cfgs.trans_heat and cfgs.heatmap:
           trans_heat_map = translucent_heatmap(self.vis_image.copy(), heat_map.astype(np.uint8).copy())
           utils.save_image(trans_heat_map, self.re_save_dir_transheat, name='trans_heat_' + self.filenaem + '.jpg')

    #Visualize the result
    def visualize(self, sess):
        
        self.create_re_dir()

        count = 0
        t0 = time.time()

        try:
            total_loss = 0

            if_continue = True
            while if_continue:
                count += 1
                images_, annos_, cur_ims_, filenames_, if_continue, _, _ = self.vis_dataset_reader.next_batch_val_vis(self.batch_size)
                pred_anno, pred_prob = sess.run([self.pred_annotation, self.pro], feed_dict={self.images:images_})
                pred_anno = np.squeeze(pred_anno, axis=3)
                
                for i in range(len(pred_anno)):
                    self.filename = filenames_[i].strip().decode('utf-8')
                    self.vis_image = cur_ims_[i]
                    self.vis_anno = annos_[i]
                    self.vis_pred = pred_anno[i]
                    self.vis_pred_prob = pred_prob[i]
                    #print(self.vis_pred_prob)
                    self.vis_one_im()
        except tf.errors.OutOfRangeError:
            pass

                    
    #Visualize the video result
    def vis_video(self, sess):
        sess.run(self.video_init)
        
        self.create_re_dir()

        count = 0
        t0 = time.time()

        try:
            total_loss = 0

            if_con = True
            while if_con:
                count += 1
                images_, cur_ims, filenames_, if_con, _, _  = self.video_dataset_reader.next_batch_video_valid(self.batch_size)
                pred_anno, pred_prob = sess.run([self.pred_annotation, self.logits], feed_dict={self.images:images_})
                pred_anno = np.squeeze(pred_anno, axis=3)

                for i in range(len(pred_anno)):
                    self.filename = filenames_[i]
                    self.vis_image = cur_ims_[i]
                    self.vis_pred = pred_anno[i]
                    self.vis_pred_prob = pred_prob[i]

                    self.vis_one_im()
        except tf.errors.OutOfRangeError:
            pass



    def valid_once(self, sess, writer, epoch, step):
        
        count = 0
        sum_acc = 0
        sum_acc_iou = 0
        mean_acc = 0
        mean_acc_iou = 0
        total_loss = 0
        t0 = time.time()

        if_continue = False
        while not if_continue:
            count += 1
            images_, annos_, if_continue, _ = self.validation_dataset_reader.next_batch(self.batch_size)
            feed_dict = {self.images: images_, self.annotations: annos_, self.lr: self.learning_rate}
            loss, summary_str, pred_anno = sess.run(fetches=[self.loss, self.summary_op, self.pred_annotation], feed_dict=feed_dict)

            #2. calculate accurary
            #if count % 10 ==0:
            self.calculate_acc(pred_anno, annos_)
            sum_acc += self.accu
            sum_acc_iou += self.accu_iou
            mean_acc = sum_acc/count
            mean_acc_iou = sum_acc_iou/count
            #3. calculate loss
            total_loss += loss

            #4. time consume
            time_consumed = time.time() - t0
            time_per_batch = time_consumed/count

            #5. check if change learning rate
            if count % 100 == 0:
                self.try_update_lr()

            print('\r' + 32 * ' ', end='')
            print('epoch %5d\t learning_rate = %g\t step = %4d\t loss = %.3f\t valid_accuracy = %.2f%%\t valid_iou_accuracy = %.2f%%' % (epoch, self.learning_rate, step, (total_loss/count), (sum_acc/count), (sum_acc_iou/count)))

        count -= 1
        print('epoch %5d\t learning_rate = %g\t loss = %.3f\t valid_accuracy = %.2f%%\t valid_iou_accuracy = %.2f%%' % 
        (epoch, self.learning_rate, total_loss/count, sum_acc/count, sum_acc_iou/count))
        print('Take time %3.1f' % (time.time() - t0))
   


    def train_one_epoch(self, sess, writer, epoch, step):
        
        count = 0
        sum_acc = 0
        sum_acc_iou = 0
        mean_acc = 0
        mean_acc_iou = 0
        total_loss = 0
        t0 = time.time()

        if_continue = False
        while not if_continue:
            count += 1
            step += 1
            images_, annos_, if_continue, _ = self.train_dataset_reader.next_batch(self.batch_size)
            feed_dict = {self.images: images_, self.annotations: annos_, self.lr: self.learning_rate}
            _, loss, summary_str, pred_anno= sess.run(fetches=[self.train_op, self.loss, self.summary_op, self.pred_annotation], feed_dict=feed_dict)
            
            #2. calculate accurary
            #if count % 10 ==0:
            self.calculate_acc(pred_anno, annos_)
            sum_acc += self.accu
            sum_acc_iou += self.accu_iou
            mean_acc = sum_acc/count
            mean_acc_iou = sum_acc_iou/count

            #3. calculate loss
            total_loss += loss
            
            writer.add_summary(summary_str, global_step=step)


            #4. time consume
            time_consumed = time.time() - t0
            time_per_batch = time_consumed/count

            #5. check if change learning rate
            if count % 100 == 0:
                self.try_update_lr()


            #5. print
            print('\r' + 12 * ' ', end='')
            print('epoch %5d\t learning_rate = %g\t step = %4d\t loss = %.3f\t mean_loss=%.3f\t train_accuracy = %.2f%%\t train_iou_accuracy = %.2f%%\t time = %.2f' % (epoch, self.learning_rate, step, loss, (total_loss/count), mean_acc, mean_acc_iou, time_per_batch))

        count -= 1
        print('epoch %5d\t learning_rate = %g\t mean_loss = %.3f\t train_accuracy = %.2f%%\t train_iou_accuracy = %.2f%%' % (epoch, self.learning_rate, (total_loss/count), (sum_acc/count), (sum_acc_iou/count)))
        print('Take time %3.1f' % (time.time() - t0))
     
        return step


    def train(self):
        #Check if has the log file
        if not os.path.exists(self.logs_dir):
            print("The logs path '%s' is not found" % self.logs_dir)
            print("Create now..")
            os.makedirs(self.logs_dir)
            print("%s is created successfully!" % self.logs_dir)

        print('prepare to train...')
        writer = tf.summary.FileWriter(logdir=self.logs_dir, graph=self.graph)

        print('The graph path is %s' % self.logs_dir)
        config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)
        #config = tf.ConfigProto(log_device_placement=False, allow_soft_palcement=True)
        with tf.device('/gpu:0'):
            with tf.Session(config=config) as sess:
                #1. initialize all variables
                sess.run(tf.global_variables_initializer())

                #2. Try to recover model
                saver = self.recover_model(sess)
                step = self.current_itr_var.eval()
                cur_epoch = self.cur_epoch.eval()
                print(self.current_itr_var.eval())

                #3. start to train 
                for epoch in range(cur_epoch, self.max_epochs + 1):
                    
                    #3.1 try to change learning rate
                    self.try_update_lr()
                    if epoch != 0 and epoch % 20 == 0:
                        self.learning_rate /= 10
                        pass

                    #3.2 train one epoch
                    step = self.train_one_epoch(sess, writer, epoch, step)

                    #3.3 save model
                    self.valid_once(sess, writer, epoch, step)
                    self.cur_epoch.load(epoch, sess)
                    self.current_itr_var.load(step, sess)
                    saver.save(sess, self.logs_dir + 'model.ckpt', step)

        writer.close()

    def vis(self):
        if not os.path.exists(self.logs_dir):
            raise Exception('The logs path %s is not found!' % self.logs_dir)

        print('The logs path is %s.' % self.logs_dir)
        config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)
        with tf.device('/gpu:0'):
            with tf.Session(config=config) as sess:
                sess.run(tf.global_variables_initializer())

                saver = self.recover_model(sess)

                self.visualize(sess)
コード例 #9
0
ファイル: pred_video_ori.py プロジェクト: gepu0221/FCN
def main(argv=None):
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image = tf.placeholder(tf.float32,
                           shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3],
                           name="input_image")
    annotation = tf.placeholder(tf.int32,
                                shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1],
                                name="annotation")

    pred_annotation_value, pred_annotation, logits, pred_prob = inference(
        image, keep_probability)
    #get the softmax result
    pred_prob = tf.nn.softmax(logits)
    tf.summary.image("input_image", image, max_outputs=2)
    tf.summary.image("pred_annotation",
                     tf.cast(pred_annotation, tf.uint8),
                     max_outputs=2)

    trainable_var = tf.trainable_variables()
    if FLAGS.debug:
        for var in trainable_var:
            utils.add_to_regularization_and_summary(var)

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()

    #Create a file to write logs.
    #filename='logs'+ FLAGS.mode + str(datetime.datetime.now()) + '.txt'
    filename = "logs_%s%s.txt" % (FLAGS.mode, datetime.datetime.now())
    path_ = os.path.join(FLAGS.logs_dir, filename)
    logs_file = open(path_, 'w')
    logs_file.write("The logs file is created at %s\n" %
                    datetime.datetime.now())
    logs_file.write("The mode is %s\n" % (FLAGS.mode))
    logs_file.write(
        "The train data batch size is %d and the validation batch size is %d.\n"
        % (FLAGS.batch_size, FLAGS.v_batch_size))
    logs_file.write("The data is %s.\n" % (FLAGS.data_dir))
    logs_file.write("The data size is %d.\n" % (IMAGE_SIZE))
    logs_file.write("The model is ---%s---.\n" % FLAGS.logs_dir)

    print("Setting up image reader...")
    logs_file.write("Setting up image reader...\n")
    valid_video_records = scene_parsing.read_validation_video_data(
        FLAGS.data_dir)
    print('number of valid_records', len(valid_video_records))
    logs_file.write('number of valid_records %d\n' % len(valid_video_records))

    print("Setting up dataset reader")
    image_options = {'resize': True, 'resize_size': IMAGE_SIZE}
    validation_dataset_reader = dataset.BatchDatset(valid_video_records,
                                                    image_options)
    val_ori_dataset_reader = dataset.BatchDatset(valid_video_records)

    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
    #if not train,restore the model trained before
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    re_save_dir = "%s%s" % (FLAGS.result_dir, datetime.datetime.now())
    logs_file.write("The result is save at file'%s'.\n" % re_save_dir)
    logs_file.write("The number of part visualization is %d.\n" %
                    FLAGS.v_batch_size)

    #Check the result path if exists.
    if not os.path.exists(re_save_dir):
        print("The path '%s' is not found." % re_save_dir)
        print("Create now ...")
        os.makedirs(re_save_dir)
        print("Create '%s' successfully." % re_save_dir)
        logs_file.write("Create '%s' successfully.\n" % re_save_dir)
    logs_file.close()

    #Copy the logs to the result file
    result_logs_file = os.path.join(re_save_dir, filename)
    shutil.copyfile(path_, result_logs_file)
    re_save_dir_im = os.path.join(re_save_dir, 'images')
    re_save_dir_pred = os.path.join(re_save_dir, 'pred')
    re_save_dir_heat = os.path.join(re_save_dir, 'heatmap')
    re_save_dir_ellip = os.path.join(re_save_dir, 'ellip')
    re_save_dir_transheat = os.path.join(re_save_dir, 'transheat')
    if not os.path.exists(re_save_dir_im):
        os.makedirs(re_save_dir_im)
    if not os.path.exists(re_save_dir_pred):
        os.makedirs(re_save_dir_pred)
    if not os.path.exists(re_save_dir_heat):
        os.makedirs(re_save_dir_heat)
    if not os.path.exists(re_save_dir_ellip):
        os.makedirs(re_save_dir_ellip)
    if not os.path.exists(re_save_dir_transheat):
        os.makedirs(re_save_dir_transheat)

    count = 0
    if_con = True
    accu_iou_t = 0
    accu_pixel_t = 0

    while if_con:
        count = count + 1
        valid_images, valid_filename, if_con, start, end = validation_dataset_reader.next_batch_video_valid(
            FLAGS.v_batch_size)
        val_ori_images, _, _, _, _ = val_ori_dataset_reader.next_batch_video_valid(
            FLAGS.v_batch_size)
        pred_value, pred, logits_, pred_prob_ = sess.run(
            [pred_annotation_value, pred_annotation, logits, pred_prob],
            feed_dict={
                image: valid_images,
                keep_probability: 1.0
            })
        print("Turn %d :----start from %d ------- to %d" % (count, start, end))
        pred = np.squeeze(pred, axis=3)
        pred_value = np.squeeze(pred_value, axis=3)

        for itr in range(len(pred)):
            filename = valid_filename[itr]['filename']
            valid_images_ = pred_ori_visualize(val_ori_images[itr].copy(),
                                               pred[itr])
            utils.save_image(valid_images_.astype(np.uint8),
                             re_save_dir_im,
                             name="inp_" + filename)

            if FLAGS.pred:
                utils.save_gray_image(np.expand_dims(pred[itr],
                                                     axis=2).astype(np.uint8),
                                      re_save_dir_pred,
                                      name="pred_" + filename)
            if FLAGS.fit_ellip:
                #valid_images_ellip=fit_ellipse_findContours_ori(valid_images[itr].copy(),np.expand_dims(pred[itr],axis=2).astype(np.uint8))
                valid_images_ellip = fit_ellipse_findContours_ori(
                    val_ori_images[itr].copy(),
                    np.expand_dims(pred[itr], axis=2).astype(np.uint8))
                utils.save_image(valid_images_ellip.astype(np.uint8),
                                 re_save_dir_ellip,
                                 name="ellip_" + filename)
            if FLAGS.heatmap:
                heat_map = density_heatmap(pred_prob_[itr, :, :, 1])
                utils.save_image(heat_map.astype(np.uint8),
                                 re_save_dir_heat,
                                 name="heat_" + filename)
            if FLAGS.trans_heat:
                trans_heat_map = translucent_heatmap(
                    valid_images[itr],
                    heat_map.astype(np.uint8).copy())
                utils.save_image(trans_heat_map,
                                 re_save_dir_transheat,
                                 name="trans_heat_" + filename)
コード例 #10
0
    def train_one_epoch(self, sess, writer, epoch, step):
        sum_acc = 0
        sum_acc_iou = 0
        count = 0
        total_loss = 0
        t0 =time.time()
        mean_acc = 0
        mean_acc_iou = 0
        try:
            while count<self.per_e_train_batch:
                step += 1
                count += 1
                #1. train
                images_, cur_ims, annos_, filenames = sess.run([self.train_images, self.train_cur_ims, self.train_annotations, self.train_filenames])
                
                pred_anno, pred_cur_anno, pred_seq_pro, summary_str, loss, _ = sess.run([self.pred_annotation, self.pred_cur_anno, self.seq_pro, self.summary_op, self.loss_mask, self.train_op],
                                                            feed_dict={self.images: cur_ims, self.mask_images: images_, 
                                                                       self.annotations: annos_, self.lr: self.learning_rate})


                print(pred_seq_pro[0].shape)
                if count % 10 == 0:
                    choosen = random.randint(0, self.batch_size-1)
                    fn = filenames[choosen].strip().decode('utf-8')
                    #img = images_[0,:,:,0:3].astype(np.uint8)
                    #img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
                    #cv2.imwrite(os.path.join('image', 'im' + str(step) + '.bmp'), img)
                    #anno = annos_[0].astype(np.uint8)
                    #cv2.imwrite(os.path.join('image', 'an' + str(step) + '.bmp'), anno)
                    pred_anno_im = (pred_anno[choosen]*255).astype(np.uint8)
                    cv2.imwrite(os.path.join('image', 'mask', str(step)+'_'+fn+'.bmp'), pred_anno_im)
                    pred_cur_im = (pred_cur_anno[choosen]*255).astype(np.uint8)
                    cv2.imwrite(os.path.join('image', 'mask', str(step)+'_cur_'+fn+'.bmp'), pred_cur_im)
                    #print(pred_seq_pro[choosen, :, :, 1])
                    heat_map = density_heatmap(pred_seq_pro[choosen, :, :, 1])
                    cv2.imwrite(os.path.join('image', 'mask', str(step)+'_heatseq_'+fn+'.bmp'), heat_map)
                    img_seq = images_[choosen]
                    cv2.imwrite(os.path.join('image', 'mask', str(step)+'_seq_'+fn+'.bmp'), img_seq)
                    #pred_seq_im = (pred_seq_anno[choosen]*255).astype(np.uint8)
                    #cv2.imwrite(os.path.join('image', 'mask', str(step)+'_seq_'+fn+'.bmp'), pred_seq_im)

                #2. calculate accurary
                #if count % 10 ==0:
                self.calculate_acc(pred_anno, annos_)
                sum_acc += self.accu
                sum_acc_iou += self.accu_iou
                mean_acc = sum_acc/count
                mean_acc_iou = sum_acc_iou/count
                #3. calculate loss
                total_loss += loss

                #4. time consume
                time_consumed = time.time() - t0
                time_per_batch = time_consumed/count

                #5. check if change learning rate
                if count % 100 == 0:
                    self.try_update_lr()
                #6. summary
                writer.add_summary(summary_str, global_step=step)

                #6. print
                print('\r' + 8 * ' ', end='')
                print('epoch %5d\t lr = %g\t step = %4d\t count = %4d\t loss = %.3f\t mean_loss=%.3f\t train_accuracy = %.2f%%\t train_iou_accuracy = %.2f%%\t time = %.2f' % (epoch, self.learning_rate, step, count, loss, (total_loss/count), mean_acc, mean_acc_iou, time_per_batch))
            
            #End one epoch
            count -= 1
            print('epoch %5d\t learning_rate = %g\t mean_loss = %.3f\t train_accuracy = %.2f%%\t train_iou_accuracy = %.2f%%' % (epoch, self.learning_rate, (total_loss/count), (sum_acc/count), (sum_acc_iou/count)))
            print('Take time %3.1f' % (time.time() - t0))

        except tf.errors.OutOfRangeError:
            print('Error!')
            count -= 1
            print('epoch %5d\t learning_rate = %g\t mean_loss = %.3f\t train_accuracy = %.2f%%\t train_iou_accuracy = %.2f%%' % (epoch, self.learning_rate, (total_loss/count), (sum_acc/count), (sum_acc_iou/count)))
            print('Take time %3.1f' % (time.time() - t0))
     
        return step
コード例 #11
0
    def train_one_epoch(self, sess, writer, epoch, step):
        print('sub_train_one_epoch')
        sum_acc = 0
        sum_acc_iou = 0
        sum_acc_ellip = 0
        count = 0
        total_loss = 0
        t0 = time.time()
        mean_acc = 0
        mean_acc_iou = 0
        mean_acc_ellip = 0
        try:
            while count < self.per_e_train_batch:
                step += 1
                count += 1
                #1. train
                images_, ellip_infos_, annos_, filenames = sess.run([
                    self.train_images, self.train_ellip_infos,
                    self.train_annotations, self.train_filenames
                ])

                pred_anno, pred_pro, summary_str, loss, _ = sess.run(
                    [
                        self.pred_annotation, self.pro, self.summary_op,
                        self.loss, self.train_op
                    ],
                    feed_dict={
                        self.images: images_,
                        self.annotations: annos_,
                        self.lr: self.learning_rate
                    })

                #pred_anno = np.squeeze(pred_anno, axis=3)
                if count % 10 == 0:
                    choosen = random.randint(0, self.batch_size - 1)
                    fn = filenames[choosen].strip().decode('utf-8')
                    pred_anno_im = (pred_anno[choosen] * 255).astype(np.uint8)
                    cv2.imwrite(
                        os.path.join('image',
                                     str(step) + '_' + fn + '.bmp'),
                        pred_anno_im)
                    heat_map = density_heatmap(pred_pro[choosen, :, :, 1])
                    cv2.imwrite(
                        os.path.join('image',
                                     str(step) + '_heatseq_' + fn + '.bmp'),
                        heat_map)

                #2. calculate accurary
                self.calculate_acc(images_.copy(), filenames, pred_anno,
                                   annos_, ellip_infos_)
                sum_acc += self.accu
                sum_acc_iou += self.accu_iou
                sum_acc_ellip += self.ellip_acc
                mean_acc = sum_acc / count
                mean_acc_iou = sum_acc_iou / count
                mean_acc_ellip = sum_acc_ellip / count
                #3. calculate loss
                total_loss += loss

                #4. time consume
                time_consumed = time.time() - t0
                time_per_batch = time_consumed / count

                #5. check if change learning rate
                if count % 100 == 0:
                    self.try_update_lr()
                #6. summary
                writer.add_summary(summary_str, global_step=step)

                #6. print
                #print('\r' + 2 * ' ', end='')
                print(
                    'epoch %5d\t lr = %g\t step = %4d\t count = %4d\t loss = %.4f\t mean_loss=%.4f\t train_acc = %.2f%%\t train_iou_acc = %.2f%%\t train_ellip_acc = %.2f\t time = %.2f'
                    % (epoch, self.learning_rate, step, count, loss,
                       (total_loss / count), mean_acc, mean_acc_iou,
                       mean_acc_ellip, time_per_batch))

            #End one epoch
            #count -= 1
            print(
                'epoch %5d\t learning_rate = %g\t mean_loss = %.3f\t train_acc = %.2f%%\t train_iou_acc = %.2f%%\t train_ellip_acc = %.2f'
                % (epoch, self.learning_rate, (total_loss / count),
                   (sum_acc / count), (sum_acc_iou / count), mean_acc_ellip))
            print('Take time %3.1f' % (time.time() - t0))

        except tf.errors.OutOfRangeError:
            print('Error!')

        return step