def test_mobilenet_v2_base_and_softmax_classifier_adam(self):
   base_model = bases.MobileNetV2Base()
   head_model = heads.SoftmaxClassifierHead(BATCH_SIZE, BOTTLENECK_SHAPE,
                                            NUM_CLASSES)
   optimizer = optimizers.Adam()
   model = TransferModel(self.dataset_dir, base_model, head_model, optimizer)
   self.assertModelAchievesAccuracy(model, 0.80)
 def test_mobilenet_v2_saved_model_quantized_and_softmax_classifier(self):
   base_model = bases.SavedModelBase(self.mobilenet_dir, quantize=True)
   head_model = heads.SoftmaxClassifierHead(BATCH_SIZE, BOTTLENECK_SHAPE,
                                            NUM_CLASSES)
   optimizer = optimizers.SGD(LEARNING_RATE)
   model = TransferModel(self.dataset_dir, base_model, head_model, optimizer)
   self.assertModelAchievesAccuracy(model, 0.80)
Ejemplo n.º 3
0
    def test_mobilenet_v2_saved_model_and_softmax_classifier_model(self):
        input_size = DEFAULT_INPUT_SIZE
        output_size = 5
        batch_size = DEFAULT_BATCH_SIZE

        converter = tflite_transfer_converter.TFLiteTransferConverter(
            output_size, self._default_base_model,
            heads.SoftmaxClassifierHead(batch_size, (input_size, ),
                                        output_size),
            optimizers.SGD(LEARNING_RATE), batch_size)
        models = converter._convert()

        parameter_shapes = [(input_size, output_size), (output_size, )]
        self.assertSignatureEqual(models['initialize'], [()], parameter_shapes)
        self.assertSignatureEqual(models['bottleneck'], [(1, input_size)],
                                  [(1, input_size)])
        self.assertSignatureEqual(models['train_head'],
                                  [(batch_size, input_size),
                                   (batch_size, output_size)] +
                                  parameter_shapes, [()] + parameter_shapes)
        self.assertSignatureEqual(models['inference'],
                                  [(1, input_size)] + parameter_shapes,
                                  [(1, output_size)])
        self.assertSignatureEqual(models['optimizer'],
                                  parameter_shapes + parameter_shapes,
                                  parameter_shapes)
Ejemplo n.º 4
0
    def test_mobilenet_v2_base_and_softmax_classifier_model(self):
        input_size = 224
        output_size = 5
        batch_size = DEFAULT_BATCH_SIZE

        base = bases.MobileNetV2Base(image_size=input_size)
        head = heads.SoftmaxClassifierHead(batch_size, base.bottleneck_shape(),
                                           output_size)
        optimizer = optimizers.SGD(LEARNING_RATE)

        converter = tflite_transfer_converter.TFLiteTransferConverter(
            output_size, base, head, optimizer, batch_size)
        models = converter._convert()

        parameter_shapes = [(7 * 7 * 1280, output_size), (output_size, )]
        self.assertSignatureEqual(models['initialize'], [()], parameter_shapes)
        self.assertSignatureEqual(models['bottleneck'],
                                  [(1, input_size, input_size, 3)],
                                  [(1, 7, 7, 1280)])
        self.assertSignatureEqual(models['train_head'],
                                  [(batch_size, 7, 7, 1280),
                                   (batch_size, output_size)] +
                                  parameter_shapes, [()] + parameter_shapes)
        self.assertSignatureEqual(models['inference'],
                                  [(1, 7, 7, 1280)] + parameter_shapes,
                                  [(1, output_size)])
        self.assertSignatureEqual(models['optimizer'],
                                  parameter_shapes + parameter_shapes,
                                  parameter_shapes)
Ejemplo n.º 5
0
    def test_mobilenet_v2_base_and_softmax_classifier_model_adam(self):
        input_size = 224
        output_size = 5
        batch_size = DEFAULT_BATCH_SIZE

        base = bases.MobileNetV2Base(image_size=input_size)
        head = heads.SoftmaxClassifierHead(batch_size, base.bottleneck_shape(),
                                           output_size)
        optimizer = optimizers.Adam()

        converter = tflite_transfer_converter.TFLiteTransferConverter(
            output_size, base, head, optimizer, batch_size)
        models = converter._convert()

        param_shapes = [(7 * 7 * 1280, output_size), (output_size, )]
        self.assertSignatureEqual(
            models['optimizer'],
            param_shapes + param_shapes + param_shapes + param_shapes + [()],
            param_shapes + param_shapes + param_shapes + [()])
def main():
    parser = argparse.ArgumentParser(
        description='Combines two TF models into a transfer learning model')
    parser.add_argument('--train_batch_size',
                        help='Training batch size',
                        type=int,
                        default=20)
    parser.add_argument('--num_classes',
                        help='Number of classes for the output',
                        type=int,
                        default=4)

    # Base model configuration.
    base_group = parser.add_mutually_exclusive_group(required=True)
    base_group.add_argument('--base_mobilenetv2',
                            help='Use MobileNetV2 as the base model',
                            dest='base_mobilenetv2',
                            action='store_true')
    base_group.add_argument(
        '--base_model_dir',
        help='Use a SavedModel under a given path as the base model',
        type=str)
    parser.add_argument('--base_quantize',
                        help='Whether the base model should be quantized',
                        dest='base_quantize',
                        action='store_true')
    parser.set_defaults(base_quantize=False)

    # Head model configuration.
    head_group = parser.add_mutually_exclusive_group(required=True)
    head_group.add_argument(
        '--head_model_dir',
        help='Use a SavedModel under a given path as the head model',
        type=str)
    head_group.add_argument('--head_softmax',
                            help='Use SoftmaxClassifier for the head model',
                            dest='head_softmax',
                            action='store_true')
    parser.add_argument(
        '--head_l2_reg',
        help='L2 regularization parameter for SoftmaxClassifier',
        type=float)

    # Optimizer configuration.
    parser.add_argument('--optimizer',
                        required=True,
                        type=str,
                        choices=['sgd', 'adam'],
                        help='Which optimizer should be used')
    parser.add_argument('--sgd_learning_rate',
                        help='Learning rate for SGD',
                        type=float)

    parser.add_argument(
        '--out_model_dir',
        help='Where the generated transfer learning model is saved',
        required=True,
        type=str)
    args = parser.parse_args()

    if args.base_mobilenetv2:
        base = bases.MobileNetV2Base(quantize=args.base_quantize)
    else:
        base = bases.SavedModelBase(args.base_model_dir,
                                    quantize=args.base_quantize)

    if args.head_model_dir:
        head = heads.LogitsSavedModelHead(args.head_model_dir)
    else:
        head = heads.SoftmaxClassifierHead(args.train_batch_size,
                                           base.bottleneck_shape(),
                                           args.num_classes,
                                           l2_reg=args.head_l2_reg)

    if args.optimizer == 'sgd':
        if args.sgd_learning_rate is not None:
            optimizer = optimizers.SGD(args.sgd_learning_rate)
        else:
            raise RuntimeError(
                '--sgd_learning_rate is required when SGD is used as an optimizer'
            )
    elif args.optimizer == 'adam':
        optimizer = optimizers.Adam()

    converter = tflite_transfer_converter.TFLiteTransferConverter(
        args.num_classes, base, head, optimizer, args.train_batch_size)
    converter.convert_and_save(args.out_model_dir)