Ejemplo n.º 1
0
class QuickNetFactory(QuickNetBaseFactory):
    """Quicknet - A model designed for fast inference using [Larq Compute Engine](https://github.com/larq/compute-engine)"""

    name = "quicknet"
    blocks_per_section: Sequence[int] = Field((2, 3, 4, 4))
    section_filters: Sequence[int] = Field((64, 128, 256, 512))
    use_squeeze_and_excite_in_section: Sequence[bool] = Field(
        (False, False, False, False))
    transition_block = Field(lambda self: self.concat_transition_block)

    @property
    def imagenet_weights_path(self):
        return utils.download_pretrained_model(
            model="quicknet",
            version="v0.2.1",
            file="quicknet_weights.h5",
            file_hash=
            "7b4fa94f5241c7aad3412ca42b5db6517dbc4847cff710cb82be10c2f83bc0be",
        )

    @property
    def imagenet_no_top_weights_path(self):
        return utils.download_pretrained_model(
            model="quicknet",
            version="v0.2.1",
            file="quicknet_weights_notop.h5",
            file_hash=
            "359eed6dae43525eddf520ea87ec9b54750ee0e022647775d115a38856be396f",
        )
Ejemplo n.º 2
0
class TrainXNORNet(TrainLarqZooModel):
    model = ComponentField(XNORNetFactory)

    epochs = Field(100)
    batch_size = Field(1200)

    initial_lr: float = Field(0.001)
    x_offset: float = Field(0.0)

    def learning_rate_schedule(self, epoch):
        epoch_dec_1 = 19
        epoch_dec_2 = 30
        epoch_dec_3 = 44
        epoch_dec_4 = 53
        epoch_dec_5 = 66
        epoch_dec_6 = 76
        epoch_dec_7 = 86
        if epoch < epoch_dec_1:
            return self.initial_lr
        elif epoch < epoch_dec_2:
            return self.initial_lr * 0.5
        elif epoch < epoch_dec_3:
            return self.initial_lr * 0.1
        elif epoch < epoch_dec_4:
            return self.initial_lr * 0.1 * 0.5
        elif epoch < epoch_dec_5:
            return self.initial_lr * 0.01
        elif epoch < epoch_dec_6:
            return self.initial_lr * 0.01 * 0.5
        elif epoch < epoch_dec_7:
            return self.initial_lr * 0.01 * 0.1
        else:
            return self.initial_lr * 0.001 * 0.1

    optimizer = Field(lambda self: tf.keras.optimizers.Adam(self.initial_lr))
Ejemplo n.º 3
0
class QuickNetXLFactory(QuickNetBaseFactory):
    """QuickNetXL - A model designed for fast inference using [Larq Compute Engine](https://github.com/larq/compute-engine)
    and high accuracy. This utilises Squeeze and Excite blocks as per [Training binary neural networks with real-to-binary convolutions](https://openreview.net/forum?id=BJg4NgBKvH)."""

    name = "quicknet_xl"
    blocks_per_section: Sequence[int] = Field((6, 8, 12, 6))
    section_filters: Sequence[int] = Field((64, 128, 256, 512))
    use_squeeze_and_excite_in_section: Sequence[bool] = Field(
        (False, False, True, True))
    transition_block = Field(lambda self: self.fp_pointwise_transition_block)

    @property
    def imagenet_weights_path(self):
        return utils.download_pretrained_model(
            model="quicknet_xl",
            version="v0.1.1",
            file="quicknet_xl_weights.h5",
            file_hash=
            "19a41e753dbd4fbc3cbdaecd3627fb536ef55d64702996aae3875a8de3cf8073",
        )

    @property
    def imagenet_no_top_weights_path(self):
        return utils.download_pretrained_model(
            model="quicknet_xl",
            version="v0.1.1",
            file="quicknet_xl_weights_notop.h5",
            file_hash=
            "ad5cbfa333b0aabde75dc524c9ce4a5ae096061da0e2dcf362ec6e587a83a511",
        )
