示例#1
0
def evaluate():
  """Evaluate."""
  os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu

  if not os.path.exists(FLAGS.eval_dir):
    os.makedirs(FLAGS.eval_dir)

  with tf.Graph().as_default() as g:

    #Select model to evaluate
    if FLAGS.net == 'squeezeDet+PruneFilter':
      mc = kitti_squeezeDetPlus_config()
      mc.BATCH_SIZE = 1
      mc.IS_TRAINING = False
      mc.IS_PRUNING = False
      mc.LOAD_PRETRAINED_MODEL = True
      mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
      model = SqueezeDetPlusPruneFilter(mc)
    elif FLAGS.net == 'squeezeDet+PruneFilterShape':
      mc = kitti_squeezeDetPlus_config()
      mc.BATCH_SIZE = 1
      mc.IS_TRAINING = False
      mc.IS_PRUNING = False
      mc.LOAD_PRETRAINED_MODEL = True
      mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
      model = SqueezeDetPlusPruneFilterShape(mc)
    elif FLAGS.net == 'squeezeDet+PruneLayer':
      mc = kitti_squeezeDetPlus_config()
      mc.BATCH_SIZE = 1
      mc.IS_TRAINING = False
      mc.IS_PRUNING = False
      mc.LOAD_PRETRAINED_MODEL = True
      mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
      model = SqueezeDetPlusPruneLayer(mc)

    imdb = kitti(FLAGS.image_set, FLAGS.data_path, mc)

    no_evaluation = True
    # Logic to load checkpoints if available
    for file in os.listdir(FLAGS.checkpoint_path):
      if file.endswith(".meta"):
        saver = tf.train.Saver(model.model_params)
        ckpt_path = FLAGS.checkpoint_path + "/" + file[:-5]
        step = file[11:-5]
        eval_once(saver, ckpt_path, imdb, model, step, True)
        no_evaluation = False

    if no_evaluation:
      saver = None
      eval_once(saver, '-', imdb, model, '0', False)
示例#2
0
def evaluate():
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu
    with tf.Graph().as_default() as g:
        mc = kitti_squeezeDetPlus_config()
        mc.BATCH_SIZE = 1  # TODO(bichen): allow batch size > 1
        mc.LOAD_PRETRAINED_MODEL = False
        model = SqueezeDetPlus(mc, state='val')
        output_node_names = ["out_det_probs", "out_det_boxes", "out_det_class"]
        det_probs = tf.identity(model.det_probs, name=output_node_names[0])
        det_boxes = tf.identity(model.det_boxes, name=output_node_names[1])
        det_class = tf.identity(model.det_class, name=output_node_names[2])
        imdb = kitti('val', FLAGS.data_path, mc)

        saver = tf.train.Saver(model.model_params)
        eval_once(saver, FLAGS.checkpoint_path, imdb, model, output_node_names)
示例#3
0
def evaluate():
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu
    with tf.Graph().as_default() as g:
        mc = kitti_squeezeDetPlus_config()
        #mc.BATCH_SIZE = 1 # TODO(bichen): allow batch size > 1
        mc.LOAD_PRETRAINED_MODEL = False
        model = SqueezeDetPlus(mc, state='val')
        imdb = kitti(FLAGS.image_set, FLAGS.data_path, mc)
        # add summary ops and placeholders
        saver = tf.train.Saver(model.model_params)
        summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)
        ckpts = set()
        while True:
            if FLAGS.run_once:
                # When run_once is true, checkpoint_path should point to the exact
                # checkpoint file.
                eval_once(saver, FLAGS.checkpoint_path, summary_writer, imdb,
                          model)
                return
            else:
                # When run_once is false, checkpoint_path should point to the directory
                # that stores checkpoint files.
                from os.path import basename
                ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
                #base_name, max_idx = ckpt.model_checkpoint_path.split('-')
                #for idx in range(0, int(max_idx), )
                if ckpt and ckpt.model_checkpoint_path:
                    if ckpt.model_checkpoint_path in ckpts:
                        # Do not evaluate on the same checkpoint
                        print(
                            'Wait {:d}s for new checkpoints to be saved ... '.
                            format(FLAGS.eval_interval_secs))
                        time.sleep(FLAGS.eval_interval_secs)
                    else:
                        ckpts.add(ckpt.model_checkpoint_path)
                        print('Evaluating {}...'.format(
                            ckpt.model_checkpoint_path))
                        eval_once(saver, ckpt.model_checkpoint_path,
                                  summary_writer, imdb, model)

                else:
                    print('No checkpoint file found')
                    if not FLAGS.run_once:
                        print(
                            'Wait {:d}s for new checkpoints to be saved ... '.
                            format(FLAGS.eval_interval_secs))
                        time.sleep(FLAGS.eval_interval_secs)
示例#4
0
文件: eval.py 项目: goan15910/ConvDet
def evaluate():
  """Evaluate."""
  assert FLAGS.net in ['vgg16', 'vgg16_v2', 'vgg16_v3', 'yolo_v2'], \
      'Selected neural net architecture not supported: {}'.format(FLAGS.net)
  assert FLAGS.dataset in ['KITTI', 'PASCAL_VOC', 'VID'], \
      'Either KITTI / PASCAL_VOC / VID'
  if FLAGS.dataset == 'KITTI':
    if FLAGS.net == 'yolo_v2':
      raise NotImplementedError
    else:
      mc = kitti_vgg16_config()
    mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
    imdb = kitti(FLAGS.train_set, FLAGS.data_path, mc)
  elif FLAGS.dataset == 'PASCAL_VOC':
    if FLAGS.net == 'yolo_v2':
      mc = pascal_voc_yolo_config()
    else:
      mc = pascal_voc_vgg16_config()
    mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
    imdb = pascal_voc(FLAGS.train_set, FLAGS.year, FLAGS.data_path, mc)
  elif FLAGS.dataset == 'VID':
    if FLAGS.net == 'yolo_v2':
      mc = vid_yolo_config()
    else:
      mc = vid_vgg16_config()
    mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
    imdb = vid(FLAGS.train_set, FLAGS.data_path, mc)

  with tf.Graph().as_default() as g:

    if FLAGS.dataset == 'KITTI':
      mc.BATCH_SIZE = 1 # TODO(bichen): allow batch size > 1
      mc.LOAD_PRETRAINED_MODEL = False
      imdb = kitti(FLAGS.image_set, FLAGS.data_path, mc)
    elif FLAGS.dataset == 'PASCAL_VOC':
      mc.BATCH_SIZE = 1 # TODO(bichen): allow batch size > 1
      mc.LOAD_PRETRAINED_MODEL = True
      imdb = pascal_voc(FLAGS.image_set, FLAGS.year, FLAGS.data_path, mc)
    elif FLAGS.dataset == 'VID':
      mc.BATCH_SIZE = 1 # TODO(bichen): allow batch size > 1
      mc.LOAD_PRETRAINED_MODEL = False
      imdb = vid(FLAGS.image_set, FLAGS.data_path, mc)

    if FLAGS.net == 'vgg16':
      model = VGG16ConvDet(mc, FLAGS.gpu)
    elif FLAGS.net == 'vgg16_v2':
      model = VGG16ConvDetV2(mc, FLAGS.gpu)
    elif FLAGS.net == 'vgg16_v3':
      model = VGG16ConvDetV3(mc, FLAGS.gpu)
    elif FLAGS.net == 'yolo_v2':
      model = YOLO_V2(mc, FLAGS.gpu)

    saver = tf.train.Saver(model.model_params)

    summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir, g)
    
    ckpts = set() 
    while True:
      if FLAGS.run_once:
        # When run_once is true, checkpoint_path should point to the exact
        # checkpoint file.
        eval_once(saver, FLAGS.checkpoint_path, summary_writer, imdb, model)
        return
      else:
        # When run_once is false, checkpoint_path should point to the directory
        # that stores checkpoint files.
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
        if ckpt and ckpt.model_checkpoint_path:
          if ckpt.model_checkpoint_path in ckpts:
            # Do not evaluate on the same checkpoint
            print ('Wait {:d}s for new checkpoints to be saved ... '
                      .format(FLAGS.eval_interval_secs))
            time.sleep(FLAGS.eval_interval_secs)
          else:
            ckpts.add(ckpt.model_checkpoint_path)
            print ('Evaluating {}...'.format(ckpt.model_checkpoint_path))
            eval_once(saver, ckpt.model_checkpoint_path, 
                      summary_writer, imdb, model)
        else:
          print('No checkpoint file found')
          if not FLAGS.run_once:
            print ('Wait {:d}s for new checkpoints to be saved ... '
                      .format(FLAGS.eval_interval_secs))
            time.sleep(FLAGS.eval_interval_secs)
示例#5
0
def evaluate():
  """Evaluate."""
  assert FLAGS.dataset == 'KITTI', \
      'Currently only supports KITTI dataset'

  os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu

  with tf.Graph().as_default() as g:

    assert FLAGS.net == 'vgg16' or FLAGS.net == 'resnet50' \
        or FLAGS.net == 'squeezeDet' or FLAGS.net == 'squeezeDet+', \
        'Selected neural net architecture not supported: {}'.format(FLAGS.net)
    if FLAGS.net == 'vgg16':
      mc = kitti_vgg16_config()
      mc.BATCH_SIZE = 1 # TODO(bichen): allow batch size > 1
      mc.LOAD_PRETRAINED_MODEL = False
      model = VGG16ConvDet(mc)
    elif FLAGS.net == 'resnet50':
      mc = kitti_res50_config()
      mc.BATCH_SIZE = 1 # TODO(bichen): allow batch size > 1
      mc.LOAD_PRETRAINED_MODEL = False
      model = ResNet50ConvDet(mc)
    elif FLAGS.net == 'squeezeDet':
      mc = kitti_squeezeDet_config()
      mc.BATCH_SIZE = 1 # TODO(bichen): allow batch size > 1
      mc.LOAD_PRETRAINED_MODEL = False
      model = SqueezeDet(mc)
    elif FLAGS.net == 'squeezeDet+':
      mc = kitti_squeezeDetPlus_config()
      mc.BATCH_SIZE = 1 # TODO(bichen): allow batch size > 1
      mc.LOAD_PRETRAINED_MODEL = False
      model = SqueezeDetPlus(mc)

    imdb = kitti(FLAGS.image_set, FLAGS.data_path, mc)

    # add summary ops and placeholders
    ap_names = []
    for cls in imdb.classes:
      ap_names.append(cls+'_easy')
      ap_names.append(cls+'_medium')
      ap_names.append(cls+'_hard')

    eval_summary_ops = []
    eval_summary_phs = {}
    for ap_name in ap_names:
      ph = tf.placeholder(tf.float32)
      eval_summary_phs['APs/'+ap_name] = ph
      eval_summary_ops.append(tf.summary.scalar('APs/'+ap_name, ph))

    ph = tf.placeholder(tf.float32)
    eval_summary_phs['APs/mAP'] = ph
    eval_summary_ops.append(tf.summary.scalar('APs/mAP', ph))

    ph = tf.placeholder(tf.float32)
    eval_summary_phs['timing/im_detect'] = ph
    eval_summary_ops.append(tf.summary.scalar('timing/im_detect', ph))

    ph = tf.placeholder(tf.float32)
    eval_summary_phs['timing/im_read'] = ph
    eval_summary_ops.append(tf.summary.scalar('timing/im_read', ph))

    ph = tf.placeholder(tf.float32)
    eval_summary_phs['timing/post_proc'] = ph
    eval_summary_ops.append(tf.summary.scalar('timing/post_proc', ph))

    ph = tf.placeholder(tf.float32)
    eval_summary_phs['num_det_per_image'] = ph
    eval_summary_ops.append(tf.summary.scalar('num_det_per_image', ph))

    saver = tf.train.Saver(model.model_params)

    summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)
    
    ckpts = set() 
    while True:
      if FLAGS.run_once:
        # When run_once is true, checkpoint_path should point to the exact
        # checkpoint file.
        eval_once(
            saver, FLAGS.checkpoint_path, summary_writer, eval_summary_ops,
            eval_summary_phs, imdb, model)
        return
      else:
        # When run_once is false, checkpoint_path should point to the directory
        # that stores checkpoint files.
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
        if ckpt and ckpt.model_checkpoint_path:
          if ckpt.model_checkpoint_path in ckpts:
            # Do not evaluate on the same checkpoint
            print ('Wait {:d}s for new checkpoints to be saved ... '
                      .format(FLAGS.eval_interval_secs))
            time.sleep(FLAGS.eval_interval_secs)
          else:
            ckpts.add(ckpt.model_checkpoint_path)
            print ('Evaluating {}...'.format(ckpt.model_checkpoint_path))
            eval_once(
                saver, ckpt.model_checkpoint_path, summary_writer,
                eval_summary_ops, eval_summary_phs, imdb, model)
        else:
          print('No checkpoint file found')
          if not FLAGS.run_once:
            print ('Wait {:d}s for new checkpoints to be saved ... '
                      .format(FLAGS.eval_interval_secs))
            time.sleep(FLAGS.eval_interval_secs)
