def _build(self, s_in: Shape, s_out: Shape) -> ShapeList: """ build the network, count params, log, maybe load pretrained weights """ assert isinstance(s_out, Shape), "Attempting to build a network with an output that is not a Shape!" s_out_copy = s_out.copy(copy_id=True) self.shape_in = s_in.copy(copy_id=True) s_out_net = self._build2(s_in, s_out) LoggerManager().get_logger().info('Network built, it has %d parameters!' % self.get_num_parameters()) # validate output shape sizes assert isinstance(s_out_net, ShapeList), "The network must output a list of Shapes, one shape per head! (ShapeList)" for shape in s_out_net.shapes: if not s_out_copy == shape: text = "One or more output shapes mismatch: %s, expected: %s" % (s_out_net, s_out_copy) if self.assert_output_match: raise ValueError(text) else: LoggerManager().get_logger().warning(text) break # load weights? if len(self.checkpoint_path) > 0: path = CheckpointCallback.find_pretrained_weights_path(self.checkpoint_path, self.model_name, raise_missing=len(self.checkpoint_path) > 0) num_replacements = 1 if self.is_external() else 999 self.loaded_weights(CheckpointCallback.load_network(path, self.get_network(), num_replacements)) self.shape_out = s_out_net.shapes[0].copy(copy_id=True) self.shape_in_list = self.shape_in.shape self.shape_out_list = self.shape_out.shape return s_out_net
def _build(self, s_in: Shape, s_out: Shape) -> ShapeList: LoggerManager().get_logger().info('Building %s:' % self.__class__.__name__) rows = [('cell index', 'name', 'class', 'input shapes', '', 'output shapes', '#params')] def get_row(idx, name: str, obj: AbstractModule) -> tuple: s_in_str = obj.get_shape_in().str() s_inner = obj.get_cached('shape_inner') s_inner_str = '' if s_inner is None else s_inner.str() s_out_str = obj.get_shape_out().str() return str(idx), name, obj.__class__.__name__, s_in_str, s_inner_str, s_out_str, count_parameters(obj) s_out_data = s_out.copy() out_shapes = self.stem.build(s_in) final_out_shapes = [] rows.append(get_row('', '-', self.stem)) # cells and (aux) heads updated_cell_order = [] for i, cell_name in enumerate(self.cell_order): strategy_name, cell = self._get_cell(name=cell_name, cell_index=i) assert self.stem.num_outputs() == cell.num_inputs() == cell.num_outputs(), 'Cell does not fit the network!' updated_cell_order.append(cell.name) s_ins = out_shapes[-cell.num_inputs():] with StrategyManagerDefault(strategy_name): s_out = cell.build(s_ins.copy(), features_mul=self.features_mul, features_fixed=self.features_first_cell if i == 0 else -1) out_shapes.extend(s_out) rows.append(get_row(i, cell_name, cell)) self.cells.append(cell) # optional (aux) head after every cell head = self._head_positions.get(i, None) if head is not None: if head.weight > 0: final_out_shapes.append(head.build(s_out[-1], s_out_data)) rows.append(get_row('', '-', head)) else: LoggerManager().get_logger().info('not adding head after cell %d, weight <= 0' % i) del self._head_positions[i] else: assert i != len(self.cell_order) - 1, "Must have a head after the final cell" # remove heads that are impossible to add for i in self._head_positions.keys(): if i >= len(self.cells): LoggerManager().get_logger().warning('Can not add a head after cell %d which does not exist, deleting the head!' % i) head = self._head_positions.get(i) for j, head2 in enumerate(self.heads): if head is head2: self.heads.__delitem__(j) break s_out = ShapeList(final_out_shapes) rows.append(('complete network', '', '', self.get_shape_in().str(), '', s_out.str(), count_parameters(self))) log_in_columns(LoggerManager().get_logger(), rows, start_space=4) self.set(cell_order=updated_cell_order) return s_out
def _build(self, s_in: Shape, s_out: Shape) -> ShapeList: """ build the network, count params, log, maybe load pretrained weights """ s_in_net = s_in.copy(copy_id=True) super()._build(s_in, s_out) rows = [('cell index', 'input shapes', 'output shapes', '#params'), ('stem', s_in.str(), self.get_stem_output_shape(), count_parameters(self.get_stem()))] LoggerManager().get_logger().info('%s (%s):' % (self.__class__.__name__, self.model_name)) for i, (s_in, s_out, cell) in enumerate(zip(self.get_cell_input_shapes(flatten=False), self.get_cell_output_shapes(flatten=False), self.get_cells())): rows.append((i, s_in.str(), s_out.str(), count_parameters(cell))) rows.append(('head(s)', self.get_heads_input_shapes(), self.get_network_output_shapes(flatten=False), count_parameters(self.get_heads()))) rows.append(("complete network", s_in_net.str(), self.get_network_output_shapes(flatten=False), count_parameters(self))) log_in_columns(LoggerManager().get_logger(), rows, start_space=4) return self.get_network_output_shapes(flatten=False)