Beispiel #1
0
def get_resnext(num_layers,
                cardinality=32,
                bottleneck_width=4,
                use_se=False,
                pretrained=False,
                ctx=cpu(),
                root=os.path.join('~', '.mxnet', 'models'),
                **kwargs):
    r"""ResNext model from `"Aggregated Residual Transformations for Deep Neural Network"
    <http://arxiv.org/abs/1611.05431>`_ paper.
    Parameters
    ----------
    num_layers : int
        Numbers of layers. Options are 50, 101.
    cardinality: int
        Number of groups
    bottleneck_width: int
        Width of bottleneck block
    pretrained : bool or str
        Boolean value controls whether to load the default pretrained weights for model.
        String value represents the hashtag for a certain version of pretrained weights.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    root : str, default '~/.mxnet/models'
        Location for keeping the model parameters.
    """
    assert num_layers in resnext_spec, \
        "Invalid number of layers: %d. Options are %s"%(
            num_layers, str(resnext_spec.keys()))
    layers = resnext_spec[num_layers]
    net = ResNext(layers,
                  cardinality,
                  bottleneck_width,
                  use_se=use_se,
                  **kwargs)
    if pretrained:
        from mxnet.gluon.model_zoo.model_store import get_model_file
        if not use_se:
            net.load_params(
                get_model_file('resnext%d_%dx%dd' %
                               (num_layers, cardinality, bottleneck_width),
                               tag=pretrained,
                               root=root),
                ctx=ctx)
        else:
            net.load_params(
                get_model_file('se_resnext%d_%dx%dd' %
                               (num_layers, cardinality, bottleneck_width),
                               tag=pretrained,
                               root=root),
                ctx=ctx)

    return net
Beispiel #2
0
def get_resnet(version, num_layers, ctx=None, pretrained=False, remove_subsample=0, **kwargs):
    r"""ResNet V1 model from `"Deep Residual Learning for Image Recognition"
    <http://arxiv.org/abs/1512.03385>`_ paper.
    ResNet V2 model from `"Identity Mappings in Deep Residual Networks"
    <https://arxiv.org/abs/1603.05027>`_ paper.

    Parameters
    ----------
    version : int
        Version of ResNet. Options are 1, 2.
    num_layers : int
        Numbers of layers. Options are 18, 34, 50, 101, 152.
    pretrained : bool, default False
        Whether to load the pretrained weights for model.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    """
    block_type, layers, channels = resnet_spec[num_layers]
    resnet_class = resnet_net_versions[version-1]
    block_class = resnet_block_versions[version-1][block_type]
    net = resnet_class(block_class, layers, channels, remove_subsample, **kwargs)
    if pretrained:
        from mxnet.gluon.model_zoo.model_store import get_model_file
        net.load_params(get_model_file('resnet%d_v%d'%(num_layers, version)), ctx=ctx, allow_missing=False, ignore_extra=True)
    return net
def get_squeezenet(version,
                   pretrained=False,
                   ctx=cpu(),
                   root=os.path.join(base.data_dir(), 'models'),
                   **kwargs):
    r"""SqueezeNet model from the `"SqueezeNet: AlexNet-level accuracy with 50x fewer parameters
    and <0.5MB model size" <https://arxiv.org/abs/1602.07360>`_ paper.
    SqueezeNet 1.1 model from the `official SqueezeNet repo
    <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
    SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
    than SqueezeNet 1.0, without sacrificing accuracy.
    Parameters
    ----------
    version : str
        Version of squeezenet. Options are '1.0', '1.1'.
    pretrained : bool, default False
        Whether to load the pretrained weights for model.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    root : str, default $MXNET_HOME/models
        Location for keeping the model parameters.
    """
    net = SqueezeNet(version, **kwargs)
    if pretrained:
        from mxnet.gluon.model_zoo.model_store import get_model_file
        net.load_parameters(get_model_file('squeezenet%s' % version,
                                           root=root),
                            ctx=ctx)
    return net
