コード例 #1
0
ファイル: detect.py プロジェクト: hdchenjian/ssd-tensorflow
def main():
    model_file = 'model/ssd300_vgg16_short.pb'
    graph = load_graph(model_file)

    with tf.Session(graph=graph) as sess:
        image_input = sess.graph.get_tensor_by_name(
            'import/define_input/image_input:0')
        result = sess.graph.get_tensor_by_name("import/result/result:0")

        image_path = 'demo/test.jpg'
        img = cv2.imread(image_path)
        img = np.float32(img)
        img = cv2.resize(img, (300, 300))
        img = np.expand_dims(img, axis=0)
        print('image_input', image_input)
        print('img', type(img), img.shape, img[0][1][1])
        enc_boxes = sess.run(result, feed_dict={image_input: img})
        print('enc_boxes', type(enc_boxes), len(enc_boxes), type(enc_boxes[0]),
              enc_boxes[0].shape)
        print('detect_result_[0][0]', enc_boxes[0][0])

        lid2name = {
            0: 'Aeroplane',
            1: 'Bicycle',
            2: 'Bird',
            3: 'Boat',
            4: 'Bottle',
            5: 'Bus',
            6: 'Car',
            7: 'Cat',
            8: 'Chair',
            9: 'Cow',
            10: 'Diningtable',
            11: 'Dog',
            12: 'Horse',
            13: 'Motorbike',
            14: 'Person',
            15: 'Pottedplant',
            16: 'Sheep',
            17: 'Sofa',
            18: 'Train',
            19: 'Tvmonitor'
        }
        preset = get_preset_by_name('vgg300')
        anchors = get_anchors_for_preset(preset)
        print('anchors', type(anchors))
        boxes = decode_boxes(enc_boxes[0], anchors, 0.5, lid2name, None)
        boxes = suppress_overlaps(boxes)[:200]
        print('boxes', boxes)

        img = cv2.imread(image_path)
        for box in boxes:
            color = (31, 119, 180)
            draw_box(img, box[1], color)

            box_data = '{} {} {} {} {} {}\n'.format(
                box[1].label, box[1].labelid, box[1].center.x, box[1].center.y,
                box[1].size.w, box[1].size.h)
            print('box_data', box_data)
        cv2.imwrite(image_path + '_out.jpg', img)
コード例 #2
0
def main():
    #---------------------------------------------------------------------------
    # Parse the commandline
    #---------------------------------------------------------------------------
    parser = argparse.ArgumentParser(description='Process a dataset for SSD')
    parser.add_argument('--data-source', default='brats18', help='data source')
    parser.add_argument('--data-dir',
                        default='/data/tng016/brats18',
                        help='data directory')
    parser.add_argument('--validation-fraction',
                        type=float,
                        default=0.025,
                        help='fraction of the data to be used for validation')
    parser.add_argument('--expand-probability',
                        type=float,
                        default=0.5,
                        help='probability of running sample expander')
    parser.add_argument('--sampler-trials',
                        type=int,
                        default=50,
                        help='number of time a sampler tries to find a sample')
    parser.add_argument('--annotate',
                        type=str2bool,
                        default='False',
                        help="Annotate the data samples")
    parser.add_argument('--compute-td',
                        type=str2bool,
                        default='True',
                        help="Compute training data")
    parser.add_argument('--preset',
                        default='vgg300',
                        choices=['vgg300', 'vgg512'],
                        help="The neural network preset")
    parser.add_argument('--process-test',
                        type=str2bool,
                        default='False',
                        help="process the test dataset")
    args = parser.parse_args()

    print('[i] Data source:          ', args.data_source)
    print('[i] Data directory:       ', args.data_dir)
    print('[i] Validation fraction:  ', args.validation_fraction)
    print('[i] Expand probability:   ', args.expand_probability)
    print('[i] Sampler trials:       ', args.sampler_trials)
    print('[i] Annotate:             ', args.annotate)
    print('[i] Compute training data:', args.compute_td)
    print('[i] Preset:               ', args.preset)
    print('[i] Process test dataset: ', args.process_test)

    #---------------------------------------------------------------------------
    # Load the data source
    #---------------------------------------------------------------------------
    print('[i] Configuring the data source...')
    #try:
    source = load_data_source(args.data_source)
    source.load_trainval_data(args.data_dir, args.validation_fraction)
    if args.process_test:
        source.load_test_data(args.data_dir)
    print('[i] # training samples:   ', source.num_train)
    print('[i] # validation samples: ', source.num_valid)
    print('[i] # testing samples:    ', source.num_test)
    print('[i] # classes:            ', source.num_classes)
    #except (ImportError, AttributeError, RuntimeError) as e:
    # print('[!] Unable to load data source:', str(e))
    # return 1

    #---------------------------------------------------------------------------
    # Annotate samples
    #---------------------------------------------------------------------------
    if args.annotate:
        print('[i] Annotating samples...')
        annotate(args.data_dir, source.train_samples, source.colors, 'train')
        annotate(args.data_dir, source.valid_samples, source.colors, 'valid')
        if args.process_test:
            annotate(args.data_dir, source.test_samples, source.colors,
                     'test ')

    #---------------------------------------------------------------------------
    # Compute the training data
    #---------------------------------------------------------------------------
    if args.compute_td:
        preset = get_preset_by_name(args.preset)
        with open(args.data_dir + '/train-samples.pkl', 'wb') as f:
            pickle.dump(source.train_samples, f)
        with open(args.data_dir + '/valid-samples.pkl', 'wb') as f:
            pickle.dump(source.valid_samples, f)

        with open(args.data_dir + '/training-data.pkl', 'wb') as f:
            data = {
                'preset':
                preset,
                'num-classes':
                source.num_classes,
                'colors':
                source.colors,
                'lid2name':
                source.lid2name,
                'lname2id':
                source.lname2id,
                'train-transforms':
                build_train_transforms(preset, source.num_classes,
                                       args.sampler_trials,
                                       args.expand_probability),
                'valid-transforms':
                build_valid_transforms(preset, source.num_classes)
            }
            pickle.dump(data, f)

    return 0
