def __init__(self, net_def: dict, inputs_stack_name: str, output_names: list, in_chs: list, name: str): super(NetModule, self).__init__(name=name) self.execution_list = [] self.inputs_stack_name = inputs_stack_name self.output_names = output_names # get inputs definition inputs_def = net_def[self.inputs_stack_name] input_blocks_defs = [] for inp_stage in inputs_def: for inp_block in inp_stage: input_blocks_defs.append(inp_block) self.inputs_ops = [] for idx, inp_block_def in enumerate(input_blocks_defs): name, op = self.get_input_oper(inp_block_def) self.inputs_ops.append( InputItem(input_name=name, oper=op, channels=in_chs[idx])) # build stacks for stack_name, stack in net_def.items(): if stack_name not in inputs_stack_name: stack_without_input_stage = stack[ 1:] # the first item on the list is the input stage # and shouldn't be processed inside the StackModule inputs_refs = self.get_stack_inputs_refs(stack[0][0]) inputs_chs = self.get_stack_input_channels(inputs_refs) sm = StackModule(stack_without_input_stage, inputs_chs, clean_name(stack_name)) out_chs = sm.get_last_channels_num() self.__setattr__(clean_name(stack_name), sm) self.execution_list.append( StackExecutionItem(stack_name=stack_name, inputs_refs=inputs_refs, out_chs=out_chs))
def parse_stack(self): all_blocks_args = decode_arch_def(self.stack_def) for block_args in all_blocks_args: ei = self.create_execution_item(block_args, self.in_chs) self.in_chs = [ei.out_chs] if not ei.is_oper: self.__setattr__(clean_name(ei.block_name), ei.block) self.execution_list.append(ei)
def make_block(arch_args, in_chs): """ Creates a block instance """ block_type_code = arch_args['block_type'] out_chs = arch_args.get('out_chs') if out_chs is None: out_chs = in_chs arch_args['name'] = clean_name(arch_args['name']) block_type_def = NetBuilderConfig.get_block_type(block_type_code) args = block_type_def.transform_args_fn(in_chs, arch_args) block = block_type_def.block_type_fn(**args) return block, out_chs
def call(self, x): evaluated_tensors = {} for ei in self.execution_list: if ei.is_oper: input_tensor_names = ei.block.get_input_tenors_names() input_tensors = [self.resolve_reference(it, evaluated_tensors) for it in input_tensor_names] if len(input_tensors) > 0: x = ei.block(input_tensors) else: x = ei.block(x) else: module = getattr(self, clean_name(ei.block_name)) x = module(x) if '#' in ei.block_name: evaluated_tensors[ei.block_name] = x evaluated_tensors.clear() return x
def call(self, x): evaluated_stacks = {} # prepare inputs for idx, op in enumerate(self.inputs_ops): inputs = x if isinstance(x, tf.Tensor) else x[idx] ii = op.oper(inputs) evaluated_stacks[op.input_name] = ii # run all the stacks for stack_exec_item in self.execution_list: stack_name = stack_exec_item.stack_name inputs_refs = stack_exec_item.inputs_refs module = getattr(self, clean_name(stack_name)) stack_inputs = [evaluated_stacks[ref] for ref in inputs_refs] if len(stack_inputs) == 1: stack_inputs = stack_inputs[0] x = module(stack_inputs) evaluated_stacks[stack_name] = x # gather outputs out_tensors = [] for on in self.output_names: t = evaluated_stacks.get(on) out_tensors.append(t) evaluated_stacks.clear() return tuple(out_tensors)