def testRegister(self): sub_expected = registry.register(Base)(ValidSub) sub_registered = registry.lookup("ValidSub", Base) self.assertEqual(sub_registered, sub_expected) alias_expected = registry.register(Base, lookup_name="alias")(ValidSub) alias_registered = registry.lookup("alias", Base) self.assertEqual(alias_registered, alias_expected) sub_method_expected = registry.register(base_method)(valid_sub_method) sub_method_registered = registry.lookup("valid_sub_method", base_method) self.assertEqual(sub_method_registered, sub_method_expected) self.assertRaises(ValueError, registry.register(Base), InvalidSub) self.assertRaises(ValueError, registry.register(base_method), invalid_sub_method) # Cannot register the same class / method twice. self.assertRaises(ValueError, registry.register(Base), ValidSub) self.assertRaises(ValueError, registry.register(base_method), valid_sub_method) # Both base and the subclass / method should be of the right type. self.assertRaises(ValueError, registry.register, "foo") self.assertRaises(ValueError, registry.register(Base), "foo") self.assertRaises(ValueError, registry.register(base_method), "foo")
def testDecorator(self): sub_registered = registry.lookup("DecoratorSub", DecoratorBase) self.assertEqual(sub_registered, DecoratorSub) alias_registered = registry.lookup("alias", DecoratorBase) self.assertEqual(alias_registered, DecoratorAlias) sub_method_registered = registry.lookup("valid_decorated_sub_method", base_method) self.assertEqual(sub_method_registered, valid_decorated_sub_method) invalid_base = registry.lookup("InvalidSub", InvalidSub) self.assertIsNone(invalid_base) invalid_sub = registry.lookup("InvalidSub", DecoratorBase) self.assertIsNone(invalid_sub)
def __init__(self): """Initializes (constructs) the `Blocks` class.""" self._block_builders = {} for block_type in BlockType: if block_type == BlockType.EMPTY_BLOCK: continue self._block_builders.update( {block_type: registry.lookup(block_type.name, blocks.Block)})
def get_dataset_provider(): """Helper function to get the data provider.""" logging.info("Getting the registered data provider") # Reigstration API data_providers = registry.lookup_all(ms_data.Provider) if len(data_providers) == 1: return data_providers[0] # Registering more than one data provider else: logging.info("Registered data provider: %s", FLAGS.phoenix_dataset) return registry.lookup(FLAGS.phoenix_dataset, ms_data.Provider)
def search_space(blocks_to_use=None): """Returns required search space for all blocks.""" search_space = ms_hparameters.Hyperparameters() for block_type in BlockType: if block_type == BlockType.EMPTY_BLOCK: continue if blocks_to_use is None or block_type.name in blocks_to_use: target = registry.lookup(block_type.name, blocks.Block) hps = target.requires_hparams() if hps: search_space.merge(hps, name_prefix=(block_type.name + '_')) return search_space