Ejemplo n.º 4
0
class QuickNetLargeFactory(QuickNetBaseFactory):
    """QuickNetLarge - A model designed for fast inference using [Larq Compute Engine](https://github.com/larq/compute-engine)
    and high accuracy. This utilises Squeeze and Excite blocks as per [Training binary neural networks with real-to-binary convolutions](https://openreview.net/forum?id=BJg4NgBKvH)."""

    name = "quicknet_large"
    blocks_per_section: Sequence[int] = Field((4, 4, 4, 4))
    section_filters: Sequence[int] = Field((64, 128, 256, 512))
    use_squeeze_and_excite_in_section: Sequence[bool] = Field(
        (False, False, True, True))
    transition_block = Field(lambda self: self.fp_pointwise_transition_block)

    @property
    def imagenet_weights_path(self):
        return utils.download_pretrained_model(
            model="quicknet_large",
            version="v0.2.1",
            file="quicknet_large_weights.h5",
            file_hash=
            "6bf778e243466c678d6da0e3a91c77deec4832460046fca9e6ac8ae97a41299c",
        )

    @property
    def imagenet_no_top_weights_path(self):
        return utils.download_pretrained_model(
            model="quicknet_large",
            version="v0.2.1",
            file="quicknet_large_weights_notop.h5",
            file_hash=
            "b65d59dd2d5af63d019997b05faff9e003510e2512aa973ee05eb1b82b8792a9",
        )
Ejemplo n.º 5
0
class TrainQuickNet(TrainLarqZooModel):
    model = ComponentField(QuickNetFactory)
    epochs = Field(600)
    batch_size = Field(2048)

    @Field
    def optimizer(self):
        binary_opt = tf.keras.optimizers.Adam(
            learning_rate=CosineDecayWithWarmup(
                max_learning_rate=1e-2,
                warmup_steps=self.steps_per_epoch * 5,
                decay_steps=self.steps_per_epoch * self.epochs,
            ))
        fp_opt = tf.keras.optimizers.SGD(
            learning_rate=CosineDecayWithWarmup(
                max_learning_rate=0.1,
                warmup_steps=self.steps_per_epoch * 5,
                decay_steps=self.steps_per_epoch * self.epochs,
            ),
            momentum=0.9,
        )
        return lq.optimizers.CaseOptimizer(
            (lq.optimizers.Bop.is_binary_variable, binary_opt),
            default_optimizer=fp_opt,
        )
Ejemplo n.º 6
0
class TrainR2B(MultiStageExperiment):
    model_modifier: str = Field("default")
    use_unsign: bool = Field(False)

    stage_0 = ComponentField(TrainFPResnet18)
    stage_1 = ComponentField(TrainR2BBFP)
    stage_2 = ComponentField(TrainR2BBAN)
    stage_3 = ComponentField(TrainR2BBNNAlternative)
Ejemplo n.º 7
0
class TrainR2BBAN(TrainR2BBFP):
    stage = Field(2)
    learning_rate: float = Field(1e-3)

    teacher_model = ComponentField(RealToBinNetFPFactory)
    student_model = ComponentField(RealToBinNetBANFactory)

    initialize_teacher_weights_from = Field("r2b_fp")
Ejemplo n.º 8
0
Archivo: data.py Proyecto: lgeiger/zoo
class Default(ImageClassification):
    decoders = Field(lambda: {"image": tfds.decode.SkipDecoding()})

    input_shape = Field((IMAGE_SIZE, IMAGE_SIZE, 3))

    def input(self, data, training):
        return preprocess_image_bytes(data["image"],
                                      is_training=training,
                                      image_size=IMAGE_SIZE)
Ejemplo n.º 9
0
class TrainR2BBNNAlternative(TrainR2BBNN):
    """We deviate slightly from Martinez et. al. here"""

    warmup_duration = Field(10)
    optimizer = Field(lambda self: tf.keras.optimizers.Adam(
        CosineDecayWithWarmup(
            max_learning_rate=self.learning_rate,
            warmup_steps=self.steps_per_epoch * self.warmup_duration,
            decay_steps=self.steps_per_epoch * self.epochs,
        )))