コード例 #3
0
def main():
    checkpoint_file = 'model/e25.ckpt'
    metagraph_file = checkpoint_file + '.meta'
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)

        preset = get_preset_by_name('vgg300')
        anchors = get_anchors_for_preset(preset)
        net = SSDVGG(sess, preset)
        net.build_from_metagraph(metagraph_file, checkpoint_file)

        #for tensor in tf.get_default_graph().as_graph_def().node: print(tensor.name)

        image_path = 'demo/test.jpg'
        img = cv2.imread(image_path)
        img = np.float32(img)
        img = cv2.resize(img, (300, 300))
        img = np.expand_dims(img, axis=0)
        print('image_input', net.image_input)
        print('img', type(img), img.shape, img[0][1][1])
        #exit()
        enc_boxes = sess.run(net.result, feed_dict={net.image_input: img})
        print('enc_boxes', type(enc_boxes), len(enc_boxes), type(enc_boxes[0]),
              enc_boxes[0].shape)

        lid2name = {
            0: 'Aeroplane',
            1: 'Bicycle',
            2: 'Bird',
            3: 'Boat',
            4: 'Bottle',
            5: 'Bus',
            6: 'Car',
            7: 'Cat',
            8: 'Chair',
            9: 'Cow',
            10: 'Diningtable',
            11: 'Dog',
            12: 'Horse',
            13: 'Motorbike',
            14: 'Person',
            15: 'Pottedplant',
            16: 'Sheep',
            17: 'Sofa',
            18: 'Train',
            19: 'Tvmonitor'
        }
        print('anchors', type(anchors))
        boxes = decode_boxes(enc_boxes[0], anchors, 0.5, lid2name, None)
        boxes = suppress_overlaps(boxes)[:200]

        img = cv2.imread(image_path)
        for box in boxes:
            color = (31, 119, 180)
            draw_box(img, box[1], color)

            box_data = '{} {} {} {} {} {}\n'.format(
                box[1].label, box[1].labelid, box[1].center.x, box[1].center.y,
                box[1].size.w, box[1].size.h)
            print('box_data', box_data)
        cv2.imwrite(image_path + '_out.jpg', img)
