Esempio n. 1
0
    def test_value_manipulation(self):
        val = np.random.random((4, 2))
        xth = KTH.variable(val)
        xtf = KTF.variable(val)

        # get_value
        valth = KTH.get_value(xth)
        valtf = KTF.get_value(xtf)
        assert valtf.shape == valth.shape
        assert_allclose(valth, valtf, atol=1e-05)

        # set_value
        val = np.random.random((4, 2))
        KTH.set_value(xth, val)
        KTF.set_value(xtf, val)

        valth = KTH.get_value(xth)
        valtf = KTF.get_value(xtf)
        assert valtf.shape == valth.shape
        assert_allclose(valth, valtf, atol=1e-05)

        # count_params
        assert KTH.count_params(xth) == KTF.count_params(xtf)

        # print_tensor
        check_single_tensor_operation('print_tensor', ())
        check_single_tensor_operation('print_tensor', (2,))
        check_single_tensor_operation('print_tensor', (4, 3))
        check_single_tensor_operation('print_tensor', (1, 2, 3))

        val = np.random.random((3, 2))
        xth = KTH.variable(val)
        xtf = KTF.variable(val)
        assert KTH.get_variable_shape(xth) == KTF.get_variable_shape(xtf)
Esempio n. 2
0
    def test_value_manipulation(self):
        val = np.random.random((4, 2))
        xth = KTH.variable(val)
        xtf = KTF.variable(val)

        # get_value
        valth = KTH.get_value(xth)
        valtf = KTF.get_value(xtf)
        assert valtf.shape == valth.shape
        assert_allclose(valth, valtf, atol=1e-05)

        # set_value
        val = np.random.random((4, 2))
        KTH.set_value(xth, val)
        KTF.set_value(xtf, val)

        valth = KTH.get_value(xth)
        valtf = KTF.get_value(xtf)
        assert valtf.shape == valth.shape
        assert_allclose(valth, valtf, atol=1e-05)

        # count_params
        assert KTH.count_params(xth) == KTF.count_params(xtf)

        # print_tensor
        check_single_tensor_operation('print_tensor', ())
        check_single_tensor_operation('print_tensor', (2, ))
        check_single_tensor_operation('print_tensor', (4, 3))
        check_single_tensor_operation('print_tensor', (1, 2, 3))

        val = np.random.random((3, 2))
        xth = KTH.variable(val)
        xtf = KTF.variable(val)
        assert KTH.get_variable_shape(xth) == KTF.get_variable_shape(xtf)
Esempio n. 3
0
    def test_value_manipulation(self):
        val = np.random.random((4, 2))
        xth = KTH.variable(val)
        xtf = KTF.variable(val)

        # get_value
        valth = KTH.get_value(xth)
        valtf = KTF.get_value(xtf)
        assert valtf.shape == valth.shape
        assert_allclose(valth, valtf, atol=1e-05)

        # set_value
        val = np.random.random((4, 2))
        KTH.set_value(xth, val)
        KTF.set_value(xtf, val)

        valth = KTH.get_value(xth)
        valtf = KTF.get_value(xtf)
        assert valtf.shape == valth.shape
        assert_allclose(valth, valtf, atol=1e-05)

        # count_params
        assert KTH.count_params(xth) == KTF.count_params(xtf)
Esempio n. 4
0
    def test_value_manipulation(self):
        val = np.random.random((4, 2))
        xth = KTH.variable(val)
        xtf = KTF.variable(val)

        # get_value
        valth = KTH.get_value(xth)
        valtf = KTF.get_value(xtf)
        assert valtf.shape == valth.shape
        assert_allclose(valth, valtf, atol=1e-05)

        # set_value
        val = np.random.random((4, 2))
        KTH.set_value(xth, val)
        KTF.set_value(xtf, val)

        valth = KTH.get_value(xth)
        valtf = KTF.get_value(xtf)
        assert valtf.shape == valth.shape
        assert_allclose(valth, valtf, atol=1e-05)

        # count_params
        assert KTH.count_params(xth) == KTF.count_params(xtf)