def get_mobilenet_v2(multiplier, pretrained=False, ctx=cpu(),
                     root=os.path.join(base.data_dir(), 'models'), **kwargs):
    r"""MobileNetV2 model from the
    `"Inverted Residuals and Linear Bottlenecks:
      Mobile Networks for Classification, Detection and Segmentation"
    <https://arxiv.org/abs/1801.04381>`_ paper.
    Parameters
    ----------
    multiplier : float
        The width multiplier for controling the model size. Only multipliers that are no
        less than 0.25 are supported. The actual number of channels is equal to the original
        channel size multiplied by this multiplier.
    pretrained : bool, default False
        Whether to load the pretrained weights for model.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    root : str, default $MXNET_HOME/models
        Location for keeping the model parameters.
    """
    net = MobileNetV2(multiplier, **kwargs)

    if pretrained:
        from mxnet.gluon.model_zoo.model_store import get_model_file
        version_suffix = '{0:.2f}'.format(multiplier)
        if version_suffix in ('1.00', '0.50'):
            version_suffix = version_suffix[:-1]
        net.load_parameters(
            get_model_file('mobilenetv2_%s' % version_suffix, root=root), ctx=ctx)
    return net
def get_mobilenet(multiplier, pretrained=False, ctx=cpu(),
                  root=os.path.join(base.data_dir(), 'models'), **kwargs):
    r"""MobileNet model from the
    `"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications"
    <https://arxiv.org/abs/1704.04861>`_ paper.
    Parameters
    ----------
    multiplier : float
        The width multiplier for controling the model size. Only multipliers that are no
        less than 0.25 are supported. The actual number of channels is equal to the original
        channel size multiplied by this multiplier.
    pretrained : bool, default False
        Whether to load the pretrained weights for model.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    root : str, default $MXNET_HOME/models
        Location for keeping the model parameters.
    """
    net = MobileNet(multiplier, **kwargs)

    if pretrained:
        from ..model_store import get_model_file
        version_suffix = '{0:.2f}'.format(multiplier)
        if version_suffix in ('1.00', '0.50'):
            version_suffix = version_suffix[:-1]
        net.load_parameters(
            get_model_file('mobilenet%s' % version_suffix, root=root), ctx=ctx)
    return net
Beispiel #6
0
def get_resnet(version, num_layers, pretrained=False, ctx=cpu(),
               root=os.path.join('~', '.mxnet', 'models'), **kwargs):
    r"""ResNet V1 model from `"Deep Residual Learning for Image Recognition"
    <http://arxiv.org/abs/1512.03385>`_ paper.
    ResNet V2 model from `"Identity Mappings in Deep Residual Networks"
    <https://arxiv.org/abs/1603.05027>`_ paper.

    Parameters
    ----------
    version : int
        Version of ResNet. Options are 1, 2.
    num_layers : int
        Numbers of layers. Options are 18, 34, 50, 101, 152.
    pretrained : bool, default False
        Whether to load the pretrained weights for model.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    root : str, default '~/.mxnet/models'
        Location for keeping the model parameters.
    """
    assert num_layers in resnet_spec, \
        "Invalid number of layers: %d. Options are %s" % (
            num_layers, str(resnet_spec.keys()))
    block_type, layers, channels = resnet_spec[num_layers]
    assert version >= 1 and version <= 2, \
        "Invalid resnet version: %d. Options are 1 and 2." % version
    resnet_class = resnet_net_versions[version-1]
    block_class = resnet_block_versions[version-1][block_type]
    net = resnet_class(block_class, layers, channels, **kwargs)
    if pretrained:
        from mxnet.gluon.model_zoo.model_store import get_model_file
        net.load_params(get_model_file('resnet%d_v%d' % (num_layers, version),
                                       root=root),
                        ctx=ctx, allow_missing=True, ignore_extra=True)
    return net
