Beispiel #1
0
def get_mobilenet_v2(s_in=Shape([3, 224, 224]), s_out=Shape([1000])) -> nn.Module:
    stem = get_stem_instance(MobileNetV2Stem, features=32, features1=16, act_fun='relu6', act_fun1='relu6')
    head = get_head_instance(FeatureMixClassificationHead, features=1280, act_fun='relu6')

    defaults = dict(k_size=3, stride=1, padding='same', expansion=6, dilation=1, bn_affine=True,
                    act_fun='relu6', act_inplace=True, att_dict=None)
    cell_partials, cell_order = get_passthrough_partials([
        (24, MobileInvertedConvLayer, defaults, dict(stride=2)),
        (24, MobileInvertedConvLayer, defaults, dict(stride=1)),

        (32, MobileInvertedConvLayer, defaults, dict(stride=2)),
        (32, MobileInvertedConvLayer, defaults, dict(stride=1)),
        (32, MobileInvertedConvLayer, defaults, dict(stride=1)),

        (64, MobileInvertedConvLayer, defaults, dict(stride=2)),
        (64, MobileInvertedConvLayer, defaults, dict(stride=1)),
        (64, MobileInvertedConvLayer, defaults, dict(stride=1)),
        (64, MobileInvertedConvLayer, defaults, dict(stride=1)),

        (96, MobileInvertedConvLayer, defaults, dict(stride=1)),
        (96, MobileInvertedConvLayer, defaults, dict(stride=1)),
        (96, MobileInvertedConvLayer, defaults, dict(stride=1)),

        (160, MobileInvertedConvLayer, defaults, dict(stride=2)),
        (160, MobileInvertedConvLayer, defaults, dict(stride=1)),
        (160, MobileInvertedConvLayer, defaults, dict(stride=1)),

        (320, MobileInvertedConvLayer, defaults, dict(stride=1)),
    ])

    return get_network(StackedCellsNetworkBody, stem, head, cell_partials, cell_order, s_in, s_out)
Beispiel #2
0
def get_mobilenet_v3_small100(s_in=Shape([3, 224, 224]), s_out=Shape([1000])) -> nn.Module:
    stem = get_stem_instance(MobileNetV2Stem, features=16, features1=16, act_fun='hswish', act_fun1='relu',
                             stride1=2, se_cmul1=0.5)
    head = get_head_instance(FeatureMixClassificationHead, features=1024, act_fun='hswish', gap_first=True, bias=True)

    defaults = dict(padding='same', dilation=1, bn_affine=True, act_inplace=True)
    se = dict(att_cls='SqueezeExcitationChannelModule', use_c_substitute=False,
              c_mul=0.25, squeeze_act='relu', excite_act='sigmoid', divisible=8,
              squeeze_bias=True, excite_bias=True, squeeze_bn=False)

    cell_partials, cell_order = get_passthrough_partials([
        (24, MobileInvertedConvLayer, defaults, dict(stride=2, k_size=3, expansion=4.5, act_fun='relu')),
        (24, MobileInvertedConvLayer, defaults, dict(stride=1, k_size=3, expansion=3.5, act_fun='relu')),

        (40, MobileInvertedConvLayer, defaults, dict(stride=2, k_size=5, expansion=4, act_fun='hswish', att_dict=se)),
        (40, MobileInvertedConvLayer, defaults, dict(stride=1, k_size=5, expansion=6, act_fun='hswish', att_dict=se)),
        (40, MobileInvertedConvLayer, defaults, dict(stride=1, k_size=5, expansion=6, act_fun='hswish', att_dict=se)),

        (48, MobileInvertedConvLayer, defaults, dict(stride=1, k_size=5, expansion=3, act_fun='hswish', att_dict=se)),
        (48, MobileInvertedConvLayer, defaults, dict(stride=1, k_size=5, expansion=3, act_fun='hswish', att_dict=se)),

        (96, MobileInvertedConvLayer, defaults, dict(stride=2, k_size=5, expansion=6, act_fun='hswish', att_dict=se)),
        (96, MobileInvertedConvLayer, defaults, dict(stride=1, k_size=5, expansion=6, act_fun='hswish', att_dict=se)),
        (96, MobileInvertedConvLayer, defaults, dict(stride=1, k_size=5, expansion=6, act_fun='hswish', att_dict=se)),

        (576, ConvLayer, dict(), dict(k_size=1, bias=False, act_fun='hswish', act_inplace=True, order='w_bn_act',
                                      use_bn=True, bn_affine=True)),
    ])

    return get_network(StackedCellsNetworkBody, stem, head, cell_partials, cell_order, s_in, s_out)
Beispiel #3
0
 def _build(self, s_in: Shape, s_out: Shape) -> Shape:
     before, after, squeeze = [], [], [
         nn.AdaptiveAvgPool2d(1), SqueezeModule()
     ]
     if self.gap_first:
         after = [
             nn.Linear(s_in.num_features(), self.features,
                       bias=True),  # no affine bn -> use bias
             Register.act_funs.get(self.act_fun)(inplace=True)
         ]
         self.cached['shape_inner'] = Shape([self.features])
     else:
         before = [
             nn.Conv2d(s_in.num_features(),
                       self.features,
                       1,
                       1,
                       0,
                       bias=False),
             nn.BatchNorm2d(self.features, affine=True),
             Register.act_funs.get(self.act_fun)(inplace=True)
         ]
         self.cached['shape_inner'] = Shape(
             [self.features, s_in.shape[1], s_in.shape[2]])
     ops = before + squeeze + after + [
         nn.Dropout(p=self.dropout),
         nn.Linear(self.features, s_out.num_features(), bias=self.bias)
     ]
     self.head_module = nn.Sequential(*ops)
     return self.probe_outputs(s_in)
Beispiel #4
0
    def from_args(cls, args: Namespace, index: int = None) -> AbstractDataSet:
        # set class attributes
        data_shape, target_shape = cls._parsed_arguments(['data_shape', 'target_shape'], args, index=index)
        cls.data_raw_shape = Shape(split(data_shape, int))
        cls.label_shape = Shape(split(target_shape, int))

        # default generation now
        return super().from_args(args, index)
Beispiel #5
0
def example_export_network(path: str) -> AbstractUninasNetwork:
    """ create a new network and export it, does not require to have onnx installed """
    network = get_network("FairNasC",
                          Shape([3, 224, 224]),
                          Shape([1000]),
                          weights_path=None)
    network = network.cuda()
    network.export_onnx(path, export_params=True)
    return network