示例#6
0
def train():
    """Train SqueezeDet model"""
    assert FLAGS.dataset == 'KITTI', \
        'Currently only support KITTI dataset'

    with tf.Graph().as_default():

        assert FLAGS.net == 'vgg16' or FLAGS.net == 'resnet50' \
            or FLAGS.net == 'squeezeDet' or FLAGS.net == 'squeezeDet+', \
            'Selected neural net architecture not supported: {}'.format(FLAGS.net)
        if FLAGS.net == 'vgg16':
            mc = kitti_vgg16_config()
            mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
            model = VGG16ConvDet(mc, FLAGS.gpu)
        elif FLAGS.net == 'resnet50':
            mc = kitti_res50_config()
            mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
            model = ResNet50ConvDet(mc, FLAGS.gpu)
        elif FLAGS.net == 'squeezeDet':
            mc = kitti_squeezeDet_config()
            mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
            model = SqueezeDet(mc, FLAGS.gpu)
        elif FLAGS.net == 'squeezeDet+':
            mc = kitti_squeezeDetPlus_config()
            mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
            model = SqueezeDetPlus(mc, FLAGS.gpu)

        imdb = kitti(FLAGS.image_set, FLAGS.data_path, mc)

        # save model size, flops, activations by layers
        with open(os.path.join(FLAGS.train_dir, 'model_metrics.txt'),
                  'w') as f:
            f.write('Number of parameter by layer:\n')
            count = 0
            for c in model.model_size_counter:
                f.write('\t{}: {}\n'.format(c[0], c[1]))
                count += c[1]
            f.write('\ttotal: {}\n'.format(count))

            count = 0
            f.write('\nActivation size by layer:\n')
            for c in model.activation_counter:
                f.write('\t{}: {}\n'.format(c[0], c[1]))
                count += c[1]
            f.write('\ttotal: {}\n'.format(count))

            count = 0
            f.write('\nNumber of flops by layer:\n')
            for c in model.flop_counter:
                f.write('\t{}: {}\n'.format(c[0], c[1]))
                count += c[1]
            f.write('\ttotal: {}\n'.format(count))
        f.close()
        print('Model statistics saved to {}.'.format(
            os.path.join(FLAGS.train_dir, 'model_metrics.txt')))

        saver = tf.train.Saver(tf.all_variables())
        summary_op = tf.merge_all_summaries()
        init = tf.initialize_all_variables()

        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        sess.run(init)
        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

        for step in xrange(FLAGS.max_steps):
            start_time = time.time()

            # read batch input
            image_per_batch, label_per_batch, box_delta_per_batch, aidx_per_batch, \
                bbox_per_batch = imdb.read_batch()

            label_indices, bbox_indices, box_delta_values, mask_indices, box_values, \
                = [], [], [], [], []
            aidx_set = set()
            num_discarded_labels = 0
            num_labels = 0
            for i in range(len(label_per_batch)):  # batch_size
                for j in range(len(
                        label_per_batch[i])):  # number of annotations
                    num_labels += 1
                    if (i, aidx_per_batch[i][j]) not in aidx_set:
                        aidx_set.add((i, aidx_per_batch[i][j]))
                        label_indices.append(
                            [i, aidx_per_batch[i][j], label_per_batch[i][j]])
                        mask_indices.append([i, aidx_per_batch[i][j]])
                        bbox_indices.extend([[i, aidx_per_batch[i][j], k]
                                             for k in range(4)])
                        box_delta_values.extend(box_delta_per_batch[i][j])
                        box_values.extend(bbox_per_batch[i][j])
                    else:
                        num_discarded_labels += 1

            if mc.DEBUG_MODE:
                print(
                    'Warning: Discarded {}/({}) labels that are assigned to the same'
                    'anchor'.format(num_discarded_labels, num_labels))

            feed_dict = {
                model.image_input:
                image_per_batch,
                model.keep_prob:
                mc.KEEP_PROB,
                model.input_mask:
                np.reshape(
                    sparse_to_dense(mask_indices, [mc.BATCH_SIZE, mc.ANCHORS],
                                    [1.0] * len(mask_indices)),
                    [mc.BATCH_SIZE, mc.ANCHORS, 1]),
                model.box_delta_input:
                sparse_to_dense(bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
                                box_delta_values),
                model.box_input:
                sparse_to_dense(bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
                                box_values),
                model.labels:
                sparse_to_dense(label_indices,
                                [mc.BATCH_SIZE, mc.ANCHORS, mc.CLASSES],
                                [1.0] * len(label_indices)),
            }

            if step % FLAGS.summary_step == 0:
                op_list = [
                    model.train_op, model.loss, summary_op, model.det_boxes,
                    model.det_probs, model.det_class, model.conf_loss,
                    model.bbox_loss, model.class_loss
                ]
                _, loss_value, summary_str, det_boxes, det_probs, det_class, conf_loss, \
                    bbox_loss, class_loss = sess.run(op_list, feed_dict=feed_dict)

                _viz_prediction_result(model, image_per_batch, bbox_per_batch,
                                       label_per_batch, det_boxes, det_class,
                                       det_probs)
                image_per_batch = bgr_to_rgb(image_per_batch)
                viz_summary = sess.run(
                    model.viz_op,
                    feed_dict={model.image_to_show: image_per_batch})

                num_discarded_labels_op = tf.scalar_summary(
                    'counter/num_discarded_labels', num_discarded_labels)
                num_labels_op = tf.scalar_summary('counter/num_labels',
                                                  num_labels)

                counter_summary_str = sess.run(
                    [num_discarded_labels_op, num_labels_op])

                summary_writer.add_summary(summary_str, step)
                summary_writer.add_summary(viz_summary, step)
                for sum_str in counter_summary_str:
                    summary_writer.add_summary(sum_str, step)

                print('conf_loss: {}, bbox_loss: {}, class_loss: {}'.format(
                    conf_loss, bbox_loss, class_loss))
            else:
                _, loss_value, conf_loss, bbox_loss, class_loss = sess.run(
                    [
                        model.train_op, model.loss, model.conf_loss,
                        model.bbox_loss, model.class_loss
                    ],
                    feed_dict=feed_dict)

            duration = time.time() - start_time

            assert not np.isnan(loss_value), \
                'Model diverged. Total loss: {}, conf_loss: {}, bbox_loss: {}, ' \
                'class_loss: {}'.format(loss_value, conf_loss, bbox_loss, class_loss)

            if step % 10 == 0:
                num_images_per_step = mc.BATCH_SIZE
                images_per_sec = num_images_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: step %d, loss = %.2f (%.1f images/sec; %.3f '
                    'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    images_per_sec, sec_per_batch))
                sys.stdout.flush()

            # Save the model checkpoint periodically.
            if step % FLAGS.checkpoint_step == 0 or (step +
                                                     1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
示例#7
0
def train():
  """Train SqueezeDet model"""
  assert FLAGS.dataset == 'KITTI', \
      'Currently only support KITTI dataset'

  os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu

  with tf.Graph().as_default():

    assert FLAGS.net == 'vgg16' or FLAGS.net == 'resnet50' \
        or FLAGS.net == 'squeezeDet' or FLAGS.net == 'squeezeDet+', \
        'Selected neural net architecture not supported: {}'.format(FLAGS.net)
    if FLAGS.net == 'vgg16':
      mc = kitti_vgg16_config()
      mc.IS_TRAINING = True
      mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
      model = VGG16ConvDet(mc)
    elif FLAGS.net == 'resnet50':
      mc = kitti_res50_config()
      mc.IS_TRAINING = True
      mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
      model = ResNet50ConvDet(mc)
    elif FLAGS.net == 'squeezeDet':
      mc = kitti_squeezeDet_config()
      mc.IS_TRAINING = True
      mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
      model = SqueezeDet(mc)
    elif FLAGS.net == 'squeezeDet+':
      mc = kitti_squeezeDetPlus_config()
      mc.IS_TRAINING = True
      mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
      model = SqueezeDetPlus(mc)

    imdb = kitti(FLAGS.image_set, FLAGS.data_path, mc)

    # save model size, flops, activations by layers
    with open(os.path.join(FLAGS.train_dir, 'model_metrics.txt'), 'w') as f:
      f.write('Number of parameter by layer:\n')
      count = 0
      for c in model.model_size_counter:
        f.write('\t{}: {}\n'.format(c[0], c[1]))
        count += c[1]
      f.write('\ttotal: {}\n'.format(count))

      count = 0
      f.write('\nActivation size by layer:\n')
      for c in model.activation_counter:
        f.write('\t{}: {}\n'.format(c[0], c[1]))
        count += c[1]
      f.write('\ttotal: {}\n'.format(count))

      count = 0
      f.write('\nNumber of flops by layer:\n')
      for c in model.flop_counter:
        f.write('\t{}: {}\n'.format(c[0], c[1]))
        count += c[1]
      f.write('\ttotal: {}\n'.format(count))
    f.close()
    print ('Model statistics saved to {}.'.format(
      os.path.join(FLAGS.train_dir, 'model_metrics.txt')))

    def _load_data(load_to_placeholder=True):
      # read batch input
      image_per_batch, label_per_batch, box_delta_per_batch, aidx_per_batch, \
          bbox_per_batch = imdb.read_batch()

      label_indices, bbox_indices, box_delta_values, mask_indices, box_values, \
          = [], [], [], [], []
      aidx_set = set()
      num_discarded_labels = 0
      num_labels = 0
      for i in range(len(label_per_batch)): # batch_size
        for j in range(len(label_per_batch[i])): # number of annotations
          num_labels += 1
          if (i, aidx_per_batch[i][j]) not in aidx_set:
            aidx_set.add((i, aidx_per_batch[i][j]))
            label_indices.append(
                [i, aidx_per_batch[i][j], label_per_batch[i][j]])
            mask_indices.append([i, aidx_per_batch[i][j]])
            bbox_indices.extend(
                [[i, aidx_per_batch[i][j], k] for k in range(4)])
            box_delta_values.extend(box_delta_per_batch[i][j])
            box_values.extend(bbox_per_batch[i][j])
          else:
            num_discarded_labels += 1

      if mc.DEBUG_MODE:
        print ('Warning: Discarded {}/({}) labels that are assigned to the same '
               'anchor'.format(num_discarded_labels, num_labels))

      if load_to_placeholder:
        image_input = model.ph_image_input
        input_mask = model.ph_input_mask
        box_delta_input = model.ph_box_delta_input
        box_input = model.ph_box_input
        labels = model.ph_labels
      else:
        image_input = model.image_input
        input_mask = model.input_mask
        box_delta_input = model.box_delta_input
        box_input = model.box_input
        labels = model.labels

      feed_dict = {
          image_input: image_per_batch,
          input_mask: np.reshape(
              sparse_to_dense(
                  mask_indices, [mc.BATCH_SIZE, mc.ANCHORS],
                  [1.0]*len(mask_indices)),
              [mc.BATCH_SIZE, mc.ANCHORS, 1]),
          box_delta_input: sparse_to_dense(
              bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
              box_delta_values),
          box_input: sparse_to_dense(
              bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
              box_values),
          labels: sparse_to_dense(
              label_indices,
              [mc.BATCH_SIZE, mc.ANCHORS, mc.CLASSES],
              [1.0]*len(label_indices)),
      }

      return feed_dict, image_per_batch, label_per_batch, bbox_per_batch

    def _enqueue(sess, coord):
      try:
        while not coord.should_stop():
          feed_dict, _, _, _ = _load_data()
          sess.run(model.enqueue_op, feed_dict=feed_dict)
          if mc.DEBUG_MODE:
            print ("added to the queue")
        if mc.DEBUG_MODE:
          print ("Finished enqueue")
      except Exception, e:
        coord.request_stop(e)

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))

    saver = tf.train.Saver(tf.global_variables())
    summary_op = tf.summary.merge_all()

    ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)

    summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

    init = tf.global_variables_initializer()
    sess.run(init)

    coord = tf.train.Coordinator()

    if mc.NUM_THREAD > 0:
      enq_threads = []
      for _ in range(mc.NUM_THREAD):
        enq_thread = threading.Thread(target=_enqueue, args=[sess, coord])
        # enq_thread.isDaemon()
        enq_thread.start()
        enq_threads.append(enq_thread)

    threads = tf.train.start_queue_runners(coord=coord, sess=sess)
    run_options = tf.RunOptions(timeout_in_ms=60000)

    # try:
    for step in xrange(FLAGS.max_steps):
      if coord.should_stop():
        sess.run(model.FIFOQueue.close(cancel_pending_enqueues=True))
        coord.request_stop()
        coord.join(threads)
        break

      start_time = time.time()

      if step % FLAGS.summary_step == 0:
        feed_dict, image_per_batch, label_per_batch, bbox_per_batch = \
            _load_data(load_to_placeholder=False)
        op_list = [
            model.train_op, model.loss, summary_op, model.det_boxes,
            model.det_probs, model.det_class, model.conf_loss,
            model.bbox_loss, model.class_loss
        ]
        _, loss_value, summary_str, det_boxes, det_probs, det_class, \
            conf_loss, bbox_loss, class_loss = sess.run(
                op_list, feed_dict=feed_dict)

        _viz_prediction_result(
            model, image_per_batch, bbox_per_batch, label_per_batch, det_boxes,
            det_class, det_probs)
        image_per_batch = bgr_to_rgb(image_per_batch)
        viz_summary = sess.run(
            model.viz_op, feed_dict={model.image_to_show: image_per_batch})

        summary_writer.add_summary(summary_str, step)
        summary_writer.add_summary(viz_summary, step)
        summary_writer.flush()

        print ('conf_loss: {}, bbox_loss: {}, class_loss: {}'.
            format(conf_loss, bbox_loss, class_loss))
      else:
        if mc.NUM_THREAD > 0:
          _, loss_value, conf_loss, bbox_loss, class_loss = sess.run(
              [model.train_op, model.loss, model.conf_loss, model.bbox_loss,
               model.class_loss], options=run_options)
        else:
          feed_dict, _, _, _ = _load_data(load_to_placeholder=False)
          _, loss_value, conf_loss, bbox_loss, class_loss = sess.run(
              [model.train_op, model.loss, model.conf_loss, model.bbox_loss,
               model.class_loss], feed_dict=feed_dict)

      duration = time.time() - start_time

      assert not np.isnan(loss_value), \
          'Model diverged. Total loss: {}, conf_loss: {}, bbox_loss: {}, ' \
          'class_loss: {}'.format(loss_value, conf_loss, bbox_loss, class_loss)

      if step % 10 == 0:
        num_images_per_step = mc.BATCH_SIZE
        images_per_sec = num_images_per_step / duration
        sec_per_batch = float(duration)
        format_str = ('%s: step %d, loss = %.2f (%.1f images/sec; %.3f '
                      'sec/batch)')
        print (format_str % (datetime.now(), step, loss_value,
                             images_per_sec, sec_per_batch))
        sys.stdout.flush()

      # Save the model checkpoint periodically.
      if step % FLAGS.checkpoint_step == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)