Ejemplo n.º 10
0
class LarqZooModelTrainingPhase(TrainingPhase):
    # parameters related to the standard cross-entropy training of the student on the target labels
    #  - weight on the loss component for standard classification
    classification_weight: float = Field(1.0)

    # parameters related to the training through attention matching between teacher and student activation volumes
    #  - weight on the loss component for spatial attention matching
    attention_matching_weight: float = Field(0.0)
    #  - list of partial names of the layers for which the outputs should be matched
    attention_matching_volume_names: Optional[List[str]] = Field(
        allow_missing=True)
    #  - optional separate list of partial names for the teacher. If not given, the names above will be used.
    attention_matching_volume_names_teacher: Optional[List[str]] = Field(
        allow_missing=True)
    #  - allow teacher to be trained to better match activations with the student
    attention_matching_train_teacher: bool = Field(False)

    # parameters related to the training through the matching of the output predictions of the teacher and student
    #  - weight on the loss component for knowledge distillation
    output_matching_weight: float = Field(0.0)
    #  - temperature used for the softmax when matching distributions
    output_matching_softmax_temperature: float = Field(1.0)
    #  - allow the teacher to be trained during output distribution matching
    output_matching_train_teacher: bool = Field(False)

    @Field
    def loss(self):
        return getattr(self.__base_getattribute__("model"),
                       "classification_loss")

    metrics = Field(lambda: ["accuracy", "sparse_top_k_categorical_accuracy"])
