コード例 #1
0
    ]

    # prepare train, valid, annotation list
    root_path = "./data/VOCdevkit/VOC2012/"
    train_img_list, train_annotation_list, val_img_list, val_annotation_list = make_datapath_list(
        root_path)

    color_mean = (104, 117, 123)
    input_size = 300
    # transform = DataTransform(input_size, color_mean)

    train_dataset = MyDataset(train_img_list,
                              train_annotation_list,
                              phase="train",
                              transform=DataTransform(input_size, color_mean),
                              anno_xml=Anno_xml(classes))
    val_dataset = MyDataset(val_img_list,
                            val_annotation_list,
                            phase="val",
                            transform=DataTransform(input_size, color_mean),
                            anno_xml=Anno_xml(classes))

    # print(train_dataset.__getitem__(1))
    batch_size = 4
    train_dataloader = data.DataLoader(train_dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       collate_fn=my_collate_fn)
    val_dataloader = data.DataLoader(val_dataset,
                                     batch_size=batch_size,
                                     shuffle=False,
コード例 #2
0
ファイル: Transform_data.py プロジェクト: PhanTom2003/SSD
    classes = ["aeroplane", "bicycle", "bird",  "boat", "bottle", 
               "bus", "car", "cat", "chair", "cow", "diningtable",
               "dog", "horse", "motorbike", "person", "pottedplant",
               "sheep", "sofa", "train", "tvmonitor"]

    # prepare train, valid, annotation list
    root_path = "./data/VOCdevkit/VOC2012/"
    train_img_list, train_annotation_list, val_img_list, val_annotation_list = make_datapath_list(root_path)

    # read img
    img_file_path = train_img_list[0]
    img = cv2.imread(img_file_path) # Height, Width, Channel(BGR)
    height, width, channels = img.shape

    # annotation information
    trans_anno = Anno_xml(classes)
    anno_info_list = trans_anno(train_annotation_list[0], width, height)

    # plot original image
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) # mặc định của matplotlib là RGB
    plt.show()

    # prepare data transform
    color_mean = (104, 117, 123)
    input_size = 300
    transform = DataTransform(input_size, color_mean)

    # transform train img
    phase = "train"
    img_transformed, boxes, labels = transform(img, phase, anno_info_list[:,:4], anno_info_list[:, 4])
    plt.imshow(cv2.cvtColor(img_transformed, cv2.COLOR_BGR2RGB)) # mặc định của matplotlib là RGB
コード例 #3
0

if __name__ == "__main__":
    classes = ['Car', 'Van', 'Truck', 'Pedestrian', 'Person_sitting', 
        'Cyclist', 'Tram', 'Misc', 'DontCare']

    # prepare train, valid, annotation list
    root_path = "../stereo_datasets/training"
    train_img_list, train_annotation_list, val_img_list, val_annotation_list = make_datapath_list(root_path)

    # prepare data transform
    color_mean = (104, 117, 123)
    input_size = 300

    train_dataset = MyDataset(train_img_list, train_annotation_list, phase="train",
    transform=DataTransform(input_size, color_mean), anno_xml=Anno_xml(classes))

    val_dataset = MyDataset(val_img_list, val_annotation_list, phase="val",
    transform=DataTransform(input_size, color_mean), anno_xml=Anno_xml(classes))

    # print(len(train_dataset))
    # print(train_dataset.__getitem__(1))

    batch_size = 4
    train_dataloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn)
    val_dataloader = data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=my_collate_fn)

    dataloader_dict = {
        "train": train_dataloader,
        "val": val_dataloader
    }