Exemplo n.º 1
0
 def __init__(self, node_id, num_prev_nodes, channels,
              num_downsample_connect):
     super().__init__()
     self.ops = nn.ModuleList()
     choice_keys = []
     for i in range(num_prev_nodes):
         stride = 2 if i < num_downsample_connect else 1
         choice_keys.append("{}_p{}".format(node_id, i))
         self.ops.append(
             mutables.LayerChoice([
                 ops.PoolBN('max', channels, 3, stride, 1, affine=False),
                 ops.PoolBN('avg', channels, 3, stride, 1, affine=False),
                 nn.Identity() if stride == 1 else ops.FactorizedReduce(
                     channels, channels, affine=False),
                 ops.SepConv(channels, channels, 3, stride, 1,
                             affine=False),
                 ops.SepConv(channels, channels, 5, stride, 2,
                             affine=False),
                 ops.DilConv(
                     channels, channels, 3, stride, 2, 2, affine=False),
                 ops.DilConv(
                     channels, channels, 5, stride, 4, 2, affine=False)
             ],
                                  key=choice_keys[-1]))
     self.drop_path = ops.DropPath()
     self.input_switch = mutables.InputChoice(
         choose_from=choice_keys,
         n_chosen=2,
         key="{}_switch".format(node_id))
Exemplo n.º 2
0
 def __init__(self, node_id, num_prev_nodes, channels,
              num_downsample_connect):
     super().__init__()
     self.ops = nn.ModuleList()
     choice_keys = []
     for i in range(num_prev_nodes):
         stride = 2 if i < num_downsample_connect else 1
         choice_keys.append("{}_p{}".format(node_id, i))
         self.ops.append(
             LayerChoice(OrderedDict([
                 ("maxpool",
                  ops.PoolBN('max', channels, 3, stride, 1, affine=False)),
                 ("avgpool",
                  ops.PoolBN('avg', channels, 3, stride, 1, affine=False)),
                 ("skipconnect", nn.Identity() if stride == 1 else
                  ops.FactorizedReduce(channels, channels, affine=False)),
                 ("sepconv3x3",
                  ops.SepConv(channels,
                              channels,
                              3,
                              stride,
                              1,
                              affine=False)),
                 ("sepconv5x5",
                  ops.SepConv(channels,
                              channels,
                              5,
                              stride,
                              2,
                              affine=False)),
                 ("dilconv3x3",
                  ops.DilConv(channels,
                              channels,
                              3,
                              stride,
                              2,
                              2,
                              affine=False)),
                 ("dilconv5x5",
                  ops.DilConv(channels,
                              channels,
                              5,
                              stride,
                              4,
                              2,
                              affine=False))
             ]),
                         label=choice_keys[-1]))
     self.drop_path = ops.DropPath()
     self.input_switch = InputChoice(n_candidates=len(choice_keys),
                                     n_chosen=2,
                                     label="{}_switch".format(node_id))
Exemplo n.º 3
0
    def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
        '''
        Node("{}_n{}".format("reduce" if reduction else "normal", depth),
             depth, channels, 2 if reduction else 0)
        num_prev_nodes: 之前的节点个数
        '''

        super().__init__()
        self.ops = nn.ModuleList()
        choice_keys = []  # 记录 节点+边 组合的名称

        for i in range(num_prev_nodes):  # 枚举之前的节点
            stride = 2 if i < num_downsample_connect else 1
            # 统一设置stride
            # 如果是reduction cell, stride=2,
            # 如果是normal cell, stride=1
            choice_keys.append("{}_p{}".format(node_id, i))

            self.ops.append(
                mutables.LayerChoice(OrderedDict([
                    ("maxpool", ops.PoolBN('max', channels, 3, stride, 1, affine=False)),
                    ("avgpool", ops.PoolBN('avg', channels, 3, stride, 1, affine=False)),
                    ("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(
                        channels, channels, affine=False)),
                    ("sepconv3x3", ops.SepConv(channels,
                                               channels, 3, stride, 1, affine=False)),
                    ("sepconv5x5", ops.SepConv(channels,
                                               channels, 5, stride, 2, affine=False)),
                    ("dilconv3x3", ops.DilConv(channels,
                                               channels, 3, stride, 2, 2, affine=False)),
                    ("dilconv5x5", ops.DilConv(channels,
                                               channels, 5, stride, 4, 2, affine=False))
                ]), key=choice_keys[-1]))

        self.drop_path = ops.DropPath()  # 以0.2的概率drop path

        self.input_switch = mutables.InputChoice(  # 控制连接方式, 维护choice_key就是为了这个使用
            choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))