Exemplo n.º 1
0
        input_mask = np.random.randint(low=0,
                                       high=2,
                                       size=(1, spec.seq_len),
                                       dtype=np.int32)
        input_type_ids = np.random.randint(low=0,
                                           high=2,
                                           size=(1, spec.seq_len),
                                           dtype=np.int32)
        random_inputs = (input_word_ids, input_mask, input_type_ids)

        self.assertTrue(
            test_util.is_same_output(tflite_output_file,
                                     model.model,
                                     random_inputs,
                                     model.model_spec,
                                     atol=atol))

    def _test_export_to_saved_model(self, model):
        save_model_output_path = os.path.join(self.get_temp_dir(),
                                              'saved_model')
        model.export(self.get_temp_dir(),
                     export_format=ExportFormat.SAVED_MODEL)

        self.assertTrue(os.path.isdir(save_model_output_path))
        self.assertNotEmpty(os.listdir(save_model_output_path))


if __name__ == '__main__':
    compat.setup_tf_behavior(tf_version=2)
    tf.test.main()
Exemplo n.º 2
0
Arquivo: cli.py Projeto: bqi1/PoseNet
 def __init__(self, tf=2):
   compat.setup_tf_behavior(tf)