def propagate(self, from_blocks: List[Namespace], block: Namespace): """Method that propagates shapes to a block. Args: from_blocks: The input blocks. block: The block to propagate its shapes. Raises: ValueError: When bidirectional==True and output_feats not even. """ ## Set default values ## if not hasattr(block, 'bidirectional'): block.bidirectional = False ## Initialize block._shape ## from_shape = get_shape('out', from_blocks[0]) output_feats = block.output_feats block._shape = create_shape(from_shape, [auto_tag, output_feats]) ## Set hidden size ## if block.bidirectional and output_feats % 2 != 0: raise ValueError( f'For bidirectional {block._class} expected output_feats to be even, but got {output_feats}.' ) block.hidden_size = output_feats // (2 if block.bidirectional else 1) ## Propagate first dimension ## set_shape_dim('out', block, 0, from_shape[0])
def propagate(self, from_blocks: List[Namespace], block: Namespace): """Method that propagates shapes to a block. Args: from_blocks: The input blocks. block: The block to propagate its shapes. Raises: ValueError: When block.output_feats not valid. NotImplementedError: If num_features_source is not one of {"from_shape", "output_feats"}. """ ## Set default values ## kernel = block.kernel_size stride = block.stride if hasattr(block, 'stride') else 1 padding = block.padding if hasattr(block, 'padding') else 0 dilation = block.dilation if hasattr(block, 'dilation') else 1 ## Initialize block._shape ## auto_dims = [auto_tag for _ in range(self.conv_dims)] from_shape = get_shape('out', from_blocks[0]) if self.num_features_source == 'from_shape': block._shape = create_shape(from_shape, [from_shape[0]] + auto_dims) elif self.num_features_source == 'output_feats': check_output_feats_dims(1, self.block_class, block) block._shape = create_shape(from_shape, [block.output_feats] + auto_dims) ## Calculate and set <<auto>> output dimensions ## for dim, val in enumerate(get_shape('out', block)): if val == auto_tag: in_length = get_shape('in', block)[dim] out_length = conv_out_length(in_length, kernel, stride, padding, dilation) set_shape_dim('out', block, dim, out_length)
def propagate(self, from_blocks: List[Namespace], block: Namespace): """Method that propagates shapes to a block. Args: from_blocks: The input blocks. block: The block to propagate its shapes. """ shape_in = get_shape('out', from_blocks[0]) shape_out = [] if block.reshape_spec == 'flatten': reshape_spec = [[n for n in range(len(shape_in))]] else: reshape_spec = norm_reshape_spec(block.reshape_spec) for val in reshape_spec: if isinstance(val, int): shape_out.append(shape_in[val]) elif isinstance(val, list): shape_out.append(prod([shape_in[x] for x in val])) elif isinstance(val, dict): idx = next(iter(val.keys())) in_dim = shape_in[int(idx)] dims = val[idx] if any(x == auto_tag for x in dims): auto_idx = dims.index(auto_tag) nonauto = prod([x for x in dims if x != auto_tag]) dims[auto_idx] = divide(in_dim, nonauto) shape_out.extend(dims) block._shape = create_shape(shape_in, shape_out)
def test_class_type_with_default_config_files(self): config = { 'class_path': 'calendar.Calendar', 'init_args': { 'firstweekday': 3 }, } config_path = os.path.join(self.tmpdir, 'config.yaml') with open(config_path, 'w') as f: json.dump({'data': {'cal': config}}, f) class MyClass: def __init__(self, cal: Optional[Calendar] = None, val: int = 2): self.cal = cal parser = ArgumentParser(error_handler=None, default_config_files=[config_path]) parser.add_argument('--op', default='from default') parser.add_class_arguments(MyClass, 'data') cfg = parser.get_defaults() self.assertEqual(config_path, str(cfg['__default_config__'])) self.assertEqual(cfg.data.cal.as_dict(), config) dump = parser.dump(cfg) self.assertIn('class_path: calendar.Calendar\n', dump) self.assertIn('firstweekday: 3\n', dump) cfg = parser.parse_args([]) self.assertEqual(cfg.data.cal.as_dict(), config) cfg = parser.parse_args(['--data.cal.class_path=calendar.Calendar'], defaults=False) self.assertEqual(cfg.data.cal, Namespace(class_path='calendar.Calendar'))
def propagate( self, from_blocks: List[Namespace], block: Namespace, propagators: dict, ext_vars: dict, cwd: str = None, ): """Method that propagates shapes in the given block. Args: from_blocks: The input blocks. block: The block to propagate its shapes. propagators: Dictionary of propagators. ext_vars: Dictionary of external variables required to load jsonnet. cwd: Working directory to resolve relative paths. Raises: ValueError: If there are multiple blocks with the same id. ValueError: If no propagator found for some block. """ add_ids_prefix(block, from_blocks) blocks = get_blocks_dict(from_blocks + block.blocks) topological_predecessors = parse_graph(from_blocks, block) try: propagate_shapes(blocks, topological_predecessors, propagators=propagators, ext_vars=ext_vars, cwd=cwd) except Exception as ex: raise type(ex)(f'block[id={block._id}]: {ex}') from ex in_shape = get_shape('out', from_blocks[0]) out_shape = get_shape('out', block.blocks[-1]) block._shape = create_shape(in_shape, out_shape)
def propagate(self, from_blocks: List[Namespace], block: Namespace): """Method that propagates shapes to a block. Args: from_blocks: The input blocks. block: The block to propagate its shapes. """ block._shape = create_shape(get_shape('out', from_blocks[0]))
def propagate( self, from_blocks: List[Namespace], block: Namespace, propagators: dict = None, ext_vars: Namespace = {}, cwd: str = None, ): """Method that propagates shapes through a module. Args: from_blocks: The input blocks. block: The block to propagate its shapes. propagators: Dictionary of propagators. ext_vars: External variables required to load jsonnet. cwd: Working directory to resolve relative paths. Raises: ValueError: If no propagator found for some block. """ block_ext_vars = deepcopy(ext_vars) if ext_vars is None: block_ext_vars = Namespace() elif isinstance(ext_vars, dict): block_ext_vars = Namespace(**block_ext_vars) if hasattr(block, '_ext_vars'): vars(block_ext_vars).update(vars(block._ext_vars)) cfg = { 'ext_vars': block_ext_vars, 'cwd': cwd, 'parent_id': block._id, 'propagate': False, 'propagators': propagators } module = ModuleArchitecture(block._path, cfg=cfg) self.connect_input(from_blocks, block, module) module.propagate() block._shape = module.architecture._shape delattr(module.architecture, '_shape') block.architecture = module.architecture
def propagate(self, from_blocks: List[Namespace], block: Namespace): """Method that propagates shapes to a block. Args: from_blocks: The input blocks. block: The block to propagate its shapes. """ shape_in = list(get_shape('out', from_blocks[0])) shape_in[block.dim] = None shape_out = list(shape_in) shape_out[block.dim] = sum( [get_shape('out', b)[block.dim] for b in from_blocks]) block._shape = create_shape(shape_in, shape_out)
def propagate(self, from_blocks: List[Namespace], block: Namespace): """Method that propagates shapes to a block. Args: from_blocks: The input blocks. block: The block to propagate its shapes. """ from_shape = get_shape('out', from_blocks[0]) if self.fixed_dims == 1: to_shape = from_shape + [block.output_feats] else: to_shape = from_shape + block.output_feats block._shape = create_shape(from_shape, to_shape)
def test_class_type_without_defaults(self): class MyCal(Calendar): def __init__(self, p1: int = 1, p2: str = '2'): pass parser = ArgumentParser(error_handler=None) parser.add_argument('--op', type=MyCal) with mock_module(MyCal) as module: cfg = parser.parse_args( [f'--op.class_path={module}.MyCal', '--op.init_args.p1=3'], defaults=False) self.assertEqual( cfg.op, Namespace(class_path=f'{module}.MyCal', init_args=Namespace(p1=3))) cfg = parser.parse_args([ '--op.class_path', f'{module}.MyCal', '--op.init_args.p1', '3' ], defaults=False) self.assertEqual( cfg.op, Namespace(class_path=f'{module}.MyCal', init_args=Namespace(p1=3)))
def test_class_type_required_params(self): class MyCal(Calendar): def __init__(self, p1: int, p2: str): pass with mock_module(MyCal) as module: parser = ArgumentParser(error_handler=None) parser.add_argument('--op', type=MyCal, default=lazy_instance(MyCal)) cfg = parser.get_defaults() self.assertEqual(cfg.op.class_path, f'{module}.MyCal') self.assertEqual(cfg.op.init_args, Namespace(p1=None, p2=None)) self.assertRaises( ParserError, lambda: parser.parse_args([f'--op={module}.MyCal']))
def test_mapping_class_typehint(self): class A: pass class B: def __init__( self, class_map: Mapping[str, A], int_list: List[int], ): self.class_map = class_map self.int_list = int_list with mock_module(A, B) as module: parser = ArgumentParser(error_handler=None) parser.add_class_arguments(B, 'b') config = { 'b': { 'class_map': { 'one': { 'class_path': f'{module}.A' }, }, 'int_list': [1], }, } cfg = parser.parse_object(config) self.assertEqual(cfg.b.class_map, {'one': Namespace(class_path=f'{module}.A')}) self.assertEqual(cfg.b.int_list, [1]) cfg_init = parser.instantiate_classes(cfg) self.assertIsInstance(cfg_init.b, B) self.assertIsInstance(cfg_init.b.class_map, dict) self.assertIsInstance(cfg_init.b.class_map['one'], A) config['b']['int_list'] = config['b']['class_map'] self.assertRaises(ParserError, lambda: parser.parse_object(config))
def test_class_type_subclass_nested_init_args(self): class Class: def __init__(self, cal: Calendar, p1: int = 0): self.cal = cal for full in ['init_args.', '']: with self.subTest('full' if full else 'short'), mock_module( Class) as module: parser = ArgumentParser() parser.add_argument('--op', type=Class) cfg = parser.parse_args([ f'--op={module}.Class', f'--op.{full}p1=1', f'--op.{full}cal=calendar.TextCalendar', f'--op.{full}cal.{full}firstweekday=2', ]) self.assertEqual(cfg.op.class_path, f'{module}.Class') self.assertEqual(cfg.op.init_args.p1, 1) self.assertEqual(cfg.op.init_args.cal.class_path, 'calendar.TextCalendar') self.assertEqual(cfg.op.init_args.cal.init_args, Namespace(firstweekday=2))
def test_class_type_subclass_short_init_args(self): parser = ArgumentParser() parser.add_argument('--op', type=Calendar) cfg = parser.parse_args(['--op=TextCalendar', '--op.firstweekday=2']) self.assertEqual(cfg.op.class_path, 'calendar.TextCalendar') self.assertEqual(cfg.op.init_args, Namespace(firstweekday=2))