def test_transconv_unknown_batchsize_shape(self):
        '''
            this func check the below test case:
                - when a module is built without specifying batch_norm size,
                  check whether the model output has a proper batch_size given by an input
        '''
        scope = 'unittest'

        model_config = DeconvModuleConfig()
        TEST_MODULE_NAME = 'conv2dtrans_unpool'
        batch_size = 1

        input_width = 2
        input_height = 2
        input_shape = [None, input_height, input_width, 1]
        unpool_rate = 3

        module_graph = tf.Graph()
        with module_graph.as_default():
            inputs = create_test_input(batchsize=input_shape[0],
                                       heightsize=input_shape[1],
                                       widthsize=input_shape[2],
                                       channelnum=input_shape[3])

            module_output, midpoint = get_deconv_module(
                inputs=inputs,
                unpool_rate=unpool_rate,
                module_type=TEST_MODULE_NAME,
                model_config=model_config,
                scope=scope)

            init = tf.global_variables_initializer()
            ckpt_saver = tf.train.Saver(tf.global_variables())

            expected_prefix = scope + '0'
            self.assertTrue(module_output.op.name.startswith(expected_prefix))
            self.assertListEqual(module_output.get_shape().as_list(), [
                None, input_shape[1] * unpool_rate,
                input_shape[2] * unpool_rate, input_shape[3]
            ])

            input_shape[0] = batch_size
            expected_output_shape = [
                input_shape[0], input_shape[1] * unpool_rate,
                input_shape[2] * unpool_rate, input_shape[3]
            ]

            # which generate a sample image using np.arange()
            print('------------------------------------------------')
            print('[tfTest] run test_transconv_unknown_batchsize_shape()')
            print('[tfTest] unpool rate = %s' % unpool_rate)

            images = create_test_input(batchsize=input_shape[0],
                                       heightsize=input_shape[1],
                                       widthsize=input_shape[2],
                                       channelnum=input_shape[3])

            # tensorboard graph summary =============
            now = datetime.utcnow().strftime("%Y%m%d%H%M%S")
            tb_logdir_path = getcwd() + '/tf_logs'
            tb_logdir = "{}/run-{}/".format(tb_logdir_path, now)

            if not tf.gfile.Exists(tb_logdir_path):
                tf.gfile.MakeDirs(tb_logdir_path)

            # summary
            tb_summary_writer = tf.summary.FileWriter(logdir=tb_logdir)
            tb_summary_writer.add_graph(module_graph)
            tb_summary_writer.close()

        with self.test_session(graph=module_graph) as sess:
            sess.run(tf.global_variables_initializer())
            output = sess.run(module_output, {inputs: images.eval()})
            self.assertListEqual(list(output.shape), expected_output_shape)
            print('[TfTest] output shape = %s' % list(output.shape))
            print('[TfTest] expected_output_shape = %s' %
                  expected_output_shape)

            # tflite compatibility check
            expected_output_name = 'unittest0/' + TEST_MODULE_NAME + '_out'

            output_node_name = 'unittest0/' + TEST_MODULE_NAME + '/' + expected_output_name

            pbsavedir, pbfilename, ckptfilename = \
                save_pb_ckpt(module_name=TEST_MODULE_NAME,
                             init=init,
                             sess=sess,
                             ckpt_saver=ckpt_saver)

            print('------------------------------------------------')
            tflitedir = getcwd() + '/tflite_files/'
            if not tf.gfile.Exists(tflitedir):
                tf.gfile.MakeDirs(tflitedir)
            tflitefilename = TEST_MODULE_NAME + '.tflite'

            toco = tf.contrib.lite.TocoConverter.from_session(
                sess=sess,
                input_tensors=[inputs],
                output_tensors=[module_output])
            tflite_model = toco.convert()
            open(tflitedir + '/' + tflitefilename, 'wb').write(tflite_model)
