def _evaluate_tuple(self, values: tuple): if not self.is_set_up: self.is_set_up = True checkpoint = torch.load(self.load_path) state_dict = checkpoint.get('state_dict', None) added_state = checkpoint.get('net_add_state', None) count_only_trainable = self.kwargs['count_only_trainable'] assert state_dict is not None, added_state is not None # stem / head weights are in every model if count_only_trainable: self.const += count_parameters( self.method.get_network().get_network().get_stem()) self.const += count_parameters( self.method.get_network().get_network().get_heads()) else: for k, v in state_dict.items(): if '.stem.' in k or '.heads.' in k: self.const += torch.numel(v) # variable num params depending on gene for choices in added_state.get('cells', list()): num_params = [0] * len(choices) for j, choice in enumerate(choices): for name, shape, trainable in choice: if trainable or not count_only_trainable: num_params[j] += torch.numel(state_dict[name]) self.choices.append(num_params) num_params = sum(self.get_params(i, g) for i, g in enumerate(values)) + self.const return num_params
def _build(self, s_in: Shape, s_out: Shape) -> ShapeList: """ build the network, count params, log, maybe load pretrained weights """ s_in_net = s_in.copy(copy_id=True) super()._build(s_in, s_out) rows = [('cell index', 'input shapes', 'output shapes', '#params'), ('stem', s_in.str(), self.get_stem_output_shape(), count_parameters(self.get_stem()))] LoggerManager().get_logger().info('%s (%s):' % (self.__class__.__name__, self.model_name)) for i, (s_in, s_out, cell) in enumerate(zip(self.get_cell_input_shapes(flatten=False), self.get_cell_output_shapes(flatten=False), self.get_cells())): rows.append((i, s_in.str(), s_out.str(), count_parameters(cell))) rows.append(('head(s)', self.get_heads_input_shapes(), self.get_network_output_shapes(flatten=False), count_parameters(self.get_heads()))) rows.append(("complete network", s_in_net.str(), self.get_network_output_shapes(flatten=False), count_parameters(self))) log_in_columns(LoggerManager().get_logger(), rows, start_space=4) return self.get_network_output_shapes(flatten=False)
def assert_stats_match(name, task_cfg, cfg: dict, num_params=None, num_macs=None): cfg_dir = replace_standard_paths('{path_tmp}/tests/cfgs/') cfg_path = Builder.save_config(cfg, cfg_dir, name) exp = Main.new_task( task_cfg, args_changes={ '{cls_data}.fake': True, '{cls_data}.batch_size_train': 2, '{cls_data}.batch_size_test': -1, '{cls_task}.is_test_run': True, '{cls_task}.save_dir': '{path_tmp}/tests/workdir/', "{cls_network}.config_path": cfg_path, "{cls_trainer}.ema_decay": -1, 'cls_network_heads': 'ClassificationHead', # necessary for the DARTS search space to disable the aux heads }, raise_unparsed=False) net = exp.get_method().get_network() macs = exp.get_method().profile_macs() net.eval() # print(net) cp = count_parameters(net) if num_params is not None: assert cp == num_params, 'Got unexpected num params for %s: %d, expected %d, diff: %d'\ % (name, cp, num_params, abs(cp - num_params)) if num_macs is not None: assert macs == num_macs, 'Got unexpected num macs for %s: %d, expected %d, diff: %d'\ % (name, macs, num_macs, abs(macs - num_macs))
(256, ShuffleNetV2Layer, defaults, dict(stride=2, k_size=7, act_fun='hswish', att_dict=att)), (256, ShuffleNetV2Layer, defaults, dict(stride=1, k_size=3, act_fun='hswish', att_dict=att)), (256, ShuffleNetV2Layer, defaults, dict(stride=1, k_size=7, act_fun='hswish', att_dict=att)), (256, ShuffleNetV2Layer, defaults, dict(stride=1, k_size=5, act_fun='hswish', att_dict=att)), (256, ShuffleNetV2Layer, defaults, dict(stride=1, k_size=5, act_fun='hswish', att_dict=att)), (256, ShuffleNetV2Layer, defaults, dict(stride=1, k_size=3, act_fun='hswish', att_dict=att)), (256, ShuffleNetV2Layer, defaults, dict(stride=1, k_size=7, act_fun='hswish', att_dict=att)), (256, ShuffleNetV2Layer, defaults, dict(stride=1, k_size=3, act_fun='hswish', att_dict=att)), (512, ShuffleNetV2Layer, defaults, dict(stride=2, k_size=7, act_fun='hswish', att_dict=att)), (512, ShuffleNetV2Layer, defaults, dict(stride=1, k_size=5, act_fun='hswish', att_dict=att)), (512, ShuffleNetV2XceptionLayer, defaults, dict(stride=1, k_size=3, act_fun='hswish', att_dict=att)), (512, ShuffleNetV2Layer, defaults, dict(stride=1, k_size=7, act_fun='hswish', att_dict=att)), (1280, ConvLayer, dict(), dict(k_size=1, bias=False, act_fun='hswish', act_inplace=True, order='w_bn_act', use_bn=True, bn_affine=True)), ]) return get_network(StackedCellsNetworkBody, stem, head, cell_partials, cell_order, s_in, s_out) if __name__ == '__main__': from uninas.utils.torch.misc import count_parameters from uninas.builder import Builder Builder() net = get_shufflenet_v2plus_medium().cuda() net.eval() print(net) print(count_parameters(net), count_parameters(net) - count_parameters(net.cells[:-1]))
def get_resnet34(s_in=Shape([3, 224, 224]), s_out=Shape([1000])) -> nn.Module: return _resnet(block=ResNetLayer, stages=(3, 4, 6, 3), expansion=1, s_in=s_in, s_out=s_out) def get_resnet50(s_in=Shape([3, 224, 224]), s_out=Shape([1000])) -> nn.Module: return _resnet(block=ResNetBottleneckLayer, stages=(3, 4, 6, 3), expansion=4, s_in=s_in, s_out=s_out) def get_resnet101(s_in=Shape([3, 224, 224]), s_out=Shape([1000])) -> nn.Module: return _resnet(block=ResNetBottleneckLayer, stages=(3, 4, 23, 3), expansion=4, s_in=s_in, s_out=s_out) def get_resnet152(s_in=Shape([3, 224, 224]), s_out=Shape([1000])) -> nn.Module: return _resnet(block=ResNetBottleneckLayer, stages=(3, 8, 36, 3), expansion=4, s_in=s_in, s_out=s_out) if __name__ == '__main__': from uninas.utils.torch.misc import count_parameters from uninas.builder import Builder Builder() net = get_resnet50().cuda() net.eval() print(net) print('params', count_parameters(net)) print('cell params', count_parameters(net.cells)) for j, cell in enumerate(net.cells): print(j, count_parameters(cell))
def get_num_parameters(self) -> int: return count_parameters(self)
def get_row(idx, name: str, obj: AbstractModule) -> tuple: s_in_str = obj.get_shape_in().str() s_inner = obj.get_cached('shape_inner') s_inner_str = '' if s_inner is None else s_inner.str() s_out_str = obj.get_shape_out().str() return str(idx), name, obj.__class__.__name__, s_in_str, s_inner_str, s_out_str, count_parameters(obj)
def _build(self, s_in: Shape, s_out: Shape) -> ShapeList: LoggerManager().get_logger().info('Building %s:' % self.__class__.__name__) rows = [('cell index', 'name', 'class', 'input shapes', '', 'output shapes', '#params')] def get_row(idx, name: str, obj: AbstractModule) -> tuple: s_in_str = obj.get_shape_in().str() s_inner = obj.get_cached('shape_inner') s_inner_str = '' if s_inner is None else s_inner.str() s_out_str = obj.get_shape_out().str() return str(idx), name, obj.__class__.__name__, s_in_str, s_inner_str, s_out_str, count_parameters(obj) s_out_data = s_out.copy() out_shapes = self.stem.build(s_in) final_out_shapes = [] rows.append(get_row('', '-', self.stem)) # cells and (aux) heads updated_cell_order = [] for i, cell_name in enumerate(self.cell_order): strategy_name, cell = self._get_cell(name=cell_name, cell_index=i) assert self.stem.num_outputs() == cell.num_inputs() == cell.num_outputs(), 'Cell does not fit the network!' updated_cell_order.append(cell.name) s_ins = out_shapes[-cell.num_inputs():] with StrategyManagerDefault(strategy_name): s_out = cell.build(s_ins.copy(), features_mul=self.features_mul, features_fixed=self.features_first_cell if i == 0 else -1) out_shapes.extend(s_out) rows.append(get_row(i, cell_name, cell)) self.cells.append(cell) # optional (aux) head after every cell head = self._head_positions.get(i, None) if head is not None: if head.weight > 0: final_out_shapes.append(head.build(s_out[-1], s_out_data)) rows.append(get_row('', '-', head)) else: LoggerManager().get_logger().info('not adding head after cell %d, weight <= 0' % i) del self._head_positions[i] else: assert i != len(self.cell_order) - 1, "Must have a head after the final cell" # remove heads that are impossible to add for i in self._head_positions.keys(): if i >= len(self.cells): LoggerManager().get_logger().warning('Can not add a head after cell %d which does not exist, deleting the head!' % i) head = self._head_positions.get(i) for j, head2 in enumerate(self.heads): if head is head2: self.heads.__delitem__(j) break s_out = ShapeList(final_out_shapes) rows.append(('complete network', '', '', self.get_shape_in().str(), '', s_out.str(), count_parameters(self))) log_in_columns(LoggerManager().get_logger(), rows, start_space=4) self.set(cell_order=updated_cell_order) return s_out