示例#8
0
                           """Can be train, trainval, val, or test""")
tf.app.flags.DEFINE_string('pretrained_model_path',
                           '../path/to/pretrained/weights.pkl',
                           """Directory where to write event logs """)
tf.app.flags.DEFINE_string('data_path', '../data/KITTI',
                           """Root directory of data""")

with tf.Graph().as_default() as g:

    mc = kitti_squeezeDetPlus_config()
    mc.BATCH_SIZE = 1
    mc.LOAD_PRETRAINED_MODEL = True
    print(FLAGS.pretrained_model_path)
    mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path

    imdb = kitti(FLAGS.image_set, FLAGS.data_path, mc)

    # add summary ops and placeholders
    ap_names = []
    for cls in imdb.classes:
        ap_names.append(cls + '_easy')
        ap_names.append(cls + '_medium')
        ap_names.append(cls + '_hard')

    eval_summary_ops = []
    eval_summary_phs = {}
    for ap_name in ap_names:
        ph = tf.placeholder(tf.float32)
        eval_summary_phs['APs/' + ap_name] = ph
        eval_summary_ops.append(tf.summary.scalar('APs/' + ap_name, ph))
示例#9
0
def evaluate():
    """Evaluate."""
    assert FLAGS.dataset == 'KITTI', \
        'Currently only supports KITTI dataset'

    with tf.Graph().as_default() as g:

        assert FLAGS.net == 'vgg16' or FLAGS.net == 'resnet50' \
            or FLAGS.net == 'squeezeDet' or FLAGS.net == 'squeezeDet+', \
            'Selected neural net architecture not supported: {}'.format(FLAGS.net)
        if FLAGS.net == 'vgg16':
            mc = kitti_vgg16_config()
            mc.BATCH_SIZE = 1  # TODO(bichen): allow batch size > 1
            mc.LOAD_PRETRAINED_MODEL = False
            model = VGG16ConvDet(mc, FLAGS.gpu)
        elif FLAGS.net == 'resnet50':
            mc = kitti_res50_config()
            mc.BATCH_SIZE = 1  # TODO(bichen): allow batch size > 1
            mc.LOAD_PRETRAINED_MODEL = False
            model = ResNet50ConvDet(mc, FLAGS.gpu)
        elif FLAGS.net == 'squeezeDet':
            mc = kitti_squeezeDet_config()
            mc.BATCH_SIZE = 1  # TODO(bichen): allow batch size > 1
            mc.LOAD_PRETRAINED_MODEL = False
            model = SqueezeDet(mc, FLAGS.gpu)
        elif FLAGS.net == 'squeezeDet+':
            mc = kitti_squeezeDetPlus_config()
            mc.BATCH_SIZE = 1  # TODO(bichen): allow batch size > 1
            mc.LOAD_PRETRAINED_MODEL = False
            model = SqueezeDetPlus(mc, FLAGS.gpu)

        imdb = kitti(FLAGS.image_set, FLAGS.data_path, mc)

        saver = tf.train.Saver(model.model_params)

        summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)

        ckpts = set()
        while True:
            if FLAGS.run_once:
                # When run_once is true, checkpoint_path should point to the exact
                # checkpoint file.
                eval_once(saver, FLAGS.checkpoint_path, summary_writer, imdb,
                          model)
                return
            else:
                # When run_once is false, checkpoint_path should point to the directory
                # that stores checkpoint files.
                ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
                if ckpt and ckpt.model_checkpoint_path:
                    if ckpt.model_checkpoint_path in ckpts:
                        # Do not evaluate on the same checkpoint
                        print(
                            'Wait {:d}s for new checkpoints to be saved ... '.
                            format(FLAGS.eval_interval_secs))
                        time.sleep(FLAGS.eval_interval_secs)
                    else:
                        ckpts.add(ckpt.model_checkpoint_path)
                        print('Evaluating {}...'.format(
                            ckpt.model_checkpoint_path))
                        eval_once(saver, ckpt.model_checkpoint_path,
                                  summary_writer, imdb, model)
                else:
                    print('No checkpoint file found')
                    if not FLAGS.run_once:
                        print(
                            'Wait {:d}s for new checkpoints to be saved ... '.
                            format(FLAGS.eval_interval_secs))
                        time.sleep(FLAGS.eval_interval_secs)
示例#10
0
def test_read_batch():
    """Test read batch function"""
    assert FLAGS.dataset in ['KITTI', 'PASCAL_VOC', 'VID', 'ILSVRC2013'], \
        """
      Invalid dataset {}
      Either KITTI / PASCAL_VOC / VID / ILSVRC2013""".format(FLAGS.dataset)
    if FLAGS.dataset == 'KITTI':
        mc = kitti_vgg16_config()
        imdb = kitti(FLAGS.image_set, FLAGS.data_path, mc)
    elif FLAGS.dataset == 'PASCAL_VOC':
        mc = pascal_voc_vgg16_config()
        imdb = pascal_voc(FLAGS.image_set, FLAGS.year, FLAGS.data_path, mc)
    elif FLAGS.dataset == 'VID':
        mc = vid_vgg16_config()
        imdb = vid(FLAGS.image_set, FLAGS.data_path, mc)
    elif FLAGS.dataset == 'ILSVRC2013':
        mc = imagenet_config()
        imdb = imagenet(FLAGS.image_set, FLAGS.data_path, mc)

    if FLAGS.dataset != 'ILSVRC2013':
        # read batch input
        image_per_batch, label_per_batch, box_delta_per_batch, aidx_per_batch, \
            bbox_per_batch = imdb.read_batch()
        #joblib.dump(bbox_per_batch, '/tmp3/jeff/bbox.pkl')

        label_indices, bbox_indices, box_delta_values, mask_indices, box_values, \
            = [], [], [], [], []
        aidx_set = set()
        num_discarded_labels = 0
        num_labels = 0
        for i in range(len(label_per_batch)):  # batch_size
            for j in range(len(label_per_batch[i])):  # number of annotations
                num_labels += 1
                if (i, aidx_per_batch[i][j]) not in aidx_set:
                    aidx_set.add((i, aidx_per_batch[i][j]))
                    label_indices.append(
                        [i, aidx_per_batch[i][j], label_per_batch[i][j]])
                    mask_indices.append([i, aidx_per_batch[i][j]])
                    bbox_indices.extend([[i, aidx_per_batch[i][j], k]
                                         for k in range(4)])
                    box_delta_values.extend(box_delta_per_batch[i][j])
                    box_values.extend(bbox_per_batch[i][j])
                else:
                    num_discarded_labels += 1

        print('Warning: Discarded {}/({}) labels that are assigned to the same'
              'anchor'.format(num_discarded_labels, num_labels))

        # Visualize detections
        _viz_gt_bboxes(mc, image_per_batch, bbox_per_batch, label_per_batch)

    elif FLAGS.dataset == 'ILSVRC2013':
        image_per_batch, label_per_batch, _ = imdb.read_cls_batch()
        label_per_batch = map(str, label_per_batch)

        # TODO(jeff): visualize classification image
        _viz_cls_labels(mc, image_per_batch, label_per_batch)

    # Save the images
    for i, im in enumerate(image_per_batch):
        fname = os.path.join(FLAGS.output_dir, '{}.jpg'.format(i))
        cv2.imwrite(fname, im)
示例#11
0
文件: run.py 项目: Saiuz/squeezeDet
def main(argv):
    config_dir = os.path.abspath(argv[1])
    assert config_dir.startswith(
        Models), 'Invalid config directory %s' % config_dir
    model_name, config_name = config_dir[len(Models):].split('/')[:2]

    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu
    model_dir = Models + model_name + '/'

    mc = kitti_squeezeDetPlus_config()
    config_dir = model_dir + config_name + '/'
    config_path = config_dir + 'config.json'
    with open(config_path, 'r+') as f:  # load custom params
        for key, value in json.load(f).items():
            mc[key] = value

    mc.IS_TRAINING = FLAGS.train
    mc.PRETRAINED_MODEL_PATH = model_dir + 'pretrained.pkl'
    train_dir = config_dir + 'train/'
    test_dir = config_dir + 'val/'
    if FLAGS.save_root:
        for dname in ['train', 'val']:
            orig_dir = os.path.join(config_dir, dname)
            save_dir = os.path.join(FLAGS.save_root, 'models', model_name,
                                    config_name, dname)
            make_dir(save_dir)

            if not os.path.exists(orig_dir):
                os.symlink(save_dir, orig_dir)
                print('Creating symlink %s -> %s' % (orig_dir, save_dir))
            elif not os.path.islink(orig_dir):
                raise RuntimeError(
                    '%s exists but is not a link. Cannot create new link to %s'
                    % (orig_dir, save_dir))

    if FLAGS.debug:
        config.PRINT_STEP = config.SUMMARY_STEP = config.CHECKPOINT_STEP = config.MAX_STEPS = 1

    if mc.IS_TRAINING:
        kitti_set = 'train'
        summary_dir = train_dir
    else:
        kitti_set = 'val'
        summary_dir = test_dir
        mc.BATCH_SIZE = 1
    imdb = kitti(kitti_set, Root + 'data/KITTI', mc)
    summary_writer = tf.summary.FileWriter(summary_dir)

    sys.path.append(model_dir)
    from load_model import load_model
    model = load_model(mc)

    def _load_data(load_to_placeholder=True):
        image_per_batch, label_per_batch, box_delta_per_batch, aidx_per_batch, bbox_per_batch = imdb.read_batch(
        )

        label_indices, bbox_indices, box_delta_values, mask_indices, box_values = [], [], [], [], []
        aidx_set = set()
        for i in range(len(label_per_batch)):
            for j in range(len(label_per_batch[i])):
                if (i, aidx_per_batch[i][j]) not in aidx_set:
                    aidx_set.add((i, aidx_per_batch[i][j]))
                    label_indices.append(
                        [i, aidx_per_batch[i][j], label_per_batch[i][j]])
                    mask_indices.append([i, aidx_per_batch[i][j]])
                    bbox_indices.extend([[i, aidx_per_batch[i][j], k]
                                         for k in range(4)])
                    box_delta_values.extend(box_delta_per_batch[i][j])
                    box_values.extend(bbox_per_batch[i][j])
        if load_to_placeholder:
            image_input = model.ph_image_input
            input_mask = model.ph_input_mask
            box_delta_input = model.ph_box_delta_input
            box_input = model.ph_box_input
            labels = model.ph_labels
        else:
            image_input = model.image_input
            input_mask = model.input_mask
            box_delta_input = model.box_delta_input
            box_input = model.box_input
            labels = model.labels

        feed_dict = {
            image_input:
            image_per_batch,
            input_mask:
            np.reshape(
                sparse_to_dense(mask_indices, [mc.BATCH_SIZE, mc.ANCHORS],
                                [1.0] * len(mask_indices)),
                [mc.BATCH_SIZE, mc.ANCHORS, 1]),
            box_delta_input:
            sparse_to_dense(bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
                            box_delta_values),
            box_input:
            sparse_to_dense(bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
                            box_values),
            labels:
            sparse_to_dense(label_indices,
                            [mc.BATCH_SIZE, mc.ANCHORS, mc.CLASSES],
                            [1.0] * len(label_indices)),
        }
        return feed_dict, image_per_batch, label_per_batch, bbox_per_batch

    def _enqueue(sess, coord):
        try:
            while not coord.should_stop():
                feed_dict, _, _, _ = _load_data()
                sess.run(model.enqueue_op, feed_dict=feed_dict)
        except Exception, e:
            if not sess.run(model.FIFOQueue.is_closed()):
                coord.request_stop(e)
