def test_keras_model_functional_get_op_product_graph(self):
        """ Test connected graph construction on keras model functional """
        tf.compat.v1.reset_default_graph()

        _ = keras_model_functional()
        conn_graph = ConnectedGraph(tf.compat.v1.get_default_graph(), ['input_1'], ['keras_model_functional/Softmax'])
        self.assertTrue(validate_branch_ops(conn_graph))
        self.assertTrue(validate_product_tensor_lists(conn_graph))
        self.assertEqual(0, conn_graph.branch_count)
        self.assertEqual(14, len(conn_graph.get_all_ops()))

        # 13 products from inter module connections
        # 22 products from parameters
        self.assertEqual(35, len(conn_graph.get_all_products()))
    def test_compute_encodings(self):
        """ Test that ops not evaluated during compute encodings are set to passThrough mode. """
        tf.compat.v1.reset_default_graph()
        sess = tf.compat.v1.Session()
        test_inp = np.ndarray((1, 32, 32, 3))

        def dummy_forward_func(sess, _):
            input_tensor = sess.graph.get_tensor_by_name('input_1:0')
            output_tensor = sess.graph.get_tensor_by_name('flatten/Reshape:0')
            sess.run(output_tensor, feed_dict={input_tensor: test_inp})

        with sess.as_default():
            _ = keras_model_functional()
            init = tf.compat.v1.global_variables_initializer()
            sess.run(init)
            sim = QuantizationSimModel(sess, ['input_1'],
                                       ['keras_model_functional/Softmax'])
            sim.compute_encodings(dummy_forward_func, None)

            for name, quant_info in sim._activation_quantizers.items():
                if name in [
                        'keras_model_functional/Softmax_quantized',
                        'keras_model_functional/BiasAdd_quantized'
                ]:
                    # Check that quantizers after op evaluated in compute_encodings are in passThrough (3) mode
                    self.assertEqual(quant_info.get_op_mode(), 3)
                    self.assertFalse(
                        quant_info.tensor_quantizer.isEncodingValid)
                elif name in ['scope_1/conv2d_3/BiasAdd_quantized']:
                    # Check that passThrough quantizers remain as passThrough (3)
                    self.assertEqual(quant_info.get_op_mode(), 3)
                    self.assertFalse(
                        quant_info.tensor_quantizer.isEncodingValid)
                else:
                    # Check that all other quantizers are in quantizeDequantize (2) mode
                    self.assertEqual(quant_info.get_op_mode(), 2)
                    self.assertTrue(
                        quant_info.tensor_quantizer.isEncodingValid)

            input_tensor = sim.session.graph.get_tensor_by_name('input_1:0')
            output_tensor = sim.session.graph.get_tensor_by_name(
                'keras_model_functional/Softmax:0')
            sim.session.run(output_tensor, feed_dict={input_tensor: test_inp})
            sim.session.close()
            del sim
    def test_keras_model_functional_with_training_ops_get_op_product_graph(self):
        """ Test connected graph construction on keras model functional with training ops attached """
        tf.compat.v1.reset_default_graph()
        _ = keras_model_functional()

        # add training ops
        optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=1e-3, name='Adam_new')
        _ = optimizer.minimize(loss=tf.compat.v1.get_default_graph().get_tensor_by_name('keras_model_functional/Softmax:0'),
                               name='train_step_new')
        conn_graph = ConnectedGraph(tf.compat.v1.get_default_graph(), ["input_1"],
                                    output_op_names=['keras_model_functional/Softmax'])
        self.assertTrue(validate_branch_ops(conn_graph))
        self.assertTrue(validate_product_tensor_lists(conn_graph))
        self.assertEqual(0, conn_graph.branch_count)
        self.assertEqual(14, len(conn_graph.get_all_ops()))

        # 13 products from inter module connections
        # 22 products from parameters
        self.assertEqual(35, len(conn_graph.get_all_products()))
    def test_get_output_data(self):
        """
        Test get_output_data method
        """

        tf.compat.v1.reset_default_graph()

        sess = tf.compat.v1.Session(graph=tf.Graph())
        input_op_names = ['input_1']
        output_op_name = 'scope_1/conv2d_2/Conv2D'
        with sess.graph.as_default():
            _ = keras_model_functional()
            init = tf.compat.v1.global_variables_initializer()
            sess.run(init)

        data = np.random.rand(1, 32, 32, 3)
        output = BiasCorrection._get_output_data(sess, input_op_names,
                                                 output_op_name, data)
        self.assertEqual(output.shape[3], 8)
        sess.close()
    def test_fused_batch_norm_matcher_keras(self):
        """ Test fused batch norm matchers """

        tf.compat.v1.reset_default_graph()
        _ = keras_model_functional()

        module_identifier = StructureModuleIdentifier(
            tf.compat.v1.get_default_graph(), ["input_1"],
            set(tf.compat.v1.get_default_graph().get_operations()))
        bn_op = tf.compat.v1.get_default_graph().get_operation_by_name(
            'batch_normalization/FusedBatchNormV3')
        self.assertTrue(bn_op in module_identifier.op_to_module_dict.keys())
        self.assertEqual(
            module_identifier.op_to_module_dict[bn_op].module_name,
            'batch_normalization')
        switch_op = tf.compat.v1.get_default_graph().get_operation_by_name(
            'scope_1/batch_normalization_1/cond/'
            'FusedBatchNormV3/Switch')
        self.assertEqual(
            module_identifier.op_to_module_dict[switch_op].module_name,
            'scope_1/batch_normalization_1')