Beispiel #7
0
def get_vgg(num_layers,
            pretrained=False,
            ctx=cpu(),
            root=os.path.join(base.data_dir(), 'models'),
            **kwargs):
    r"""VGG model from the `"Very Deep Convolutional Networks for Large-Scale Image Recognition"
    <https://arxiv.org/abs/1409.1556>`_ paper.
    Parameters
    ----------
    num_layers : int
        Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
    pretrained : bool, default False
        Whether to load the pretrained weights for model.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    root : str, default $MXNET_HOME/models
        Location for keeping the model parameters.
    """
    layers, filters = vgg_spec[num_layers]
    net = VGG(layers, filters, **kwargs)
    if pretrained:
        from mxnet.gluon.model_zoo.model_store import get_model_file
        batch_norm_suffix = '_bn' if kwargs.get('batch_norm') else ''
        net.load_parameters(get_model_file('vgg%d%s' %
                                           (num_layers, batch_norm_suffix),
                                           root=root),
                            ctx=ctx)
    return net
Beispiel #8
0
def _get_elmo_model(model_cls, model_name, dataset_name, pretrained, ctx, root, **kwargs):
    vocab = ELMoCharVocab()
    if 'char_vocab_size' not in kwargs:
        kwargs['char_vocab_size'] = len(vocab)
    net = model_cls(**kwargs)
    if pretrained:
        model_file = get_model_file('_'.join([model_name, dataset_name]), root=root)
        net.load_parameters(model_file, ctx=ctx)
    return net, vocab
Beispiel #9
0
def _load_pretrained_params(net,
                            model_name,
                            dataset_name,
                            root,
                            ctx,
                            ignore_extra=False):
    path = '_'.join([model_name, dataset_name])
    model_file = model_store.get_model_file(path, root=root)
    net.load_parameters(model_file, ctx=ctx, ignore_extra=ignore_extra)
Beispiel #10
0
def _get_elmo_model(model_cls, model_name, dataset_name, pretrained, ctx, root, **kwargs):
    vocab = ELMoCharVocab()
    if 'char_vocab_size' not in kwargs:
        kwargs['char_vocab_size'] = len(vocab)
    net = model_cls(**kwargs)
    if pretrained:
        model_file = get_model_file('_'.join([model_name, dataset_name]), root=root)
        net.load_parameters(model_file, ctx=ctx)
    return net, vocab
Beispiel #11
0
def get_vgg(num_layers, pretrained=False, ctx=mx.cpu(),
            root=os.path.join('~', '.mxnet', 'models'), **kwargs):
    layers, filters = vgg_spec[num_layers]
    net = VGG(layers, filters, **kwargs)
    if pretrained:
        from mxnet.gluon.model_zoo.model_store import get_model_file
        batch_norm_suffix = '_bn' if kwargs.get('batch_norm') else ''
        net.load_params(get_model_file('vgg%d%s'%(num_layers, batch_norm_suffix),
                                       root=root), ctx=ctx)
    return net
Beispiel #12
0
def _load_pretrained_params(net, model_name, dataset_name, root, ctx):
    model_file = get_model_file('_'.join([model_name, dataset_name]),
                                root=root)
    temp_params = mx.nd.load(model_file)
    new_temp_params = {
        'transformer_' + k if 'src_embed' not in k else k: v
        for k, v in temp_params.items()
    }
    mx.nd.save('_'.join(['temporal', model_name, dataset_name]),
               new_temp_params)
    net.load_params('_'.join(['temporal', model_name, dataset_name]), ctx=ctx)
Beispiel #13
0
def get_vgg(num_layers,
            pretrained=False,
            ctx=mx.cpu(),
            root=os.path.join('~', '.mxnet', 'models'),
            **kwargs):
    layers, filters = vgg_spec[num_layers]
    net = VGG(layers, filters, **kwargs)
    if pretrained:
        from mxnet.gluon.model_zoo.model_store import get_model_file
        batch_norm_suffix = '_bn' if kwargs.get('batch_norm') else ''
        net.load_params(get_model_file('vgg%d%s' %
                                       (num_layers, batch_norm_suffix),
                                       root=root),
                        ctx=ctx)
    return net
Beispiel #14
0
def _load_pretrained_params(net,
                            model_name,
                            dataset_name,
                            root,
                            ctx,
                            ignore_extra=False,
                            allow_missing=False):
    assert isinstance(dataset_name, str), \
      'dataset_name(str) is required when loading pretrained models. Got {}'.format(dataset_name)
    path = '_'.join([model_name, dataset_name])
    model_file = model_store.get_model_file(path, root=root)
    net.load_parameters(model_file,
                        ctx=ctx,
                        ignore_extra=ignore_extra,
                        allow_missing=allow_missing)