Beispiel #6
0
 def probe_outputs(self,
                   s_in: ShapeOrList,
                   module: nn.Module = None,
                   multiple_outputs=False) -> ShapeOrList:
     """ returning the output shape of one forward pass using zero tensors """
     with torch.no_grad():
         if module is None:
             module = self
         x = s_in.random_tensor(batch_size=2)
         s = module(x)
         if multiple_outputs:
             return ShapeList([Shape(list(sx.shape)[1:]) for sx in s])
         return Shape(list(s.shape)[1:])
Beispiel #7
0
class SubImagenet100Data(Imagenet1000Data):
    """
    Subset of the ImageNet data set with fewer classes, and fewer images per class
    http://image-net.org/
    https://github.com/microsoft/Cream/blob/main/tools/generate_subImageNet.py
    """

    length = (25000, 0, 5000)  # training, valid, test
    data_raw_shape = Shape([
        3, 300, 300
    ])  # channel height width, the shapes of the raw images actually vary
    label_shape = Shape([100])
    data_mean = (0.485, 0.456, 0.406)  # not recomputed for the subset
    data_std = (0.229, 0.224, 0.225)  # not recomputed for the subset

    can_download = False
Beispiel #8
0
def get_network(
    net_cls: Type[AbstractNetworkBody],
    stem: AbstractModule,
    head: AbstractModule,
    cell_partials: dict,
    cell_order: [str],
    s_in=Shape([3, 224, 224]),
    s_out=Shape([1000])
) -> nn.Module:
    net_kwargs = net_cls.parsed_argument_defaults()
    net_kwargs.update(
        dict(cell_configs={},
             cell_partials=cell_partials,
             cell_order=cell_order))
    network = StackedCellsNetworkBody(stem=stem,
                                      heads=nn.ModuleList([head]),
                                      **net_kwargs)
    network.build(s_in=s_in, s_out=s_out)
    return network
Beispiel #9
0
def get_shufflenet_v2plus_medium(s_in=Shape([3, 224, 224]), s_out=Shape([1000])) -> nn.Module:
    stem = get_stem_instance(ConvStem, k_size=3, features=16, act_fun='hswish', stride=2, use_bn=True, bn_affine=True,
                            order='w_bn_act')
    head = get_head_instance(SeFeatureMixClassificationHead, se_cmul=0.25, se_act_fun='relu', se_squeeze_bias=True,
                            se_bn=True, se_excite_bias=False,
                            features=1280, act_fun='hswish', bias0=False, dropout=0.0, bias1=False)

    defaults = dict(padding='same', dilation=1, bn_affine=True, act_inplace=False, expansion=1)
    att = dict(att_cls='SqueezeExcitationChannelModule', use_c_substitute=False,
               c_mul=0.25, squeeze_act='relu', excite_act='relu6', divisible=8,
               squeeze_bias=False, excite_bias=False, squeeze_bn=True, squeeze_bn_affine=True)

    cell_partials, cell_order = get_passthrough_partials([
        (48, ShuffleNetV2Layer,          defaults, dict(stride=2, k_size=3, act_fun='relu')),
        (48, ShuffleNetV2Layer,          defaults, dict(stride=1, k_size=3, act_fun='relu')),
        (48, ShuffleNetV2XceptionLayer,  defaults, dict(stride=1, k_size=3, act_fun='relu')),
        (48, ShuffleNetV2Layer,          defaults, dict(stride=1, k_size=5, act_fun='relu')),

        (128, ShuffleNetV2Layer,         defaults, dict(stride=2, k_size=5, act_fun='hswish')),
        (128, ShuffleNetV2Layer,         defaults, dict(stride=1, k_size=5, act_fun='hswish')),
        (128, ShuffleNetV2Layer,         defaults, dict(stride=1, k_size=3, act_fun='hswish')),
        (128, ShuffleNetV2Layer,         defaults, dict(stride=1, k_size=3, act_fun='hswish')),

        (256, ShuffleNetV2Layer,         defaults, dict(stride=2, k_size=7, act_fun='hswish', att_dict=att)),
        (256, ShuffleNetV2Layer,         defaults, dict(stride=1, k_size=3, act_fun='hswish', att_dict=att)),
        (256, ShuffleNetV2Layer,         defaults, dict(stride=1, k_size=7, act_fun='hswish', att_dict=att)),
        (256, ShuffleNetV2Layer,         defaults, dict(stride=1, k_size=5, act_fun='hswish', att_dict=att)),
        (256, ShuffleNetV2Layer,         defaults, dict(stride=1, k_size=5, act_fun='hswish', att_dict=att)),
        (256, ShuffleNetV2Layer,         defaults, dict(stride=1, k_size=3, act_fun='hswish', att_dict=att)),
        (256, ShuffleNetV2Layer,         defaults, dict(stride=1, k_size=7, act_fun='hswish', att_dict=att)),
        (256, ShuffleNetV2Layer,         defaults, dict(stride=1, k_size=3, act_fun='hswish', att_dict=att)),

        (512, ShuffleNetV2Layer,         defaults, dict(stride=2, k_size=7, act_fun='hswish', att_dict=att)),
        (512, ShuffleNetV2Layer,         defaults, dict(stride=1, k_size=5, act_fun='hswish', att_dict=att)),
        (512, ShuffleNetV2XceptionLayer, defaults, dict(stride=1, k_size=3, act_fun='hswish', att_dict=att)),
        (512, ShuffleNetV2Layer,         defaults, dict(stride=1, k_size=7, act_fun='hswish', att_dict=att)),

        (1280, ConvLayer, dict(), dict(k_size=1, bias=False, act_fun='hswish', act_inplace=True, order='w_bn_act',
                                       use_bn=True, bn_affine=True)),
    ])

    return get_network(StackedCellsNetworkBody, stem, head, cell_partials, cell_order, s_in, s_out)
Beispiel #10
0
class AbstractCNNClassificationDataSet(AbstractDataSet):
    length = (0, 0, 0)  # training, valid, test
    data_raw_shape = Shape([-1, -1, -1])  # channel height width
    label_shape = Shape([-1])
    data_mean = (-1, -1, -1)
    data_std = (-1, -1, -1)

    def _get_train_data(self, used_transforms: transforms.Compose) -> torch.utils.data.Dataset:
        raise NotImplementedError

    def _get_test_data(self, used_transforms: transforms.Compose) -> torch.utils.data.Dataset:
        raise NotImplementedError

    def _get_fake_train_data(self, used_transforms: transforms.Compose) -> torch.utils.data.Dataset:
        return FakeData(self.length[0], self.data_raw_shape.shape, self.num_classes(), used_transforms)

    def _get_fake_valid_data(self, used_transforms: transforms.Compose) -> torch.utils.data.Dataset:
        return FakeData(self.length[1], self.data_raw_shape.shape, self.num_classes(), used_transforms)

    def _get_fake_test_data(self, used_transforms: transforms.Compose) -> torch.utils.data.Dataset:
        return FakeData(self.length[2], self.data_raw_shape.shape, self.num_classes(), used_transforms)