Exemple #6
0
    def test_param_read_keras_model_with_fused_batchnorms(self):
        """
        Test to validate fused BN op param read AIMET api(s) on Keras layers.
        This test also reproduces SFTI issue
        tensorflow.python.framework.errors_impl.InvalidArgumentError
        :return:
        """

        tf.keras.backend.clear_session()
        with tf.device('/cpu:0'):
            model = keras_model_functional()
            model.summary()

        sess = tf.compat.v1.keras.backend.get_session()
        init = tf.compat.v1.global_variables_initializer()
        sess.run(init)

        # layer 1 , 3 and 5 are fused BN of different types
        with sess.as_default():
            # read weights ( beta, gamma, mean, variance)
            bn_1 = model.layers[2]
            bn_2 = model.layers[4]
            bn_3 = model.layers[6]
            keras_bn_1_params = get_bn_params_keras_layer(bn_1)
            keras_bn_2_params = get_bn_params_keras_layer(bn_2)
            keras_bn_3_params = get_bn_params_keras_layer(bn_3)

            bn_op_1 = sess.graph.get_operation_by_name('batch_normalization/FusedBatchNormV3')
            bn_op_2 = sess.graph.get_operation_by_name('scope_1/batch_normalization_1/cond/FusedBatchNormV3_1')
            bn_op_3 = sess.graph.get_operation_by_name('scope_1/batch_normalization_2/FusedBatchNormV3')
            bn_1_params = get_bn_params_aimet_api(sess, bn_op_1)
            bn_2_params = get_bn_params_aimet_api(sess, bn_op_2)
            bn_3_params = get_bn_params_aimet_api(sess, bn_op_3)

            self.assertTrue(np.allclose(keras_bn_1_params, bn_1_params))
            self.assertTrue(np.allclose(keras_bn_2_params, bn_2_params))
            self.assertTrue(np.allclose(keras_bn_3_params, bn_3_params))

        sess.close()