示例#12
0
def train():
  """Train SqueezeDet model"""
  assert FLAGS.dataset == 'KITTI' or FLAGS.dataset == 'CITYSCAPE', \
      'Currently only support KITTI and CITYSCAPE datasets'
  assert FLAGS.mask_parameterization in [4,8], 'Values other than 4 and 8 are not supported !'

  os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu

  with tf.Graph().as_default():

    assert FLAGS.net == 'vgg16' or FLAGS.net == 'resnet50' \
        or FLAGS.net == 'squeezeDet' or FLAGS.net == 'squeezeDet+', \
        'Selected neural net architecture not supported: {}'.format(FLAGS.net)
    if FLAGS.net == 'vgg16':
      if FLAGS.dataset == 'KITTI':
        mc = kitti_vgg16_config(FLAGS.mask_parameterization, FLAGS.only_tune_last_layer, FLAGS.encoding_type)
      elif FLAGS.dataset == 'CITYSCAPE':
        mc = cityscape_vgg16_config(FLAGS.mask_parameterization, FLAGS.log_anchors, FLAGS.only_tune_last_layer, FLAGS.encoding_type)
      mc.IS_TRAINING = True
      # mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
      print("Not using pretrained model for VGG, uncomment above line and comment below line to use pretrained model !")
      mc.LOAD_PRETRAINED_MODEL = False
      if FLAGS.warm_restart_lr != -1.0:
        print("Updating the learning rate for warm restart to", FLAGS.warm_restart_lr)
        mc.LEARNING_RATE = FLAGS.warm_restart_lr
      model = VGG16ConvDet(mc)
    elif FLAGS.net == 'resnet50':
      if FLAGS.dataset == 'KITTI':
        mc = kitti_res50_config(FLAGS.mask_parameterization, FLAGS.only_tune_last_layer, FLAGS.encoding_type)
      elif FLAGS.dataset == 'CITYSCAPE':
        mc = cityscape_res50_config(FLAGS.mask_parameterization, FLAGS.log_anchors, FLAGS.only_tune_last_layer, FLAGS.encoding_type)
      mc.IS_TRAINING = True
      mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
      if FLAGS.warm_restart_lr != -1.0:
        print("Updating the learning rate for warm restart to", FLAGS.warm_restart_lr)
        mc.LEARNING_RATE = FLAGS.warm_restart_lr
      model = ResNet50ConvDet(mc)
    elif FLAGS.net == 'squeezeDet':
      if FLAGS.dataset == 'KITTI':
        mc = kitti_squeezeDet_config(FLAGS.mask_parameterization, FLAGS.only_tune_last_layer, FLAGS.encoding_type)
      elif FLAGS.dataset == 'CITYSCAPE':
        mc = cityscape_squeezeDet_config(FLAGS.mask_parameterization, FLAGS.log_anchors, FLAGS.only_tune_last_layer, FLAGS.encoding_type)
      mc.IS_TRAINING = True
      mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
      if FLAGS.warm_restart_lr != -1.0:
        print("Updating the learning rate for warm restart to", FLAGS.warm_restart_lr)
        mc.LEARNING_RATE = FLAGS.warm_restart_lr
      model = SqueezeDet(mc)
    elif FLAGS.net == 'squeezeDet+':
      if FLAGS.dataset == 'KITTI':
        mc = kitti_squeezeDetPlus_config(FLAGS.mask_parameterization, FLAGS.only_tune_last_layer, FLAGS.encoding_type)
      elif FLAGS.dataset == 'CITYSCAPE':
        mc = cityscape_squeezeDetPlus_config(FLAGS.mask_parameterization, FLAGS.log_anchors, FLAGS.only_tune_last_layer, FLAGS.encoding_type)
      mc.IS_TRAINING = True
      mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
      if FLAGS.warm_restart_lr != -1.0:
        print("Updating the learning rate for warm restart to", FLAGS.warm_restart_lr)
        mc.LEARNING_RATE = FLAGS.warm_restart_lr
      model = SqueezeDetPlus(mc)

    imdb_valid = None
    if FLAGS.dataset == 'KITTI':
      imdb = kitti(FLAGS.image_set, FLAGS.data_path, mc)
      if FLAGS.eval_valid:
        imdb_valid = kitti('val', FLAGS.data_path, mc)
        imdb_valid.mc.DATA_AUGMENTATION = False
    elif FLAGS.dataset == 'CITYSCAPE':
      imdb = cityscape(FLAGS.image_set, FLAGS.data_path, mc)
      print("Margins for Training data:", imdb.left_margin, imdb.top_margin, imdb.right_margin, imdb.bottom_margin)
      if FLAGS.eval_valid:
        imdb_valid = cityscape('val', FLAGS.data_path, mc)
        imdb_valid.mc.DATA_AUGMENTATION = False
        print("Margins for Validation data:", imdb_valid.left_margin, imdb_valid.top_margin, imdb_valid.right_margin, imdb_valid.bottom_margin)

    print("Training model data augmentation:", imdb.mc.DATA_AUGMENTATION)
    if imdb_valid != None:
      print("Validation model data augmentation:", imdb_valid.mc.DATA_AUGMENTATION)
    # save model size, flops, activations by layers
    with open(os.path.join(FLAGS.train_dir, 'model_metrics.txt'), 'w') as f:
      f.write('Number of parameter by layer:\n')
      count = 0
      for c in model.model_size_counter:
        f.write('\t{}: {}\n'.format(c[0], c[1]))
        count += c[1]
      f.write('\ttotal: {}\n'.format(count))

      count = 0
      f.write('\nActivation size by layer:\n')
      for c in model.activation_counter:
        f.write('\t{}: {}\n'.format(c[0], c[1]))

        count += c[1]
      f.write('\ttotal: {}\n'.format(count))

      count = 0
      f.write('\nNumber of flops by layer:\n')
      for c in model.flop_counter:
        f.write('\t{}: {}\n'.format(c[0], c[1]))
        count += c[1]
      f.write('\ttotal: {}\n'.format(count))
    f.close()
    print ('Model statistics saved to {}.'.format(
      os.path.join(FLAGS.train_dir, 'model_metrics.txt')))

    def _load_data(load_to_placeholder=True, eval_valid=False):
      # read batch input
      if eval_valid:
        # Only for validation set
        image_per_batch, label_per_batch, box_delta_per_batch, aidx_per_batch, \
          bbox_per_batch, edge_adhesions_per_batch = imdb_valid.read_batch(shuffle=False, wrap_around=False)
        keep_prob_value = 1.0
      else:
        image_per_batch, label_per_batch, box_delta_per_batch, aidx_per_batch, \
            bbox_per_batch, edge_adhesions_per_batch = imdb.read_batch()
        keep_prob_value = mc.DROP_OUT_PROB

      label_indices, bbox_indices, box_delta_values, mask_indices, box_values, edge_adhesions, edge_indices\
          = [], [], [], [], [], [], []
      aidx_set = set()
      num_discarded_labels = 0
      num_labels = 0
      for i in range(len(label_per_batch)): # batch_size
        for j in range(len(label_per_batch[i])): # number of annotations
          num_labels += 1
          if (i, aidx_per_batch[i][j]) not in aidx_set:
            aidx_set.add((i, aidx_per_batch[i][j]))
            label_indices.append(
                [i, aidx_per_batch[i][j], label_per_batch[i][j]])
            mask_indices.append([i, aidx_per_batch[i][j]])
            bbox_indices.extend(
                [[i, aidx_per_batch[i][j], k] for k in range(FLAGS.mask_parameterization)])
            box_delta_values.extend(box_delta_per_batch[i][j])
            box_values.extend(bbox_per_batch[i][j])
            edge_adhesions.extend(edge_adhesions_per_batch[i][j])
            edge_indices.extend(
                [[i, aidx_per_batch[i][j], k] for k in range(FLAGS.mask_parameterization)])
          else:
            num_discarded_labels += 1

      if mc.DEBUG_MODE:
        print ('Warning: Discarded {}/({}) labels that are assigned to the same '
               'anchor'.format(num_discarded_labels, num_labels))

      if load_to_placeholder:
        image_input = model.ph_image_input
        input_mask = model.ph_input_mask
        box_delta_input = model.ph_box_delta_input
        box_input = model.ph_box_input
        labels = model.ph_labels
        edge_scenarios = model.ph_edge_adhesions
      else:
        image_input = model.image_input
        input_mask = model.input_mask
        box_delta_input = model.box_delta_input
        box_input = model.box_input
        labels = model.labels
        edge_scenarios = model.edge_adhesions

      feed_dict = {
          image_input: image_per_batch,
          input_mask: np.reshape(
              sparse_to_dense(
                  mask_indices, [mc.BATCH_SIZE, mc.ANCHORS],
                  [1.0]*len(mask_indices)),
              [mc.BATCH_SIZE, mc.ANCHORS, 1]),
          box_delta_input: sparse_to_dense(
              bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, FLAGS.mask_parameterization],
              box_delta_values),
          box_input: sparse_to_dense(
              bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, FLAGS.mask_parameterization],
              box_values),
          labels: sparse_to_dense(
              label_indices,
              [mc.BATCH_SIZE, mc.ANCHORS, mc.CLASSES],
              [1.0]*len(label_indices)),
          model.keep_prob: keep_prob_value,
          edge_scenarios: sparse_to_dense(
              edge_indices, [mc.BATCH_SIZE, mc.ANCHORS, FLAGS.mask_parameterization],
              edge_adhesions).astype(np.bool, copy=False),
      }

      return feed_dict, image_per_batch, label_per_batch, bbox_per_batch, edge_indices


    def _enqueue(sess, coord):
      try:
        while not coord.should_stop():
          feed_dict, _, _, _, _ = _load_data()
          sess.run(model.enqueue_op, feed_dict=feed_dict)
          if mc.DEBUG_MODE:
            print ("added to the queue")
        if mc.DEBUG_MODE:
          print ("Finished enqueue")
      except tf.errors.CancelledError:
        coord.request_stop()

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))

    saver = tf.train.Saver(tf.global_variables())
    summary_op = tf.summary.merge_all()

    init = tf.global_variables_initializer()
    sess.run(init)
    glb_step = sess.run(model.global_step)
    print("Global step before restore:", glb_step)

    print("Kernels before restore")
    for v in tf.trainable_variables():
      if 'kernels' in v.name:
        print("First few weights of ", v.name, " are ", sess.run(v)[0,0,0,0:5])

    print("Learning rate before restore", sess.run(model.lr))

    ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
    if ckpt and ckpt.model_checkpoint_path:
      print("Found checkpoint at step: ", int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]))
      last_layer_name = model.preds.name.split('/')[0]
      if FLAGS.mask_parameterization == 8 and FLAGS.bounding_box_checkpoint:
        print("Loading only partial weights (except last layer", last_layer_name, ")")
        saver_partial_weights = tf.train.Saver([v for v in tf.global_variables() if last_layer_name not in v.name])
        saver_partial_weights.restore(sess, ckpt.model_checkpoint_path)
        if FLAGS.warm_restart_lr != -1.0:
          print("Resetting global step")
          sess.run([model.global_step.assign(0)])
      else:
        print("Loading all weights (including the last layer", last_layer_name, ")")
        saver.restore(sess, ckpt.model_checkpoint_path)
    else:
      print("Checkpoint not found !")
    glb_step = sess.run(model.global_step)
    print("Global step after restore:", glb_step)
    
    print("Kernels after restore")
    for v in tf.trainable_variables():
      if 'kernels' in v.name:
        print("First few weights of ", v.name, " are ", sess.run(v)[0,0,0,0:5])

    print("Learning rate after restore", sess.run(model.lr))

    summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
    with open(os.path.join(FLAGS.train_dir, 'training_metrics.txt'), 'a') as f:
      f.write("Global step after restore: "+str(glb_step)+"\n")
    f.close()
    if FLAGS.eval_valid:
      with open(os.path.join(FLAGS.train_dir, 'validation_metrics.txt'), 'a') as f:
        f.write("Global step after restore: "+str(glb_step)+"\n")
      f.close()
    coord = tf.train.Coordinator()

    if mc.NUM_THREAD > 0:
      enq_threads = []
      for _ in range(mc.NUM_THREAD):
        enq_thread = threading.Thread(target=_enqueue, args=[sess, coord])
        # enq_thread.isDaemon()
        enq_thread.start()
        enq_threads.append(enq_thread)

    threads = tf.train.start_queue_runners(coord=coord, sess=sess)
    run_options = tf.RunOptions(timeout_in_ms=60000)

    try: 
      for step in xrange(glb_step, FLAGS.max_steps):
        if coord.should_stop():
          sess.run(model.FIFOQueue.close(cancel_pending_enqueues=True))
          coord.request_stop()
          coord.join(threads)
          break

        start_time = time.time()

        if step % FLAGS.summary_step == 0:
          feed_dict, image_per_batch, label_per_batch, bbox_per_batch, edge_ids = \
              _load_data(load_to_placeholder=False)
          op_list = [
              model.train_op, model.loss, summary_op, model.det_boxes,
              model.det_probs, model.det_class, model.conf_loss,
              model.bbox_loss, model.class_loss, model.edge_adhesions,
          ]
          _, loss_value, summary_str, det_boxes, det_probs, det_class, \
              conf_loss, bbox_loss, class_loss, edge_adhesions_pre_filtered = sess.run(
                  op_list, feed_dict=feed_dict)

          summary_writer.add_summary(summary_str, step)
          # Visualize the training examples only if validation is not enabled
          if not FLAGS.eval_valid:
            visualize_gt_masks = False
            visualize_pred_masks = False
            if mc.EIGHT_POINT_REGRESSION:
              visualize_gt_masks = True
              visualize_pred_masks = True

            assert np.array_equal(feed_dict[model.edge_adhesions], edge_adhesions_pre_filtered), \
                "Training Gt edge adhesion not matching edge adhesion tensor" 
            edge_adhesions_per_batch = [[0]]*mc.BATCH_SIZE
            for id_val in range(mc.BATCH_SIZE):
              selected_ids = np.where(np.asarray(edge_ids)[:,0] == id_val)[0]
              # print("Before",np.asarray(edge_ids)[selected_ids][:,1])
              indexes_int = np.unique(np.asarray(edge_ids)[selected_ids][:,1], return_index=True)[1]
              anchors_ids = np.asarray([np.asarray(edge_ids)[selected_ids][:,1][index] for index in sorted(indexes_int)])
              # print("After",anchors_ids)
              batch_id = [id_val]*len(anchors_ids)
              edge_adhesions_per_batch[id_val] = edge_adhesions_pre_filtered[batch_id, anchors_ids]

            _viz_prediction_result(
                model, image_per_batch, bbox_per_batch, label_per_batch, det_boxes,
                det_class, det_probs, visualize_gt_masks, visualize_pred_masks)
            image_per_batch = bgr_to_rgb(image_per_batch)
            viz_summary = sess.run(
                model.viz_op, feed_dict={model.image_to_show: image_per_batch})
            summary_writer.add_summary(viz_summary, step)
          
          print ('total_loss: {}, conf_loss: {}, bbox_loss: {}, class_loss: {}'.\
                format(loss_value, conf_loss, bbox_loss, class_loss))
          with open(os.path.join(FLAGS.train_dir, 'training_metrics.txt'), 'a') as f:
            f.write('step: {}, total_loss: {}, conf_loss: {}, bbox_loss: {}, class_loss: {}\n'.\
                format(step, loss_value, conf_loss, bbox_loss, class_loss))
          f.close()
          if FLAGS.eval_valid:
            print ('\n!! Validation Set evaluation at step ', step, ' !!')
            with open(os.path.join(FLAGS.train_dir, 'validation_metrics.txt'), 'a') as f:
              f.write('\n!! Validation Set evaluation at step '+str(step)+' !!\n')
              loss_list = []
              batch_nr = 0
              while True:
                batch_nr += 1
                if len(imdb_valid._image_idx) % mc.BATCH_SIZE > 0:
                  # if batch_size unevenly divides the number of samples.
                  # then number of batches is one more than the actual num of batches
                  num_of_batches = (len(imdb_valid._image_idx) // mc.BATCH_SIZE) + 1
                else:
                  num_of_batches = (len(imdb_valid._image_idx) // mc.BATCH_SIZE)
                if batch_nr > num_of_batches:
                  break
                feed_dict_val, image_per_batch_val, label_per_batch_val, bbox_per_batch_val, edge_ids_val = \
                    _load_data(load_to_placeholder=False, eval_valid=True)
                op_list_val = [
                    model.loss, model.conf_loss, model.bbox_loss, \
                    model.class_loss, model.det_boxes, \
                    model.det_probs, model.det_class,
                    model.edge_adhesions,
                ]
                loss_value_val, conf_loss_val, bbox_loss_val, class_loss_val, det_boxes_val, \
                  det_probs_val, det_class_val, edge_adhesions_pre_filtered_val = sess.run(op_list_val, feed_dict=feed_dict_val)

                if batch_nr == 1:
                  # Sample the first batch for visualization
                  visualize_gt_masks = False
                  visualize_pred_masks = False
                  if mc.EIGHT_POINT_REGRESSION:
                    visualize_gt_masks = True
                    visualize_pred_masks = True

                  assert np.array_equal(feed_dict_val[model.edge_adhesions], edge_adhesions_pre_filtered_val), \
                          "Validation Gt edge adhesion not matching edge adhesion tensor"
                  edge_adhesions_per_batch_val = [[0]]*mc.BATCH_SIZE
                  for id_val in range(mc.BATCH_SIZE):
                    selected_ids = np.where(np.asarray(edge_ids_val)[:,0] == id_val)[0]
                    indexes_int = np.unique(np.asarray(edge_ids_val)[selected_ids][:,1], return_index=True)[1]
                    anchors_ids = np.asarray([np.asarray(edge_ids_val)[selected_ids][:,1][index] for index in sorted(indexes_int)])
                    batch_id = [id_val]*len(anchors_ids)
                    edge_adhesions_per_batch_val[id_val] = edge_adhesions_pre_filtered_val[batch_id, anchors_ids]

                  _viz_prediction_result(
                      model, image_per_batch_val, bbox_per_batch_val, label_per_batch_val, det_boxes_val,
                      det_class_val, det_probs_val, visualize_gt_masks, visualize_pred_masks)
                  image_per_batch_visualize = bgr_to_rgb(image_per_batch_val)

                loss_list.append([loss_value_val, conf_loss_val, bbox_loss_val, class_loss_val])
                f.write('Batch: {}, total_loss: {}, conf_loss: {}, bbox_loss: {}, class_loss: {}\n'.\
                        format(batch_nr, loss_value_val, conf_loss_val, bbox_loss_val, class_loss_val))
              loss_list = np.asarray(loss_list)
              loss_means = [np.mean(loss_list[:,0]), np.mean(loss_list[:,1]), np.mean(loss_list[:,2]), np.mean(loss_list[:,3])]
              loss_stds = [np.std(loss_list[:,0]), np.std(loss_list[:,1]), np.std(loss_list[:,2]), np.std(loss_list[:,3])]
              print ('Mean values : total_loss: {}, conf_loss: {}, bbox_loss: {}, class_loss: {}'.\
                format(loss_means[0], loss_means[1], loss_means[2], loss_means[3]))
              print ('Standard Deviation values : total_loss: {}, conf_loss: {}, bbox_loss: {}, class_loss: {}'.\
                format(loss_stds[0], loss_stds[1], loss_stds[2], loss_stds[3]))
              f.write('Mean values : total_loss: {}, conf_loss: {}, bbox_loss: {}, class_loss: {}\n'.\
                format(loss_means[0], loss_means[1], loss_means[2], loss_means[3]))
              f.write('Standard Deviation values : total_loss: {}, conf_loss: {}, bbox_loss: {}, class_loss: {}\n'.\
                format(loss_stds[0], loss_stds[1], loss_stds[2], loss_stds[3]))
              # Visualize the validation examples
              if len(image_per_batch_visualize) != 0:
                viz_summary = sess.run(
                    model.viz_op, feed_dict={model.image_to_show: image_per_batch_visualize})
                summary_writer.add_summary(viz_summary, step)
            f.close()
          summary_writer.flush()
        else:
          if mc.NUM_THREAD > 0:
            _, loss_value, conf_loss, bbox_loss, class_loss = sess.run(
                [model.train_op, model.loss, model.conf_loss, model.bbox_loss,
                 model.class_loss], options=run_options)
          else:
            feed_dict, _, _, _, _ = _load_data(load_to_placeholder=False)
            _, loss_value, conf_loss, bbox_loss, class_loss = sess.run(
                [model.train_op, model.loss, model.conf_loss, model.bbox_loss,
                 model.class_loss], feed_dict=feed_dict)

        duration = time.time() - start_time

        assert not np.isnan(loss_value), \
            'Model diverged. Total loss: {}, conf_loss: {}, bbox_loss: {}, ' \
            'class_loss: {}'.format(loss_value, conf_loss, bbox_loss, class_loss)

        if step % 10 == 0:
          num_images_per_step = mc.BATCH_SIZE
          images_per_sec = num_images_per_step / duration
          sec_per_batch = float(duration)
          format_str = ('%s: step %d, loss = %.2f (%.1f images/sec; %.3f '
                        'sec/batch)')
          print (format_str % (datetime.now(), step, loss_value,
                               images_per_sec, sec_per_batch))
          with open(os.path.join(FLAGS.train_dir, 'training_metrics.txt'), 'a') as f:
            f.write(format_str % (datetime.now(), step, loss_value,
                               images_per_sec, sec_per_batch) + '\n')
          f.close()
          sys.stdout.flush()

        # Save the model checkpoint periodically.
        if step % FLAGS.checkpoint_step == 0 or (step + 1) == FLAGS.max_steps:
          checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
          print("Checkpointing at ", step)
          saver.save(sess, checkpoint_path, global_step=step)
      sess.run(model.FIFOQueue.close(cancel_pending_enqueues=True))
      coord.request_stop()
      coord.join(threads)
    except KeyboardInterrupt:
      print("Keyboard interrupt caught ! Terminating..")
      sess.run(model.FIFOQueue.close(cancel_pending_enqueues=True))
      coord.request_stop()
      coord.join(threads)
      sys.exit(0)
    except:
      print("Unexpected error:", sys.exc_info()[0])
      sess.run(model.FIFOQueue.close(cancel_pending_enqueues=True))
      coord.request_stop()
      coord.join(threads)
      sys.exit(0)
示例#13
0
def train():
    """Train SqueezeDet model"""
    assert FLAGS.dataset == 'KITTI', \
        'Currently only support KITTI dataset'

    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu

    with tf.Graph().as_default():

        # assert FLAGS.net == 'vgg16' or FLAGS.net == 'resnet50' \
        #     or FLAGS.net == 'squeezeDet' or FLAGS.net=="zynqDet" or FLAGS.net == 'squeezeDet+', \
        #     'Selected neural net architecture not supported: {}'.format(FLAGS.net)
        if FLAGS.net == 'vgg16':
            mc = kitti_vgg16_config()
            mc.IS_TRAINING = True
            mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
            model = VGG16ConvDet(mc)
        elif FLAGS.net == 'resnet50':
            mc = kitti_res50_config()
            mc.IS_TRAINING = True
            mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
            model = ResNet50ConvDet(mc)
        elif FLAGS.net == 'squeezeDet':
            mc = kitti_squeezeDet_config()
            mc.IS_TRAINING = True
            mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
            model = SqueezeDet(mc)
        elif FLAGS.net == 'squeezeDet+':
            mc = kitti_squeezeDetPlus_config()
            mc.IS_TRAINING = True
            mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
            model = SqueezeDetPlus(mc)

        elif FLAGS.net == "zynqDet":
            mc = kitti_zynqDet_FPN_config()
            mc.IS_TRAINING = True
            mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
            model = ZynqDet_FPN(mc)

        elif FLAGS.net == "squeezeDet_FPN":
            mc = kitti_squeezeDet_FPN_config()
            mc.IS_TRAINING = True
            mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
            model = SqueezeDet_FPN(mc)
        elif FLAGS.net == "yolo":
            mc = kitti_vgg16_config()
            mc.IS_TRAINING = True
            mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
            model = tinyDarkNet_FPN(mc)
        elif FLAGS.net == "ZynqDet_Quant":
            mc = kitti_zynqDet_FPN_config()
            mc.IS_TRAINING = True
            mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
            model = ZynqDet_FPN_Quant(mc)
        else:
            assert (0)

        imdb = kitti(FLAGS.image_set, FLAGS.data_path, mc)

        # save model size, flops, activations by layers
        with open(os.path.join(FLAGS.train_dir, 'model_metrics.txt'),
                  'w') as f:
            f.write('Number of parameter by layer:\n')
            count = 0
            for c in model.model_size_counter:
                f.write('\t{}: {}\n'.format(c[0], c[1]))
                count += c[1]
            f.write('\ttotal: {}\n'.format(count))

            count = 0
            f.write('\nActivation size by layer:\n')
            for c in model.activation_counter:
                f.write('\t{}: {}\n'.format(c[0], c[1]))
                count += c[1]
            f.write('\ttotal: {}\n'.format(count))

            count = 0
            f.write('\nNumber of flops by layer:\n')
            for c in model.flop_counter:
                f.write('\t{}: {}\n'.format(c[0], c[1]))
                count += c[1]
            f.write('\ttotal: {}\n'.format(count))
        f.close()
        print('Model statistics saved to {}.'.format(
            os.path.join(FLAGS.train_dir, 'model_metrics.txt')))

        def _load_data_per_scale(label_per_batch, box_delta_per_batch,
                                 aidx_per_batch, bbox_per_batch, s, ANCHORS):
            label_indices, bbox_indices, box_delta_values, mask_indices, box_values, \
                = [], [], [], [], []
            aidx_set = set()

            for i in range(len(label_per_batch)):  # batch_size
                for j in range(
                        len(aidx_per_batch[i]
                            [s])):  # number of annotations with IOU > 0.5
                    aidx, gt_id = aidx_per_batch[i][s][j]

                    if (i, aidx) not in aidx_set:
                        aidx_set.add((i, aidx))
                        label_indices.append(
                            [i, aidx, label_per_batch[i][gt_id]])
                        mask_indices.append([i, aidx])
                        bbox_indices.extend([[i, aidx, k] for k in range(4)])
                        box_delta_values.extend(box_delta_per_batch[i][s][j])
                        box_values.extend(bbox_per_batch[i][gt_id])
                    else:
                        num_discarded_labels += 1

            input_mask = np.reshape(
                sparse_to_dense(mask_indices, [mc.BATCH_SIZE, ANCHORS],
                                [1.0] * len(mask_indices)),
                [mc.BATCH_SIZE, ANCHORS, 1])
            box_delta_input = sparse_to_dense(bbox_indices,
                                              [mc.BATCH_SIZE, ANCHORS, 4],
                                              box_delta_values)
            box_input = sparse_to_dense(bbox_indices,
                                        [mc.BATCH_SIZE, ANCHORS, 4],
                                        box_values)
            labels = sparse_to_dense(label_indices,
                                     [mc.BATCH_SIZE, ANCHORS, mc.CLASSES],
                                     [1.0] * len(label_indices))

            return input_mask, box_delta_input, box_input, labels

        def _load_data(load_to_placeholder=True):
            # read batch input
            image_per_batch, label_per_batch, box_delta_per_batch, aidx_per_batch, \
                bbox_per_batch = imdb.read_batch()
            # print("label", label_per_batch[0])
            # print(aidx_per_batch[0])
            input_mask, box_delta_input, box_input, labels  =\
            _load_data_per_scale(label_per_batch, box_delta_per_batch, aidx_per_batch, bbox_per_batch, 0, mc.ANCHORS)

            input_mask2, box_delta_input2, box_input2, labels2 =\
            _load_data_per_scale(label_per_batch, box_delta_per_batch, aidx_per_batch, bbox_per_batch, 1, mc.ANCHORS2)

            input_mask3, box_delta_input3, box_input3, labels3 =\
            _load_data_per_scale(label_per_batch, box_delta_per_batch, aidx_per_batch, bbox_per_batch, 2, mc.ANCHORS3)

            input_mask_total = np.concatenate(
                [input_mask, input_mask2, input_mask3], 1)
            box_delta_input_total = np.concatenate(
                [box_delta_input, box_delta_input2, box_delta_input3], 1)
            box_input_total = np.concatenate(
                [box_input, box_input2, box_input3], 1)
            labels_total = np.concatenate([labels, labels2, labels3], 1)

            # print("input_mask_shape", input_mask_total.shape)
            # if mc.DEBUG_MODE:
            #   print ('Warning: Discarded {}/({}) labels that are assigned to the same '
            #          'anchor'.format(num_discarded_labels, num_labels))

            if load_to_placeholder:
                image_input = model.ph_image_input
                input_mask = model.ph_input_mask
                box_delta_input = model.ph_box_delta_input
                box_input = model.ph_box_input
                labels = model.ph_labels
            else:
                image_input = model.image_input
                input_mask = model.input_mask
                box_delta_input = model.box_delta_input
                box_input = model.box_input
                labels = model.labels

            feed_dict = {
                image_input: image_per_batch,
                input_mask: input_mask_total,
                box_delta_input: box_delta_input_total,
                box_input: box_input_total,
                labels: labels_total
            }
            # feed_dict = {
            #     image_input: image_per_batch,
            #     input_mask: np.reshape(
            #         sparse_to_dense(
            #             mask_indices, [mc.BATCH_SIZE, mc.ANCHOR_TOTAL],
            #             [1.0]*len(mask_indices)),
            #         [mc.BATCH_SIZE, mc.ANCHOR_TOTAL, 1]),
            #     box_delta_input: sparse_to_dense(
            #         bbox_indices, [mc.BATCH_SIZE, mc.ANCHOR_TOTAL, 4],
            #         box_delta_values),
            #     box_input: sparse_to_dense(
            #         bbox_indices, [mc.BATCH_SIZE, mc.ANCHOR_TOTAL, 4],
            #         box_values),
            #     labels: sparse_to_dense(
            #         label_indices,
            #         [mc.BATCH_SIZE, mc.ANCHOR_TOTAL, mc.CLASSES],
            #         [1.0]*len(label_indices)),
            # }

            # feed_dict = {
            #     image_input: image_per_batch,
            #     input_mask: np.reshape(
            #         sparse_to_dense(
            #             mask_indices, [mc.BATCH_SIZE, mc.ANCHORS],
            #             [1.0]*len(mask_indices)),
            #         [mc.BATCH_SIZE, mc.ANCHORS, 1]),
            #     box_delta_input: sparse_to_dense(
            #         bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
            #         box_delta_values),
            #     box_input: sparse_to_dense(
            #         bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
            #         box_values),
            #     labels: sparse_to_dense(
            #         label_indices,
            #         [mc.BATCH_SIZE, mc.ANCHORS, mc.CLASSES],
            #         [1.0]*len(label_indices)),
            # }
            return feed_dict, image_per_batch, label_per_batch, bbox_per_batch

        def _enqueue(sess, coord):
            try:
                while not coord.should_stop():
                    feed_dict, _, _, _ = _load_data()
                    # print("input.shape", feed_dict[model.ph_labels].shape)
                    sess.run(model.enqueue_op, feed_dict=feed_dict)
                    if mc.DEBUG_MODE:
                        print("added to the queue")
                if mc.DEBUG_MODE:
                    print("Finished enqueue")
            except Exception, e:
                coord.request_stop(e)

        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)
        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
        ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)

        if FLAGS.resume:
            if ckpt and ckpt.model_checkpoint_path:
                print("restoring...", FLAGS.train_dir,
                      ckpt.model_checkpoint_path)
                saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            init = tf.global_variables_initializer()
            sess.run(init)

        coord = tf.train.Coordinator()

        if mc.NUM_THREAD > 0:
            enq_threads = []
            for _ in range(mc.NUM_THREAD):
                enq_thread = threading.Thread(target=_enqueue,
                                              args=[sess, coord])
                # enq_thread.isDaemon()
                enq_thread.start()
                enq_threads.append(enq_thread)

        threads = tf.train.start_queue_runners(coord=coord, sess=sess)
        run_options = tf.RunOptions(timeout_in_ms=60000)
        best_map = 0
        # try:
        for step in xrange(FLAGS.max_steps):
            # for step in xrange(1):
            if coord.should_stop():
                sess.run(model.FIFOQueue.close(cancel_pending_enqueues=True))
                coord.request_stop()
                coord.join(threads)
                break

            start_time = time.time()

            if step % FLAGS.summary_step == 0:
                feed_dict, image_per_batch, label_per_batch, bbox_per_batch = \
                    _load_data(load_to_placeholder=False)
                op_list = [
                    model.train_op, model.loss, summary_op, model.det_boxes,
                    model.det_probs, model.det_class, model.conf_loss,
                    model.bbox_loss, model.class_loss
                ]
                _, loss_value, summary_str, det_boxes, det_probs, det_class, \
                    conf_loss, bbox_loss, class_loss = sess.run(
                        op_list, feed_dict=feed_dict)

                _viz_prediction_result(model, image_per_batch, bbox_per_batch,
                                       label_per_batch, det_boxes, det_class,
                                       det_probs)
                image_per_batch = bgr_to_rgb(image_per_batch)
                viz_summary = sess.run(
                    model.viz_op,
                    feed_dict={model.image_to_show: image_per_batch})

                summary_writer.add_summary(summary_str, step)
                summary_writer.add_summary(viz_summary, step)
                summary_writer.flush()

                # if step > 10000:
                #   test.main()
                #   opts = options.parse_opts()
                #   opts.eval_dir = './data/out/tmp/bbox'
                #   scene_list =  []
                #   for i in test_set:
                #     scene_list.append('exp{:03d}_B'.format(i))
                #   mAP = eval_det.eval_batch(scene_list,opts)
                #   if (mAP > best_map):
                #     best_map = mAP
                #     print("best_map at step ", step)
                #     checkpoint_path = os.path.join(FLAGS.train_dir, 'best_model.ckpt')
                #     saver.save(sess, checkpoint_path, global_step=step)

                print('conf_loss: {}, bbox_loss: {}, class_loss: {}'.format(
                    conf_loss, bbox_loss, class_loss))
            else:
                # print("input", feed_dict[model.labels], label_per_batch)
                # print(label_per_batch[0])
                # print(feed_dict[model.labels][0,0,:])
                # print(feed_dict[model.labels][0,1,:])
                # print("input.shape", feed_dict[model.labels].shape)

                if mc.NUM_THREAD > 0:
                    _, loss_value, conf_loss, bbox_loss, class_loss, num_objects = sess.run(
                        [
                            model.train_op, model.loss, model.conf_loss,
                            model.bbox_loss, model.class_loss,
                            model.num_objects
                        ],
                        options=run_options)
                else:
                    feed_dict, _, label_per_batch, bbox_per_batch = _load_data(
                        load_to_placeholder=False)
                    _, loss_value, conf_loss, bbox_loss, class_loss = sess.run(
                        [
                            model.train_op, model.loss, model.conf_loss,
                            model.bbox_loss, model.class_loss
                        ],
                        feed_dict=feed_dict)

            duration = time.time() - start_time

            assert not np.isnan(loss_value), \
                'Model diverged. Total loss: {}, conf_loss: {}, bbox_loss: {}, ' \
                'class_loss: {}'.format(loss_value, conf_loss, bbox_loss, class_loss)

            if step % 10 == 0:
                num_images_per_step = mc.BATCH_SIZE
                images_per_sec = num_images_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: step %d, loss = %.2f, conf_loss = %.2f, bbox_loss = %.2f, class_loss = %.2f (%.1f images/sec; %.3f '
                    'sec/batch)')
                print(format_str %
                      (datetime.now(), step, loss_value, conf_loss, bbox_loss,
                       class_loss, images_per_sec, sec_per_batch))
                sys.stdout.flush()

            # Save the model checkpoint periodically.
            if step % FLAGS.checkpoint_step == 0 or (step +
                                                     1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir,
                                               'model_{}.ckpt'.format(step))
                saver.save(sess, checkpoint_path, global_step=step)
                tf.train.write_graph(sess.graph.as_graph_def(),
                                     FLAGS.train_dir,
                                     'tensorflowModel.pbtxt',
                                     as_text=True)