Beispiel #11
0
def _resnet(block: Type[AbstractResNetLayer], stages=(2, 2, 2, 2), inner_channels=(64, 128, 256, 512), expansion=1,
            s_in=Shape([3, 224, 224]), s_out=Shape([1000])) -> nn.Module:
    stem = get_stem_instance(ConvStem, features=inner_channels[0], stride=2, k_size=7, act_fun='relu')
    head = get_head_instance(ClassificationHead, bias=True, dropout=0.0)
    layers = [(inner_channels[0], PoolingLayer,
               dict(pool_type='max', k_size=3, padding='same', order='w', dropout_rate=0), dict(stride=2))]

    channels = [int(c*expansion) for c in inner_channels]
    defaults = dict(k_size=3, stride=1, padding='same', dilation=1, bn_affine=True, act_fun='relu', act_inplace=True,
                    expansion=1/expansion, has_first_act=False)
    for s, (num, cx) in enumerate(zip(stages, channels)):
        for i in range(num):
            if s > 0 and i == 0:
                layers.append((cx, block, defaults, dict(stride=2, shortcut_type='conv1x1')))
            elif i == 0 and expansion > 1:
                layers.append((cx, block, defaults, dict(stride=1, shortcut_type='conv1x1')))
            else:
                layers.append((cx, block, defaults, dict(stride=1, shortcut_type='id')))

    cell_partials, cell_order = get_passthrough_partials(layers)
    return get_network(StackedCellsNetworkBody, stem, head, cell_partials, cell_order, s_in, s_out)
Beispiel #12
0
class FashionMnistData(AbstractCNNClassificationDataSet):
    """
    """

    length = (60000, 0, 10000)  # training, valid, test
    data_raw_shape = Shape([1, 28, 28])  # channel height width
    label_shape = Shape([10])
    data_mean = (0.2860, )
    data_std = (0.3530, )

    def _get_train_data(self, used_transforms: transforms.Compose):
        return datasets.FashionMNIST(root=self.dir,
                                     train=True,
                                     download=self.download,
                                     transform=used_transforms)

    def _get_test_data(self, used_transforms: transforms.Compose):
        return datasets.FashionMNIST(root=self.dir,
                                     train=False,
                                     download=self.download,
                                     transform=used_transforms)
Beispiel #13
0
class Cinic10Data(AbstractCNNClassificationDataSet):
    """
    CINIC-10: CINIC-10 Is Not Imagenet or CIFAR-10
    https://github.com/BayesWatch/cinic-10
    """

    length = (90000, 90000, 90000)  # training, valid, test
    data_raw_shape = Shape([3, 32, 32])
    label_shape = Shape([10])
    data_mean = (0.47889522, 0.47227842, 0.43047404)
    data_std = (0.24205776, 0.23828046, 0.25874835)

    def _get_train_data(self, used_transforms: transforms.Compose):
        return datasets.ImageFolder(os.path.join(self.dir, 'train'),
                                    transform=used_transforms)

    def _get_valid_data(self, used_transforms: transforms.Compose):
        return datasets.ImageFolder(os.path.join(self.dir, 'valid'),
                                    transform=used_transforms)

    def _get_test_data(self, used_transforms: transforms.Compose):
        return datasets.ImageFolder(os.path.join(self.dir, 'test'),
                                    transform=used_transforms)
Beispiel #14
0
class Cifar10Data(AbstractCNNClassificationDataSet):
    """
    The popular CIFAR-10 data set
    https://www.cs.toronto.edu/~kriz/cifar.html
    """

    length = (50000, 0, 10000)  # training, valid, test
    data_raw_shape = Shape([3, 32, 32])  # channel height width
    label_shape = Shape([10])
    data_mean = (0.49139968, 0.48215827, 0.44653124)
    data_std = (0.24703233, 0.24348505, 0.26158768)

    def _get_train_data(self, used_transforms: transforms.Compose):
        return datasets.CIFAR10(root=self.dir,
                                train=True,
                                download=self.download,
                                transform=used_transforms)

    def _get_test_data(self, used_transforms: transforms.Compose):
        return datasets.CIFAR10(root=self.dir,
                                train=False,
                                download=self.download,
                                transform=used_transforms)
Beispiel #15
0
class Imagenet1000Data(AbstractCNNClassificationDataSet):
    """
    The ImageNet data set
    http://image-net.org/
    """

    length = (1281167, 0, 50000)  # training, valid, test
    data_raw_shape = Shape([
        3, 300, 300
    ])  # channel height width, the shapes of the raw images actually vary
    label_shape = Shape([1000])
    data_mean = (0.485, 0.456, 0.406)
    data_std = (0.229, 0.224, 0.225)

    can_download = False

    def _get_train_data(self, used_transforms: transforms.Compose):
        return datasets.ImageFolder(os.path.join(self.dir, 'train'),
                                    transform=used_transforms)

    def _get_test_data(self, used_transforms: transforms.Compose):
        return datasets.ImageFolder(os.path.join(self.dir, 'val'),
                                    transform=used_transforms)
Beispiel #16
0
class SumToyData(AbstractDataSet):
    """
    Toy data set of data points: input vector of length 10 and its sum as target
    """

    length = (10000, 0, 5000)  # training, valid, test
    data_raw_shape = Shape([10])
    label_shape = Shape([1])
    data_mean = (0.0,)
    data_std = (1.0,)

    def _before_loading(self):
        """ called before loading training/validation/test data """
        # change the data shape of this class
        self.data_raw_shape.shape[0] = self.additional_args.get('vector_size')

    @classmethod
    def args_to_add(cls, index=None) -> [Argument]:
        """ list arguments to add to argparse when this class (or a child class) is chosen """
        return super().args_to_add(index) + [
            Argument('vector_size', default=10, type=int, help='size of the vectors'),
        ]

    def _get_train_data(self, used_transforms: transforms.Compose):
        return SumToyDataset(used_transforms, rows=self.length[0], columns=self.data_raw_shape.num_features())

    def _get_test_data(self, used_transforms: transforms.Compose):
        return SumToyDataset(used_transforms, rows=self.length[2], columns=self.data_raw_shape.num_features())

    def _get_fake_train_data(self, used_transforms: transforms.Compose):
        raise NotImplementedError

    def _get_fake_valid_data(self, used_transforms: transforms.Compose):
        raise NotImplementedError

    def _get_fake_test_data(self, used_transforms: transforms.Compose):
        raise NotImplementedError
