def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint-path', required=True)
    parser.add_argument('--image', nargs='+', required=True)
    parser.add_argument('--num-classes', type=int, required=True)
    args = parser.parse_args()

    model = SqueezeNet(weights=None, classes=args.num_classes)
    model.load_weights(args.checkpoint_path)

    xs = []
    for path in args.image:
        img = image.load_img(path, target_size=(SIZE, SIZE))
        x = image.img_to_array(img)
        xs.append(x)

    xs = np.array(xs)
    xs = preprocess_input(xs)

    probs = model.predict(xs)

    print('')
    for i, path in enumerate(args.image):
        print('%s' % path)
        print('    Prediction: %s' % np.argmax(probs[i]))
示例#2
0
class TLClassifier:
    def __init__(self, is_site):
        #TODO load classifier
        assert not is_site
        weights_file = r'light_classification/models/squeezenet_weights.h5'  #Replace with real world classifier

        image_shape = (224, 224, 3)

        self.states = (TrafficLight.RED, TrafficLight.YELLOW,
                       TrafficLight.GREEN, TrafficLight.UNKNOWN)

        print('Loading model..')
        self.model = SqueezeNet(len(self.states), *image_shape)
        self.model.load_weights(weights_file, by_name=True)
        self.model._make_predict_function()
        print('Loaded weights: %s' % weights_file)

    def get_classification(self, image):
        """Determines the color of the traffic light in the image

        Args:
            image (cv::Mat): image containing the traffic light

        Returns:
            int: ID of traffic light color (specified in styx_msgs/TrafficLight)

        """
        mini_batch = cv2.resize(
            image, (224, 224),
            cv2.INTER_AREA).astype('float')[np.newaxis, ..., ::-1] / 255.
        light = self.states[np.argmax(self.model.predict(mini_batch))]

        return light
示例#3
0
def create_squeezenet(nclass, train=True):
    input = Input(shape=(HEIGHT, WIDTH, 3), name="img_input")
    base_model = SqueezeNet(include_top=False,
                            weights=None,
                            input_tensor=input)
    x = base_model.output

    if train:
        base_model.load_weights(
            "models/squeezenet_weights_tf_dim_ordering_tf_kernels_notop.h5")
        x = Dropout(0.2, name='drop9')(x)
    x = Convolution2D(nclass, (1, 1), padding='valid', name='conv10')(x)
    x = Activation('relu', name='relu_conv10')(x)
    x = GlobalAveragePooling2D()(x)
    x = Activation('softmax', name='loss')(x)
    model = Model(input, x, name='squeezenet')
    return model
示例#4
0
# images = np.array([cv2.resize(cv2.cvtColor(im, cv2.COLOR_GRAY2RGB), (227, 227)) for im in images])
# images = np.array(images)
# print images.shape
# classes = to_categorical(classes, nb_classes=nr_classes)

print('Loading model..')
model = SqueezeNet(nb_classes, input_shape=input_shape)
adam = Adam(lr=0.040)
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss="categorical_crossentropy",
              optimizer='adam',
              metrics=['accuracy'])
if os.path.isfile(weights_file):
    print('Loading weights: %s' % weights_file)
    model.load_weights(weights_file, by_name=True)

print('Fitting model')
# model.fit(images, classes, batch_size=batch_size, nb_epoch=nb_epoch, verbose=1, validation_split=0.2, initial_epoch=0)
model.fit_generator(training_data,
                    samples_per_epoch=samples_per_epoch,
                    validation_data=validation_data,
                    nb_val_samples=nb_val_samples,
                    nb_epoch=nb_epoch,
                    verbose=1,
                    initial_epoch=initial_epoch)
print("Finished fitting model")

