コード例 #1
0
    def propagate(self, from_blocks: List[Namespace], block: Namespace):
        """Method that propagates shapes to a block.

        Args:
            from_blocks: The input blocks.
            block: The block to propagate its shapes.

        Raises:
            ValueError: When bidirectional==True and output_feats not even.
        """
        ## Set default values ##
        if not hasattr(block, 'bidirectional'):
            block.bidirectional = False

        ## Initialize block._shape ##
        from_shape = get_shape('out', from_blocks[0])
        output_feats = block.output_feats
        block._shape = create_shape(from_shape, [auto_tag, output_feats])

        ## Set hidden size ##
        if block.bidirectional and output_feats % 2 != 0:
            raise ValueError(
                f'For bidirectional {block._class} expected output_feats to be even, but got {output_feats}.'
            )
        block.hidden_size = output_feats // (2 if block.bidirectional else 1)

        ## Propagate first dimension ##
        set_shape_dim('out', block, 0, from_shape[0])
コード例 #2
0
ファイル: conv.py プロジェクト: omni-us/narchi
    def propagate(self, from_blocks: List[Namespace], block: Namespace):
        """Method that propagates shapes to a block.

        Args:
            from_blocks: The input blocks.
            block: The block to propagate its shapes.

        Raises:
            ValueError: When block.output_feats not valid.
            NotImplementedError: If num_features_source is not one of {"from_shape", "output_feats"}.
        """
        ## Set default values ##
        kernel = block.kernel_size
        stride = block.stride if hasattr(block, 'stride') else 1
        padding = block.padding if hasattr(block, 'padding') else 0
        dilation = block.dilation if hasattr(block, 'dilation') else 1

        ## Initialize block._shape ##
        auto_dims = [auto_tag for _ in range(self.conv_dims)]
        from_shape = get_shape('out', from_blocks[0])
        if self.num_features_source == 'from_shape':
            block._shape = create_shape(from_shape,
                                        [from_shape[0]] + auto_dims)
        elif self.num_features_source == 'output_feats':
            check_output_feats_dims(1, self.block_class, block)
            block._shape = create_shape(from_shape,
                                        [block.output_feats] + auto_dims)

        ## Calculate and set <<auto>> output dimensions ##
        for dim, val in enumerate(get_shape('out', block)):
            if val == auto_tag:
                in_length = get_shape('in', block)[dim]
                out_length = conv_out_length(in_length, kernel, stride,
                                             padding, dilation)
                set_shape_dim('out', block, dim, out_length)
コード例 #3
0
    def propagate(self, from_blocks: List[Namespace], block: Namespace):
        """Method that propagates shapes to a block.

        Args:
            from_blocks: The input blocks.
            block: The block to propagate its shapes.
        """
        shape_in = get_shape('out', from_blocks[0])
        shape_out = []
        if block.reshape_spec == 'flatten':
            reshape_spec = [[n for n in range(len(shape_in))]]
        else:
            reshape_spec = norm_reshape_spec(block.reshape_spec)
        for val in reshape_spec:
            if isinstance(val, int):
                shape_out.append(shape_in[val])
            elif isinstance(val, list):
                shape_out.append(prod([shape_in[x] for x in val]))
            elif isinstance(val, dict):
                idx = next(iter(val.keys()))
                in_dim = shape_in[int(idx)]
                dims = val[idx]
                if any(x == auto_tag for x in dims):
                    auto_idx = dims.index(auto_tag)
                    nonauto = prod([x for x in dims if x != auto_tag])
                    dims[auto_idx] = divide(in_dim, nonauto)
                shape_out.extend(dims)
        block._shape = create_shape(shape_in, shape_out)
コード例 #4
0
    def test_class_type_with_default_config_files(self):
        config = {
            'class_path': 'calendar.Calendar',
            'init_args': {
                'firstweekday': 3
            },
        }
        config_path = os.path.join(self.tmpdir, 'config.yaml')
        with open(config_path, 'w') as f:
            json.dump({'data': {'cal': config}}, f)

        class MyClass:
            def __init__(self, cal: Optional[Calendar] = None, val: int = 2):
                self.cal = cal

        parser = ArgumentParser(error_handler=None,
                                default_config_files=[config_path])
        parser.add_argument('--op', default='from default')
        parser.add_class_arguments(MyClass, 'data')

        cfg = parser.get_defaults()
        self.assertEqual(config_path, str(cfg['__default_config__']))
        self.assertEqual(cfg.data.cal.as_dict(), config)
        dump = parser.dump(cfg)
        self.assertIn('class_path: calendar.Calendar\n', dump)
        self.assertIn('firstweekday: 3\n', dump)

        cfg = parser.parse_args([])
        self.assertEqual(cfg.data.cal.as_dict(), config)
        cfg = parser.parse_args(['--data.cal.class_path=calendar.Calendar'],
                                defaults=False)
        self.assertEqual(cfg.data.cal,
                         Namespace(class_path='calendar.Calendar'))
