示例#1
0
    def _build(self, s_in: Shape, s_out: Shape) -> ShapeList:
        """ build the network, count params, log, maybe load pretrained weights """
        assert isinstance(s_out, Shape), "Attempting to build a network with an output that is not a Shape!"
        s_out_copy = s_out.copy(copy_id=True)
        self.shape_in = s_in.copy(copy_id=True)
        s_out_net = self._build2(s_in, s_out)
        LoggerManager().get_logger().info('Network built, it has %d parameters!' % self.get_num_parameters())

        # validate output shape sizes
        assert isinstance(s_out_net, ShapeList), "The network must output a list of Shapes, one shape per head! (ShapeList)"
        for shape in s_out_net.shapes:
            if not s_out_copy == shape:
                text = "One or more output shapes mismatch: %s, expected: %s" % (s_out_net, s_out_copy)
                if self.assert_output_match:
                    raise ValueError(text)
                else:
                    LoggerManager().get_logger().warning(text)
                    break

        # load weights?
        if len(self.checkpoint_path) > 0:
            path = CheckpointCallback.find_pretrained_weights_path(self.checkpoint_path, self.model_name,
                                                                   raise_missing=len(self.checkpoint_path) > 0)
            num_replacements = 1 if self.is_external() else 999
            self.loaded_weights(CheckpointCallback.load_network(path, self.get_network(), num_replacements))

        self.shape_out = s_out_net.shapes[0].copy(copy_id=True)
        self.shape_in_list = self.shape_in.shape
        self.shape_out_list = self.shape_out.shape
        return s_out_net
示例#2
0
    def _build(self, s_in: Shape, s_out: Shape) -> ShapeList:
        LoggerManager().get_logger().info('Building %s:' % self.__class__.__name__)
        rows = [('cell index', 'name', 'class', 'input shapes', '', 'output shapes', '#params')]

        def get_row(idx, name: str, obj: AbstractModule) -> tuple:
            s_in_str = obj.get_shape_in().str()
            s_inner = obj.get_cached('shape_inner')
            s_inner_str = '' if s_inner is None else s_inner.str()
            s_out_str = obj.get_shape_out().str()
            return str(idx), name, obj.__class__.__name__, s_in_str, s_inner_str, s_out_str, count_parameters(obj)

        s_out_data = s_out.copy()
        out_shapes = self.stem.build(s_in)
        final_out_shapes = []
        rows.append(get_row('', '-', self.stem))

        # cells and (aux) heads
        updated_cell_order = []
        for i, cell_name in enumerate(self.cell_order):
            strategy_name, cell = self._get_cell(name=cell_name, cell_index=i)
            assert self.stem.num_outputs() == cell.num_inputs() == cell.num_outputs(), 'Cell does not fit the network!'
            updated_cell_order.append(cell.name)
            s_ins = out_shapes[-cell.num_inputs():]
            with StrategyManagerDefault(strategy_name):
                s_out = cell.build(s_ins.copy(),
                                   features_mul=self.features_mul,
                                   features_fixed=self.features_first_cell if i == 0 else -1)
            out_shapes.extend(s_out)
            rows.append(get_row(i, cell_name, cell))
            self.cells.append(cell)

            # optional (aux) head after every cell
            head = self._head_positions.get(i, None)
            if head is not None:
                if head.weight > 0:
                    final_out_shapes.append(head.build(s_out[-1], s_out_data))
                    rows.append(get_row('', '-', head))
                else:
                    LoggerManager().get_logger().info('not adding head after cell %d, weight <= 0' % i)
                    del self._head_positions[i]
            else:
                assert i != len(self.cell_order) - 1, "Must have a head after the final cell"

        # remove heads that are impossible to add
        for i in self._head_positions.keys():
            if i >= len(self.cells):
                LoggerManager().get_logger().warning('Can not add a head after cell %d which does not exist, deleting the head!' % i)
                head = self._head_positions.get(i)
                for j, head2 in enumerate(self.heads):
                    if head is head2:
                        self.heads.__delitem__(j)
                        break

        s_out = ShapeList(final_out_shapes)
        rows.append(('complete network', '', '', self.get_shape_in().str(), '', s_out.str(), count_parameters(self)))
        log_in_columns(LoggerManager().get_logger(), rows, start_space=4)
        self.set(cell_order=updated_cell_order)
        return s_out
示例#3
0
 def _build(self, s_in: Shape, s_out: Shape) -> ShapeList:
     """ build the network, count params, log, maybe load pretrained weights """
     s_in_net = s_in.copy(copy_id=True)
     super()._build(s_in, s_out)
     rows = [('cell index', 'input shapes', 'output shapes', '#params'),
             ('stem', s_in.str(), self.get_stem_output_shape(), count_parameters(self.get_stem()))]
     LoggerManager().get_logger().info('%s (%s):' % (self.__class__.__name__, self.model_name))
     for i, (s_in, s_out, cell) in enumerate(zip(self.get_cell_input_shapes(flatten=False),
                                                 self.get_cell_output_shapes(flatten=False), self.get_cells())):
         rows.append((i, s_in.str(), s_out.str(), count_parameters(cell)))
     rows.append(('head(s)', self.get_heads_input_shapes(), self.get_network_output_shapes(flatten=False),
                  count_parameters(self.get_heads())))
     rows.append(("complete network", s_in_net.str(), self.get_network_output_shapes(flatten=False),
                  count_parameters(self)))
     log_in_columns(LoggerManager().get_logger(), rows, start_space=4)
     return self.get_network_output_shapes(flatten=False)