Beispiel #1
0
    def testRegister(self):
        sub_expected = registry.register(Base)(ValidSub)
        sub_registered = registry.lookup("ValidSub", Base)
        self.assertEqual(sub_registered, sub_expected)

        alias_expected = registry.register(Base, lookup_name="alias")(ValidSub)
        alias_registered = registry.lookup("alias", Base)
        self.assertEqual(alias_registered, alias_expected)

        sub_method_expected = registry.register(base_method)(valid_sub_method)
        sub_method_registered = registry.lookup("valid_sub_method",
                                                base_method)
        self.assertEqual(sub_method_registered, sub_method_expected)

        self.assertRaises(ValueError, registry.register(Base), InvalidSub)
        self.assertRaises(ValueError, registry.register(base_method),
                          invalid_sub_method)

        # Cannot register the same class / method twice.
        self.assertRaises(ValueError, registry.register(Base), ValidSub)
        self.assertRaises(ValueError, registry.register(base_method),
                          valid_sub_method)

        # Both base and the subclass / method should be of the right type.
        self.assertRaises(ValueError, registry.register, "foo")
        self.assertRaises(ValueError, registry.register(Base), "foo")
        self.assertRaises(ValueError, registry.register(base_method), "foo")
Beispiel #2
0
    def testDecorator(self):
        sub_registered = registry.lookup("DecoratorSub", DecoratorBase)
        self.assertEqual(sub_registered, DecoratorSub)

        alias_registered = registry.lookup("alias", DecoratorBase)
        self.assertEqual(alias_registered, DecoratorAlias)

        sub_method_registered = registry.lookup("valid_decorated_sub_method",
                                                base_method)
        self.assertEqual(sub_method_registered, valid_decorated_sub_method)

        invalid_base = registry.lookup("InvalidSub", InvalidSub)
        self.assertIsNone(invalid_base)

        invalid_sub = registry.lookup("InvalidSub", DecoratorBase)
        self.assertIsNone(invalid_sub)
Beispiel #3
0
 def __init__(self):
     """Initializes (constructs) the `Blocks` class."""
     self._block_builders = {}
     for block_type in BlockType:
         if block_type == BlockType.EMPTY_BLOCK:
             continue
         self._block_builders.update(
             {block_type: registry.lookup(block_type.name, blocks.Block)})
Beispiel #4
0
def get_dataset_provider():
    """Helper function to get the data provider."""
    logging.info("Getting the registered data provider")
    # Reigstration API
    data_providers = registry.lookup_all(ms_data.Provider)
    if len(data_providers) == 1:
        return data_providers[0]

    # Registering more than one data provider
    else:
        logging.info("Registered data provider: %s", FLAGS.phoenix_dataset)
        return registry.lookup(FLAGS.phoenix_dataset, ms_data.Provider)
Beispiel #5
0
    def search_space(blocks_to_use=None):
        """Returns required search space for all blocks."""
        search_space = ms_hparameters.Hyperparameters()
        for block_type in BlockType:
            if block_type == BlockType.EMPTY_BLOCK:
                continue
            if blocks_to_use is None or block_type.name in blocks_to_use:
                target = registry.lookup(block_type.name, blocks.Block)
                hps = target.requires_hparams()
                if hps:
                    search_space.merge(hps,
                                       name_prefix=(block_type.name + '_'))

        return search_space