Beispiel #17
0
 def test_rebuild(self):
     """
     getting finalized configs from which we can build modules
     """
     builder = Builder()
     StrategyManager().delete_strategy('default')
     StrategyManager().add_strategy(RandomChoiceStrategy(max_epochs=1))
     n, c, h, w = 2, 8, 16, 16
     x = torch.empty(size=[n, c, h, w])
     shape = Shape([c, h, w])
     layers = [
         FusedMobileInvertedConvLayer(name='mmicl',
                                      k_sizes=(3, 5, 7),
                                      expansions=(3, 6)),
         SuperConvThresholdLayer(k_sizes=(3, 5, 7)),
         SuperSepConvThresholdLayer(k_sizes=(3, 5, 7)),
         SuperMobileInvertedConvThresholdLayer(k_sizes=(3, 5, 7),
                                               expansions=(3, 6),
                                               sse_dict=dict(c_muls=(0.0,
                                                                     0.25,
                                                                     0.5))),
         LinearTransformerLayer(),
         SuperConvLayer(k_sizes=(3, 5, 7), name='scl1'),
         SuperSepConvLayer(k_sizes=(3, 5, 7), name='scl2'),
         SuperMobileInvertedConvLayer(k_sizes=(3, 5, 7),
                                      name='scl3',
                                      expansions=(2, 3, 4, 6)),
     ]
     for layer in layers:
         assert layer.build(shape, c) == shape
     StrategyManager().build()
     StrategyManager().forward()
     for layer in layers:
         print('\n' * 2)
         print(layer.__class__.__name__)
         for i in range(3):
             StrategyManager().randomize_weights()
             StrategyManager().forward()
             for finalize in [False, True]:
                 cfg = layer.config(finalize=finalize)
                 print('\t', i, 'finalize', finalize)
                 print('\t\tconfig dct:', cfg)
                 cfg_layer = builder.from_config(cfg)
                 assert cfg_layer.build(shape, c) == shape
                 cfg_layer.forward(x)
                 print('\t\tmodule str:', cfg_layer.str()[1:])
                 del cfg, cfg_layer
Beispiel #18
0
 def _build(self, s_in: Shape, s_out: Shape) -> Shape:
     ops = [nn.AdaptiveAvgPool2d(1)]
     if self.se_cmul > 0:
         ops.append(
             SqueezeExcitationChannelModule(
                 s_in.num_features(),
                 c_mul=self.se_cmul,
                 squeeze_act=self.se_act_fun,
                 squeeze_bias=self.se_squeeze_bias and not self.se_bn,
                 excite_bias=self.se_excite_bias,
                 squeeze_bn=self.se_bn,
                 squeeze_bn_affine=self.se_squeeze_bias))
     ops.extend([
         SqueezeModule(),
         nn.Linear(s_in.num_features(), self.features, bias=self.bias0),
         Register.act_funs.get(self.act_fun)(inplace=True),
         nn.Dropout(p=self.dropout),
         nn.Linear(self.features, s_out.num_features(), bias=self.bias1)
     ])
     self.head_module = nn.Sequential(*ops)
     self.cached['shape_inner'] = Shape([self.features])
     return self.probe_outputs(s_in)
Beispiel #19
0
    def test_output_shapes(self):
        """
        expected output shapes of standard layers
        """
        Builder()
        StrategyManager().delete_strategy('default')
        StrategyManager().add_strategy(RandomChoiceStrategy(max_epochs=1))

        bs, c1, c2, hw1, hw2 = 4, 4, 8, 32, 16
        s_in = Shape([c1, hw1, hw1])
        x = torch.empty(size=[bs] + s_in.shape)

        case_s1_c1 = (c1, 1, Shape([c1, hw1, hw1]))
        case_s1_c2 = (c2, 1, Shape([c2, hw1, hw1]))
        case_s2_c1 = (c1, 2, Shape([c1, hw2, hw2]))
        case_s2_c2 = (c2, 2, Shape([c2, hw2, hw2]))

        for cls, cases, kwargs in [
            (SkipLayer, [case_s1_c1, case_s1_c2], dict()),
            (ZeroLayer, [case_s1_c1, case_s1_c2, case_s2_c1,
                         case_s2_c2], dict()),
            (FactorizedReductionLayer, [case_s2_c1, case_s2_c2], dict()),
            (PoolingLayer, [case_s1_c1, case_s1_c2, case_s2_c1,
                            case_s2_c2], dict(k_size=3)),
            (ConvLayer, [case_s1_c1, case_s1_c2, case_s2_c1,
                         case_s2_c2], dict(k_size=3)),
            (SepConvLayer, [case_s1_c1, case_s1_c2, case_s2_c1,
                            case_s2_c2], dict(k_size=3)),
            (MobileInvertedConvLayer,
             [case_s1_c1, case_s1_c2, case_s2_c1, case_s2_c2], dict(k_size=3)),
            (MobileInvertedConvLayer,
             [case_s1_c1, case_s1_c2, case_s2_c1,
              case_s2_c2], dict(k_size=(3, ))),
            (MobileInvertedConvLayer,
             [case_s1_c1, case_s1_c2, case_s2_c1, case_s2_c2],
             dict(k_size=(3, 5, 7), k_size_in=(1, 1), k_size_out=(1, 1))),
            (FusedMobileInvertedConvLayer,
             [case_s1_c1, case_s1_c2, case_s2_c1, case_s2_c2],
             dict(name='mmicl1',
                  k_sizes=(3, 5, 7),
                  k_size_in=(1, 1),
                  k_size_out=(1, 1))),
            (FusedMobileInvertedConvLayer,
             [case_s1_c1, case_s1_c2, case_s2_c1, case_s2_c2],
             dict(name='mmicl2',
                  k_sizes=((3, 5), (3, 5, 7)),
                  k_size_in=(1, 1),
                  k_size_out=(1, 1))),
            (ShuffleNetV2Layer, [case_s1_c1, case_s1_c2,
                                 case_s2_c2], dict(k_size=3)),
            (ShuffleNetV2XceptionLayer, [case_s1_c1, case_s1_c2,
                                         case_s2_c2], dict(k_size=3)),
            (LinearTransformerLayer, [case_s1_c1, case_s1_c2], dict()),
            (SuperConvThresholdLayer,
             [case_s1_c1, case_s1_c2, case_s2_c1,
              case_s2_c2], dict(k_sizes=(3, 5, 7))),
            (SuperSepConvThresholdLayer,
             [case_s1_c1, case_s1_c2, case_s2_c1,
              case_s2_c2], dict(k_sizes=(3, 5, 7))),
            (SuperMobileInvertedConvThresholdLayer,
             [case_s1_c1, case_s1_c2, case_s2_c1, case_s2_c2],
             dict(k_sizes=(3, 5, 7),
                  expansions=(3, 6),
                  sse_dict=dict(c_muls=(0.0, 0.25, 0.5)))),
            (SuperConvLayer, [case_s1_c1, case_s1_c2, case_s2_c1,
                              case_s2_c2], dict(k_sizes=(3, 5, 7),
                                                name='scl')),
            (SuperSepConvLayer,
             [case_s1_c1, case_s1_c2, case_s2_c1,
              case_s2_c2], dict(k_sizes=(3, 5, 7), name='sscl')),
            (SuperMobileInvertedConvLayer,
             [case_s1_c1, case_s1_c2, case_s2_c1, case_s2_c2],
             dict(k_sizes=(3, 5, 7), name='smicl', expansions=(3, 6))),
            (AttentionLayer, [case_s1_c1],
             dict(att_dict=dict(att_cls='EfficientChannelAttentionModule'))),
            (AttentionLayer, [case_s1_c1],
             dict(att_dict=dict(att_cls='SqueezeExcitationChannelModule'))),
        ]:
            for c, stride, shape_out in cases:
                m1 = cls(stride=stride, **kwargs)
                s_out = m1.build(s_in, c)
                assert s_out == shape_out, 'Expected output shape does not match, %s, build=%s / expected=%s' %\
                                           (cls.__name__, s_out, shape_out)
                assert_output_shape(m1, x, [bs] + shape_out.shape)
                print('%s(stride=%d, c_in=%d, c_out=%d)' %
                      (cls.__name__, stride, c1, c))
