def _build(self, s_ins: ShapeList, features_mul=1, features_fixed=-1) -> ShapeList: assert len(s_ins) == self.num_inputs() c_out = self._num_output_features(s_ins, features_mul, features_fixed) c_inner = c_out // self.concat.num is_prev_reduce = s_ins[0].shape[1] > s_ins[1].shape[1] base_kwargs = dict(use_bn=True, bn_affine=True, act_fun='relu', order='act_w_bn') # if the previous layer reduces the spatial size, the layer before that has larger sizes than this one! if is_prev_reduce: self.preprocess.append( FactorizedReductionLayer(stride=2, **base_kwargs)) else: self.preprocess.append( ConvLayer(k_size=1, dilation=1, stride=1, **base_kwargs)) s_inner_p0 = self.preprocess[0].build(s_ins[0], c_inner) self.preprocess.append( ConvLayer(k_size=1, dilation=1, stride=1, **base_kwargs)) s_inner_p1 = self.preprocess[1].build(s_ins[1], c_inner) inner_shapes = [s_inner_p0, s_inner_p1] for m in self.blocks: s = m.build(inner_shapes, c_inner) inner_shapes.append(s) s_ins.append(self.concat.build(inner_shapes, c_out)) return ShapeList(s_ins[-self._num_outputs:])
def _build(self, s_ins: ShapeList, features_mul=1, features_fixed=-1) -> ShapeList: assert len(s_ins) == self.num_inputs() c_out = self._num_output_features(s_ins, features_mul, features_fixed) inner_shapes = s_ins.copy() for m in self.blocks: s = m.build(inner_shapes, c_out) inner_shapes.append(s) return ShapeList([s_ins[-1]])
def _get_cell_output_shapes(self) -> ShapeList: """ output shape(s) of each cell in order """ training = self.training self.train(False) x = self.shape_in.random_tensor(batch_size=2) x = self.get_stem()(x) shapes = ShapeList([]) for i in range(self.num_cells()): x = self.specific_forward(x, start_cell=i, end_cell=i) shapes.append(ShapeList.from_tensors(x)) self.train(training) return shapes
def _get_stem_output_shape(self) -> ShapeList: """ output shapes of the stem """ training = self.training self.train(False) y = self.specific_forward(self.shape_in.random_tensor(batch_size=2), start_cell=-1, end_cell=-1) self.train(training) return ShapeList.from_tensors(y)
def _build(self, s_in: Shape, s_out: Shape) -> ShapeList: LoggerManager().get_logger().info('Building %s:' % self.__class__.__name__) rows = [('cell index', 'name', 'class', 'input shapes', '', 'output shapes', '#params')] def get_row(idx, name: str, obj: AbstractModule) -> tuple: s_in_str = obj.get_shape_in().str() s_inner = obj.get_cached('shape_inner') s_inner_str = '' if s_inner is None else s_inner.str() s_out_str = obj.get_shape_out().str() return str(idx), name, obj.__class__.__name__, s_in_str, s_inner_str, s_out_str, count_parameters(obj) s_out_data = s_out.copy() out_shapes = self.stem.build(s_in) final_out_shapes = [] rows.append(get_row('', '-', self.stem)) # cells and (aux) heads updated_cell_order = [] for i, cell_name in enumerate(self.cell_order): strategy_name, cell = self._get_cell(name=cell_name, cell_index=i) assert self.stem.num_outputs() == cell.num_inputs() == cell.num_outputs(), 'Cell does not fit the network!' updated_cell_order.append(cell.name) s_ins = out_shapes[-cell.num_inputs():] with StrategyManagerDefault(strategy_name): s_out = cell.build(s_ins.copy(), features_mul=self.features_mul, features_fixed=self.features_first_cell if i == 0 else -1) out_shapes.extend(s_out) rows.append(get_row(i, cell_name, cell)) self.cells.append(cell) # optional (aux) head after every cell head = self._head_positions.get(i, None) if head is not None: if head.weight > 0: final_out_shapes.append(head.build(s_out[-1], s_out_data)) rows.append(get_row('', '-', head)) else: LoggerManager().get_logger().info('not adding head after cell %d, weight <= 0' % i) del self._head_positions[i] else: assert i != len(self.cell_order) - 1, "Must have a head after the final cell" # remove heads that are impossible to add for i in self._head_positions.keys(): if i >= len(self.cells): LoggerManager().get_logger().warning('Can not add a head after cell %d which does not exist, deleting the head!' % i) head = self._head_positions.get(i) for j, head2 in enumerate(self.heads): if head is head2: self.heads.__delitem__(j) break s_out = ShapeList(final_out_shapes) rows.append(('complete network', '', '', self.get_shape_in().str(), '', s_out.str(), count_parameters(self))) log_in_columns(LoggerManager().get_logger(), rows, start_space=4) self.set(cell_order=updated_cell_order) return s_out
def _get_network_output_shapes(self) -> ShapeList: """ output shapes of the network """ training = self.training self.train(False) cell_shapes = self.get_cell_output_shapes() y = self.specific_forward(cell_shapes.shapes[-1].random_tensor(batch_size=2), start_cell=len(cell_shapes), end_cell=None) self.train(training) return ShapeList.from_tensors(y)
def _build2(self, s_in: Shape, s_out: Shape) -> ShapeList: """ build the network """ assert s_in.num_dims() == s_out.num_dims() == 1 s_cur = self.stem.build(s_in, c_out=self._layer_widths[0]) for i in range(len(self._layer_widths) - 1): s_cur = self.cells[i].build(s_cur, c_out=self._layer_widths[i + 1]) s_heads = [ h.build(s_cur, c_out=s_out.num_features()) for h in self.heads ] return ShapeList(s_heads)
def _build(self, s_ins: ShapeList, features_mul=1, features_fixed=-1) -> ShapeList: assert len(s_ins) == self.num_inputs() return ShapeList([ self.op.build( s_ins[0], self._num_output_features(s_ins, features_mul, features_fixed)) ])
def probe_outputs(self, s_in: ShapeOrList, module: nn.Module = None, multiple_outputs=False) -> ShapeOrList: """ returning the output shape of one forward pass using zero tensors """ with torch.no_grad(): if module is None: module = self x = s_in.random_tensor(batch_size=2) s = module(x) if multiple_outputs: return ShapeList([Shape(list(sx.shape)[1:]) for sx in s]) return Shape(list(s.shape)[1:])
def _get_cell_input_shapes_uncached(self) -> ShapeList: shapes = ShapeList([self.input_shape]) shapes.extend(self._get_cell_output_shapes()) return shapes[:-1]
def _get_cell_output_shapes(self) -> ShapeList: return ShapeList([c.get_shape_out() for c in self.get_cells()])
def _get_cell_input_shapes(self) -> ShapeList: """ input shape(s) of each cell in order """ shapes = ShapeList([self._get_stem_output_shape()]) shapes.extend(self._get_cell_output_shapes()) return shapes[:-1]
def _build(self, s_in: Shape) -> ShapeList: cm = self.features // 2 s0 = self.stem00.build(s_in, cm) s1 = self.stem01.build(s0, self.features) s2 = self.stem1.build(s1, self.features) return ShapeList([s1, s2])
def _build(self, s_in: Shape) -> ShapeList: s = self.stem_module.build(s_in, self.features) return ShapeList([s, s])
def _build(self, s_in: Shape) -> ShapeList: """ build the stem for the data set, return list of output feature shapes """ self.cached['shape_in'] = s_in return ShapeList([self.stem_module.build(s_in, self.features)])
def _build(self, s_in: Shape) -> ShapeList: """ build the stem for the data set, return list of output feature sizes """ s0 = self.stem0.build(s_in, self.features) self.cached['shape_inner'] = s0 s1 = self.stem1.build(s0, self.features1) return ShapeList([s1])