示例#1
0
labels=np.array(labels, dtype=np.float32)
index=np.random.permutation(len(images))
images=images[index]
labels=labels[index]

model=model.model(config.nbr_classes, config.nbr_boxes, config.cellule_y, config.cellule_x)

checkpoint=tf.train.Checkpoint(model=model)
checkpoint.restore(tf.train.latest_checkpoint("./training/"))

grid=np.meshgrid(np.arange(config.cellule_x, dtype=np.float32), np.arange(config.cellule_y, dtype=np.float32))
grid=np.expand_dims(np.stack(grid, axis=-1), axis=2)
grid=np.tile(grid, (1, 1, config.nbr_boxes, 1))

for i in range(len(images)):
  img=common.prepare_image(images[i], labels[i], False)
  predictions=model(np.array([images[i]]))
  
  pred_boxes=predictions[0, :, :, :, 0:4]
  pred_conf=common.sigmoid(predictions[0, :, :, :, 4])
  pred_classes=common.softmax(predictions[0, :, :, :, 5:])
  ids=np.argmax(pred_classes, axis=-1)

  x_center=((grid[:, :, :, 0]+common.sigmoid(pred_boxes[:, :, :, 0]))*config.r_x)
  y_center=((grid[:, :, :, 1]+common.sigmoid(pred_boxes[:, :, :, 1]))*config.r_y)
  w=(np.exp(pred_boxes[:, :, :, 2])*config.anchors[:, 0]*config.r_x)
  h=(np.exp(pred_boxes[:, :, :, 3])*config.anchors[:, 1]*config.r_y)

  x_min=(x_center-w/2).astype(np.int32)
  y_min=(y_center-h/2).astype(np.int32)
  x_max=(x_center+w/2).astype(np.int32)
示例#2
0
文件: images.py 项目: L42Project/Yolo
import cv2
import numpy as np
import common
import config as cfg

list_all_labels, list_labels, list_labels, list_attributs = common.infos_xmls(
    cfg.dir_dataset, with_attribut=cfg.with_attribut, verbose=True)

if not cfg.with_attribut:
    list_attributs = None

images, labels, labels2, mask_attributs = common.prepare_dataset(
    cfg.dir_dataset,
    list_labels=list_labels,
    list_attributs=list_attributs,
    data_augmentation=False,
    verbose=True)

print("Nbr image:", len(images))
images = images / 255

for i in range(len(images)):
    image = common.prepare_image(images[i], labels[i], True)
    cv2.imshow("image",
               cv2.resize(image, (2 * cfg.image_size, 2 * cfg.image_size)))
    key = cv2.waitKey(3) & 0xFF
    if key == ord('q'):
        break