Beispiel #20
0
def get_resnet34(s_in=Shape([3, 224, 224]), s_out=Shape([1000])) -> nn.Module:
    return _resnet(block=ResNetLayer, stages=(3, 4, 6, 3), expansion=1, s_in=s_in, s_out=s_out)
Beispiel #21
0
    def __init__(self, data_dir: str, save_dir: Union[str, None],
                 bs_train: int, bs_test: int,
                 train_transforms: transforms.Compose, test_transforms: transforms.Compose,
                 train_batch_aug: Union[BatchAugmentations, None], test_batch_aug: Union[BatchAugmentations, None],
                 num_workers: int, num_prefetch: int,
                 valid_split: Union[int, float], valid_shuffle: bool,
                 fake: bool, download: bool,
                 **additional_args):
        """

        :param data_dir: where to find (or download) the data set
        :param save_dir: global save dir, can store and reuse the info which data was used in the random valid split
        :param bs_train: batch size for the train loader
        :param bs_test: batch size for the test loader, <= 0 to have the same as bs_train
        :param train_transforms: train augmentations (on each data point individually)
        :param test_transforms: test augmentations (on each data point individually)
        :param train_batch_aug: train augmentations (across the entire batch)
        :param test_batch_aug: test augmentations (across the entire batch)
        :param num_workers: number of workers prefetching data
        :param num_prefetch: number of batches prefetched by every worker
        :param valid_split: absolute number of data points if int or >1, otherwise a fraction of the training set
        :param valid_shuffle: whether to shuffle validation data
        :param fake: use fake data instead (no need to provide either real data or enabling downloading)
        :param download: whether downloading is allowed
        :param additional_args: arguments that are added and used by child classes
        """
        super().__init__()
        logger = LoggerManager().get_logger()
        self.dir = data_dir
        self.bs_train = bs_train
        self.bs_test = bs_test if bs_test > 0 else self.bs_train
        self.num_workers, self.num_prefetch = num_workers, num_prefetch
        self.valid_shuffle = valid_shuffle
        self.additional_args = additional_args

        self.fake = fake
        self.download = download and not self.fake
        if self.download and (not self.can_download):
            LoggerManager().get_logger().warning("The dataset can not be downloaded, but may be asked to.")

        self.train_transforms = train_transforms
        self.test_transforms = test_transforms
        self.train_batch_augmentations = train_batch_aug
        self.test_batch_augmentations = test_batch_aug

        # load/create meta info dict
        if isinstance(save_dir, str) and len(save_dir) > 0:
            meta_path = '%s/data.meta.pt' % replace_standard_paths(save_dir)
            if os.path.isfile(meta_path):
                meta = torch.load(meta_path)
            else:
                meta = defaultdict(dict)
        else:
            meta, meta_path = defaultdict(dict), None

        # give subclasses a good spot to react to additional arguments
        self._before_loading()

        # data
        if self.fake:
            train_data = self._get_fake_train_data(self.train_transforms)
            self.test_data = self._get_fake_test_data(self.test_transforms)
        else:
            train_data = self._get_train_data(self.train_transforms)
            self.test_data = self._get_test_data(self.test_transforms)

        # split train into train+valid or using stand-alone valid set
        if valid_split > 0:
            s1 = int(valid_split) if valid_split >= 1 else int(len(train_data)*valid_split)
            if s1 >= len(train_data):
                logger.warning("Tried to set valid split larger than the training set size, setting to 0.5")
                s1 = len(train_data)//2
            s0 = len(train_data) - s1
            if meta['splits'].get((s0, s1), None) is None:
                meta['splits'][(s0, s1)] = torch.randperm(s0+s1).tolist()
            indices = meta['splits'][(s0, s1)]
            self.valid_data = torch.utils.data.Subset(train_data, np.array(indices[s0:]).astype(np.int))
            train_data = torch.utils.data.Subset(train_data, np.array(indices[0:s0]).astype(np.int))
            logger.info('Data Set: splitting training set, will use %s data points as validation set' % s1)
            if self.length[1] > 0:
                logger.info('Data Set: a dedicated validation set exists, but it will be replaced.')
        elif self.length[1] > 0:
            if self.fake:
                self.valid_data = self._get_fake_valid_data(self.test_transforms)
            else:
                self.valid_data = self._get_valid_data(self.test_transforms)
            logger.info('Data Set: using the dedicated validation set with test augmentations')
        else:
            self.valid_data = None
            logger.info('Data Set: not using a validation set at all.')
        self.train_data = train_data

        # shapes
        data, label = self.train_data[0]
        self.data_shape = Shape(list(data.shape))

        # save meta info dict
        if meta_path is not None:
            torch.save(meta, meta_path)