Ejemplo n.º 11
0
class TrainBinaryAlexNet(TrainLarqZooModel):
    model = ComponentField(BinaryAlexNetFactory)

    batch_size: int = Field(512)
    epochs: int = Field(150)

    def learning_rate_schedule(self, epoch):
        return 1e-2 * 0.5**(epoch // 10)

    optimizer = Field(
        lambda self: tf.keras.optimizers.Adam(self.learning_rate_schedule(0)))
Ejemplo n.º 12
0
class TrainR2BBNN(TrainR2BBFP):
    stage = Field(3)
    learning_rate: float = Field(2e-4)
    weight_decay_constant: float = Field(0.0)

    classification_weight = Field(1.0)
    attention_matching_weight = Field(0.0)
    output_matching_weight = Field(0.8)
    output_matching_softmax_temperature = Field(1.0)

    x_offset: float = Field(0.0)
    teacher_model = ComponentField(RealToBinNetBANFactory)
    student_model = ComponentField(RealToBinNetBNNFactory)

    initialize_teacher_weights_from = Field("r2b_ban")
    initialize_student_weights_from = Field("r2b_ban")
Ejemplo n.º 13
0
class TrainDoReFaNet(TrainLarqZooModel):
    model = ComponentField(DoReFaNetFactory)

    epochs = Field(90)
    batch_size = Field(256)

    learning_rate: float = Field(2e-4)
    decay_start: int = Field(60)
    decay_step_2: int = Field(75)
    fast_decay_start: int = Field(82)

    def learning_rate_schedule(self, epoch):
        if epoch < self.decay_start:
            return self.learning_rate
        elif epoch < self.decay_step_2:
            return self.learning_rate * 0.2
        elif epoch < self.fast_decay_start:
            return self.learning_rate * 0.2 * 0.2
        else:
            return (
                self.learning_rate
                * 0.2
                * 0.2
                * 0.1 ** ((epoch - self.fast_decay_start) // 2 + 1)
            )

    optimizer = Field(
        lambda self: tf.keras.optimizers.Adam(self.learning_rate, epsilon=1e-5)
    )
Ejemplo n.º 14
0
class RealToBinNetBANFactory(RealToBinNetFactory):
    model_name = Field("r2b_ban")
    input_quantizer = "ste_sign"
    kernel_quantizer = None
    kernel_constraint = None

    @property
    def kernel_regularizer(self):
        return tf.keras.regularizers.l2(1e-5)
Ejemplo n.º 15
0
class ResNet18FPFactory(ResNet18Factory):
    model_name = Field("resnet_fp")
    input_quantizer = None
    kernel_quantizer = None
    kernel_constraint = None

    @property
    def kernel_regularizer(self):
        return tf.keras.regularizers.l2(self.weight_decay_constant)
Ejemplo n.º 16
0
class StrongBaselineNetBANFactory(StrongBaselineNetFactory):
    model_name = Field("baseline_ban")
    input_quantizer = "ste_sign"
    kernel_quantizer = None
    kernel_constraint = None

    @property
    def kernel_regularizer(self):
        return tf.keras.regularizers.l2(1e-5)
Ejemplo n.º 17
0
class StrongBaselineNetBNNFactory(StrongBaselineNetFactory):
    model_name = Field("baseline_bnn")
    kernel_quantizer = "ste_sign"
    kernel_constraint = "weight_clip"

    @property
    def input_quantizer(self):
        if self.use_unsign:
            return lq.quantizers.SteUnsign()
        else:
            return lq.quantizers.SteSign()
Ejemplo n.º 18
0
class TrainBiRealNet(TrainLarqZooModel):
    model = ComponentField(BiRealNetFactory)

    epochs = Field(300)
    batch_size = Field(512)

    learning_rate: float = Field(5e-3)
    decay_schedule: str = Field("linear")

    @Field
    def optimizer(self):
        if self.decay_schedule == "linear_cosine":
            lr = tf.keras.experimental.LinearCosineDecay(
                self.learning_rate, 750684)
        elif self.decay_schedule == "linear":
            lr = tf.keras.optimizers.schedules.PolynomialDecay(
                self.learning_rate, 750684, end_learning_rate=0, power=1.0)
        else:
            lr = self.learning_rate
        return tf.keras.optimizers.Adam(lr)
Ejemplo n.º 19
0
class RealToBinNetFPFactory(RealToBinNetFactory):
    model_name = Field("r2b_fp")
    kernel_quantizer = None
    kernel_constraint = None

    @property
    def input_quantizer(self):
        return tf.keras.layers.Activation("tanh")

    @property
    def kernel_regularizer(self):
        return tf.keras.regularizers.l2(1e-5)
Ejemplo n.º 20
0
class RealToBinNetBNNFactory(RealToBinNetFactory):
    model_name = Field("r2b_bnn")
    kernel_quantizer = "ste_sign"
    kernel_constraint = "weight_clip"
    # import pdb; pdb.set_trace()

    @property
    def input_quantizer(self):
        if self.use_unsign:
            return lq.quantizers.SteUnsign()
        else:
            return lq.quantizers.SteSign()
Ejemplo n.º 21
0
class ModelFactory:
    """A base class for Larq Zoo models. Defines some common fields."""

    input_quantizer = None
    kernel_quantizer = None
    kernel_constraint = None

    # This field is included for automatic inference of `num_clases`, if no
    # value is otherwise provided. We set `allow_missing` because we don't want
    # to throw an error if a dataset is not provided, as long as `num_classes`
    # is overriden.
    dataset: Optional[Dataset] = ComponentField(allow_missing=True)

    @Field
    def num_classes(self) -> int:
        if self.dataset is None:
            raise TypeError(
                "No `dataset` is defined so unable to infer `num_classes`. Please "
                "provide a `dataset` or override `num_classes` directly."
            )
        return self.dataset.num_classes

    include_top: bool = Field(True)
    weights: Optional[str] = Field(None)

    input_shape: Optional[Tuple[DimType, DimType, DimType]] = Field(None)
    input_tensor: Optional[utils.TensorType] = Field(None)

    @property
    def image_input(self) -> utils.TensorType:
        if not hasattr(self, "_image_input"):
            input_shape = utils.validate_input(
                self.input_shape,
                self.weights,
                self.include_top,
                self.num_classes,
            )
            self._image_input = utils.get_input_layer(input_shape, self.input_tensor)
        return self._image_input
Ejemplo n.º 22
0
class QuickNetXLFactory(QuickNetBaseFactory):
    """QuickNetXL - A model designed for fast inference using [Larq Compute Engine](https://github.com/larq/compute-engine)
    and high accuracy. This utilises Squeeze and Excite blocks as per [Training binary neural networks with real-to-binary convolutions](https://openreview.net/forum?id=BJg4NgBKvH)."""

    name = "quicknet_xl"
    blocks_per_section: Sequence[int] = Field((6, 8, 12, 6))
    section_filters: Sequence[int] = Field((64, 128, 256, 512))
    use_squeeze_and_excite_in_section: Sequence[bool] = Field(
        (False, False, True, True))
    transition_block = Field(lambda self: self.fp_pointwise_transition_block)

    def build(self) -> tf.keras.models.Model:
        model = super().build()
        # Load weights.
        if self.weights == "imagenet":
            # Download appropriate file
            if self.include_top:
                weights_path = utils.download_pretrained_model(
                    model="quicknet_xl",
                    version="v0.1.0",
                    file="quicknet_xl_weights.h5",
                    file_hash=
                    "a85eea1204fa9a8401f922f94531858493e3518e3374347978ed7ba615410498",
                )
            else:
                weights_path = utils.download_pretrained_model(
                    model="quicknet_xl",
                    version="v0.1.0",
                    file="quicknet_xl_weights_notop.h5",
                    file_hash=
                    "b97074d6618acde4201d1f8676d32272d27743ddfe27c6c97e4516511ebb5008",
                )
            model.load_weights(weights_path)
        elif self.weights is not None:
            model.load_weights(self.weights)
        return model
Ejemplo n.º 23
0
class QuickNetFactory(QuickNetBaseFactory):
    """Quicknet - A model designed for fast inference using [Larq Compute Engine](https://github.com/larq/compute-engine)"""

    name = "quicknet"
    blocks_per_section: Sequence[int] = Field((2, 3, 4, 4))
    section_filters: Sequence[int] = Field((64, 128, 256, 512))
    use_squeeze_and_excite_in_section: Sequence[bool] = Field(
        (False, False, False, False))
    transition_block = Field(lambda self: self.concat_transition_block)

    def build(self) -> tf.keras.models.Model:
        model = super().build()

        # Load weights.
        if self.weights == "imagenet":
            # Download appropriate file
            if self.include_top:
                weights_path = utils.download_pretrained_model(
                    model="quicknet",
                    version="v0.2.0",
                    file="quicknet_weights.h5",
                    file_hash=
                    "6a765f120ba7b62a7740e842c4f462eb7ba3dd65eb46b4694c5bc8169618fae7",
                )
            else:
                weights_path = utils.download_pretrained_model(
                    model="quicknet",
                    version="v0.2.0",
                    file="quicknet_weights_notop.h5",
                    file_hash=
                    "5bf2fc450fb8cc322b33a16410bf88fed09d05c221550c2d5805a04985383ac2",
                )
            model.load_weights(weights_path)
        elif self.weights is not None:
            model.load_weights(self.weights)
        return model
Ejemplo n.º 24
0
class QuickNetLargeFactory(QuickNetBaseFactory):
    """QuickNetLarge - A model designed for fast inference using [Larq Compute Engine](https://github.com/larq/compute-engine)
    and high accuracy. This utilises Squeeze and Excite blocks as per [Training binary neural networks with real-to-binary convolutions](https://openreview.net/forum?id=BJg4NgBKvH)."""

    name = "quicknet_large"
    blocks_per_section: Sequence[int] = Field((4, 4, 4, 4))
    section_filters: Sequence[int] = Field((64, 128, 256, 512))
    use_squeeze_and_excite_in_section: Sequence[bool] = Field(
        (False, False, True, True))
    transition_block = Field(lambda self: self.fp_pointwise_transition_block)

    def build(self) -> tf.keras.models.Model:
        model = super().build()
        # Load weights.
        if self.weights == "imagenet":
            # Download appropriate file
            if self.include_top:
                weights_path = utils.download_pretrained_model(
                    model="quicknet_large",
                    version="v0.2.0",
                    file="quicknet_large_weights.h5",
                    file_hash=
                    "2d9ebbf8ba0500552e4dd243c3e52fd8291f965ef6a0e1dbba13cc72bf6eee8b",
                )
            else:
                weights_path = utils.download_pretrained_model(
                    model="quicknet_large",
                    version="v0.2.0",
                    file="quicknet_large_weights_notop.h5",
                    file_hash=
                    "067655ef8a1a1e99ef1c71fa775c09aca44bdfad0b9b71538b4ec500c3beee4f",
                )
            model.load_weights(weights_path)
        elif self.weights is not None:
            model.load_weights(self.weights)
        return model
Ejemplo n.º 25
0
class RealToBinNetFPFactory(RealToBinNetFactory):
    model_name = Field("r2b_fp")
    use_hard_activation: bool = Field()
    kernel_quantizer = None
    kernel_constraint = None

    @property
    def input_quantizer(self):
        if self.use_unsign:
            if self.use_hard_activation:
                ## TODO (VINN): this should be Shifted to match the sigmoid!!
                # return lq.activations.HardTanh(lower_b=0.0, upper_b=1.0)
                return lq.activations.HardSigmoid()
            else:
                return tf.keras.layers.Activation("sigmoid")
        else:
            if self.use_hard_activation:
                return lq.activations.HardTanh(lower_b=-1.0, upper_b=1.0)
            else:
                return tf.keras.layers.Activation("tanh")

    @property
    def kernel_regularizer(self):
        return tf.keras.regularizers.l2(self.weight_decay_constant)
Ejemplo n.º 26
0
class RealToBinNetBANFactory(RealToBinNetFactory):
    model_name = Field("r2b_ban")
    kernel_quantizer = None
    kernel_constraint = None

    @property
    def input_quantizer(self):
        if self.use_unsign:
            return lq.quantizers.SteUnsign()
        else:
            return lq.quantizers.SteSign()

    @property
    def kernel_regularizer(self):
        return tf.keras.regularizers.l2(self.weight_decay_constant)
Ejemplo n.º 27
0
class PadCropAndFlip(Preprocessing):
    pad_size: int = Field()

    def input(self, data, training):
        image = data["image"]
        if training:
            image = tf.image.resize_with_crop_or_pad(image, self.pad_size,
                                                     self.pad_size)
            image = tf.image.random_crop(image, self.input_shape)
            image = tf.image.random_flip_left_right(image)
        else:
            image = tf.image.resize_with_crop_or_pad(image,
                                                     *self.input_shape[:2])
        return tf.cast(image, tf.float32) / (255.0 / 2.0) - 1.0

    def output(self, data):
        return data["label"]
Ejemplo n.º 28
0
class TrainR2BStrongBaselineBAN(LarqZooModelTrainingPhase):
    stage = Field(0)

    dataset = ComponentField(ImageNet)
    model_modifier: str = Field("default")

    learning_rate: float = Field(1e-3)
    learning_rate_decay: float = Field(0.1)
    epochs: int = Field(75)
    batch_size: int = Field(8)
    # amount_of_images: int = Field(1281167)
    warmup_duration: int = Field(5)

    optimizer = Field(lambda self: tf.keras.optimizers.Adam(
        R2BStepSchedule(
            initial_learning_rate=self.learning_rate,
            steps_per_epoch=self.steps_per_epoch,
            decay_fraction=self.learning_rate_decay,
        )))

    student_model = ComponentField(StrongBaselineNetBANFactory)
Ejemplo n.º 29
0
class TrainFPResnet18(LarqZooModelTrainingPhase):
    stage = Field(0)
    dataset = ComponentField(ImageNet)
    model_modifier: str = Field("default")
    # learning_rate: float = Field(1e-1)
    learning_rate: float = Field(1e-3)
    weight_decay_constant: float = Field(1e-5)
    epochs: int = Field(100)
    batch_size: int = Field(512)
    # amount_of_images: int = Field(1281167)
    warmup_duration: int = Field(5)

    optimizer = Field(
        # lambda self: tf.keras.optimizers.SGD(
        lambda self: tf.keras.optimizers.Adam(
            CosineDecayWithWarmup(
                max_learning_rate=self.learning_rate,
                warmup_steps=self.warmup_duration * self.steps_per_epoch,
                decay_steps=
                (self.epochs - self.warmup_duration) * self.steps_per_epoch,
            )))
    # import pdb; pdb.set_trace()

    student_model = ComponentField(ResNet18FPFactory)
Ejemplo n.º 30
0
class TrainFPResnet18(LarqZooModelTrainingPhase):
    stage = Field(0)
    dataset = ComponentField(ImageNet)
    learning_rate: float = Field(1e-1)
    epochs: int = Field(100)
    batch_size: int = Field(512)
    # amount_of_images: int = Field(1281167)
    warmup_duration: int = Field(5)

    optimizer = Field(lambda self: tf.keras.optimizers.SGD(
        CosineDecayWithWarmup(
            max_learning_rate=self.learning_rate,
            warmup_steps=self.warmup_duration * self.steps_per_epoch,
            decay_steps=self.epochs * self.steps_per_epoch,
        )))

    student_model = ComponentField(ResNet18FPFactory)