def build_arm_ffm_head(self): if self.training: if 2 in self.lasts: self.heads32 = Head(self.num_filters(32, self._stem_head_width[1]), self._num_classes, True, norm_layer=BatchNorm2d) if 1 in self.lasts: self.heads16 = Head(self.num_filters(16, self._stem_head_width[1])+self.ch_16, self._num_classes, True, norm_layer=BatchNorm2d) else: self.heads16 = Head(self.ch_16, self._num_classes, True, norm_layer=BatchNorm2d) else: self.heads16 = Head(self.num_filters(16, self._stem_head_width[1]), self._num_classes, True, norm_layer=BatchNorm2d) self.heads8 = Head(self.num_filters(8, self._stem_head_width[1]) * self._branch, self._num_classes, Fch=self._Fch, scale=4, branch=self._branch, is_aux=False, norm_layer=BatchNorm2d) if 2 in self.lasts: self.arms32 = nn.ModuleList([ ConvNorm(self.num_filters(32, self._stem_head_width[1]), self.num_filters(16, self._stem_head_width[1]), 1, 1, 0, slimmable=False), ConvNorm(self.num_filters(16, self._stem_head_width[1]), self.num_filters(8, self._stem_head_width[1]), 1, 1, 0, slimmable=False), ]) self.refines32 = nn.ModuleList([ ConvNorm(self.num_filters(16, self._stem_head_width[1])+self.ch_16, self.num_filters(16, self._stem_head_width[1]), 3, 1, 1, slimmable=False), ConvNorm(self.num_filters(8, self._stem_head_width[1])+self.ch_8_2, self.num_filters(8, self._stem_head_width[1]), 3, 1, 1, slimmable=False), ]) if 1 in self.lasts: self.arms16 = ConvNorm(self.num_filters(16, self._stem_head_width[1]), self.num_filters(8, self._stem_head_width[1]), 1, 1, 0, slimmable=False) self.refines16 = ConvNorm(self.num_filters(8, self._stem_head_width[1])+self.ch_8_1, self.num_filters(8, self._stem_head_width[1]), 3, 1, 1, slimmable=False) self.ffm = FeatureFusion(self.num_filters(8, self._stem_head_width[1]) * self._branch, self.num_filters(8, self._stem_head_width[1]) * self._branch, reduction=1, Fch=self._Fch, scale=8, branch=self._branch, norm_layer=BatchNorm2d)
class Network_Multi_Path_Infer(nn.Module): def __init__(self, alphas, betas, ratios, num_classes=19, layers=9, criterion=nn.CrossEntropyLoss(ignore_index=-1), Fch=12, width_mult_list=[1.,], stem_head_width=(1., 1.), ignore_skip=False): super(Network_Multi_Path_Infer, self).__init__() self._num_classes = num_classes assert layers >= 2 self._layers = layers self._criterion = criterion self._Fch = Fch if ratios[0].size(1) == 1: if ignore_skip: self._width_mult_list = [1.,] else: self._width_mult_list = [4./12,] else: self._width_mult_list = width_mult_list self._stem_head_width = stem_head_width self.latency = 0 self.stem = nn.Sequential( ConvNorm(3, self.num_filters(2, stem_head_width[0])*2, kernel_size=3, stride=2, padding=1, bias=False, groups=1, slimmable=False), BasicResidual2x(self.num_filters(2, stem_head_width[0])*2, self.num_filters(4, stem_head_width[0])*2, kernel_size=3, stride=2, groups=1, slimmable=False), BasicResidual2x(self.num_filters(4, stem_head_width[0])*2, self.num_filters(8, stem_head_width[0]), kernel_size=3, stride=2, groups=1, slimmable=False) ) self.ops0, self.path0, self.downs0, self.widths0 = network_metas(alphas, betas, ratios, self._width_mult_list, layers, 0, ignore_skip=ignore_skip) self.ops1, self.path1, self.downs1, self.widths1 = network_metas(alphas, betas, ratios, self._width_mult_list, layers, 1, ignore_skip=ignore_skip) self.ops2, self.path2, self.downs2, self.widths2 = network_metas(alphas, betas, ratios, self._width_mult_list, layers, 2, ignore_skip=ignore_skip) def num_filters(self, scale, width=1.0): return int(np.round(scale * self._Fch * width)) def build_structure(self, lasts): self._branch = len(lasts) self.lasts = lasts self.ops = [ getattr(self, "ops%d"%last) for last in lasts ] self.paths = [ getattr(self, "path%d"%last) for last in lasts ] self.downs = [ getattr(self, "downs%d"%last) for last in lasts ] self.widths = [ getattr(self, "widths%d"%last) for last in lasts ] self.branch_groups, self.cells = self.get_branch_groups_cells(self.ops, self.paths, self.downs, self.widths, self.lasts) self.build_arm_ffm_head() def build_arm_ffm_head(self): if self.training: if 2 in self.lasts: self.heads32 = Head(self.num_filters(32, self._stem_head_width[1]), self._num_classes, True, norm_layer=BatchNorm2d) if 1 in self.lasts: self.heads16 = Head(self.num_filters(16, self._stem_head_width[1])+self.ch_16, self._num_classes, True, norm_layer=BatchNorm2d) else: self.heads16 = Head(self.ch_16, self._num_classes, True, norm_layer=BatchNorm2d) else: self.heads16 = Head(self.num_filters(16, self._stem_head_width[1]), self._num_classes, True, norm_layer=BatchNorm2d) self.heads8 = Head(self.num_filters(8, self._stem_head_width[1]) * self._branch, self._num_classes, Fch=self._Fch, scale=4, branch=self._branch, is_aux=False, norm_layer=BatchNorm2d) if 2 in self.lasts: self.arms32 = nn.ModuleList([ ConvNorm(self.num_filters(32, self._stem_head_width[1]), self.num_filters(16, self._stem_head_width[1]), 1, 1, 0, slimmable=False), ConvNorm(self.num_filters(16, self._stem_head_width[1]), self.num_filters(8, self._stem_head_width[1]), 1, 1, 0, slimmable=False), ]) self.refines32 = nn.ModuleList([ ConvNorm(self.num_filters(16, self._stem_head_width[1])+self.ch_16, self.num_filters(16, self._stem_head_width[1]), 3, 1, 1, slimmable=False), ConvNorm(self.num_filters(8, self._stem_head_width[1])+self.ch_8_2, self.num_filters(8, self._stem_head_width[1]), 3, 1, 1, slimmable=False), ]) if 1 in self.lasts: self.arms16 = ConvNorm(self.num_filters(16, self._stem_head_width[1]), self.num_filters(8, self._stem_head_width[1]), 1, 1, 0, slimmable=False) self.refines16 = ConvNorm(self.num_filters(8, self._stem_head_width[1])+self.ch_8_1, self.num_filters(8, self._stem_head_width[1]), 3, 1, 1, slimmable=False) self.ffm = FeatureFusion(self.num_filters(8, self._stem_head_width[1]) * self._branch, self.num_filters(8, self._stem_head_width[1]) * self._branch, reduction=1, Fch=self._Fch, scale=8, branch=self._branch, norm_layer=BatchNorm2d) def get_branch_groups_cells(self, ops, paths, downs, widths, lasts): num_branch = len(ops) layers = max([len(path) for path in paths]) groups_all = [] self.ch_16 = 0; self.ch_8_2 = 0; self.ch_8_1 = 0 cells = nn.ModuleDict() # layer-branch: op branch_connections = np.ones((num_branch, num_branch)) # maintain connections of heads of branches of different scales # all but the last layer # we determine branch-merging by comparing their next layer: if next-layer differs, then the "down" of current layer must differ for l in range(layers): connections = np.ones((num_branch, num_branch)) # if branch i/j share same scale & op in this layer for i in range(num_branch): for j in range(i+1, num_branch): # we also add constraint on ops[i][l] != ops[j][l] since some skip-connect may already be shrinked/compacted => layers of branches may no longer aligned in terms of alphas # last layer won't merge if len(paths[i]) <= l+1 or len(paths[j]) <= l+1 or paths[i][l+1] != paths[j][l+1] or ops[i][l] != ops[j][l] or widths[i][l] != widths[j][l]: connections[i, j] = connections[j, i] = 0 branch_connections *= connections branch_groups = [] # build branch_group for processing for branch in range(num_branch): # also accept if this is the last layer of branch (len(paths[branch]) == l+1) if len(paths[branch]) < l+1: continue inserted = False for group in branch_groups: if branch_connections[group[0], branch] == 1: group.append(branch) inserted = True continue if not inserted: branch_groups.append([branch]) for group in branch_groups: # branch in the same group must share the same op/scale/down/width if len(group) >= 2: assert ops[group[0]][l] == ops[group[1]][l] and paths[group[0]][l+1] == paths[group[1]][l+1] and downs[group[0]][l] == downs[group[1]][l] and widths[group[0]][l] == widths[group[1]][l] if len(group) == 3: assert ops[group[1]][l] == ops[group[2]][l] and paths[group[1]][l+1] == paths[group[2]][l+1] and downs[group[1]][l] == downs[group[2]][l] and widths[group[1]][l] == widths[group[2]][l] op = ops[group[0]][l] scale = 2**(paths[group[0]][l]+3) down = downs[group[0]][l] if l < len(paths[group[0]]) - 1: assert down == paths[group[0]][l+1] - paths[group[0]][l] assert down in [0, 1] if l == 0: cell = Cell(op, self.num_filters(scale, self._stem_head_width[0]), self.num_filters(scale*(down+1), widths[group[0]][l]), down) elif l == len(paths[group[0]]) - 1: # last cell for this branch assert down == 0 cell = Cell(op, self.num_filters(scale, widths[group[0]][l-1]), self.num_filters(scale, self._stem_head_width[1]), down) else: cell = Cell(op, self.num_filters(scale, widths[group[0]][l-1]), self.num_filters(scale*(down+1), widths[group[0]][l]), down) # For Feature Fusion: keep record of dynamic #channel of last 1/16 and 1/8 of "1/32 branch"; last 1/8 of "1/16 branch" if 2 in self.lasts and self.lasts.index(2) in group and down and scale == 16: self.ch_16 = cell._C_in if 2 in self.lasts and self.lasts.index(2) in group and down and scale == 8: self.ch_8_2 = cell._C_in if 1 in self.lasts and self.lasts.index(1) in group and down and scale == 8: self.ch_8_1 = cell._C_in for branch in group: cells[str(l)+"-"+str(branch)] = cell groups_all.append(branch_groups) return groups_all, cells def agg_ffm(self, outputs8, outputs16, outputs32): pred32 = []; pred16 = []; pred8 = [] # order of predictions is not important for branch in range(self._branch): last = self.lasts[branch] if last == 2: if self.training: pred32.append(outputs32[branch]) out = self.arms32[0](outputs32[branch]) out = F.interpolate(out, size=(outputs16[branch].size(2), outputs16[branch].size(3)), mode='bilinear', align_corners=True) # out = F.interpolate(out, size=(int(out.size(2))*2, int(out.size(3))*2), mode='bilinear', align_corners=True) out = self.refines32[0](torch.cat([out, outputs16[branch]], dim=1)) if self.training: pred16.append(outputs16[branch]) out = self.arms32[1](out) out = F.interpolate(out, size=(outputs8[branch].size(2), outputs8[branch].size(3)), mode='bilinear', align_corners=True) # out = F.interpolate(out, size=(int(out.size(2))*2, int(out.size(3))*2), mode='bilinear', align_corners=True) out = self.refines32[1](torch.cat([out, outputs8[branch]], dim=1)) pred8.append(out) elif last == 1: if self.training: pred16.append(outputs16[branch]) out = self.arms16(outputs16[branch]) out = F.interpolate(out, size=(outputs8[branch].size(2), outputs8[branch].size(3)), mode='bilinear', align_corners=True) # out = F.interpolate(out, size=(int(out.size(2))*2, int(out.size(3))*2), mode='bilinear', align_corners=True) out = self.refines16(torch.cat([out, outputs8[branch]], dim=1)) pred8.append(out) elif last == 0: pred8.append(outputs8[branch]) if len(pred32) > 0: pred32 = self.heads32(torch.cat(pred32, dim=1)) else: pred32 = None if len(pred16) > 0: pred16 = self.heads16(torch.cat(pred16, dim=1)) else: pred16 = None pred8 = self.heads8(self.ffm(torch.cat(pred8, dim=1))) if self.training: return pred8, pred16, pred32 else: return pred8 def forward(self, input): _, _, H, W = input.size() stem = self.stem(input) # store the last feature map w. corresponding scale of each branch outputs8 = [stem] * self._branch outputs16 = [stem] * self._branch outputs32 = [stem] * self._branch outputs = [stem] * self._branch for layer in range(len(self.branch_groups)): for group in self.branch_groups[layer]: output = self.cells[str(layer)+"-"+str(group[0])](outputs[group[0]]) scale = int(H // output.size(2)) for branch in group: outputs[branch] = output if scale == 8: outputs8[branch] = output elif scale == 16: outputs16[branch] = output elif scale == 32: outputs32[branch] = output if self.training: pred8, pred16, pred32 = self.agg_ffm(outputs8, outputs16, outputs32) pred8 = F.interpolate(pred8, scale_factor=8, mode='bilinear', align_corners=True) if pred16 is not None: pred16 = F.interpolate(pred16, scale_factor=16, mode='bilinear', align_corners=True) if pred32 is not None: pred32 = F.interpolate(pred32, scale_factor=32, mode='bilinear', align_corners=True) return pred8, pred16, pred32 else: pred8 = self.agg_ffm(outputs8, outputs16, outputs32) out = F.interpolate(pred8, size=(int(pred8.size(2))*8, int(pred8.size(3))*8), mode='bilinear', align_corners=True) return out def forward_latency(self, size): _, H, W = size latency_total = 0 latency, size = self.stem[0].forward_latency(size); latency_total += latency latency, size = self.stem[1].forward_latency(size); latency_total += latency latency, size = self.stem[2].forward_latency(size); latency_total += latency # store the last feature map w. corresponding scale of each branch outputs8 = [size] * self._branch outputs16 = [size] * self._branch outputs32 = [size] * self._branch outputs = [size] * self._branch for layer in range(len(self.branch_groups)): for group in self.branch_groups[layer]: latency, size = self.cells[str(layer)+"-"+str(group[0])].forward_latency(outputs[group[0]]) latency_total += latency scale = int(H // size[1]) for branch in group: outputs[branch] = size if scale == 4: outputs4[branch] = size elif scale == 16: outputs16[branch] = size elif scale == 32: outputs32[branch] = size for branch in range(self._branch): last = self.lasts[branch] if last == 2: latency, size = self.arms32[0].forward_latency(outputs32[branch]); latency_total += latency latency, size = self.refines32[0].forward_latency((size[0]+self.ch_16, size[1]*2, size[2]*2)); latency_total += latency latency, size = self.arms32[1].forward_latency(size); latency_total += latency latency, size = self.refines32[1].forward_latency((size[0]+self.ch_8_2, size[1]*2, size[2]*2)); latency_total += latency out_size = size elif last == 1: latency, size = self.arms16.forward_latency(outputs16[branch]); latency_total += latency latency, size = self.refines16.forward_latency((size[0]+self.ch_8_1, size[1]*2, size[2]*2)); latency_total += latency out_size = size elif last == 0: out_size = outputs8[branch] latency, size = self.ffm.forward_latency((out_size[0]*self._branch, out_size[1], out_size[2])); latency_total += latency latency, size = self.heads8.forward_latency(size); latency_total += latency return latency_total, size
(H // 8, W // 8, 8 * Fch + 8 * Fch_max * w_in, 8 * Fch, 3, 1)] = latency for branch in range(1, 4): lookup_table["ff_H%d_W%d_C%d" % (H // 8, W // 8, 8 * Fch * branch)] = FeatureFusion._latency( H // 8, W // 8, 8 * Fch * branch, 8 * Fch * branch) np.save(file_name, lookup_table) print("head......") Fch_range = [8, 12] for Fch in Fch_range: for branch in range(1, 4): print("Fch", Fch, "branch", branch) lookup_table["head_H%d_W%d_Cin%d_Cout%d" % (H // 8, W // 8, 8 * Fch * branch, 19)] = Head._latency( H // 8, W // 8, 8 * Fch * branch, 19) np.save(file_name, lookup_table) def find_latency(name, info, lookup_table, H=1024, W=2048): if name == "stem": latency = lookup_table["stem_H%d_W%d_F%d" % (H, W, info["F"])] elif name == "head": latency = lookup_table["head_H%d_W%d_F%d_branch%d_19" % (H, W, info["F"], info["branch"])] elif name == "refines": latency = lookup_table["refines%d_H%d_W%d_F%d" % (info["scale"], H, W, info["F"])] elif name == "arms": latency = lookup_table["arms%d_H%d_W%d_F%d" % (info["scale"], H, W, info["F"])]
def __init__(self, num_classes=19, layers=16, criterion=nn.CrossEntropyLoss(ignore_index=-1), Fch=12, width_mult_list=[1.,], prun_modes=['arch_ratio',], stem_head_width=[(1., 1.),]): super(Network_Multi_Path, self).__init__() self._num_classes = num_classes assert layers >= 3 self._layers = layers self._criterion = criterion self._Fch = Fch self._width_mult_list = width_mult_list self._prun_modes = prun_modes self.prun_mode = None # prun_mode is higher priority than _prun_modes self._stem_head_width = stem_head_width self._flops = 0 self._params = 0 self.stem = nn.ModuleList([ nn.Sequential( ConvNorm(3, self.num_filters(2, stem_ratio)*2, kernel_size=3, stride=2, padding=1, bias=False, groups=1, slimmable=False), BasicResidual2x(self.num_filters(2, stem_ratio)*2, self.num_filters(4, stem_ratio)*2, kernel_size=3, stride=2, groups=1, slimmable=False), BasicResidual2x(self.num_filters(4, stem_ratio)*2, self.num_filters(8, stem_ratio), kernel_size=3, stride=2, groups=1, slimmable=False) ) for stem_ratio, _ in self._stem_head_width ]) self.cells = nn.ModuleList() for l in range(layers): cells = nn.ModuleList() if l == 0: # first node has only one input (prev cell's output) cells.append(Cell(self.num_filters(8), width_mult_list=width_mult_list)) elif l == 1: cells.append(Cell(self.num_filters(8), width_mult_list=width_mult_list)) cells.append(Cell(self.num_filters(16), width_mult_list=width_mult_list)) elif l < layers - 1: cells.append(Cell(self.num_filters(8), width_mult_list=width_mult_list)) cells.append(Cell(self.num_filters(16), width_mult_list=width_mult_list)) cells.append(Cell(self.num_filters(32), down=False, width_mult_list=width_mult_list)) else: cells.append(Cell(self.num_filters(8), down=False, width_mult_list=width_mult_list)) cells.append(Cell(self.num_filters(16), down=False, width_mult_list=width_mult_list)) cells.append(Cell(self.num_filters(32), down=False, width_mult_list=width_mult_list)) self.cells.append(cells) self.refine32 = nn.ModuleList([ nn.ModuleList([ ConvNorm(self.num_filters(32, head_ratio), self.num_filters(16, head_ratio), kernel_size=1, bias=False, groups=1, slimmable=False), ConvNorm(self.num_filters(32, head_ratio), self.num_filters(16, head_ratio), kernel_size=3, padding=1, bias=False, groups=1, slimmable=False), ConvNorm(self.num_filters(16, head_ratio), self.num_filters(8, head_ratio), kernel_size=1, bias=False, groups=1, slimmable=False), ConvNorm(self.num_filters(16, head_ratio), self.num_filters(8, head_ratio), kernel_size=3, padding=1, bias=False, groups=1, slimmable=False)]) for _, head_ratio in self._stem_head_width ]) self.refine16 = nn.ModuleList([ nn.ModuleList([ ConvNorm(self.num_filters(16, head_ratio), self.num_filters(8, head_ratio), kernel_size=1, bias=False, groups=1, slimmable=False), ConvNorm(self.num_filters(16, head_ratio), self.num_filters(8, head_ratio), kernel_size=3, padding=1, bias=False, groups=1, slimmable=False)]) for _, head_ratio in self._stem_head_width ]) self.head0 = nn.ModuleList([ Head(self.num_filters(8, head_ratio), num_classes, False) for _, head_ratio in self._stem_head_width ]) self.head1 = nn.ModuleList([ Head(self.num_filters(8, head_ratio), num_classes, False) for _, head_ratio in self._stem_head_width ]) self.head2 = nn.ModuleList([ Head(self.num_filters(8, head_ratio), num_classes, False) for _, head_ratio in self._stem_head_width ]) self.head02 = nn.ModuleList([ Head(self.num_filters(8, head_ratio)*2, num_classes, False) for _, head_ratio in self._stem_head_width ]) self.head12 = nn.ModuleList([ Head(self.num_filters(8, head_ratio)*2, num_classes, False) for _, head_ratio in self._stem_head_width ]) # contains arch_param names: {"alphas": alphas, "betas": betas, "ratios": ratios} self._arch_names = [] self._arch_parameters = [] for i in range(len(self._prun_modes)): arch_name, arch_param = self._build_arch_parameters(i) self._arch_names.append(arch_name) self._arch_parameters.append(arch_param) self._reset_arch_parameters(i) # switch set of arch if we have more than 1 arch self.arch_idx = 0
def build_arm_ffm_head(self): # 24, 40, 96, 320 if self.training: self.heads32 = Head(self.f_channels[-1], self._num_classes, True, norm_layer=BatchNorm2d) self.heads16 = Head(self.f_channels[-2], self._num_classes, True, norm_layer=BatchNorm2d) self.heads8 = Decoder(self.num_filters(8, self._stem_head_width[1]), self.f_channels[0], self._num_classes, Fch=self._Fch, scale=4, branch=1, is_aux=False, norm_layer=BatchNorm2d) self.arms32 = nn.ModuleList([ ConvNorm(self.f_channels[-1], self.num_filters(16, self._stem_head_width[1]), 1, 1, 0, slimmable=False), ConvNorm(self.num_filters(16, self._stem_head_width[1]), self.num_filters(8, self._stem_head_width[1]), 1, 1, 0, slimmable=False), ]) self.refines32 = nn.ModuleList([ ConvNorm(self.num_filters(16, self._stem_head_width[1]) + self.f_channels[-2], self.num_filters(16, self._stem_head_width[1]), 3, 1, 1, slimmable=False), ConvNorm(self.num_filters(8, self._stem_head_width[1]) + self.f_channels[-3], self.num_filters(8, self._stem_head_width[1]), 3, 1, 1, slimmable=False), ]) self.ffm = FeatureFusion(self.num_filters(8, self._stem_head_width[1]), self.num_filters(8, self._stem_head_width[1]), reduction=1, Fch=self._Fch, scale=8, branch=1, norm_layer=BatchNorm2d)