Beispiel #22
0
class AbstractDataSet(ArgsInterface):
    length = (0, 0, 0)  # training, valid, test
    data_raw_shape = Shape([])
    label_shape = Shape([])
    data_mean = None
    data_std = None

    can_download = True

    def __init__(self, data_dir: str, save_dir: Union[str, None],
                 bs_train: int, bs_test: int,
                 train_transforms: transforms.Compose, test_transforms: transforms.Compose,
                 train_batch_aug: Union[BatchAugmentations, None], test_batch_aug: Union[BatchAugmentations, None],
                 num_workers: int, num_prefetch: int,
                 valid_split: Union[int, float], valid_shuffle: bool,
                 fake: bool, download: bool,
                 **additional_args):
        """

        :param data_dir: where to find (or download) the data set
        :param save_dir: global save dir, can store and reuse the info which data was used in the random valid split
        :param bs_train: batch size for the train loader
        :param bs_test: batch size for the test loader, <= 0 to have the same as bs_train
        :param train_transforms: train augmentations (on each data point individually)
        :param test_transforms: test augmentations (on each data point individually)
        :param train_batch_aug: train augmentations (across the entire batch)
        :param test_batch_aug: test augmentations (across the entire batch)
        :param num_workers: number of workers prefetching data
        :param num_prefetch: number of batches prefetched by every worker
        :param valid_split: absolute number of data points if int or >1, otherwise a fraction of the training set
        :param valid_shuffle: whether to shuffle validation data
        :param fake: use fake data instead (no need to provide either real data or enabling downloading)
        :param download: whether downloading is allowed
        :param additional_args: arguments that are added and used by child classes
        """
        super().__init__()
        logger = LoggerManager().get_logger()
        self.dir = data_dir
        self.bs_train = bs_train
        self.bs_test = bs_test if bs_test > 0 else self.bs_train
        self.num_workers, self.num_prefetch = num_workers, num_prefetch
        self.valid_shuffle = valid_shuffle
        self.additional_args = additional_args

        self.fake = fake
        self.download = download and not self.fake
        if self.download and (not self.can_download):
            LoggerManager().get_logger().warning("The dataset can not be downloaded, but may be asked to.")

        self.train_transforms = train_transforms
        self.test_transforms = test_transforms
        self.train_batch_augmentations = train_batch_aug
        self.test_batch_augmentations = test_batch_aug

        # load/create meta info dict
        if isinstance(save_dir, str) and len(save_dir) > 0:
            meta_path = '%s/data.meta.pt' % replace_standard_paths(save_dir)
            if os.path.isfile(meta_path):
                meta = torch.load(meta_path)
            else:
                meta = defaultdict(dict)
        else:
            meta, meta_path = defaultdict(dict), None

        # give subclasses a good spot to react to additional arguments
        self._before_loading()

        # data
        if self.fake:
            train_data = self._get_fake_train_data(self.train_transforms)
            self.test_data = self._get_fake_test_data(self.test_transforms)
        else:
            train_data = self._get_train_data(self.train_transforms)
            self.test_data = self._get_test_data(self.test_transforms)

        # split train into train+valid or using stand-alone valid set
        if valid_split > 0:
            s1 = int(valid_split) if valid_split >= 1 else int(len(train_data)*valid_split)
            if s1 >= len(train_data):
                logger.warning("Tried to set valid split larger than the training set size, setting to 0.5")
                s1 = len(train_data)//2
            s0 = len(train_data) - s1
            if meta['splits'].get((s0, s1), None) is None:
                meta['splits'][(s0, s1)] = torch.randperm(s0+s1).tolist()
            indices = meta['splits'][(s0, s1)]
            self.valid_data = torch.utils.data.Subset(train_data, np.array(indices[s0:]).astype(np.int))
            train_data = torch.utils.data.Subset(train_data, np.array(indices[0:s0]).astype(np.int))
            logger.info('Data Set: splitting training set, will use %s data points as validation set' % s1)
            if self.length[1] > 0:
                logger.info('Data Set: a dedicated validation set exists, but it will be replaced.')
        elif self.length[1] > 0:
            if self.fake:
                self.valid_data = self._get_fake_valid_data(self.test_transforms)
            else:
                self.valid_data = self._get_valid_data(self.test_transforms)
            logger.info('Data Set: using the dedicated validation set with test augmentations')
        else:
            self.valid_data = None
            logger.info('Data Set: not using a validation set at all.')
        self.train_data = train_data

        # shapes
        data, label = self.train_data[0]
        self.data_shape = Shape(list(data.shape))

        # save meta info dict
        if meta_path is not None:
            torch.save(meta, meta_path)

    def _before_loading(self):
        """ called before loading training/validation/test data """
        pass

    @classmethod
    def from_args(cls, args: Namespace, index: int = None) -> 'AbstractDataSet':
        # parsed arguments, and the global save dir
        all_args = cls._all_parsed_arguments(args, index=index)

        data_dir = replace_standard_paths(all_args.pop('dir'))
        fake = all_args.pop('fake')
        download = all_args.pop('download') and not fake

        try:
            _, save_dir = find_in_args(args, '.save_dir')
            save_dir = replace_standard_paths(save_dir)
        except ValueError:
            save_dir = ""

        # augmentations per data point and batch, for training and test
        tr_d, tr_b, te_d, te_b = [], [], [], []
        for i, aug_set in enumerate(cls._parsed_meta_arguments(Register.augmentation_sets, 'cls_augmentations', args, index=index)):
            tr_d_, tr_b_ = aug_set.get_train_transforms(args, i, cls)
            te_d_, te_b_ = aug_set.get_test_transforms(args, i, cls)
            tr_d.extend(tr_d_)
            tr_b.extend(tr_b_)
            te_d.extend(te_d_)
            te_b.extend(te_b_)
        if cls.is_on_images():
            final_transforms = [transforms.ToTensor(), transforms.Normalize(cls.data_mean, cls.data_std)]
        else:
            final_transforms = []
        train_transforms = transforms.Compose(tr_d + final_transforms)
        test_transforms = transforms.Compose(te_d + final_transforms)
        train_batch_aug = BatchAugmentations(tr_b) if len(tr_b) > 0 else None
        test_batch_aug = BatchAugmentations(te_b) if len(te_b) > 0 else None

        return cls(data_dir=data_dir, save_dir=save_dir,
                   bs_train=all_args.pop('batch_size_train'), bs_test=all_args.pop('batch_size_test'),
                   train_transforms=train_transforms, test_transforms=test_transforms,
                   train_batch_aug=train_batch_aug, test_batch_aug=test_batch_aug,
                   num_workers=all_args.pop('num_workers'), num_prefetch=all_args.pop('num_prefetch'),
                   valid_split=all_args.pop('valid_split'), valid_shuffle=all_args.pop('valid_shuffle'),
                   fake=fake, download=download, **all_args)

    def list_train_transforms(self) -> str:
        bfs = [] if self.train_batch_augmentations is None else self.train_batch_augmentations.batch_functions
        return ', '.join([cls.__class__.__name__ for cls in self.train_transforms.transforms + bfs])

    def list_test_transforms(self) -> str:
        bfs = [] if self.test_batch_augmentations is None else self.test_batch_augmentations.batch_functions
        return ', '.join([cls.__class__.__name__ for cls in self.test_transforms.transforms + bfs])

    def get_transforms(self, train=True, exclude_normalize=False) -> transforms.Compose:
        """
        get the transforms of training or test data,
        if 'exclude_normalize' is set, discard Normalize (good for visualization)
        """
        transform = self.train_transforms.transforms if train else self.test_transforms.transforms
        if exclude_normalize:
            return transforms.Compose([t for t in transform if not isinstance(t, transforms.Normalize)])
        return transform

    @classmethod
    def meta_args_to_add(cls) -> [MetaArgument]:
        """
        list meta arguments to add to argparse for when this class is chosen,
        classes specified in meta arguments may have their own respective arguments
        """
        kwargs = Register.get_my_kwargs(cls)
        aug_sets = Register.augmentation_sets.filter_match_all(on_images=kwargs.get('images'))
        return super().meta_args_to_add() + [
            MetaArgument('cls_augmentations', aug_sets, help_name='data augmentation'),
        ]

    @classmethod
    def args_to_add(cls, index=None) -> [Argument]:
        """ list arguments to add to argparse when this class (or a child class) is chosen """
        args = super().args_to_add(index) + [
            Argument('dir', default='{path_data}', type=str, help='data dir', is_path=True),
            Argument('download', default='False', type=str, help='allow downloading', is_bool=True),
            Argument('fake', default='False', type=str, help='use fake data', is_bool=True),
            Argument('batch_size_train', default=64, type=int, help='batch size for each train data loader (i.e. ddp may cause multiple instances)'),
            Argument('batch_size_test', default=-1, type=int, help='batch size for each eval/test data loader, same as train size if <0'),
            Argument('num_workers', default=4, type=int, help='number of workers for data loaders'),
            Argument('num_prefetch', default=2, type=int, help='number batches that each worker prefetches'),
            Argument('valid_split', default=0.0, type=float, help='num samples if >1, else % split, for the validation set'),
            Argument('valid_shuffle', default='False', type=str, help='shuffle the validation set', is_bool=True),
        ]
        return args

    def get_batch_size(self, train=True) -> int:
        return self.bs_train if train else self.bs_test

    def get_data_shape(self) -> Shape:
        return self.data_shape

    def get_label_shape(self) -> Shape:
        return self.__class__.label_shape

    def _str_dict(self) -> dict:
        dct = super()._str_dict()
        dct.update({
            'data shape': self.get_data_shape(),
            'label shape': self.get_label_shape(),
        })
        dct.update({'fake': self.fake} if self.fake else {})
        dct.update({} if self.train_data is None else {'training data': len(self.train_data)})
        dct.update({} if self.valid_data is None else {'valid data': len(self.valid_data)})
        dct.update({} if self.test_data is None else {'test data': len(self.test_data)})
        dct.update({} if self.test_data is None else {'train batch size': self.bs_train})
        dct.update({} if self.test_data is None else {'test batch size': self.bs_test})
        return dct

    def sample_random_data(self, batch_size=1) -> torch.Tensor:
        """ get random data with correct size """
        size = [batch_size] + list(self.data_shape.shape)
        return torch.randn(size=size, dtype=torch.float32)

    def train_loader(self, dist=False) -> InfIterator:
        return self._loader(self.train_data, is_train=True, shuffle=True, dist=dist, wrap=True)

    def valid_loader(self, dist=False) -> InfIterator:
        return self._loader(self.valid_data, is_train=False, shuffle=self.valid_shuffle, dist=dist, wrap=True)

    def mixed_train_valid_loader(self, dist=False) -> MultiLoader:
        """ for having training/valid set both for training, in bi-optimization settings """
        assert self.train_data is not None, 'Training data must not be None when using mixed loading'
        assert self.valid_data is not None, 'Valid data must not be None when using mixed loading'
        return MultiLoader([
            self._loader(self.train_data, is_train=True, shuffle=True, dist=dist, wrap=False),
            self._loader(self.valid_data, is_train=True, shuffle=self.valid_shuffle, dist=dist, wrap=False)
        ])

    def interleaved_train_valid_loader(self, multiples=(1, 1), dist=False) -> InterleavedLoader:
        """ for having training/valid set both for training, in bi-optimization settings """
        assert self.train_data is not None, 'Training data must not be None when using mixed loading'
        assert self.valid_data is not None, 'Valid data must not be None when using mixed loading'
        return InterleavedLoader([
            self._loader(self.train_data, is_train=True, shuffle=True, dist=dist, wrap=False),
            self._loader(self.valid_data, is_train=True, shuffle=self.valid_shuffle, dist=dist, wrap=False)
        ], multiples)

    def test_loader(self, dist=False) -> InfIterator:
        return self._loader(self.test_data, is_train=False, shuffle=False, dist=dist, wrap=True)

    def _loader(self, data, is_train=True, shuffle=True, dist=False, wrap=True):
        if data is None:
            return None
        bs = self.bs_train if is_train else self.bs_test
        sampler = DistributedSampler(data) if dist else None
        loader = DataLoader(data, batch_size=bs, shuffle=shuffle and not dist, num_workers=self.num_workers,
                            pin_memory=True, sampler=sampler, prefetch_factor=self.num_prefetch)
        if wrap:
            return InfIterator(loader)
        return loader

    def _get_train_data(self, used_transforms: transforms.Compose):
        raise NotImplementedError

    def _get_valid_data(self, used_transforms: transforms.Compose):
        return None

    def _get_test_data(self, used_transforms: transforms.Compose):
        raise NotImplementedError

    def _get_fake_train_data(self, used_transforms: transforms.Compose):
        raise NotImplementedError

    def _get_fake_valid_data(self, used_transforms: transforms.Compose):
        raise NotImplementedError

    def _get_fake_test_data(self, used_transforms: transforms.Compose):
        raise NotImplementedError

    @classmethod
    def is_on_images(cls) -> bool:
        kwargs = Register.get_my_kwargs(cls)
        return kwargs.get('images', False)

    @classmethod
    def is_classification(cls) -> bool:
        kwargs = Register.get_my_kwargs(cls)
        return kwargs.get('classification', False)

    @classmethod
    def num_classes(cls) -> int:
        kwargs = Register.get_my_kwargs(cls)
        assert kwargs.get('classification', False)
        return cls.label_shape.num_features()