コード例 #5
0
    def propagate(
        self,
        from_blocks: List[Namespace],
        block: Namespace,
        propagators: dict,
        ext_vars: dict,
        cwd: str = None,
    ):
        """Method that propagates shapes in the given block.

        Args:
            from_blocks: The input blocks.
            block: The block to propagate its shapes.
            propagators: Dictionary of propagators.
            ext_vars: Dictionary of external variables required to load jsonnet.
            cwd: Working directory to resolve relative paths.

        Raises:
            ValueError: If there are multiple blocks with the same id.
            ValueError: If no propagator found for some block.
        """
        add_ids_prefix(block, from_blocks)
        blocks = get_blocks_dict(from_blocks + block.blocks)
        topological_predecessors = parse_graph(from_blocks, block)
        try:
            propagate_shapes(blocks,
                             topological_predecessors,
                             propagators=propagators,
                             ext_vars=ext_vars,
                             cwd=cwd)
        except Exception as ex:
            raise type(ex)(f'block[id={block._id}]: {ex}') from ex
        in_shape = get_shape('out', from_blocks[0])
        out_shape = get_shape('out', block.blocks[-1])
        block._shape = create_shape(in_shape, out_shape)
コード例 #6
0
ファイル: same.py プロジェクト: omni-us/narchi
    def propagate(self, from_blocks: List[Namespace], block: Namespace):
        """Method that propagates shapes to a block.

        Args:
            from_blocks: The input blocks.
            block: The block to propagate its shapes.
        """
        block._shape = create_shape(get_shape('out', from_blocks[0]))
コード例 #7
0
ファイル: module.py プロジェクト: omni-us/narchi
    def propagate(
        self,
        from_blocks: List[Namespace],
        block: Namespace,
        propagators: dict = None,
        ext_vars: Namespace = {},
        cwd: str = None,
    ):
        """Method that propagates shapes through a module.

        Args:
            from_blocks: The input blocks.
            block: The block to propagate its shapes.
            propagators: Dictionary of propagators.
            ext_vars: External variables required to load jsonnet.
            cwd: Working directory to resolve relative paths.

        Raises:
            ValueError: If no propagator found for some block.
        """
        block_ext_vars = deepcopy(ext_vars)
        if ext_vars is None:
            block_ext_vars = Namespace()
        elif isinstance(ext_vars, dict):
            block_ext_vars = Namespace(**block_ext_vars)
        if hasattr(block, '_ext_vars'):
            vars(block_ext_vars).update(vars(block._ext_vars))
        cfg = {
            'ext_vars': block_ext_vars,
            'cwd': cwd,
            'parent_id': block._id,
            'propagate': False,
            'propagators': propagators
        }
        module = ModuleArchitecture(block._path, cfg=cfg)
        self.connect_input(from_blocks, block, module)
        module.propagate()
        block._shape = module.architecture._shape
        delattr(module.architecture, '_shape')
        block.architecture = module.architecture
コード例 #8
0
    def propagate(self, from_blocks: List[Namespace], block: Namespace):
        """Method that propagates shapes to a block.

        Args:
            from_blocks: The input blocks.
            block: The block to propagate its shapes.
        """
        shape_in = list(get_shape('out', from_blocks[0]))
        shape_in[block.dim] = None
        shape_out = list(shape_in)
        shape_out[block.dim] = sum(
            [get_shape('out', b)[block.dim] for b in from_blocks])
        block._shape = create_shape(shape_in, shape_out)
コード例 #9
0
ファイル: fixed.py プロジェクト: omni-us/narchi
    def propagate(self, from_blocks: List[Namespace], block: Namespace):
        """Method that propagates shapes to a block.

        Args:
            from_blocks: The input blocks.
            block: The block to propagate its shapes.
        """
        from_shape = get_shape('out', from_blocks[0])
        if self.fixed_dims == 1:
            to_shape = from_shape + [block.output_feats]
        else:
            to_shape = from_shape + block.output_feats
        block._shape = create_shape(from_shape, to_shape)
