Ejemplo n.º 1
0
def kitti_model_config():
    """Specify the parameters to tune below."""
    mc = base_model_config('KITTI')
    # mc.IMAGE_WIDTH           = 1864 # half width 621
    # mc.IMAGE_HEIGHT          = 562 # half height 187
    mc.IMAGE_WIDTH = 1248  # half width 621
    mc.IMAGE_HEIGHT = 384  # half height 187
    # mc.IMAGE_WIDTH           = 621
    # mc.IMAGE_HEIGHT          = 187

    mc.WEIGHT_DECAY = 0.0001
    mc.PROB_THRESH = 0.005
    mc.TOP_N_DETECTION = 64
    mc.PLOT_PROB_THRESH = 0.4
    mc.NMS_THRESH = 0.4
    mc.LEARNING_RATE = 0.01
    mc.MOMENTUM = 0.9
    mc.DECAY_STEPS = 10000
    mc.LR_DECAY_FACTOR = 0.5
    mc.BATCH_SIZE = 20
    mc.LOSS_COEF_BBOX = 5.0
    mc.LOSS_COEF_CONF_POS = 75.0
    mc.LOSS_COEF_CONF_NEG = 100.0
    mc.LOSS_COEF_CLASS = 1.0
    mc.MAX_GRAD_NORM = 1.0
    mc.DATA_AUGMENTATION = True
    mc.DRIFT_X = 150
    mc.DRIFT_Y = 100
    mc.ANCHOR_BOX = set_anchors(mc)
    mc.ANCHORS = len(mc.ANCHOR_BOX)
    mc.ANCHOR_PER_GRID = 9
    mc.USE_DECONV = False
    mc.EXCLUDE_HARD_EXAMPLES = False

    return mc
def coco_config():
  """Specify the parameters to tune below."""
  mc = base_model_config('COCO')

  mc.IMAGE_WIDTH = 256 #1248
  mc.IMAGE_HEIGHT = 256 #384
  mc.BATCH_SIZE = 15

  mc.WEIGHT_DECAY = 0.0001
  mc.LEARNING_RATE = 0.001
  mc.DECAY_STEPS = 10000
  mc.MAX_GRAD_NORM = 1.0
  mc.MOMENTUM = 0.9
  mc.LR_DECAY_FACTOR = 0.5

  mc.LOSS_COEF_BBOX = 5.0
  mc.LOSS_COEF_CONF_POS = 75.0
  mc.LOSS_COEF_CONF_NEG = 100.0
  mc.LOSS_COEF_CLASS = 1.0

  mc.PLOT_PROB_THRESH = 0.4
  mc.NMS_THRESH = 0.4
  mc.PROB_THRESH = 0.005
  mc.TOP_N_DETECTION = 64

  mc.DATA_AUGMENTATION = True
  mc.DRIFT_X = 150
  mc.DRIFT_Y = 100
  mc.EXCLUDE_HARD_EXAMPLES = False

  mc.ANCHOR_BOX = set_anchors(mc)
  mc.ANCHORS = len(mc.ANCHOR_BOX)
  mc.ANCHOR_PER_GRID = 9

  return mc
import logging
import os
import math

import tensorflow as tf
import numpy as np

from config.config import base_model_config
from data.kitti_raw_manager import load_raw_forward_data, get_spherical_data
from plot.plot import plot_points

cfg = base_model_config()


def main(args=None):

    if tf.gfile.Exists(cfg.log_dir):
        tf.gfile.DeleteRecursively(cfg.log_dir)
    tf.gfile.MakeDirs(cfg.log_dir)

    drives = os.listdir(cfg.basedir)
    frame = load_raw_forward_data('0001')[0]

    spherical = get_spherical_data(frame)

    plot_points(np.array(spherical))


if __name__ == '__main__':
    tf.app.run()