Example #1
0
 def handleNewRoi(self, roi, width, height):
     roi = cv.resize(roi, (96, 96))
     images = np.array([roi])
     images = AgenderNetMobileNetV2.prep_image(images)
     prediction = self.model.predict(images)
     genders, ages = AgenderNetMobileNetV2.decode_prediction(prediction)
     print(genders, ages)
     if len(genders) > 0:
         return genders[0], ages[0]
     return None, None
    def __init__(self, db, *args, **kwargs):
        """
        The main window of the program
        """

        # initialise db
        self.db = db

        # sets the frame and the pages
        tk.Frame.__init__(self, *args, **kwargs)
        self.p1 = TablePage(db, self)
        self.p2 = CamPage(self)
        self.p3 = DisplayPage(self)

        # sets a button frame at the top, insert into the window
        buttonframe = tk.Frame(self)
        container = tk.Frame(self)
        buttonframe.pack(side="top", fill="x", expand=False)
        container.pack(side="top", fill="both", expand=True)

        # place the button at the top
        self.p1.place(in_=container, x=0, y=0, relwidth=1, relheight=1)
        self.p2.place(in_=container, x=0, y=0, relwidth=1, relheight=1)
        self.p3.place(in_=container, x=0, y=0, relwidth=1, relheight=1)

        # set the function of each button
        b1 = tk.Button(buttonframe, text="Adverts", command=self.p1.lift)
        b2 = tk.Button(buttonframe, text="Cam", command=self.p2.lift)
        b3 = tk.Button(buttonframe, text="Display", command=self.p3.lift)

        # position the buttons
        b1.pack(side="left")
        b2.pack(side="left")
        b3.pack(side="left")

        # show the first page first
        self.p1.show()

        # values for to be used in the few functions
        width, height = 600, 600

        # captured video object with height and width wanted
        self.cap = cv2.VideoCapture(0)
        self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
        self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)

        # sets margin and image size
        self.margin = 0.4
        self.img_size = 96

        # gets the face detector
        self.detector = dlib.get_frontal_face_detector()

        # loads the model and weights
        self.model = AgenderNetMobileNetV2()
        self.model.load_weights(
            'model/weight/mobilenetv2/mobilenet_v2-02-4.0368-0.8085-9.6728.h5')

        # calls the function to show the frame
        self.show_frame()
Example #3
0
def main():
    args = parser.parse_args()
    MODEL = args.model

    model = None
    logger.info('Load model and weight')
    if MODEL == 'mobilenetv2':
        model = AgenderNetMobileNetV2()
        model.load_weights(
            'model/weight/mobilenetv2/model.10-3.8290-0.8965-6.9498.h5')
    elif MODEL == 'inceptionv3':
        model = AgenderNetInceptionV3()
        model.load_weights(
            'model/weight/inceptionv3/model.16-3.7887-0.9004-6.6744.h5')
    elif MODEL == 'mobilenetv1':
        model = AgenderNetMobileNetV1()
        #model.load_weigts('model/weight/mobilenetv1/model.h5')
    else:
        model = AgenderSSRNet(64, [3, 3, 3], 1.0, 1.0)
        model.load_weights(
            'model/weight/ssrnet/model.37-7.3318-0.8643-7.1952.h5')

    logger.info('Read image')
    image = cv2.imread(
        'data/imdb_aligned/02/nm0000002_rm1346607872_1924-9-16_2004.jp')
    image = cv2.resize(image, (model.input_size, model.input_size))
    image = np.expand_dims(image, axis=0)
    image = model.prep_image(image)

    logger.info('Predict with {}'.format(MODEL))
    wrapped = wrapper(predictone, model, image)
    logger.info(proces_time(wrapped))
Example #4
0
def main():
    args = parser.parse_args()
    DB = args.db_name
    MODEL = args.model

    data = pd.read_csv('data/db/{}.csv'.format(DB))
    model = None
    logger.info('Load model and weight')
    if MODEL == 'mobilenetv2':
        model = AgenderNetMobileNetV2()
        model.load_weights(
            'model/weight/mobilenetv2/model.10-3.8290-0.8965-6.9498.h5')
    elif MODEL == 'inceptionv3':
        model = AgenderNetInceptionV3()
        model.load_weights(
            'model/weight/inceptionv3/model.16-3.7887-0.9004-6.6744.h5')
    elif MODEL == 'mobilenetv1':
        model = AgenderNetMobileNetV1()
        #model.load_weights('model/weight/mobilenetv1/___.h5')
    else:
        model = AgenderSSRNet(64, [3, 3, 3], 1.0, 1.0)
        model.load_weights(
            'model/weight/ssrnet/model.37-7.3318-0.8643-7.1952.h5')

    logger.info('Read image')
    images = [
        cv2.imread('{}_aligned/{}'.format(DB, path))
        for path in tqdm(data.full_path.values)
    ]
    images = [
        cv2.resize(image, (model.input_size, model.input_size))
        for image in images
    ]
    images = np.array(images)
    images = model.prep_image(images)

    logger.info('Predict data')
    start = time.time()
    prediction = model.predict(images)
    pred_gender, pred_age = model.decode_prediction(prediction)
    elapsed = time.time() - start
    logger.info('Time elapsed {:.2f} sec'.format(elapsed))

    result = pd.DataFrame()
    result['full_path'] = data['full_path']
    result['age'] = pred_age
    result['gender'] = pred_gender
    result.to_csv('result/{}.csv'.format(DB), index=False)