コード例 #10
0
    def test_class_type_without_defaults(self):
        class MyCal(Calendar):
            def __init__(self, p1: int = 1, p2: str = '2'):
                pass

        parser = ArgumentParser(error_handler=None)
        parser.add_argument('--op', type=MyCal)

        with mock_module(MyCal) as module:
            cfg = parser.parse_args(
                [f'--op.class_path={module}.MyCal', '--op.init_args.p1=3'],
                defaults=False)
            self.assertEqual(
                cfg.op,
                Namespace(class_path=f'{module}.MyCal',
                          init_args=Namespace(p1=3)))
            cfg = parser.parse_args([
                '--op.class_path', f'{module}.MyCal', '--op.init_args.p1', '3'
            ],
                                    defaults=False)
            self.assertEqual(
                cfg.op,
                Namespace(class_path=f'{module}.MyCal',
                          init_args=Namespace(p1=3)))
コード例 #11
0
    def test_class_type_required_params(self):
        class MyCal(Calendar):
            def __init__(self, p1: int, p2: str):
                pass

        with mock_module(MyCal) as module:
            parser = ArgumentParser(error_handler=None)
            parser.add_argument('--op',
                                type=MyCal,
                                default=lazy_instance(MyCal))

            cfg = parser.get_defaults()
            self.assertEqual(cfg.op.class_path, f'{module}.MyCal')
            self.assertEqual(cfg.op.init_args, Namespace(p1=None, p2=None))
            self.assertRaises(
                ParserError,
                lambda: parser.parse_args([f'--op={module}.MyCal']))
コード例 #12
0
    def test_mapping_class_typehint(self):
        class A:
            pass

        class B:
            def __init__(
                self,
                class_map: Mapping[str, A],
                int_list: List[int],
            ):
                self.class_map = class_map
                self.int_list = int_list

        with mock_module(A, B) as module:
            parser = ArgumentParser(error_handler=None)
            parser.add_class_arguments(B, 'b')

            config = {
                'b': {
                    'class_map': {
                        'one': {
                            'class_path': f'{module}.A'
                        },
                    },
                    'int_list': [1],
                },
            }

            cfg = parser.parse_object(config)
            self.assertEqual(cfg.b.class_map,
                             {'one': Namespace(class_path=f'{module}.A')})
            self.assertEqual(cfg.b.int_list, [1])

            cfg_init = parser.instantiate_classes(cfg)
            self.assertIsInstance(cfg_init.b, B)
            self.assertIsInstance(cfg_init.b.class_map, dict)
            self.assertIsInstance(cfg_init.b.class_map['one'], A)

            config['b']['int_list'] = config['b']['class_map']
            self.assertRaises(ParserError, lambda: parser.parse_object(config))
コード例 #13
0
    def test_class_type_subclass_nested_init_args(self):
        class Class:
            def __init__(self, cal: Calendar, p1: int = 0):
                self.cal = cal

        for full in ['init_args.', '']:
            with self.subTest('full' if full else 'short'), mock_module(
                    Class) as module:
                parser = ArgumentParser()
                parser.add_argument('--op', type=Class)
                cfg = parser.parse_args([
                    f'--op={module}.Class',
                    f'--op.{full}p1=1',
                    f'--op.{full}cal=calendar.TextCalendar',
                    f'--op.{full}cal.{full}firstweekday=2',
                ])
                self.assertEqual(cfg.op.class_path, f'{module}.Class')
                self.assertEqual(cfg.op.init_args.p1, 1)
                self.assertEqual(cfg.op.init_args.cal.class_path,
                                 'calendar.TextCalendar')
                self.assertEqual(cfg.op.init_args.cal.init_args,
                                 Namespace(firstweekday=2))
コード例 #14
0
 def test_class_type_subclass_short_init_args(self):
     parser = ArgumentParser()
     parser.add_argument('--op', type=Calendar)
     cfg = parser.parse_args(['--op=TextCalendar', '--op.firstweekday=2'])
     self.assertEqual(cfg.op.class_path, 'calendar.TextCalendar')
     self.assertEqual(cfg.op.init_args, Namespace(firstweekday=2))