Пример #1
0
    def testListHparams(self):
        @registry.register_hparams
        def hp1():
            pass

        @registry.register_hparams("hp2_named")
        def hp2():
            pass

        @registry.register_ranged_hparams
        def rhp1(_):
            pass

        @registry.register_ranged_hparams("rhp2_named")
        def rhp2(_):
            pass

        self.assertSetEqual(set(["hp1", "hp2_named"]),
                            set(registry.list_hparams()))
        self.assertSetEqual(set(["rhp1", "rhp2_named"]),
                            set(registry.list_ranged_hparams()))
Пример #2
0
  def testListHparams(self):

    @registry.register_hparams
    def hp1():
      pass

    @registry.register_hparams("hp2_named")
    def hp2():
      pass

    @registry.register_ranged_hparams
    def rhp1(_):
      pass

    @registry.register_ranged_hparams("rhp2_named")
    def rhp2(_):
      pass

    self.assertSetEqual(set(["hp1", "hp2_named"]), set(registry.list_hparams()))
    self.assertSetEqual(
        set(["rhp1", "rhp2_named"]), set(registry.list_ranged_hparams()))
# -*- coding: utf-8 -*-
"""
@author: 代码医生工作室 
@公众号:xiangyuejiqiren   (内有更多优秀文章及学习资料)
@来源: <深度学习之TensorFlow工程化项目实战>配套代码 (700+页)
@配套代码技术支持:bbs.aianaconda.com      (有问必答)
"""

#6-19

import tensorflow as tf
from tensor2tensor import models

from tensor2tensor.utils import t2t_model
from tensor2tensor.utils import registry

print(len(registry.list_models()), registry.list_models())
print(registry.model('transformer'))
print(len(registry.list_hparams()), registry.list_hparams('transformer'))
print(registry.hparams('transformer_base_v1'))
Пример #4
0
 def testHParamsImported(self):
   hparams = registry.list_hparams()
   self.assertTrue("transformer_base" in hparams)
Пример #5
0
 def testHParamsImported(self):
     hparams = registry.list_hparams()
     self.assertTrue("transformer_base" in hparams)