Exemple #7
0
    def test_training_batchnorm(self):
        """ Test BNUtils get_training() with both fused and non fused batchnorms, with all three training modes """

        tf.compat.v1.reset_default_graph()

        # Model with fused batchnorms
        _ = keras_model_functional()
        fused_bn_training_true_op = tf.compat.v1.get_default_graph().get_operation_by_name('batch_normalization/FusedBatchNormV3')
        self.assertTrue(BNUtils.get_training(fused_bn_training_true_op))
        self.assertTrue(isinstance(BNUtils.get_training(fused_bn_training_true_op), bool))

        fused_bn_training_tensor_op = tf.compat.v1.get_default_graph().get_operation_by_name('scope_1/batch_normalization_1/cond/'
                                                                                   'FusedBatchNormV3_1')
        training_tensor = tf.compat.v1.get_default_graph().get_tensor_by_name('is_training:0')
        self.assertEqual(BNUtils.get_training(fused_bn_training_tensor_op), training_tensor)

        fused_bn_training_false_op = tf.compat.v1.get_default_graph().get_operation_by_name('scope_1/batch_normalization_2/'
                                                                                  'FusedBatchNormV3')
        self.assertFalse(BNUtils.get_training(fused_bn_training_false_op))

        tf.compat.v1.reset_default_graph()

        # Model with non fused batchnorms
        _ = keras_model_functional_with_non_fused_batchnorms()
        bn_training_true_op = tf.compat.v1.get_default_graph().get_operation_by_name('batch_normalization/batchnorm/mul_1')
        self.assertTrue(BNUtils.get_training(bn_training_true_op))
        self.assertTrue(isinstance(BNUtils.get_training(bn_training_true_op), bool))

        bn_training_tensor_op = tf.compat.v1.get_default_graph().get_operation_by_name('scope_1/batch_normalization_1/batchnorm/'
                                                                             'mul_1')
        training_tensor = tf.compat.v1.get_default_graph().get_tensor_by_name('is_training:0')
        self.assertEqual(BNUtils.get_training(bn_training_tensor_op), training_tensor)

        bn_training_false_op = tf.compat.v1.get_default_graph().get_operation_by_name('scope_1/batch_normalization_2/batchnorm/'
                                                                            'mul_1')
        self.assertFalse(BNUtils.get_training(bn_training_false_op))

        tf.compat.v1.reset_default_graph()
    def test_conv_subgraph_with_a_model(self):
        """ Detect Conv2D, Conv2D with Bias and FusedBatchNorm subgraphs in the session graph. """

        patterns_to_match = []
        op_to_pattern_dict = {}
        for pattern_name, subgraph_constructor in subgraph_constructors.items(
        ):
            input_shape = subgraph_constructor['input_shape']
            constructor_string = subgraph_constructor['constructor']
            logger.debug(pattern_name)
            subgraph = create_subgraph_for_op_default(input_shape,
                                                      constructor_string)
            patterns = create_op_type_patterns_from_subgraph(
                subgraph, additional_starting_ops=[])
            patterns_to_match.append(patterns[-1])
            op_to_pattern_dict[pattern_name] = patterns[-1]
            logger.debug("Length of %s pattern: %d", pattern_name,
                         len(patterns))

        # OneOfPattern for Conv@D, Conv2D with Bias and FusedBatchNorm
        all_patterns = graph_matcher.OneofPattern(patterns_to_match)
        layer_matcher = graph_matcher.GraphMatcher(all_patterns)

        # Use the keras_model_functional.
        sess = tf.compat.v1.Session(graph=tf.Graph())
        with sess.graph.as_default():
            with tf.device('/cpu:0'):
                model = keras_model_functional()
                model.summary()
            init = tf.compat.v1.global_variables_initializer()
        sess.run(init)

        # Uncomment to use Tensorboard
        # _ = tf.compat.v1.summary.FileWriter('./subgraph', sess.graph)

        # Graph Match
        matched_op_set = set(
        )  # Set to keep track of Ops that have been detected already.
        match_counter = 0
        # layer_matcher = graph_matcher.GraphMatcher(conv_bias_patterns[-1])
        for match_result in layer_matcher.match_graph(sess.graph):
            if match_result:
                match_counter += 1

                # Conv2D Ops could be with or without Bias.
                # As the first step, detect all the Conv2D Ops with Bias.
                conv_bias_op = match_result.get_op(
                    op_to_pattern_dict['Conv2D_with_bias'])
                if conv_bias_op:
                    if conv_bias_op.inputs[0]._op not in matched_op_set:
                        logger.debug("Conv Op with bias: %s, %d",
                                     conv_bias_op.name, match_counter)
                        matched_op_set.add(conv_bias_op.inputs[0]._op)

                # Since the Conv Op with Bias is already added to the matched_op_set,
                # Conv Ops with Bias won't be duplicated by the following match.
                conv_op = match_result.get_op(op_to_pattern_dict['Conv2D'])
                if conv_op:
                    if conv_op not in matched_op_set:
                        logger.debug("Conv Op no bias: %s, %d", conv_op.name,
                                     match_counter)
                        matched_op_set.add(conv_op)

                bn_1_op = match_result.get_op(
                    op_to_pattern_dict['BN_keras_with_training_tensor'])
                if bn_1_op:
                    if bn_1_op.inputs[0]._op not in matched_op_set:
                        matched_op_set.add(bn_1_op.inputs[0]._op)
                        logger.debug("FusedBatchNorm 1 Op: %s, %d",
                                     bn_1_op.inputs[0]._op.name, match_counter)
                    if bn_1_op.inputs[1]._op not in matched_op_set:
                        matched_op_set.add(bn_1_op.inputs[1]._op)
                        logger.debug("FusedBatchNorm 1 Op: %s, %d",
                                     bn_1_op.inputs[1]._op.name, match_counter)

                bn_2_op = match_result.get_op(
                    op_to_pattern_dict['BN_keras_with_training_False'])
                if bn_2_op:
                    logger.debug("FusedBatchNorm 2 Op: %s", bn_2_op.name)

                bn_3_op = match_result.get_op(
                    op_to_pattern_dict['BN_keras_with_training_True'])
                if bn_3_op:
                    logger.debug("FusedBatchNorm 3 Op: %s", bn_3_op.name)

                flatten_op = match_result.get_op(op_to_pattern_dict['Flatten'])
                if flatten_op:
                    logger.debug("Flatten Op: %s, %d", flatten_op.name,
                                 match_counter)
                    matched_op_set.add(flatten_op)

                dense_op = match_result.get_op(op_to_pattern_dict['Dense'])
                if dense_op:
                    logger.debug("dense Op: %s, %d", dense_op.name,
                                 match_counter)
                    matched_op_set.add(dense_op)

        logger.debug(len(matched_op_set))