예제 #1
0
    def __init__(self, net_def: dict, inputs_stack_name: str,
                 output_names: list, in_chs: list, name: str):
        super(NetModule, self).__init__(name=name)

        self.execution_list = []
        self.inputs_stack_name = inputs_stack_name
        self.output_names = output_names

        # get inputs definition

        inputs_def = net_def[self.inputs_stack_name]
        input_blocks_defs = []
        for inp_stage in inputs_def:
            for inp_block in inp_stage:
                input_blocks_defs.append(inp_block)

        self.inputs_ops = []
        for idx, inp_block_def in enumerate(input_blocks_defs):
            name, op = self.get_input_oper(inp_block_def)
            self.inputs_ops.append(
                InputItem(input_name=name, oper=op, channels=in_chs[idx]))
        # build stacks

        for stack_name, stack in net_def.items():
            if stack_name not in inputs_stack_name:
                stack_without_input_stage = stack[
                    1:]  # the first item on the list is the input stage
                # and shouldn't be processed inside the StackModule
                inputs_refs = self.get_stack_inputs_refs(stack[0][0])

                inputs_chs = self.get_stack_input_channels(inputs_refs)

                sm = StackModule(stack_without_input_stage, inputs_chs,
                                 clean_name(stack_name))
                out_chs = sm.get_last_channels_num()

                self.__setattr__(clean_name(stack_name), sm)

                self.execution_list.append(
                    StackExecutionItem(stack_name=stack_name,
                                       inputs_refs=inputs_refs,
                                       out_chs=out_chs))
예제 #2
0
    def parse_stack(self):
        all_blocks_args = decode_arch_def(self.stack_def)

        for block_args in all_blocks_args:
            ei = self.create_execution_item(block_args, self.in_chs)
            self.in_chs = [ei.out_chs]

            if not ei.is_oper:
                self.__setattr__(clean_name(ei.block_name), ei.block)

            self.execution_list.append(ei)
예제 #3
0
def make_block(arch_args, in_chs):
    """
    Creates a block instance
    """
    block_type_code = arch_args['block_type']
    out_chs = arch_args.get('out_chs')
    if out_chs is None:
        out_chs = in_chs

    arch_args['name'] = clean_name(arch_args['name'])

    block_type_def = NetBuilderConfig.get_block_type(block_type_code)
    args = block_type_def.transform_args_fn(in_chs, arch_args)
    block = block_type_def.block_type_fn(**args)

    return block, out_chs
예제 #4
0
    def call(self, x):
        evaluated_tensors = {}
        for ei in self.execution_list:

            if ei.is_oper:
                input_tensor_names = ei.block.get_input_tenors_names()
                input_tensors = [self.resolve_reference(it, evaluated_tensors) for it in input_tensor_names]
                if len(input_tensors) > 0:
                    x = ei.block(input_tensors)
                else:
                    x = ei.block(x)
            else:
                module = getattr(self, clean_name(ei.block_name))
                x = module(x)

            if '#' in ei.block_name:
                evaluated_tensors[ei.block_name] = x

        evaluated_tensors.clear()
        return x
예제 #5
0
    def call(self, x):

        evaluated_stacks = {}

        # prepare inputs

        for idx, op in enumerate(self.inputs_ops):
            inputs = x if isinstance(x, tf.Tensor) else x[idx]

            ii = op.oper(inputs)
            evaluated_stacks[op.input_name] = ii

        # run all the stacks

        for stack_exec_item in self.execution_list:
            stack_name = stack_exec_item.stack_name
            inputs_refs = stack_exec_item.inputs_refs

            module = getattr(self, clean_name(stack_name))

            stack_inputs = [evaluated_stacks[ref] for ref in inputs_refs]
            if len(stack_inputs) == 1:
                stack_inputs = stack_inputs[0]

            x = module(stack_inputs)

            evaluated_stacks[stack_name] = x

        # gather outputs

        out_tensors = []
        for on in self.output_names:
            t = evaluated_stacks.get(on)
            out_tensors.append(t)

        evaluated_stacks.clear()

        return tuple(out_tensors)