示例#14
0
def train():
  """Train SqueezeDet model"""
  assert FLAGS.dataset == 'KITTI', \
      'Currently only support KITTI dataset'
  os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu
  with tf.Graph().as_default():
    assert FLAGS.net == 'shuffleDet' or FLAGS.net == 'squeezeDet' or FLAGS.net == 'squeezeDet+', \
        'Selected neural net architecture not supported: {}'.format(FLAGS.net)
    if FLAGS.net == 'squeezeDet':
      mc = kitti_squeezeDet_config()
      mc.IS_TRAINING = True
      mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
      model = SqueezeDet(mc)
    elif FLAGS.net == 'shuffleDet':
      mc = kitti_shuffleDet_config()
      mc.IS_TRAINING = True
      mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
      model = shuffleDet(mc)
    imdb = kitti(FLAGS.image_set, FLAGS.data_path, mc)

    # save model size, flops, activations by layers
    _build_model_metrics(model)   

    def _load_data(load_to_placeholder=True):
      # read batch input
      image_per_batch, label_per_batch, box_delta_per_batch, aidx_per_batch, \
          bbox_per_batch = imdb.read_batch()

      label_indices, bbox_indices, box_delta_values, mask_indices, box_values, \
          = [], [], [], [], []
      aidx_set = set()
      num_discarded_labels = 0
      num_labels = 0
      for i in range(len(label_per_batch)): # batch_size
        for j in range(len(label_per_batch[i])): # number of annotations
          num_labels += 1
          if (i, aidx_per_batch[i][j]) not in aidx_set:
            aidx_set.add((i, aidx_per_batch[i][j]))
            label_indices.append(
                [i, aidx_per_batch[i][j], label_per_batch[i][j]])
            mask_indices.append([i, aidx_per_batch[i][j]])
            bbox_indices.extend(
                [[i, aidx_per_batch[i][j], k] for k in range(4)])
            box_delta_values.extend(box_delta_per_batch[i][j])
            box_values.extend(bbox_per_batch[i][j])
          else:
            num_discarded_labels += 1

      if load_to_placeholder:
        image_input = model.ph_image_input
        input_mask = model.ph_input_mask
        box_delta_input = model.ph_box_delta_input
        box_input = model.ph_box_input
        labels = model.ph_labels
      else:
        image_input = model.image_input
        input_mask = model.input_mask
        box_delta_input = model.box_delta_input
        box_input = model.box_input
        labels = model.labels

      feed_dict = {
          image_input: image_per_batch,
          input_mask: np.reshape(
              sparse_to_dense(
                  mask_indices, [mc.BATCH_SIZE, mc.ANCHORS],
                  [1.0]*len(mask_indices)),
              [mc.BATCH_SIZE, mc.ANCHORS, 1]),
          box_delta_input: sparse_to_dense(
              bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
              box_delta_values),
          box_input: sparse_to_dense(
              bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
              box_values),
          labels: sparse_to_dense(
              label_indices,
              [mc.BATCH_SIZE, mc.ANCHORS, mc.CLASSES],
              [1.0]*len(label_indices)),
      }

      return feed_dict, image_per_batch, label_per_batch, bbox_per_batch

    def _enqueue(sess, coord):
      try:
        while not coord.should_stop():
          feed_dict, _, _, _ = _load_data()
          sess.run(model.enqueue_op, feed_dict=feed_dict)
          if mc.DEBUG_MODE:
            print ("added to the queue")
        if mc.DEBUG_MODE:
          print ("Finished enqueue")
      except Exception, e:
        coord.request_stop(e)

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    # print(tf.global_variables())


    # variables = tf.trainable_variables() 
    variables = tf.contrib.framework.get_variables_to_restore()

    variables_to_resotre = [v for v in variables if 'Momentum' not in v.name and 'iou' not in v.name and 'global_step' not in v.name and 'preds' not in v.name]
    saver = tf.train.Saver(variables_to_resotre)
    summary_op = tf.summary.merge_all()
    checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
    # saver.restore(sess, checkpoint_path)

    # ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
    # print('ckpt is:'+str(ckpt))
    # if ckpt and ckpt.model_checkpoint_path:
    #     saver.restore(sess, ckpt.model_checkpoint_path)

    summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
    init = tf.global_variables_initializer()
    init = tf.initialize_all_variables()
    sess.run(init)

    coord = tf.train.Coordinator()

    if mc.NUM_THREAD > 0:
      enq_threads = []
      for _ in range(mc.NUM_THREAD):
        enq_thread = threading.Thread(target=_enqueue, args=[sess, coord])
        # enq_thread.isDaemon()
        enq_thread.start()
        enq_threads.append(enq_thread)

    threads = tf.train.start_queue_runners(coord=coord, sess=sess)
    run_options = tf.RunOptions(timeout_in_ms=60000)
    if FLAGS.reload:
      checkpoint_path = os.path.join(FLAGS.train_dir, '/')
      print('load the parameters from: '+str(FLAGS.train_dir))
      ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
      if ckpt and ckpt.model_checkpoint_path:
        print('Restores from checkpoint....')
        # Restores from checkpoint
        saver.restore(sess, ckpt.model_checkpoint_path)
        # Assuming model_checkpoint_path looks something like:
        #   /my-favorite-path/cifar10_train/model.ckpt-0,
        # extract global_step from it.
        # step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
        step = 1
      else:
        step = 1
        print('No checkpoint found!')
    else:
      tf.gfile.DeleteRecursively(FLAGS.train_dir)
      tf.gfile.MakeDirs(FLAGS.train_dir)
      step = 1

    # try:
    for step in xrange(FLAGS.max_steps):
      if coord.should_stop():
        sess.run(model.FIFOQueue.close(cancel_pending_enqueues=True))
        coord.request_stop()
        coord.join(threads)
        break

      start_time = time.time()

      if step % FLAGS.summary_step == 0:
        feed_dict, image_per_batch, label_per_batch, bbox_per_batch = \
            _load_data(load_to_placeholder=False)
        op_list = [
            model.train_op, model.loss, summary_op, model.det_boxes,
            model.det_probs, model.det_class, model.conf_loss,
            model.bbox_loss, model.class_loss
        ]
        _, loss_value, summary_str, det_boxes, det_probs, det_class, \
            conf_loss, bbox_loss, class_loss = sess.run(
                op_list, feed_dict=feed_dict)

        _viz_prediction_result(
            model, image_per_batch, bbox_per_batch, label_per_batch, det_boxes,
            det_class, det_probs)
        image_per_batch = bgr_to_rgb(image_per_batch)
        viz_summary = sess.run(
            model.viz_op, feed_dict={model.image_to_show: image_per_batch})

        summary_writer.add_summary(summary_str, step)
        summary_writer.add_summary(viz_summary, step)
        summary_writer.flush()

        print ('conf_loss: {}, bbox_loss: {}, class_loss: {}'.
            format(conf_loss, bbox_loss, class_loss))
      else:
        if mc.NUM_THREAD > 0:
          _, loss_value, conf_loss, bbox_loss, class_loss = sess.run(
              [model.train_op, model.loss, model.conf_loss, model.bbox_loss,
               model.class_loss], options=run_options)
        else:
          feed_dict, _, _, _ = _load_data(load_to_placeholder=False)
          _, loss_value, conf_loss, bbox_loss, class_loss = sess.run(
              [model.train_op, model.loss, model.conf_loss, model.bbox_loss,
               model.class_loss], feed_dict=feed_dict)

      duration = time.time() - start_time

      assert not np.isnan(loss_value), \
          'Model diverged. Total loss: {}, conf_loss: {}, bbox_loss: {}, ' \
          'class_loss: {}'.format(loss_value, conf_loss, bbox_loss, class_loss)

      if step % 10 == 0:
        num_images_per_step = mc.BATCH_SIZE
        images_per_sec = num_images_per_step / duration
        sec_per_batch = float(duration)
        format_str = ('%s: step %d, loss = %.2f (%.1f images/sec; %.3f '
                      'sec/batch)')
        print (format_str % (datetime.now(), step, loss_value,
                             images_per_sec, sec_per_batch))
        sys.stdout.flush()

      # Save the model checkpoint periodically.
      if step % FLAGS.checkpoint_step == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step,write_state = True)