print('Saving weights')
model.save_weights(weights_file, overwrite=True)
print('Evaluating model')
示例#5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint-path', required=True)
    parser.add_argument('--test-dir', default='data/test')
    parser.add_argument('--output-file', default='confusion_matrix.png')
    parser.add_argument('--batch-size', type=int, default=32)
    parser.add_argument('--limit', type=int, default=0)
    parser.add_argument('--num-classes', type=int, required=True)
    args = parser.parse_args()

    model = SqueezeNet(weights=None, classes=args.num_classes)
    model.load_weights(args.checkpoint_path)

    data = []
    classes = sorted(os.listdir(args.test_dir))
    if len(classes) != args.num_classes:
        raise ValueError('expecting %d classes, found %d in %s' % (
            args.num_classes,
            len(classes),
            args.test_dir
        ))

    for ic in range(len(classes)):
        directory = os.path.join(args.test_dir, classes[ic])
        for path in os.listdir(directory):
            full_path = os.path.join(directory, path)
            data.append((full_path, ic))

    rng = random.Random(0)
    rng.shuffle(data)
    if args.limit > 0:
        data = data[:args.limit]

    chunked = list(chunks(data, args.batch_size))
    gstart = time.time()
    cmat = np.zeros((len(classes), len(classes)), dtype=np.int)
    last_print = 0
    for i, chunk in enumerate(chunked):
        start = time.time()
        paths, ys = zip(*chunk)
        xs = []
        for path in paths:
            img = image.load_img(path, target_size=(SIZE, SIZE))
            x = image.img_to_array(img)
            xs.append(x)
        xs = np.array(xs)
        xs = preprocess_input(xs)

        probs = model.predict(xs, batch_size=args.batch_size)
        preds = probs.argmax(axis=1)
        for actual, predicted in zip(ys, preds):
            cmat[actual][predicted] += 1

        diff = time.time() - start
        gdiff = time.time() - gstart
        if time.time() - last_print > 1 or i == len(chunked)-1:
            last_print = time.time()
            print('batch %d/%d (in %.3fs, %.1fs elapsed, %.1fs remaining)' % (
                i+1,
                len(chunked),
                time.time() - start,
                gdiff,
                gdiff / (i+1) * (len(chunked)-i-1)
            ))

    print(cmat)
    plot_cmat(cmat, classes, args.output_file)
    print('saved figure to %s' % args.output_file)
示例#6
0
print('Loading images..')
paths = [
    os.path.join(subdir, f) for subdir, dirs, files in os.walk(images_dir)
    for f in files if f.endswith('.jpg')
]

images = [load_image(path) for path in paths]
nr_classes = len(decode)
images = np.array(images)

print('Loading model..')
model = SqueezeNet(nr_classes)
model.compile(loss="categorical_crossentropy", optimizer="adam")
if os.path.isfile(weights_file):
    print('Loading weights...')
    model.load_weights(weights_file)

print("Classifying images...")
# predictions = model.predict(images, batch_size=100, verbose=1)
# print('Predicted %s images' % len(predictions))
for i in xrange(len(images)):
    img = np.expand_dims(images[i], axis=0)
    res = model.predict(img)
    results = res[0].argsort()[-5:][::-1]
    print('%s: ' % paths[i])
    for j in xrange(len(results)):
        result = decode[results[j]]
        text = '%.3f: %s' % (res[0][results[j]], result)
        print(text)

    # confidences = predictions[i].argsort()[-5:][::-1]
示例#7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train-dir', default='data/train')
    parser.add_argument('--test-dir', default='data/test')
    parser.add_argument('--logdir', default='logs')
    parser.add_argument('--batch-size', type=int, default=32)
    parser.add_argument('--epochs', type=int, required=True)
    parser.add_argument('--num-classes', type=int, required=True)
    parser.add_argument('--checkpoint-pattern',
                        default='weights-{epoch:d}-{val_acc:.4f}.hdf5')
    parser.add_argument('--learning-rate', type=float, default=1e-4)
    parser.add_argument('--restore')
    args = parser.parse_args()

    # count samples
    train_files = count_files(args.train_dir, '.png')
    print('Found %d train files.' % train_files)
    test_files = count_files(args.test_dir, '.png')
    print('Found %d test files.' % test_files)

    if args.restore:
        model = SqueezeNet(weights=None, classes=args.num_classes)
        model.load_weights(args.restore)
    else:
        model = SqueezeNet(weights='imagenet', classes=args.num_classes)

    model.compile(loss='categorical_crossentropy',
                  optimizer=optimizers.Adam(lr=args.learning_rate),
                  metrics=['accuracy'])

    train_datagen = ImageDataGenerator(
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        preprocessing_function=preprocess_single)

    test_datagen = ImageDataGenerator(preprocessing_function=preprocess_single)

    train_generator = train_datagen.flow_from_directory(
        args.train_dir,
        target_size=(SIZE, SIZE),
        batch_size=args.batch_size,
        class_mode='categorical')

    test_generator = test_datagen.flow_from_directory(
        args.test_dir,
        target_size=(SIZE, SIZE),
        batch_size=args.batch_size,
        class_mode='categorical')

    checkpoint = ModelCheckpoint(args.checkpoint_pattern,
                                 monitor='val_acc',
                                 verbose=1,
                                 save_best_only=True,
                                 mode='max')
    tensorboard = TensorBoard(log_dir=args.logdir,
                              histogram_freq=0,
                              batch_size=args.batch_size,
                              write_graph=True,
                              write_grads=False,
                              write_images=False,
                              embeddings_freq=0,
                              embeddings_layer_names=None,
                              embeddings_metadata=None)
    callbacks = [checkpoint, tensorboard]

    model.fit_generator(train_generator,
                        steps_per_epoch=(train_files // args.batch_size),
                        epochs=args.epochs,
                        validation_data=test_generator,
                        validation_steps=(test_files // args.batch_size),
                        callbacks=callbacks)