コード例 #4
0
def main():
    # Parse the commandline
    parser = argparse.ArgumentParser(description='Process a dataset for SSD')
    parser.add_argument('--data-source',
                        default='pascal_voc',
                        help='data source')
    parser.add_argument('--data-dir',
                        default='pascal-voc',
                        help='data directory')
    parser.add_argument('--expand-probability',
                        type=float,
                        default=0.5,
                        help='probability of running sample expander')
    parser.add_argument('--sampler-trials',
                        type=int,
                        default=50,
                        help='number of time a sampler tries to find a sample')
    parser.add_argument('--preset',
                        default='vgg300',
                        choices=['vgg300', 'vgg512'])
    args = parser.parse_args()

    print('[i] Data source:          ', args.data_source)
    print('[i] Data directory:       ', args.data_dir)
    print('[i] Expand probability:   ', args.expand_probability)
    print('[i] Sampler trials:       ', args.sampler_trials)
    print('[i] Preset:               ', args.preset)

    # Load the data source
    print('[i] Configuring the data source...')
    try:
        source = load_data_source(args.data_source)
        source.load_trainval_data(args.data_dir)
        print('[i] # training samples:   ', source.num_train)
        print('[i] # validation samples: ', source.num_valid)
        print('[i] # classes:            ', source.num_classes)
    except (ImportError, AttributeError, RuntimeError) as e:
        print('[!] Unable to load data source:', str(e))
        return 1

    # Compute the training data
    preset = get_preset_by_name(args.preset)
    with open(args.data_dir + '/train-samples.pkl', 'wb') as f:
        pickle.dump(source.train_samples, f)
    with open(args.data_dir + '/valid-samples.pkl', 'wb') as f:
        pickle.dump(source.valid_samples, f)

    with open(args.data_dir + '/training-data.pkl', 'wb') as f:
        data = {
            'preset':
            preset,
            'num-classes':
            source.num_classes,
            'colors':
            source.colors,
            'lid2name':
            source.lid2name,
            'lname2id':
            source.lname2id,
            'train-transforms':
            build_train_transforms(preset, source.num_classes,
                                   args.sampler_trials,
                                   args.expand_probability),
            'valid-transforms':
            build_valid_transforms(preset, source.num_classes)
        }
        pickle.dump(data, f)

    return 0
コード例 #5
0
def main():
    parser = argparse.ArgumentParser(description='Data prepropcess for SSD')
    parser.add_argument('--data_dir',
                        default='/home/gs/data/VOCdevkit.../VOC2007'
                        )  # KITTI PATH: '/home/gs/data/KITTI'
    parser.add_argument('--train_save_dir', default=os.path.dirname(__file__))
    parser.add_argument('--preset',
                        default='vgg300',
                        choices=['vgg300', 'vgg512'])
    parser.add_argument('--data_set', default='VOC', choices=['VOC', 'KITTI'])
    parser.add_argument('--sampler_trials', type=int, default=50)
    parser.add_argument('--expand_probability', type=float, default=0.5)
    args = parser.parse_args()

    source = DATASETSource(args.data_set)
    source.load_trainval_data(args.data_dir, args.data_set)

    preset = get_preset_by_name(args.preset)

    with open(os.path.join(args.train_save_dir, 'train-samples.pkl'),
              'wb') as fp:
        pickle.dump(source.train_samples, fp)

    # pickle 모듈을 이용하면 원하는 데이터를 자료형의 변경없이 파일로 저장하여 그대로 로드할 수 있다.
    with open(
            os.path.join(args.train_save_dir, 'train-details.pkl'), 'wb'
    ) as fp:  #pickle로 데이터를 저장하거나 불러올때는 파일을 바이트형식으로 읽거나 써야한다. (wb, rb)
        data = {
            'preset':
            preset,
            'num_classes':
            source.num_classes,
            'idx2label':
            source.idx_to_label,
            'label2idx':
            source.label_to_idx,
            'train_transforms':
            build_train_transform(preset, source.num_classes,
                                  args.sampler_trials, args.expand_probability,
                                  source.label_to_idx, args.data_set)
        }
        pickle.dump(data, fp)

    source = DATASETSource(args.data_set)
    source.load_validval_data(args.data_dir, args.data_set)

    with open(os.path.join(args.train_save_dir, 'valid-samples.pkl'),
              'wb') as fp:
        pickle.dump(source.valid_samples, fp)

    # pickle 모듈을 이용하면 원하는 데이터를 자료형의 변경없이 파일로 저장하여 그대로 로드할 수 있다.
    with open(
            os.path.join(args.train_save_dir, 'valid-details.pkl'), 'wb'
    ) as fp:  #pickle로 데이터를 저장하거나 불러올때는 파일을 바이트형식으로 읽거나 써야한다. (wb, rb)
        data = {
            'preset':
            preset,
            'num_classes':
            source.num_classes,
            'idx2label':
            source.idx_to_label,
            'label2idx':
            source.label_to_idx,
            'valid_transforms':
            build_valid_transforms(preset, source.num_classes,
                                   source.label_to_idx, args.data_set)
        }
        pickle.dump(data, fp)

    return 0