示例#15
0
def train():
    """Train SqueezeDetPlus model"""

    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu

    if not os.path.exists(FLAGS.train_dir):
        os.makedirs(FLAGS.train_dir)

    with tf.Graph().as_default():
        # Below logic is implemented to choose between pruning and regular training
        # Also the type of network is chosen (depending on pruned structure)
        # Learning rates are dependent on number of to be updated parameters
        # Similarly as lambdas for loss functions (see paper)
        if FLAGS.pruning:
            if FLAGS.net == 'squeezeDet+PruneFilter':
                mc = kitti_squeezeDetPlus_config()
                mc.LEARNING_RATE = 0.1
                mc.IS_TRAINING = True
                mc.IS_PRUNING = True
                mc.LITE_MODE = False
                mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
                model = SqueezeDetPlusPruneFilter(mc)
            elif FLAGS.net == 'squeezeDet+PruneFilterShape':
                mc = kitti_squeezeDetPlus_config()
                mc.LEARNING_RATE = 0.001
                mc.IS_TRAINING = True
                mc.IS_PRUNING = True
                mc.LITE_MODE = False
                mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
                model = SqueezeDetPlusPruneFilterShape(mc)
            elif FLAGS.net == 'squeezeDet+PruneLayer':
                mc = kitti_squeezeDetPlus_config()
                mc.LEARNING_RATE = 0.0003
                mc.IS_TRAINING = True
                mc.IS_PRUNING = True
                mc.LITE_MODE = False
                mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
                model = SqueezeDetPlusPruneLayer(mc)
        else:
            if FLAGS.net == 'squeezeDet+PruneFilter':
                mc = kitti_squeezeDetPlus_config()
                mc.IS_TRAINING = True
                mc.IS_PRUNING = False
                mc.LITE_MODE = False
                mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
                model = SqueezeDetPlusPruneFilter(mc)
            elif FLAGS.net == 'squeezeDet+PruneFilterShape':
                mc = kitti_squeezeDetPlus_config()
                mc.IS_TRAINING = True
                mc.IS_PRUNING = False
                mc.LITE_MODE = False
                mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
                model = SqueezeDetPlusPruneFilterShape(mc)
            elif FLAGS.net == 'squeezeDet+PruneLayer':
                mc = kitti_squeezeDetPlus_config()
                mc.IS_TRAINING = True
                mc.IS_PRUNING = False
                mc.LITE_MODE = False
                mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
                model = SqueezeDetPlusPruneLayer(mc)

        imdb = kitti(FLAGS.image_set, FLAGS.data_path, mc)

        # save model size, flops, activations by layers
        with open(os.path.join(FLAGS.train_dir, 'model_metrics.txt'),
                  'w') as f:
            f.write('Number of parameter by layer:\n')
            count = 0
            for c in model.model_size_counter:
                f.write('\t{}: {}\n'.format(c[0], c[1]))
                count += c[1]
            f.write('\ttotal: {}\n'.format(count))

            count = 0
            f.write('\nActivation size by layer:\n')
            for c in model.activation_counter:
                f.write('\t{}: {}\n'.format(c[0], c[1]))
                count += c[1]
            f.write('\ttotal: {}\n'.format(count))

            count = 0
            f.write('\nNumber of flops by layer:\n')
            for c in model.flop_counter:
                f.write('\t{}: {}\n'.format(c[0], c[1]))
                count += c[1]
            f.write('\ttotal: {}\n'.format(count))
        f.close()
        print('Model statistics saved to {}.'.format(
            os.path.join(FLAGS.train_dir, 'model_metrics.txt')))

        def _load_data(load_to_placeholder=True):
            # read batch input
            image_per_batch, label_per_batch, box_delta_per_batch, aidx_per_batch, \
                bbox_per_batch = imdb.read_batch()

            label_indices, bbox_indices, box_delta_values, mask_indices, box_values, \
                = [], [], [], [], []
            aidx_set = set()
            num_discarded_labels = 0
            num_labels = 0
            for i in range(len(label_per_batch)):  # batch_size
                for j in range(len(
                        label_per_batch[i])):  # number of annotations
                    num_labels += 1
                    if (i, aidx_per_batch[i][j]) not in aidx_set:
                        aidx_set.add((i, aidx_per_batch[i][j]))
                        label_indices.append(
                            [i, aidx_per_batch[i][j], label_per_batch[i][j]])
                        mask_indices.append([i, aidx_per_batch[i][j]])
                        bbox_indices.extend([[i, aidx_per_batch[i][j], k]
                                             for k in range(4)])
                        box_delta_values.extend(box_delta_per_batch[i][j])
                        box_values.extend(bbox_per_batch[i][j])
                    else:
                        num_discarded_labels += 1

            if mc.DEBUG_MODE:
                print(
                    'Warning: Discarded {}/({}) labels that are assigned to the same '
                    'anchor'.format(num_discarded_labels, num_labels))

            if load_to_placeholder:
                image_input = model.ph_image_input
                input_mask = model.ph_input_mask
                box_delta_input = model.ph_box_delta_input
                box_input = model.ph_box_input
                labels = model.ph_labels
            else:
                image_input = model.image_input
                input_mask = model.input_mask
                box_delta_input = model.box_delta_input
                box_input = model.box_input
                labels = model.labels

            feed_dict = {
                image_input:
                image_per_batch,
                input_mask:
                np.reshape(
                    sparse_to_dense(mask_indices, [mc.BATCH_SIZE, mc.ANCHORS],
                                    [1.0] * len(mask_indices)),
                    [mc.BATCH_SIZE, mc.ANCHORS, 1]),
                box_delta_input:
                sparse_to_dense(bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
                                box_delta_values),
                box_input:
                sparse_to_dense(bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
                                box_values),
                labels:
                sparse_to_dense(label_indices,
                                [mc.BATCH_SIZE, mc.ANCHORS, mc.CLASSES],
                                [1.0] * len(label_indices)),
            }

            return feed_dict, image_per_batch, label_per_batch, bbox_per_batch

        def _enqueue(sess, coord):
            try:
                while not coord.should_stop():
                    feed_dict, _, _, _ = _load_data()
                    sess.run(model.enqueue_op, feed_dict=feed_dict)
                    if mc.DEBUG_MODE:
                        print("added to the queue")
                if mc.DEBUG_MODE:
                    print("Finished enqueue")
            except Exception as e:
                coord.request_stop(e)

        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))

        summary_op = tf.summary.merge_all()
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)

        if ckpt and ckpt.model_checkpoint_path:
            print("restoring checkpoint")
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            uninitialized_vars = []
            for var in tf.all_variables():
                try:
                    sess.run(var)
                except tf.errors.FailedPreconditionError:
                    uninitialized_vars.append(var)

            init_new_vars_op = tf.initialize_variables(uninitialized_vars)

            sess.run(init_new_vars_op)

        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

        coord = tf.train.Coordinator()

        if mc.NUM_THREAD > 0:
            enq_threads = []
            for _ in range(mc.NUM_THREAD):
                enq_thread = threading.Thread(target=_enqueue,
                                              args=[sess, coord])
                # enq_thread.isDaemon()
                enq_thread.start()
                enq_threads.append(enq_thread)

        threads = tf.train.start_queue_runners(coord=coord, sess=sess)
        run_options = tf.RunOptions(timeout_in_ms=60000)

        for step in range(FLAGS.max_steps):
            if coord.should_stop():
                sess.run(model.FIFOQueue.close(cancel_pending_enqueues=True))
                coord.request_stop()
                coord.join(threads)
                break

            start_time = time.time()

            if step % FLAGS.summary_step == 0:
                feed_dict, image_per_batch, label_per_batch, bbox_per_batch = \
                    _load_data(load_to_placeholder=False)
                op_list = [
                    model.train_op, model.loss, summary_op, model.det_boxes,
                    model.det_probs, model.det_class, model.conf_loss,
                    model.bbox_loss, model.gamma_loss, model.class_loss
                ]
                _, loss_value, summary_str, det_boxes, det_probs, det_class, \
                   conf_loss, bbox_loss, gamma_loss, class_loss = sess.run(
                       op_list, feed_dict=feed_dict)

                _viz_prediction_result(model, image_per_batch, bbox_per_batch,
                                       label_per_batch, det_boxes, det_class,
                                       det_probs)
                image_per_batch = bgr_to_rgb(image_per_batch)
                viz_summary = sess.run(
                    model.viz_op,
                    feed_dict={model.image_to_show: image_per_batch})

                summary_writer.add_summary(summary_str, step)
                summary_writer.add_summary(viz_summary, step)
                summary_writer.flush()

                print(
                    'conf_loss: {}, bbox_loss: {}, gamma_loss: {}, class_loss: {}'
                    .format(conf_loss, bbox_loss, gamma_loss, class_loss))
            else:
                if mc.NUM_THREAD > 0:

                    _, loss_value, conf_loss, bbox_loss, gamma_loss, class_loss = sess.run(
                        [
                            model.train_op, model.loss, model.conf_loss,
                            model.bbox_loss, model.gamma_loss, model.class_loss
                        ],
                        options=run_options)
                else:

                    feed_dict, _, _, _ = _load_data(load_to_placeholder=False)
                    _, loss_value, conf_loss, bbox_loss, gamma_loss, class_loss = sess.run(
                        [
                            model.train_op, model.loss, model.conf_loss,
                            model.bbox_loss, model.gamma_loss, model.class_loss
                        ],
                        feed_dict=feed_dict)

            duration = time.time() - start_time

            assert not np.isnan(loss_value), \
                'Model diverged. Total loss: {}, conf_loss: {}, bbox_loss: {}, gamma_loss: {}, ' \
                'class_loss: {}'.format(loss_value, conf_loss, bbox_loss, gamma_loss, class_loss)

            if step % 10 == 0:
                num_images_per_step = mc.BATCH_SIZE
                images_per_sec = num_images_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: step %d, loss = %.2f (%.1f images/sec; %.3f '
                    'sec/batch)')
                print (format_str % (datetime.now(), step, loss_value,
                                     images_per_sec, sec_per_batch))\

                sys.stdout.flush()

            # Save the model checkpoint periodically.
            if step % FLAGS.checkpoint_step == 0 or (step +
                                                     1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

        coord.request_stop()
示例#16
0
def train():
    """Train SqueezeDet model"""
    assert FLAGS.net in ['vgg16', 'vgg16_v2', 'vgg16_v3', 'yolo_v2'], \
        'Selected neural net architecture not supported: {}'.format(FLAGS.net)
    assert FLAGS.dataset in ['KITTI', 'PASCAL_VOC', 'VID'], \
        'Either KITTI / PASCAL_VOC / VID'
    if FLAGS.dataset == 'KITTI':
        if FLAGS.net == 'yolo_v2':
            raise NotImplementedError
        else:
            mc = kitti_vgg16_config()
        mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
        train_imdb = kitti(FLAGS.train_set, FLAGS.data_path, mc)
    elif FLAGS.dataset == 'PASCAL_VOC':
        if FLAGS.net == 'yolo_v2':
            mc = pascal_voc_yolo_config()
        else:
            mc = pascal_voc_vgg16_config()
        mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
        train_imdb = pascal_voc(FLAGS.train_set, FLAGS.year, FLAGS.data_path,
                                mc)
    elif FLAGS.dataset == 'VID':
        if FLAGS.net == 'yolo_v2':
            mc = vid_yolo_config()
        else:
            mc = vid_vgg16_config()
        mc.PRETRAINED_MODEL_PATH = FLAGS.pretrained_model_path
        train_imdb = vid(FLAGS.train_set, FLAGS.data_path, mc)

    with tf.Graph().as_default():

        if FLAGS.net == 'vgg16':
            model = VGG16ConvDet(mc, FLAGS.gpu)
        elif FLAGS.net == 'vgg16_v2':
            model = VGG16ConvDetV2(mc, FLAGS.gpu)
        elif FLAGS.net == 'vgg16_v3':
            model = VGG16ConvDetV3(mc, FLAGS.gpu)
        elif FLAGS.net == 'yolo_v2':
            model = YOLO_V2(mc, FLAGS.gpu)

        # save model size, flops, activations by layers
        with open(os.path.join(FLAGS.train_dir, 'model_metrics.txt'),
                  'w') as f:
            f.write('Number of parameter by layer:\n')
            count = 0
            for c in model.model_size_counter:
                f.write('\t{}: {}\n'.format(c[0], c[1]))
                count += c[1]
            f.write('\ttotal: {}\n'.format(count))

            count = 0
            f.write('\nActivation size by layer:\n')
            for c in model.activation_counter:
                f.write('\t{}: {}\n'.format(c[0], c[1]))
                count += c[1]
            f.write('\ttotal: {}\n'.format(count))

            count = 0
            f.write('\nNumber of flops by layer:\n')
            for c in model.flop_counter:
                f.write('\t{}: {}\n'.format(c[0], c[1]))
                count += c[1]
            f.write('\ttotal: {}\n'.format(count))
        f.close()
        print('Model statistics saved to {}.'.format(
            os.path.join(FLAGS.train_dir, 'model_metrics.txt')))

        saver = tf.train.Saver(tf.all_variables())
        summary_op = tf.summary.merge_all()
        init = tf.initialize_all_variables()

        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        sess.run(init)
        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

        for step in xrange(FLAGS.max_steps):
            start_time = time.time()

            # read batch input
            image_per_batch, label_per_batch, box_delta_per_batch, aidx_per_batch, \
                bbox_per_batch = train_imdb.read_batch()

            label_indices, bbox_indices, box_delta_values, mask_indices, box_values, \
                num_discarded_labels, num_labels = _process_batch(image_per_batch, \
                                label_per_batch, box_delta_per_batch, \
                                aidx_per_batch, bbox_per_batch, mc.DEBUG_MODE)

            feed_dict = {
                model.image_input:
                image_per_batch,
                model.is_training:
                True,
                model.keep_prob:
                mc.KEEP_PROB,
                model.input_mask:
                np.reshape(
                    sparse_to_dense(mask_indices, [mc.BATCH_SIZE, mc.ANCHORS],
                                    [1.0] * len(mask_indices)),
                    [mc.BATCH_SIZE, mc.ANCHORS, 1]),
                model.box_delta_input:
                sparse_to_dense(bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
                                box_delta_values),
                model.box_input:
                sparse_to_dense(bbox_indices, [mc.BATCH_SIZE, mc.ANCHORS, 4],
                                box_values),
                model.labels:
                sparse_to_dense(label_indices,
                                [mc.BATCH_SIZE, mc.ANCHORS, mc.CLASSES],
                                [1.0] * len(label_indices)),
            }

            if step % FLAGS.summary_step == 0:
                op_list = [
                    model.train_op, model.loss, summary_op, model.det_boxes,
                    model.det_probs, model.det_class, model.conf_loss,
                    model.bbox_loss, model.class_loss
                ]
                _, loss_value, summary_str, det_boxes, det_probs, det_class, conf_loss, \
                    bbox_loss, class_loss = sess.run(op_list, feed_dict=feed_dict)

                _viz_prediction_result(model, image_per_batch, bbox_per_batch,
                                       label_per_batch, det_boxes, det_class,
                                       det_probs)
                image_per_batch = bgr_to_rgb(image_per_batch)
                viz_summary = sess.run(
                    model.viz_op,
                    feed_dict={model.image_to_show: image_per_batch})

                num_discarded_labels_op = tf.summary.scalar(
                    'counter/num_discarded_labels', num_discarded_labels)
                num_labels_op = tf.summary.scalar('counter/num_labels',
                                                  num_labels)

                counter_summary_str = sess.run(
                    [num_discarded_labels_op, num_labels_op])

                summary_writer.add_summary(summary_str, step)
                summary_writer.add_summary(viz_summary, step)
                for sum_str in counter_summary_str:
                    summary_writer.add_summary(sum_str, step)

                print('conf_loss: {}, bbox_loss: {}, class_loss: {}'.format(
                    conf_loss, bbox_loss, class_loss))
            else:
                _, loss_value, conf_loss, bbox_loss, class_loss = sess.run(
                    [
                        model.train_op, model.loss, model.conf_loss,
                        model.bbox_loss, model.class_loss
                    ],
                    feed_dict=feed_dict)

            duration = time.time() - start_time

            assert not np.isnan(loss_value), \
                'Model diverged. Total loss: {}, conf_loss: {}, bbox_loss: {}, ' \
                'class_loss: {}'.format(loss_value, conf_loss, bbox_loss, class_loss)

            if step % 10 == 0:
                num_images_per_step = mc.BATCH_SIZE
                images_per_sec = num_images_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: step %d, loss = %.2f (%.1f images/sec; %.3f '
                    'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    images_per_sec, sec_per_batch))
                sys.stdout.flush()

            # Save the model checkpoint periodically.
            if step % FLAGS.checkpoint_step == 0 or (step +
                                                     1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)