def test_build_feature_network(self):
        config = hparams_config.get_efficientdet_config('efficientdet-d0')
        with tf.Session(graph=tf.Graph()) as sess:
            inputs = {
                0: tf.ones([1, 512, 512, 3]),
                1: tf.ones([1, 256, 256, 16]),
                2: tf.ones([1, 128, 128, 24]),
                3: tf.ones([1, 64, 64, 40]),
                4: tf.ones([1, 32, 32, 112]),
                5: tf.ones([1, 16, 16, 320])
            }
            tf.random.set_random_seed(SEED)
            new_feats1 = efficientdet_arch_keras.build_feature_network(
                inputs, config)
            sess.run(tf.global_variables_initializer())
            new_feats1 = sess.run(new_feats1)
        with tf.Session(graph=tf.Graph()) as sess:
            inputs = {
                0: tf.ones([1, 512, 512, 3]),
                1: tf.ones([1, 256, 256, 16]),
                2: tf.ones([1, 128, 128, 24]),
                3: tf.ones([1, 64, 64, 40]),
                4: tf.ones([1, 32, 32, 112]),
                5: tf.ones([1, 16, 16, 320])
            }
            tf.random.set_random_seed(SEED)
            new_feats2 = legacy_arch.build_feature_network(inputs, config)
            sess.run(tf.global_variables_initializer())
            new_feats2 = sess.run(new_feats2)

        for i in range(config.min_level, config.max_level + 1):
            self.assertAllEqual(new_feats1[i], new_feats2[i])
Example #2
0
    def test_model_output(self):
        inputs_shape = [1, 512, 512, 3]
        config = hparams_config.get_efficientdet_config('efficientdet-d0')
        with tf.Session(graph=tf.Graph()) as sess:
            feats = tf.ones(inputs_shape)
            tf.random.set_random_seed(SEED)
            feats = efficientdet_arch_keras.build_backbone(feats, config)
            feats = efficientdet_arch_keras.build_feature_network(
                feats, config)
            feats = efficientdet_arch_keras.build_class_and_box_outputs(
                feats, config)
            # TODO(tanmingxing): Fix the failure for keras Model.
            # feats = efficientdet_arch_keras.EfficientDetModel(config=config)(feats)
            sess.run(tf.global_variables_initializer())
            keras_class_out, keras_box_out = sess.run(feats)
        with tf.Session(graph=tf.Graph()) as sess:
            feats = tf.ones(inputs_shape)
            tf.random.set_random_seed(SEED)
            feats = legacy_arch.efficientdet(feats, config=config)
            sess.run(tf.global_variables_initializer())
            legacy_class_out, legacy_box_out = sess.run(feats)
        for i in range(3, 8):
            self.assertAllEqual(keras_class_out[i - 3], legacy_class_out[i])
            self.assertAllEqual(keras_box_out[i - 3], legacy_box_out[i])

        feats = tf.ones(inputs_shape)
        tf.random.set_random_seed(SEED)
        model = efficientdet_arch_keras.EfficientDetModel(config=config)
        eager_class_out, eager_box_out = model(feats)
        for i in range(3, 8):
            # TODO(tanmingxing): fix the failing case.
            self.assertAllEqual(eager_class_out[i - 3], legacy_class_out[i])
            self.assertAllEqual(eager_box_out[i - 3], legacy_box_out[i])
 def test_model_output(self):
     inputs_shape = [1, 512, 512, 3]
     config = hparams_config.get_efficientdet_config('efficientdet-d0')
     with tf.Session(graph=tf.Graph()) as sess:
         feats = tf.ones(inputs_shape)
         tf.random.set_random_seed(SEED)
         feats, _ = efficientdet_arch_keras.build_backbone(feats, config)
         feats = efficientdet_arch_keras.build_feature_network(
             feats, config)
         feats = efficientdet_arch_keras.build_class_and_box_outputs(
             feats, config)
         sess.run(tf.global_variables_initializer())
         class_output1, box_output1 = sess.run(feats)
     with tf.Session(graph=tf.Graph()) as sess:
         feats = tf.ones(inputs_shape)
         tf.random.set_random_seed(SEED)
         feats = legacy_arch.build_backbone(feats, config)
         feats = legacy_arch.build_feature_network(feats, config)
         feats = legacy_arch.build_class_and_box_outputs(feats, config)
         sess.run(tf.global_variables_initializer())
         class_output2, box_output2 = sess.run(feats)
     for i in range(3, 8):
         self.assertAllEqual(class_output1[i], class_output2[i])
         self.assertAllEqual(box_output1[i], box_output2[i])