def get_fpn_resnet(num_layers,
                   pretrained=False,
                   ctx=mx.cpu(),
                   root=os.path.join(os.getcwd(), 'models'),
                   dummy=False):
    r"""ResNet V1 model from `"Deep Residual Learning for Image Recognition"
    <http://arxiv.org/abs/1512.03385>`_ paper.
    ResNet V2 model from `"Identity Mappings in Deep Residual Networks"
    <https://arxiv.org/abs/1603.05027>`_ paper.

    Parameters
    ----------
    num_layers : int
        Numbers of layers. Options are 18, 34, 50, 101, 152
    pretrained : bool, default False
        Whether to load the pretrained weights for model.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    root : str, default $MXNET_HOME/models
        Location for keeping the model parameters.
    """
    assert num_layers in resnet_spec, \
        "Invalid number of layers: %d. Options are %s" % (
            num_layers, str(resnet_spec.keys()))

    block_type, layers, channels = resnet_spec[num_layers]
    block_class = resnet_block_versions[block_type]
    net = ResNetV2(block_class, layers, channels)
    if pretrained:
        from mxnet.gluon.model_zoo.model_store import get_model_file
        net.load_parameters(
            get_model_file(f'resnet{num_layers}_v2', root=root),
            ctx=ctx,
            allow_missing=True,  # 무조건 True
            ignore_extra=True)
        if not dummy:
            logging.info(f"resnet{num_layers} pretrained weight load 완료")
        lateral_init(net, ctx)
        extraconv_init(net, ctx)
        upsampleconv_init(net, ctx)
    else:
        if not dummy:
            logging.info(f"resnet{num_layers} weight init 완료")
        net.initialize(ctx=ctx)
    return net
Beispiel #16
0
def alexnet(pretrained=False,
            ctx=cpu(),
            root=os.path.join(base.data_dir(), 'models'),
            **kwargs):
    r"""AlexNet model from the `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
    Parameters
    ----------
    pretrained : bool, default False
        Whether to load the pretrained weights for model.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    root : str, default $MXNET_HOME/models
        Location for keeping the model parameters.
    """
    net = AlexNet(**kwargs)
    if pretrained:
        from mxnet.gluon.model_zoo.model_store import get_model_file
        net.load_parameters(get_model_file('alexnet', root=root), ctx=ctx)
    return net
def eco_full(pretrained=False,
             ctx=gpu(),
             root=os.path.join(base.data_dir(), '/path/to/json'),
             **kwargs):
    r"""Build ECO_Full network

    Parameters
    ----------
    pretrained : bool, default False
    ctx : Context, default GPU
        The context in which to load the pretrained weights.
    root : str, default $MXNET_HOME/models
        Location for keeping the model parameters.
    """
    net = Eco(**kwargs)
    if pretrained:
        from mxnet.gluon.model_zoo.model_store import get_model_file
        net.load_parameters(get_model_file('eco_full_kinetics', root=root),
                            ctx=ctx)
    return net
Beispiel #18
0
def get_densenet(num_layers, pretrained=False, ctx=cpu(),
                 root=os.path.join(base.data_dir(), 'models'), **kwargs):
    r"""Densenet-BC model from the
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ paper.
    Parameters
    ----------
    num_layers : int
        Number of layers for the variant of densenet. Options are 121, 161, 169, 201.
    pretrained : bool, default False
        Whether to load the pretrained weights for model.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    root : str, default $MXNET_HOME/models
        Location for keeping the model parameters.
    """
    num_init_features, growth_rate, block_config = densenet_spec[num_layers]
    net = DenseNet(num_init_features, growth_rate, block_config, **kwargs)
    if pretrained:
        from mxnet.gluon.model_zoo.model_store import get_model_file
        net.load_parameters(get_model_file('densenet%d'%(num_layers), root=root), ctx=ctx)
    return net