Esempio n. 5
0
def train(batch_size, input_shape,
          x_train, y_train,
          x_valid, y_valid,
          model_name, num_workers,
          resume):
    print('Found {} images belonging to {} classes'.format(len(x_train), 128))
    print('Found {} images belonging to {} classes'.format(len(x_valid), 128))
    train_generator = AugmentedDataset(
        x_train, y_train,
        batch_size=batch_size, input_shape=input_shape)
    valid_generator = FurnituresDatasetNoAugmentation(
        x_valid, y_valid,
        batch_size=batch_size, input_shape=input_shape)
    class_weight = compute_class_weight(
        'balanced', np.unique(y_train), y_train)
    class_weight_dict = dict.fromkeys(np.unique(y_train))
    for key in class_weight_dict.keys():
        class_weight_dict.update({key: class_weight[key]})


    filepath = 'checkpoint/{}/iter1.hdf5'.format(model_name)
    save_best = ModelCheckpoint(filepath=filepath,
                                verbose=1,
                                monitor='val_acc',
                                save_best_only=True,
                                mode='max')
    save_on_train_end = ModelCheckpoint(filepath=filepath,
                                        verbose=1,
                                        monitor='val_acc',
                                        period=args.epochs)
    reduce_lr = ReduceLROnPlateau(monitor='val_acc',
                                  factor=0.2,
                                  patience=2,
                                  verbose=1)
    callbacks = [save_best, save_on_train_end, reduce_lr]

    if resume == 'True':
        print('\nResume training from the last checkpoint')
        model = load_model(filepath)
        trainable_count = int(
            np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
        print('Trainable params: {:,}'.format(trainable_count))
        model.fit_generator(generator=train_generator,
                            epochs=args.epochs,
                            callbacks=callbacks,
                            validation_data=valid_generator,
                            class_weight=class_weight_dict,
                            workers=num_workers)
    else:
        print('\nTrain the last Dense layer')
        if model_name == 'densenet_201':
            model = build_densenet_201()
        elif model_name == 'inception_v3':
            model = build_inception_v3()
        elif model_name == 'inception_resnet_v2':
            model = build_inception_resnet_v2()
        elif model_name == 'xception':
            model = build_xception()
        for layer in model.layers[:-1]:
            layer.trainable = False
            model.compile(optimizer=Adam(lr=0.001), loss='categorical_crossentropy',
                          metrics=['acc'])
        model.fit_generator(generator=train_generator,
                            epochs=5,
                            callbacks=callbacks,
                            validation_data=valid_generator,
                            class_weight=class_weight_dict,
                            workers=num_workers)
        K.clear_session()

        print("\nFine-tune the network")
        model = load_model(filepath)
        for layer in model.layers:
            layer.trainable = True
            if hasattr(layer, 'kernel_regularizer'):
                layer.kernel_regularizer = regularizers.l2(0.0001)
        trainable_count = int(
            np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
        print('Trainable params: {:,}'.format(trainable_count))
        model.compile(optimizer=Adam(lr=3e-5),
                      loss='categorical_crossentropy',
                      metrics=['acc'])
        model.fit_generator(generator=train_generator,
                            epochs=30,
                            callbacks=callbacks,
                            validation_data=valid_generator,
                            class_weight=class_weight_dict,
                            workers=num_workers)
        K.clear_session()
def train_with_sift_features(batch_size, input_shape, x_train, y_train,
                             x_valid, y_valid, sift_features_train,
                             sift_features_valid, model_name, num_workers,
                             resume):
    print('Found {} images belonging to {} classes'.format(len(x_train), 128))
    print('Found {} images belonging to {} classes'.format(len(x_valid), 128))
    train_generator = AugmentedDatasetWithSiftFeatures(x_train,
                                                       y_train,
                                                       sift_features_train,
                                                       batch_size=batch_size,
                                                       input_shape=input_shape)
    valid_generator = DatasetWithSiftFeatures(x_valid,
                                              y_valid,
                                              sift_features_valid,
                                              batch_size=batch_size,
                                              input_shape=input_shape)
    class_weight = compute_class_weight('balanced', np.unique(y_train),
                                        y_train)
    class_weight_dict = dict.fromkeys(np.unique(y_train))
    for key in class_weight_dict.keys():
        class_weight_dict.update({key: class_weight[key]})

    filepath = 'checkpoint/{}/sift_iter1.hdf5'.format(model_name)
    save_best = ModelCheckpoint(filepath=filepath,
                                verbose=1,
                                monitor='val_acc',
                                save_best_only=True,
                                mode='max')
    save_on_train_end = ModelCheckpoint(filepath=filepath,
                                        verbose=1,
                                        monitor='val_acc',
                                        period=args.epochs)
    reduce_lr = ReduceLROnPlateau(monitor='val_acc',
                                  factor=0.2,
                                  patience=2,
                                  verbose=1)
    callbacks = [save_best, save_on_train_end, reduce_lr]

    if resume == 'True':
        print('\nResume training from the last checkpoint')
        model = load_model(filepath)
        trainable_count = int(
            np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
        print('Trainable params: {:,}'.format(trainable_count))
        model.fit_generator(generator=train_generator,
                            epochs=args.epochs,
                            callbacks=callbacks,
                            validation_data=valid_generator,
                            class_weight=class_weight_dict,
                            workers=num_workers)
    else:
        model = Xception(include_top=False, pooling='max')
        sift_features = Input(shape=(512, ))
        x = Concatenate()([model.layers[-1].output, sift_features])
        x = Dense(units=128,
                  activation='linear',
                  name='predictions',
                  kernel_regularizer=regularizers.l2(0.0001))(x)
        model = Model([model.layers[0].input, sift_features], x)

        for layer in model.layers[:-1]:
            layer.trainable = False

        model.compile(optimizer=Adam(lr=0.001),
                      loss='categorical_hinge',
                      metrics=['acc'])
        model.fit_generator(generator=train_generator,
                            epochs=5,
                            callbacks=callbacks,
                            validation_data=valid_generator,
                            class_weight=class_weight_dict,
                            workers=num_workers)
        K.clear_session()

        print("\nFine-tune the network")
        model = load_model(filepath)
        for layer in model.layers:
            layer.trainable = True
        trainable_count = int(
            np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
        print('Trainable params: {:,}'.format(trainable_count))
        model.compile(optimizer=SGD(lr=0.0001, momentum=0.9),
                      loss='categorical_hinge',
                      metrics=['acc'])
        model.fit_generator(generator=train_generator,
                            epochs=30,
                            callbacks=callbacks,
                            validation_data=valid_generator,
                            class_weight=class_weight_dict,
                            workers=num_workers)
        K.clear_session()