def test_annotation_only_changes_hlo_metadata_conv(self, weight_prec, acts_prec): FLAGS.metadata_enabled = False quant_act = quantization.QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, prec=acts_prec, bounds=1.0) input_shape = (1, 8, 8, 3) module_no_annotation = aqt_flax_layers.ConvAqt( features=4, kernel_size=(3, 3), padding='VALID', paxis_name='batch', quant_context=quant_config.QuantContext(update_bounds=False), train=False, hparams=aqt_flax_layers.ConvAqt.HParams( weight_prec=weight_prec, quant_act=quant_act, quant_type=QuantType.fake_quant), kernel_init=initializers.ones, bias_init=initializers.ones, dtype=jnp.float32) init_state = module_no_annotation.init( self.rng_key, jnp.ones(input_shape, jnp.float32)) output_no_annotation = module_no_annotation.apply( init_state, jnp.ones(input_shape)) hlo_no_annotation = hlo_utils.load_hlo_proto_from_model( module_no_annotation, init_state, [input_shape]) del init_state FLAGS.metadata_enabled = True module_w_annotation = aqt_flax_layers.ConvAqt( features=4, kernel_size=(3, 3), padding='VALID', paxis_name='batch', quant_context=quant_config.QuantContext(update_bounds=False), train=False, hparams=aqt_flax_layers.ConvAqt.HParams( weight_prec=weight_prec, quant_act=quant_act, quant_type=QuantType.fake_quant), kernel_init=initializers.ones, bias_init=initializers.ones, dtype=jnp.float32) init_state = module_w_annotation.init( self.rng_key, jnp.ones(input_shape, jnp.float32)) output_w_annotation = module_w_annotation.apply( init_state, jnp.ones(input_shape)) hlo_w_annotation = hlo_utils.load_hlo_proto_from_model( module_w_annotation, init_state, [input_shape]) del init_state onp.testing.assert_array_equal(output_no_annotation, output_w_annotation) self.compare_hlo_instructions(hlo_no_annotation, hlo_w_annotation)
def __call__( self, inputs, ): """Applies ResNet model. Number of residual blocks inferred from hparams.""" num_classes = self.num_classes hparams = self.hparams num_filters = self.num_filters dtype = self.dtype x = aqt_flax_layers.ConvAqt( features=num_filters, kernel_size=(7, 7), strides=(2, 2), padding=[(3, 3), (3, 3)], use_bias=False, dtype=dtype, name='init_conv', train=self.train, quant_context=self.quant_context, paxis_name='batch', hparams=hparams.conv_init, )(inputs) x = nn.BatchNorm(use_running_average=not self.train, momentum=0.9, epsilon=1e-5, dtype=dtype, name='init_bn')(x) x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') filter_multiplier = hparams.filter_multiplier for i, block_hparams in enumerate(hparams.residual_blocks): proj = block_hparams.conv_proj # For projection layers (unless it is the first layer), strides = (2, 2) if i > 0 and proj is not None: filter_multiplier *= 2 strides = (2, 2) else: strides = (1, 1) x = ResidualBlock(filters=int(num_filters * filter_multiplier), hparams=block_hparams, quant_context=self.quant_context, strides=strides, train=self.train, dtype=dtype)(x) x = jnp.mean(x, axis=(1, 2)) x = aqt_flax_layers.DenseAqt( features=num_classes, dtype=dtype, train=self.train, quant_context=self.quant_context, paxis_name='batch', hparams=hparams.dense_layer, )(x, padding_mask=None) x = jnp.asarray(x, dtype) output = nn.log_softmax(x) return output
def __call__(self, inputs, hparams, kernel_size, num_filters, strides, dtype=jnp.float32): output = aqt_flax_layers.ConvAqt( features=num_filters, kernel_size=kernel_size, strides=strides, use_bias=False, dtype=dtype, train=False, quant_context=quant_config.QuantContext(update_bounds=False), paxis_name='batch', hparams=hparams)(inputs) return output
def test_group_conv(self, weight_prec=None): x = jnp.ones((1, 8, 8, 4)) conv_module = flax_layers.ConvAqt( features=4, kernel_size=(3, 3), feature_group_count=2, padding='VALID', paxis_name='batch', quant_context=quant_config.QuantContext(update_bounds=False), train=False, hparams=flax_layers.ConvAqt.HParams( weight_prec=weight_prec, quant_act=None, quant_type=QuantType.fake_quant), kernel_init=initializers.ones, bias_init=initializers.ones, dtype=jnp.float32) y, state = conv_module.init_with_output(self.rng_key, x) self.assertEqual(state['params']['kernel'].shape, (3, 3, 2, 4)) test_utils.assert_all_close_prec(y, onp.full((1, 6, 6, 4), 19.), weight_prec)
def __call__( self, inputs, ): """Applies ResNet model. Number of residual blocks inferred from hparams.""" num_classes = self.num_classes hparams = self.hparams num_filters = self.num_filters dtype = self.dtype assert hparams.act_function in act_function_zoo.keys( ), 'Activation function type is not supported.' x = aqt_flax_layers.ConvAqt( features=num_filters, kernel_size=(7, 7), strides=(2, 2), padding=[(3, 3), (3, 3)], use_bias=False, dtype=dtype, name='init_conv', train=self.train, quant_context=self.quant_context, paxis_name='batch', hparams=hparams.conv_init, )( inputs) x = nn.BatchNorm( use_running_average=not self.train, momentum=0.9, epsilon=1e-5, dtype=dtype, name='init_bn')( x) if hparams.act_function == 'relu': x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') else: # TODO(yichi): try adding other activation functions here # Use avg pool so that for binary nets, the distribution is symmetric. x = nn.avg_pool(x, (3, 3), strides=(2, 2), padding='SAME') filter_multiplier = hparams.filter_multiplier for i, block_hparams in enumerate(hparams.residual_blocks): proj = block_hparams.conv_proj # For projection layers (unless it is the first layer), strides = (2, 2) if i > 0 and proj is not None: filter_multiplier *= 2 strides = (2, 2) else: strides = (1, 1) x = ResidualBlock( filters=int(num_filters * filter_multiplier), hparams=block_hparams, quant_context=self.quant_context, strides=strides, train=self.train, dtype=dtype)( x) if hparams.act_function == 'none': # The DenseAQT below is not binarized. # If removing the activation functions, there will be no act function # between the last residual block and the dense layer. # So add a ReLU in that case. # TODO(yichi): try BPReLU x = nn.relu(x) else: pass x = jnp.mean(x, axis=(1, 2)) x = aqt_flax_layers.DenseAqt( features=num_classes, dtype=dtype, train=self.train, quant_context=self.quant_context, paxis_name='batch', hparams=hparams.dense_layer, )(x, padding_mask=None) x = jnp.asarray(x, dtype) output = nn.log_softmax(x) return output