def inception_v3(pretrained=False,
                 ctx=cpu(),
                 root=os.path.join(base.data_dir(), 'models'),
                 **kwargs):
    r"""Inception v3 model from
    `"Rethinking the Inception Architecture for Computer Vision"
    <http://arxiv.org/abs/1512.00567>`_ paper.
    Parameters
    ----------
    pretrained : bool, default False
        Whether to load the pretrained weights for model.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    root : str, default $MXNET_HOME/models
        Location for keeping the model parameters.
    """
    net = Inception3(**kwargs)
    if pretrained:
        from mxnet.gluon.model_zoo.model_store import get_model_file
        net.load_parameters(get_model_file('inceptionv3', root=root), ctx=ctx)
    return net
Beispiel #20
0
def get_nasnet(repeat=6,
               penultimate_filters=4032,
               pretrained=False,
               ctx=cpu(),
               root=os.path.join('~', '.mxnet', 'models'),
               **kwargs):
    r"""NASNet A model from
    `"Learning Transferable Architectures for Scalable Image Recognition"
    <https://arxiv.org/abs/1707.07012>`_ paper
    Parameters
    ----------
    repeat : int
        Number of cell repeats
    penultimate_filters : int
        Number of filters in the penultimate layer of the network
    pretrained : bool or str
        Boolean value controls whether to load the default pretrained weights for model.
        String value represents the hashtag for a certain version of pretrained weights.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    root : str, default '~/.mxnet/models'
        Location for keeping the model parameters.
    """
    assert repeat >= 2, \
        "Invalid number of repeat: %d. It should be at least two"%(repeat)
    net = NASNetALarge(repeat=repeat,
                       penultimate_filters=penultimate_filters,
                       **kwargs)
    if pretrained:
        from mxnet.gluon.model_zoo.model_store import get_model_file
        net.load_params(get_model_file('nasnet_%d_%d' %
                                       (repeat, penultimate_filters),
                                       tag=pretrained,
                                       root=root),
                        ctx=ctx)
    return net
Beispiel #21
0
def _load_pretrained_params(net, model_name, dataset_name, root, ctx):
    model_file = get_model_file('_'.join([model_name, dataset_name]),
                                root=root)
    net.load_params(model_file, ctx=ctx)
Beispiel #22
0
 def load_base_model(self, ctx):
     from mxnet.gluon.model_zoo.model_store import get_model_file
     self.resnet.load_params(get_model_file('resnet50_v1'), ctx=ctx, allow_missing=False, ignore_extra=True)
Beispiel #23
0
def _load_pretrained_params(net, model_name, dataset_name, root, ctx, ignore_extra=False):
    path = '_'.join([model_name, dataset_name])
    model_file = model_store.get_model_file(path, root=root)
    net.load_parameters(model_file, ctx=ctx, ignore_extra=ignore_extra)