Beispiel #2
0
    def test_endpoint_name_shape(self):

        ch_in_num = 256
        ch_out_num = 256
        model_config = ConvModuleConfig()
        scope = 'unittest'
        stride = 2
        kernel_size = 3

        # test input shape
        input_shape = [1, 64, 64, ch_in_num]

        # expected output shape
        expected_output_height = input_shape[1] / stride
        expected_output_width = input_shape[2] / stride
        expected_output_shape = [
            input_shape[0], expected_output_height, expected_output_width,
            ch_out_num
        ]

        module_graph = tf.Graph()
        with module_graph.as_default():
            inputs = create_test_input(batchsize=input_shape[0],
                                       heightsize=input_shape[1],
                                       widthsize=input_shape[2],
                                       channelnum=input_shape[3])

            module_output, end_points = get_module(ch_in=inputs,
                                                   ch_out_num=ch_out_num,
                                                   model_config=model_config,
                                                   stride=stride,
                                                   kernel_size=kernel_size,
                                                   conv_type=TEST_MODULE_NAME,
                                                   scope=scope)

            init = tf.global_variables_initializer()
            ckpt_saver = tf.train.Saver(tf.global_variables())

        expected_output_name = 'unittest0/' + TEST_MODULE_NAME + '_out'

        print('------------------------------------------------')
        print('[tfTest] run test_endpoint_name_shape()')
        print('[tfTest] module name = %s' % TEST_MODULE_NAME)

        print('[tfTest] model output name = %s' % end_points.keys()[-1])
        print('[tfTest] expected output name = %s' % expected_output_name)

        print('[tfTest] model output shape = %s' %
              module_output.get_shape().as_list())
        print('[tfTest] expected output shape = %s' % expected_output_shape)

        # check name of the module output
        self.assertTrue(expected_output_name in end_points)

        # check shape of the module output
        self.assertListEqual(module_output.get_shape().as_list(),
                             expected_output_shape)
        self.assertTrue(expected_output_name in end_points)

        # tensorboard graph summary =============
        now = datetime.utcnow().strftime("%Y%m%d%H%M%S")
        tb_logdir_path = getcwd() + '/tf_logs'
        tb_logdir = "{}/run-{}/".format(tb_logdir_path, now)

        if not tf.gfile.Exists(tb_logdir_path):
            tf.gfile.MakeDirs(tb_logdir_path)

        # summary
        tb_summary_writer = tf.summary.FileWriter(logdir=tb_logdir)
        tb_summary_writer.add_graph(module_graph)
        tb_summary_writer.close()

        with self.test_session(graph=module_graph) as sess:
            output_node_name = 'unittest0/' + TEST_MODULE_NAME + '/' + expected_output_name

            pbsavedir, pbfilename, ckptfilename = \
                save_pb_ckpt(module_name=TEST_MODULE_NAME,
                             init=init,
                             sess=sess,
                             ckpt_saver=ckpt_saver)

            # frozen graph generation
            convert_to_frozen_pb(module_name=TEST_MODULE_NAME,
                                 pbsavedir=pbsavedir,
                                 pbfilename=pbfilename,
                                 ckptfilename=ckptfilename,
                                 output_node_name=output_node_name,
                                 input_shape=input_shape)

            # # check tflite compatibility
            print('------------------------------------------------')
            print('[tfTest] convert to tflite')
            tflitedir = getcwd() + '/tflite_files/'
            if not tf.gfile.Exists(tflitedir):
                tf.gfile.MakeDirs(tflitedir)
            tflitefilename = TEST_MODULE_NAME + '.tflite'

            toco = tf.contrib.lite.TocoConverter.from_session(
                sess=sess,
                input_tensors=[inputs],
                output_tensors=[module_output])
            tflite_model = toco.convert()
            open(tflitedir + '/' + tflitefilename, 'wb').write(tflite_model)
            print('[tfTest] tflite conversion successful')