Example #5
0
 def __init__(self):
     self.model = AgenderNetMobileNetV2()
     self.model.load_weights(
         'model/weight/mobilenetv2/model.10-3.8290-0.8965-6.9498.h5')
     print("DemographicDetector() -> __init__ complete!")
Example #6
0
def main():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    K.tensorflow_backend.set_session(sess)

    args = parser.parse_args()
    MODEL = args.model
    EPOCH = args.epoch
    BATCH_SIZE = args.batch_size
    NUM_WORKER = args.num_worker

    db, paths, age_label, gender_label = load_data()
    n_fold = 1
    print('[K-FOLD] Started...')
    kf = KFold(n_splits=10, shuffle=True, random_state=1)
    kf_split = kf.split(age_label)
    for train_idx, test_idx in kf_split:
        model = None
        if MODEL == 'ssrnet':
            model = AgenderSSRNet(64, [3, 3, 3], 1.0, 1.0)
        elif MODEL == 'inceptionv3':
            model = AgenderNetInceptionV3()
        elif MODEL == 'mobilenetv1':
            model = AgenderNetMobileNetV1()
        else:
            model = AgenderNetMobileNetV2()
        train_db = db[train_idx]
        train_paths = paths[train_idx]
        train_age = age_label[train_idx]
        train_gender = gender_label[train_idx]

        test_db = db[test_idx]
        test_paths = paths[test_idx]
        test_age = age_label[test_idx]
        test_gender = gender_label[test_idx]

        losses = {
            "age_prediction": "categorical_crossentropy",
            "gender_prediction": "categorical_crossentropy",
        }
        metrics = {
            "age_prediction": mae,
            "gender_prediction": "acc",
        }
        if MODEL == 'ssrnet':
            losses = {
                "age_prediction": "mae",
                "gender_prediction": "mae",
            }
            metrics = {
                "age_prediction": "mae",
                "gender_prediction": "binary_accuracy",
            }

        callbacks = [
            ModelCheckpoint(
                "train_weight/{}-{epoch:02d}-{val_loss:.4f}-{val_gender_prediction_acc:.4f}-{val_age_prediction_mae:.4f}.h5"
                .format(MODEL),
                verbose=1,
                save_best_only=True,
                save_weights_only=True),
            CSVLogger('train_log/{}-{}.log'.format(MODEL, n_fold))
        ]
        if MODEL == 'ssrnet':
            callbacks = [
                ModelCheckpoint(
                    "train_weight/{}-{epoch:02d}-{val_loss:.4f}-{val_gender_prediction_binary_accuracy:.4f}-{val_age_prediction_mean_absolute_error:.4f}.h5"
                    .format(MODEL),
                    verbose=1,
                    save_best_only=True,
                    save_weights_only=True),
                CSVLogger('train_log/{}-{}.log'.format(MODEL, n_fold)),
                DecayLearningRate([30, 60])
            ]
        model.compile(optimizer='adam', loss=losses, metrics=metrics)
        model.fit_generator(DataGenerator(model, train_db, train_paths,
                                          train_age, train_gender, BATCH_SIZE),
                            validation_data=DataGenerator(
                                model, test_db, test_paths, test_age,
                                test_gender, BATCH_SIZE),
                            epochs=EPOCH,
                            verbose=2,
                            workers=NUM_WORKER,
                            use_multiprocessing=True,
                            max_queue_size=int(BATCH_SIZE * 2),
                            callbacks=callbacks)
        n_fold += 1
        del train_db, train_paths, train_age, train_gender
        del test_db, test_paths, test_age, test_gender
Example #7
0
        mot_tracker.update(dets, genders, ages)
        for tracker in mot_tracker.trackers:
            (left, top, right, bottom) = convert_x_to_bbox(
                tracker.kf.x[:4, :]).astype('int').flatten()
            cv2.rectangle(frame, (left, top), (right, bottom), (0, 255, 0), 2)
            age = tracker.smooth_age()
            gender = 'M' if tracker.smooth_gender() == 1 else 'F'
            cv2.putText(frame, "id: {} {} {}".format(tracker.id, gender, age),
                        (left - 10, top - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 255, 0), 2)
        cv2.putText(frame, "{:.1f} FPS".format(fps.fps()), (1100, 50),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
        cv2.namedWindow("Frame", cv2.WINDOW_NORMAL)
        cv2.imshow("Frame", frame)
        key = cv2.waitKey(1) & 0xFF
        fps.update()
        grabbed, frame = stream.read()
    fps.stop()
    stream.release()
    cv2.destroyAllWindows()


if __name__ == '__main__':
    print('[INFO] Load model')
    model = AgenderNetMobileNetV2()
    print('[INFO] Load weight')
    model.load_weights(
        'model/weight/mobilenetv2/model.10-3.8290-0.8965-6.9498.h5')
    graph = tf.get_default_graph()
    main()