def build_network(name, classes, checkpoint=None, ctx=mx.cpu(), **kwargs):
    if name == 'drml':
        grid = kwargs.get('grid', (8, 8))
        if checkpoint:
            print(Fore.GREEN + 'Restoring params from checkpoint: {}'.format(
                os.path.basename(checkpoint)))
            symbol_file = checkpoint[:-11] + 'symbol.json'
            net = gluon.SymbolBlock.imports(symbol_file, ['data'], checkpoint,
                                            ctx)
        else:
            net = DRML(classes, grid)
            net.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
        return net

    elif name == 'r50':
        if checkpoint:
            print(Fore.GREEN + 'Restoring params from checkpoint: {}'.format(
                os.path.basename(checkpoint)))
            symbol_file = checkpoint[:-11] + 'symbol.json'
            net = gluon.SymbolBlock.imports(symbol_file, ['data'], checkpoint,
                                            ctx)
        else:
            symbol_file = os.path.join(os.path.dirname(__file__), '..',
                                       'model', 'model-r50-am-lfw',
                                       'model-symbol.json')
            params_file = os.path.join(os.path.dirname(__file__), '..',
                                       'model', 'model-r50-am-lfw',
                                       'model-0000.params')
            net = PretrainedModel(classes, symbol_file, params_file, ctx)
            net.output.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
        return net

    elif name == 'mobileface':
        if checkpoint:
            print(Fore.GREEN + 'Restoring params from checkpoint: {}'.format(
                os.path.basename(checkpoint)))
            symbol_file = checkpoint[:-11] + 'symbol.json'
            net = gluon.SymbolBlock.imports(symbol_file, ['data'], checkpoint,
                                            ctx)
        else:
            symbol_file = os.path.join(os.path.dirname(__file__), '..',
                                       'model', 'model-y1-test2',
                                       'model-symbol.json')
            params_file = os.path.join(os.path.dirname(__file__), '..',
                                       'model', 'model-y1-test2',
                                       'model-0000.params')
            net = PretrainedModel(classes, symbol_file, params_file, ctx)
            net.output.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
        return net

    elif name == 'vggface2':
        if checkpoint:
            print(Fore.GREEN + 'Restoring params from checkpoint: {}'.format(
                os.path.basename(checkpoint)))
            symbol_file = checkpoint[:-11] + 'symbol.json'
            net = gluon.SymbolBlock.imports(symbol_file, ['data'], checkpoint,
                                            ctx)
        else:
            symbol_file = os.path.join(os.path.dirname(__file__), '..',
                                       'model', 'model-r50-vggface2',
                                       'model-symbol.json')
            params_file = os.path.join(os.path.dirname(__file__), '..',
                                       'model', 'model-r50-vggface2',
                                       'model-0000.params')
            net = VGGFace(classes, symbol_file, params_file, ctx)
            net.output.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
        return net

    elif name == 'dpn68':
        if checkpoint:
            print(Fore.GREEN + 'Restoring params from checkpoint: {}'.format(
                os.path.basename(checkpoint)))
            symbol_file = checkpoint[:-11] + 'symbol.json'
            net = gluon.SymbolBlock.imports(symbol_file, ['data'], checkpoint,
                                            ctx)
        else:
            symbol_file = os.path.join(os.path.dirname(__file__), '..',
                                       'model', 'model-dpn68-vggface2', 'new',
                                       'model-symbol.json')
            params_file = os.path.join(os.path.dirname(__file__), '..',
                                       'model', 'model-dpn68-vggface2', 'new',
                                       'model-0009.params')
            net = VGGFace(classes, symbol_file, params_file, ctx)
            net.output.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
        return net

    elif name == 'd121':
        if checkpoint:
            print(Fore.GREEN + 'Restoring params from checkpoint: {}'.format(
                os.path.basename(checkpoint)))
            symbol_file = checkpoint[:-11] + 'symbol.json'
            net = gluon.SymbolBlock.imports(symbol_file, ['data'], checkpoint,
                                            ctx)
        else:
            symbol_file = os.path.join(os.path.dirname(__file__), '..',
                                       'model', 'model-d121-vggface2', 'new',
                                       'model-symbol.json')
            params_file = os.path.join(os.path.dirname(__file__), '..',
                                       'model', 'model-d121-vggface2', 'new',
                                       'model-0000.params')
            net = VGGFace(classes, symbol_file, params_file, ctx)
            net.output.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
        return net

    else:
        if checkpoint:
            print(Fore.GREEN + 'Restoring params from checkpoint: {}'.format(
                os.path.basename(checkpoint)))
            symbol_file = checkpoint[:-11] + 'symbol.json'
            net = gluon.SymbolBlock.imports(symbol_file, ['data'], checkpoint,
                                            ctx)
        else:
            net = vision.get_model(name=name, classes=classes)
            net.features.load_parameters(model_store.get_model_file(name),
                                         ctx=ctx,
                                         ignore_extra=True)
            net.output.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
        return net
Beispiel #25
0
def download_qa_ckpt():
    model_store._model_sha1['bert_qa'] = '7eb11865ecac2a412457a7c8312d37a1456af7fc'
    result = model_store.get_model_file('bert_qa', root='.')
    print('Downloaded checkpoint to {}'.format(result))
    return result
Beispiel #26
0
def _load_pretrained_params(net, model_name, dataset_name, root, ctx):
    model_file = get_model_file('_'.join([model_name, dataset_name]), root=root)
    net.load_params(model_file, ctx=ctx)