class QuickNetFactory(QuickNetBaseFactory): """Quicknet - A model designed for fast inference using [Larq Compute Engine](https://github.com/larq/compute-engine)""" name = "quicknet" blocks_per_section: Sequence[int] = Field((2, 3, 4, 4)) section_filters: Sequence[int] = Field((64, 128, 256, 512)) use_squeeze_and_excite_in_section: Sequence[bool] = Field( (False, False, False, False)) transition_block = Field(lambda self: self.concat_transition_block) @property def imagenet_weights_path(self): return utils.download_pretrained_model( model="quicknet", version="v0.2.1", file="quicknet_weights.h5", file_hash= "7b4fa94f5241c7aad3412ca42b5db6517dbc4847cff710cb82be10c2f83bc0be", ) @property def imagenet_no_top_weights_path(self): return utils.download_pretrained_model( model="quicknet", version="v0.2.1", file="quicknet_weights_notop.h5", file_hash= "359eed6dae43525eddf520ea87ec9b54750ee0e022647775d115a38856be396f", )
class TrainXNORNet(TrainLarqZooModel): model = ComponentField(XNORNetFactory) epochs = Field(100) batch_size = Field(1200) initial_lr: float = Field(0.001) x_offset: float = Field(0.0) def learning_rate_schedule(self, epoch): epoch_dec_1 = 19 epoch_dec_2 = 30 epoch_dec_3 = 44 epoch_dec_4 = 53 epoch_dec_5 = 66 epoch_dec_6 = 76 epoch_dec_7 = 86 if epoch < epoch_dec_1: return self.initial_lr elif epoch < epoch_dec_2: return self.initial_lr * 0.5 elif epoch < epoch_dec_3: return self.initial_lr * 0.1 elif epoch < epoch_dec_4: return self.initial_lr * 0.1 * 0.5 elif epoch < epoch_dec_5: return self.initial_lr * 0.01 elif epoch < epoch_dec_6: return self.initial_lr * 0.01 * 0.5 elif epoch < epoch_dec_7: return self.initial_lr * 0.01 * 0.1 else: return self.initial_lr * 0.001 * 0.1 optimizer = Field(lambda self: tf.keras.optimizers.Adam(self.initial_lr))
class QuickNetXLFactory(QuickNetBaseFactory): """QuickNetXL - A model designed for fast inference using [Larq Compute Engine](https://github.com/larq/compute-engine) and high accuracy. This utilises Squeeze and Excite blocks as per [Training binary neural networks with real-to-binary convolutions](https://openreview.net/forum?id=BJg4NgBKvH).""" name = "quicknet_xl" blocks_per_section: Sequence[int] = Field((6, 8, 12, 6)) section_filters: Sequence[int] = Field((64, 128, 256, 512)) use_squeeze_and_excite_in_section: Sequence[bool] = Field( (False, False, True, True)) transition_block = Field(lambda self: self.fp_pointwise_transition_block) @property def imagenet_weights_path(self): return utils.download_pretrained_model( model="quicknet_xl", version="v0.1.1", file="quicknet_xl_weights.h5", file_hash= "19a41e753dbd4fbc3cbdaecd3627fb536ef55d64702996aae3875a8de3cf8073", ) @property def imagenet_no_top_weights_path(self): return utils.download_pretrained_model( model="quicknet_xl", version="v0.1.1", file="quicknet_xl_weights_notop.h5", file_hash= "ad5cbfa333b0aabde75dc524c9ce4a5ae096061da0e2dcf362ec6e587a83a511", )
class QuickNetLargeFactory(QuickNetBaseFactory): """QuickNetLarge - A model designed for fast inference using [Larq Compute Engine](https://github.com/larq/compute-engine) and high accuracy. This utilises Squeeze and Excite blocks as per [Training binary neural networks with real-to-binary convolutions](https://openreview.net/forum?id=BJg4NgBKvH).""" name = "quicknet_large" blocks_per_section: Sequence[int] = Field((4, 4, 4, 4)) section_filters: Sequence[int] = Field((64, 128, 256, 512)) use_squeeze_and_excite_in_section: Sequence[bool] = Field( (False, False, True, True)) transition_block = Field(lambda self: self.fp_pointwise_transition_block) @property def imagenet_weights_path(self): return utils.download_pretrained_model( model="quicknet_large", version="v0.2.1", file="quicknet_large_weights.h5", file_hash= "6bf778e243466c678d6da0e3a91c77deec4832460046fca9e6ac8ae97a41299c", ) @property def imagenet_no_top_weights_path(self): return utils.download_pretrained_model( model="quicknet_large", version="v0.2.1", file="quicknet_large_weights_notop.h5", file_hash= "b65d59dd2d5af63d019997b05faff9e003510e2512aa973ee05eb1b82b8792a9", )
class TrainQuickNet(TrainLarqZooModel): model = ComponentField(QuickNetFactory) epochs = Field(600) batch_size = Field(2048) @Field def optimizer(self): binary_opt = tf.keras.optimizers.Adam( learning_rate=CosineDecayWithWarmup( max_learning_rate=1e-2, warmup_steps=self.steps_per_epoch * 5, decay_steps=self.steps_per_epoch * self.epochs, )) fp_opt = tf.keras.optimizers.SGD( learning_rate=CosineDecayWithWarmup( max_learning_rate=0.1, warmup_steps=self.steps_per_epoch * 5, decay_steps=self.steps_per_epoch * self.epochs, ), momentum=0.9, ) return lq.optimizers.CaseOptimizer( (lq.optimizers.Bop.is_binary_variable, binary_opt), default_optimizer=fp_opt, )
class TrainR2B(MultiStageExperiment): model_modifier: str = Field("default") use_unsign: bool = Field(False) stage_0 = ComponentField(TrainFPResnet18) stage_1 = ComponentField(TrainR2BBFP) stage_2 = ComponentField(TrainR2BBAN) stage_3 = ComponentField(TrainR2BBNNAlternative)
class TrainR2BBAN(TrainR2BBFP): stage = Field(2) learning_rate: float = Field(1e-3) teacher_model = ComponentField(RealToBinNetFPFactory) student_model = ComponentField(RealToBinNetBANFactory) initialize_teacher_weights_from = Field("r2b_fp")
class Default(ImageClassification): decoders = Field(lambda: {"image": tfds.decode.SkipDecoding()}) input_shape = Field((IMAGE_SIZE, IMAGE_SIZE, 3)) def input(self, data, training): return preprocess_image_bytes(data["image"], is_training=training, image_size=IMAGE_SIZE)
class TrainR2BBNNAlternative(TrainR2BBNN): """We deviate slightly from Martinez et. al. here""" warmup_duration = Field(10) optimizer = Field(lambda self: tf.keras.optimizers.Adam( CosineDecayWithWarmup( max_learning_rate=self.learning_rate, warmup_steps=self.steps_per_epoch * self.warmup_duration, decay_steps=self.steps_per_epoch * self.epochs, )))
class LarqZooModelTrainingPhase(TrainingPhase): # parameters related to the standard cross-entropy training of the student on the target labels # - weight on the loss component for standard classification classification_weight: float = Field(1.0) # parameters related to the training through attention matching between teacher and student activation volumes # - weight on the loss component for spatial attention matching attention_matching_weight: float = Field(0.0) # - list of partial names of the layers for which the outputs should be matched attention_matching_volume_names: Optional[List[str]] = Field( allow_missing=True) # - optional separate list of partial names for the teacher. If not given, the names above will be used. attention_matching_volume_names_teacher: Optional[List[str]] = Field( allow_missing=True) # - allow teacher to be trained to better match activations with the student attention_matching_train_teacher: bool = Field(False) # parameters related to the training through the matching of the output predictions of the teacher and student # - weight on the loss component for knowledge distillation output_matching_weight: float = Field(0.0) # - temperature used for the softmax when matching distributions output_matching_softmax_temperature: float = Field(1.0) # - allow the teacher to be trained during output distribution matching output_matching_train_teacher: bool = Field(False) @Field def loss(self): return getattr(self.__base_getattribute__("model"), "classification_loss") metrics = Field(lambda: ["accuracy", "sparse_top_k_categorical_accuracy"])
class TrainBinaryAlexNet(TrainLarqZooModel): model = ComponentField(BinaryAlexNetFactory) batch_size: int = Field(512) epochs: int = Field(150) def learning_rate_schedule(self, epoch): return 1e-2 * 0.5**(epoch // 10) optimizer = Field( lambda self: tf.keras.optimizers.Adam(self.learning_rate_schedule(0)))
class TrainR2BBNN(TrainR2BBFP): stage = Field(3) learning_rate: float = Field(2e-4) weight_decay_constant: float = Field(0.0) classification_weight = Field(1.0) attention_matching_weight = Field(0.0) output_matching_weight = Field(0.8) output_matching_softmax_temperature = Field(1.0) x_offset: float = Field(0.0) teacher_model = ComponentField(RealToBinNetBANFactory) student_model = ComponentField(RealToBinNetBNNFactory) initialize_teacher_weights_from = Field("r2b_ban") initialize_student_weights_from = Field("r2b_ban")
class TrainDoReFaNet(TrainLarqZooModel): model = ComponentField(DoReFaNetFactory) epochs = Field(90) batch_size = Field(256) learning_rate: float = Field(2e-4) decay_start: int = Field(60) decay_step_2: int = Field(75) fast_decay_start: int = Field(82) def learning_rate_schedule(self, epoch): if epoch < self.decay_start: return self.learning_rate elif epoch < self.decay_step_2: return self.learning_rate * 0.2 elif epoch < self.fast_decay_start: return self.learning_rate * 0.2 * 0.2 else: return ( self.learning_rate * 0.2 * 0.2 * 0.1 ** ((epoch - self.fast_decay_start) // 2 + 1) ) optimizer = Field( lambda self: tf.keras.optimizers.Adam(self.learning_rate, epsilon=1e-5) )
class RealToBinNetBANFactory(RealToBinNetFactory): model_name = Field("r2b_ban") input_quantizer = "ste_sign" kernel_quantizer = None kernel_constraint = None @property def kernel_regularizer(self): return tf.keras.regularizers.l2(1e-5)
class ResNet18FPFactory(ResNet18Factory): model_name = Field("resnet_fp") input_quantizer = None kernel_quantizer = None kernel_constraint = None @property def kernel_regularizer(self): return tf.keras.regularizers.l2(self.weight_decay_constant)
class StrongBaselineNetBANFactory(StrongBaselineNetFactory): model_name = Field("baseline_ban") input_quantizer = "ste_sign" kernel_quantizer = None kernel_constraint = None @property def kernel_regularizer(self): return tf.keras.regularizers.l2(1e-5)
class StrongBaselineNetBNNFactory(StrongBaselineNetFactory): model_name = Field("baseline_bnn") kernel_quantizer = "ste_sign" kernel_constraint = "weight_clip" @property def input_quantizer(self): if self.use_unsign: return lq.quantizers.SteUnsign() else: return lq.quantizers.SteSign()
class TrainBiRealNet(TrainLarqZooModel): model = ComponentField(BiRealNetFactory) epochs = Field(300) batch_size = Field(512) learning_rate: float = Field(5e-3) decay_schedule: str = Field("linear") @Field def optimizer(self): if self.decay_schedule == "linear_cosine": lr = tf.keras.experimental.LinearCosineDecay( self.learning_rate, 750684) elif self.decay_schedule == "linear": lr = tf.keras.optimizers.schedules.PolynomialDecay( self.learning_rate, 750684, end_learning_rate=0, power=1.0) else: lr = self.learning_rate return tf.keras.optimizers.Adam(lr)
class RealToBinNetFPFactory(RealToBinNetFactory): model_name = Field("r2b_fp") kernel_quantizer = None kernel_constraint = None @property def input_quantizer(self): return tf.keras.layers.Activation("tanh") @property def kernel_regularizer(self): return tf.keras.regularizers.l2(1e-5)
class RealToBinNetBNNFactory(RealToBinNetFactory): model_name = Field("r2b_bnn") kernel_quantizer = "ste_sign" kernel_constraint = "weight_clip" # import pdb; pdb.set_trace() @property def input_quantizer(self): if self.use_unsign: return lq.quantizers.SteUnsign() else: return lq.quantizers.SteSign()
class ModelFactory: """A base class for Larq Zoo models. Defines some common fields.""" input_quantizer = None kernel_quantizer = None kernel_constraint = None # This field is included for automatic inference of `num_clases`, if no # value is otherwise provided. We set `allow_missing` because we don't want # to throw an error if a dataset is not provided, as long as `num_classes` # is overriden. dataset: Optional[Dataset] = ComponentField(allow_missing=True) @Field def num_classes(self) -> int: if self.dataset is None: raise TypeError( "No `dataset` is defined so unable to infer `num_classes`. Please " "provide a `dataset` or override `num_classes` directly." ) return self.dataset.num_classes include_top: bool = Field(True) weights: Optional[str] = Field(None) input_shape: Optional[Tuple[DimType, DimType, DimType]] = Field(None) input_tensor: Optional[utils.TensorType] = Field(None) @property def image_input(self) -> utils.TensorType: if not hasattr(self, "_image_input"): input_shape = utils.validate_input( self.input_shape, self.weights, self.include_top, self.num_classes, ) self._image_input = utils.get_input_layer(input_shape, self.input_tensor) return self._image_input
class QuickNetXLFactory(QuickNetBaseFactory): """QuickNetXL - A model designed for fast inference using [Larq Compute Engine](https://github.com/larq/compute-engine) and high accuracy. This utilises Squeeze and Excite blocks as per [Training binary neural networks with real-to-binary convolutions](https://openreview.net/forum?id=BJg4NgBKvH).""" name = "quicknet_xl" blocks_per_section: Sequence[int] = Field((6, 8, 12, 6)) section_filters: Sequence[int] = Field((64, 128, 256, 512)) use_squeeze_and_excite_in_section: Sequence[bool] = Field( (False, False, True, True)) transition_block = Field(lambda self: self.fp_pointwise_transition_block) def build(self) -> tf.keras.models.Model: model = super().build() # Load weights. if self.weights == "imagenet": # Download appropriate file if self.include_top: weights_path = utils.download_pretrained_model( model="quicknet_xl", version="v0.1.0", file="quicknet_xl_weights.h5", file_hash= "a85eea1204fa9a8401f922f94531858493e3518e3374347978ed7ba615410498", ) else: weights_path = utils.download_pretrained_model( model="quicknet_xl", version="v0.1.0", file="quicknet_xl_weights_notop.h5", file_hash= "b97074d6618acde4201d1f8676d32272d27743ddfe27c6c97e4516511ebb5008", ) model.load_weights(weights_path) elif self.weights is not None: model.load_weights(self.weights) return model
class QuickNetFactory(QuickNetBaseFactory): """Quicknet - A model designed for fast inference using [Larq Compute Engine](https://github.com/larq/compute-engine)""" name = "quicknet" blocks_per_section: Sequence[int] = Field((2, 3, 4, 4)) section_filters: Sequence[int] = Field((64, 128, 256, 512)) use_squeeze_and_excite_in_section: Sequence[bool] = Field( (False, False, False, False)) transition_block = Field(lambda self: self.concat_transition_block) def build(self) -> tf.keras.models.Model: model = super().build() # Load weights. if self.weights == "imagenet": # Download appropriate file if self.include_top: weights_path = utils.download_pretrained_model( model="quicknet", version="v0.2.0", file="quicknet_weights.h5", file_hash= "6a765f120ba7b62a7740e842c4f462eb7ba3dd65eb46b4694c5bc8169618fae7", ) else: weights_path = utils.download_pretrained_model( model="quicknet", version="v0.2.0", file="quicknet_weights_notop.h5", file_hash= "5bf2fc450fb8cc322b33a16410bf88fed09d05c221550c2d5805a04985383ac2", ) model.load_weights(weights_path) elif self.weights is not None: model.load_weights(self.weights) return model
class QuickNetLargeFactory(QuickNetBaseFactory): """QuickNetLarge - A model designed for fast inference using [Larq Compute Engine](https://github.com/larq/compute-engine) and high accuracy. This utilises Squeeze and Excite blocks as per [Training binary neural networks with real-to-binary convolutions](https://openreview.net/forum?id=BJg4NgBKvH).""" name = "quicknet_large" blocks_per_section: Sequence[int] = Field((4, 4, 4, 4)) section_filters: Sequence[int] = Field((64, 128, 256, 512)) use_squeeze_and_excite_in_section: Sequence[bool] = Field( (False, False, True, True)) transition_block = Field(lambda self: self.fp_pointwise_transition_block) def build(self) -> tf.keras.models.Model: model = super().build() # Load weights. if self.weights == "imagenet": # Download appropriate file if self.include_top: weights_path = utils.download_pretrained_model( model="quicknet_large", version="v0.2.0", file="quicknet_large_weights.h5", file_hash= "2d9ebbf8ba0500552e4dd243c3e52fd8291f965ef6a0e1dbba13cc72bf6eee8b", ) else: weights_path = utils.download_pretrained_model( model="quicknet_large", version="v0.2.0", file="quicknet_large_weights_notop.h5", file_hash= "067655ef8a1a1e99ef1c71fa775c09aca44bdfad0b9b71538b4ec500c3beee4f", ) model.load_weights(weights_path) elif self.weights is not None: model.load_weights(self.weights) return model
class RealToBinNetFPFactory(RealToBinNetFactory): model_name = Field("r2b_fp") use_hard_activation: bool = Field() kernel_quantizer = None kernel_constraint = None @property def input_quantizer(self): if self.use_unsign: if self.use_hard_activation: ## TODO (VINN): this should be Shifted to match the sigmoid!! # return lq.activations.HardTanh(lower_b=0.0, upper_b=1.0) return lq.activations.HardSigmoid() else: return tf.keras.layers.Activation("sigmoid") else: if self.use_hard_activation: return lq.activations.HardTanh(lower_b=-1.0, upper_b=1.0) else: return tf.keras.layers.Activation("tanh") @property def kernel_regularizer(self): return tf.keras.regularizers.l2(self.weight_decay_constant)
class RealToBinNetBANFactory(RealToBinNetFactory): model_name = Field("r2b_ban") kernel_quantizer = None kernel_constraint = None @property def input_quantizer(self): if self.use_unsign: return lq.quantizers.SteUnsign() else: return lq.quantizers.SteSign() @property def kernel_regularizer(self): return tf.keras.regularizers.l2(self.weight_decay_constant)
class PadCropAndFlip(Preprocessing): pad_size: int = Field() def input(self, data, training): image = data["image"] if training: image = tf.image.resize_with_crop_or_pad(image, self.pad_size, self.pad_size) image = tf.image.random_crop(image, self.input_shape) image = tf.image.random_flip_left_right(image) else: image = tf.image.resize_with_crop_or_pad(image, *self.input_shape[:2]) return tf.cast(image, tf.float32) / (255.0 / 2.0) - 1.0 def output(self, data): return data["label"]
class TrainR2BStrongBaselineBAN(LarqZooModelTrainingPhase): stage = Field(0) dataset = ComponentField(ImageNet) model_modifier: str = Field("default") learning_rate: float = Field(1e-3) learning_rate_decay: float = Field(0.1) epochs: int = Field(75) batch_size: int = Field(8) # amount_of_images: int = Field(1281167) warmup_duration: int = Field(5) optimizer = Field(lambda self: tf.keras.optimizers.Adam( R2BStepSchedule( initial_learning_rate=self.learning_rate, steps_per_epoch=self.steps_per_epoch, decay_fraction=self.learning_rate_decay, ))) student_model = ComponentField(StrongBaselineNetBANFactory)
class TrainFPResnet18(LarqZooModelTrainingPhase): stage = Field(0) dataset = ComponentField(ImageNet) model_modifier: str = Field("default") # learning_rate: float = Field(1e-1) learning_rate: float = Field(1e-3) weight_decay_constant: float = Field(1e-5) epochs: int = Field(100) batch_size: int = Field(512) # amount_of_images: int = Field(1281167) warmup_duration: int = Field(5) optimizer = Field( # lambda self: tf.keras.optimizers.SGD( lambda self: tf.keras.optimizers.Adam( CosineDecayWithWarmup( max_learning_rate=self.learning_rate, warmup_steps=self.warmup_duration * self.steps_per_epoch, decay_steps= (self.epochs - self.warmup_duration) * self.steps_per_epoch, ))) # import pdb; pdb.set_trace() student_model = ComponentField(ResNet18FPFactory)
class TrainFPResnet18(LarqZooModelTrainingPhase): stage = Field(0) dataset = ComponentField(ImageNet) learning_rate: float = Field(1e-1) epochs: int = Field(100) batch_size: int = Field(512) # amount_of_images: int = Field(1281167) warmup_duration: int = Field(5) optimizer = Field(lambda self: tf.keras.optimizers.SGD( CosineDecayWithWarmup( max_learning_rate=self.learning_rate, warmup_steps=self.warmup_duration * self.steps_per_epoch, decay_steps=self.epochs * self.steps_per_epoch, ))) student_model = ComponentField(ResNet18FPFactory)