Beispiel #23
0
def get_resnet152(s_in=Shape([3, 224, 224]), s_out=Shape([1000])) -> nn.Module:
    return _resnet(block=ResNetBottleneckLayer, stages=(3, 8, 36, 3), expansion=4, s_in=s_in, s_out=s_out)
Beispiel #24
0
def get_mixnet_s(s_in=Shape([3, 224, 224]), s_out=Shape([1000])) -> nn.Module:
    stem = get_stem_instance(MobileNetV2Stem,
                             features=16,
                             features1=16,
                             act_fun='relu',
                             act_fun1='relu')
    head = get_head_instance(FeatureMixClassificationHead,
                             features=1536,
                             act_fun='relu')

    defaults = dict(k_size=(3, ),
                    k_size_in=1,
                    k_size_out=1,
                    padding='same',
                    dilation=1,
                    bn_affine=True,
                    act_inplace=True,
                    att_dict=None)
    se25 = dict(att_cls='SqueezeExcitationChannelModule',
                use_c_substitute=True,
                c_mul=0.25,
                squeeze_act='swish',
                excite_act='sigmoid',
                squeeze_bias=True,
                excite_bias=True,
                squeeze_bn=False)
    se5 = dict(att_cls='SqueezeExcitationChannelModule',
               use_c_substitute=True,
               c_mul=0.5,
               squeeze_act='swish',
               excite_act='sigmoid',
               squeeze_bias=True,
               excite_bias=True,
               squeeze_bn=False)

    cell_partials, cell_order = get_passthrough_partials([
        (24, MobileInvertedConvLayer, defaults,
         dict(stride=2,
              expansion=6,
              act_fun='relu',
              k_size=(3, ),
              k_size_in=(1, 1),
              k_size_out=(1, 1))),
        (24, MobileInvertedConvLayer, defaults,
         dict(stride=1,
              expansion=3,
              act_fun='relu',
              k_size=(3, ),
              k_size_in=(1, 1),
              k_size_out=(1, 1))),
        (40, MobileInvertedConvLayer, defaults,
         dict(stride=2,
              expansion=6,
              act_fun='swish',
              k_size=(3, 5, 7),
              att_dict=se5)),
        (40, MobileInvertedConvLayer, defaults,
         dict(stride=1,
              expansion=6,
              act_fun='swish',
              k_size=(3, 5),
              k_size_in=(1, 1),
              k_size_out=(1, 1),
              att_dict=se5)),
        (40, MobileInvertedConvLayer, defaults,
         dict(stride=1,
              expansion=6,
              act_fun='swish',
              k_size=(3, 5),
              k_size_in=(1, 1),
              k_size_out=(1, 1),
              att_dict=se5)),
        (40, MobileInvertedConvLayer, defaults,
         dict(stride=1,
              expansion=6,
              act_fun='swish',
              k_size=(3, 5),
              k_size_in=(1, 1),
              k_size_out=(1, 1),
              att_dict=se5)),
        (80, MobileInvertedConvLayer, defaults,
         dict(stride=2,
              expansion=6,
              act_fun='swish',
              k_size=(3, 5, 7),
              k_size_out=(1, 1),
              att_dict=se25)),
        (80, MobileInvertedConvLayer, defaults,
         dict(stride=1,
              expansion=6,
              act_fun='swish',
              k_size=(3, 5),
              k_size_out=(1, 1),
              att_dict=se25)),
        (80, MobileInvertedConvLayer, defaults,
         dict(stride=1,
              expansion=6,
              act_fun='swish',
              k_size=(3, 5),
              k_size_out=(1, 1),
              att_dict=se25)),
        (120, MobileInvertedConvLayer, defaults,
         dict(stride=1,
              expansion=6,
              act_fun='swish',
              k_size=(3, 5, 7),
              k_size_in=(1, 1),
              k_size_out=(1, 1),
              att_dict=se5)),
        (120, MobileInvertedConvLayer, defaults,
         dict(stride=1,
              expansion=3,
              act_fun='swish',
              k_size=(3, 5, 7, 9),
              k_size_in=(1, 1),
              k_size_out=(1, 1),
              att_dict=se5)),
        (120, MobileInvertedConvLayer, defaults,
         dict(stride=1,
              expansion=3,
              act_fun='swish',
              k_size=(3, 5, 7, 9),
              k_size_in=(1, 1),
              k_size_out=(1, 1),
              att_dict=se5)),
        (200, MobileInvertedConvLayer, defaults,
         dict(stride=2,
              expansion=6,
              act_fun='swish',
              k_size=(3, 5, 7, 9, 11),
              att_dict=se5)),
        (200, MobileInvertedConvLayer, defaults,
         dict(stride=1,
              expansion=6,
              act_fun='swish',
              k_size=(3, 5, 7, 9),
              k_size_out=(1, 1),
              att_dict=se5)),
        (200, MobileInvertedConvLayer, defaults,
         dict(stride=1,
              expansion=6,
              act_fun='swish',
              k_size=(3, 5, 7, 9),
              k_size_out=(1, 1),
              att_dict=se5)),
    ])

    return get_network(StackedCellsNetworkBody, stem, head, cell_partials,
                       cell_order, s_in, s_out)
Beispiel #25
0
 def build_from_cache(self) -> ShapeList:
     """ build the network from cached input/output shapes """
     shape_in = Shape(self.shape_in_list)
     shape_out = Shape(self.shape_out_list)
     return self.build(shape_in, shape_out)