def testListHparams(self): @registry.register_hparams def hp1(): pass @registry.register_hparams("hp2_named") def hp2(): pass @registry.register_ranged_hparams def rhp1(_): pass @registry.register_ranged_hparams("rhp2_named") def rhp2(_): pass self.assertSetEqual(set(["hp1", "hp2_named"]), set(registry.list_hparams())) self.assertSetEqual(set(["rhp1", "rhp2_named"]), set(registry.list_ranged_hparams()))
def testListHparams(self): @registry.register_hparams def hp1(): pass @registry.register_hparams("hp2_named") def hp2(): pass @registry.register_ranged_hparams def rhp1(_): pass @registry.register_ranged_hparams("rhp2_named") def rhp2(_): pass self.assertSetEqual(set(["hp1", "hp2_named"]), set(registry.list_hparams())) self.assertSetEqual( set(["rhp1", "rhp2_named"]), set(registry.list_ranged_hparams()))
# -*- coding: utf-8 -*- """ @author: 代码医生工作室 @公众号:xiangyuejiqiren (内有更多优秀文章及学习资料) @来源: <深度学习之TensorFlow工程化项目实战>配套代码 (700+页) @配套代码技术支持:bbs.aianaconda.com (有问必答) """ #6-19 import tensorflow as tf from tensor2tensor import models from tensor2tensor.utils import t2t_model from tensor2tensor.utils import registry print(len(registry.list_models()), registry.list_models()) print(registry.model('transformer')) print(len(registry.list_hparams()), registry.list_hparams('transformer')) print(registry.hparams('transformer_base_v1'))
def testHParamsImported(self): hparams = registry.list_hparams() self.assertTrue("transformer_base" in hparams)