Пример #1
0
            print(
                'Invalid pool type %s specified. Defaulting to average pooling.'
                % pool_type)
        x = F.avg_pool3d(x,
                         kernel_size=(x.size(2), x.size(3), x.size(4)),
                         padding=padding,
                         count_include_pad=count_include_pad)
    return x


if __name__ == '__main__':
    import torchvision
    import numpy as np
    import torch
    from torch.autograd import Variable
    from dpn import dpn107

    dpn107_2d = dpn107()
    dpn107_i3d = i3d_dpn107()

    data = np.ones((1, 3, 224, 224), dtype=np.float32)
    tensor = torch.from_numpy(data)
    inputs = Variable(tensor)
    out1 = dpn107_2d(inputs)
    print(out1)

    data2 = np.ones((1, 3, 32, 224, 224), dtype=np.float32)
    tensor2 = torch.from_numpy(data2)
    inputs2 = Variable(tensor2)
    out2 = dpn107_i3d(inputs2)
    print(out2)
Пример #2
0
def create_model(model_name, num_classes=1000, pretrained=False, **kwargs):
    if 'test_time_pool' in kwargs:
        test_time_pool = kwargs.pop('test_time_pool')
    else:
        test_time_pool = True
    if 'extra' in kwargs:
        extra = kwargs.pop('extra')
    else:
        extra = True
    if model_name == 'dpn68':
        model = dpn68(num_classes=num_classes,
                      pretrained=pretrained,
                      test_time_pool=test_time_pool)
    elif model_name == 'dpn68b':
        model = dpn68b(num_classes=num_classes,
                       pretrained=pretrained,
                       test_time_pool=test_time_pool)
    elif model_name == 'dpn92':
        model = dpn92(num_classes=num_classes,
                      pretrained=pretrained,
                      test_time_pool=test_time_pool,
                      extra=extra)
    elif model_name == 'dpn98':
        model = dpn98(num_classes=num_classes,
                      pretrained=pretrained,
                      test_time_pool=test_time_pool)
    elif model_name == 'dpn131':
        model = dpn131(num_classes=num_classes,
                       pretrained=pretrained,
                       test_time_pool=test_time_pool)
    elif model_name == 'dpn107':
        model = dpn107(num_classes=num_classes,
                       pretrained=pretrained,
                       test_time_pool=test_time_pool)
    elif model_name == 'resnet18':
        model = resnet18(num_classes=num_classes,
                         pretrained=pretrained,
                         **kwargs)
    elif model_name == 'resnet34':
        model = resnet34(num_classes=num_classes,
                         pretrained=pretrained,
                         **kwargs)
    elif model_name == 'resnet50':
        model = resnet50(num_classes=num_classes,
                         pretrained=pretrained,
                         **kwargs)
    elif model_name == 'resnet101':
        model = resnet101(num_classes=num_classes,
                          pretrained=pretrained,
                          **kwargs)
    elif model_name == 'resnet152':
        model = resnet152(num_classes=num_classes,
                          pretrained=pretrained,
                          **kwargs)
    elif model_name == 'densenet121':
        model = densenet121(num_classes=num_classes,
                            pretrained=pretrained,
                            **kwargs)
    elif model_name == 'densenet161':
        model = densenet161(num_classes=num_classes,
                            pretrained=pretrained,
                            **kwargs)
    elif model_name == 'densenet169':
        model = densenet169(num_classes=num_classes,
                            pretrained=pretrained,
                            **kwargs)
    elif model_name == 'densenet201':
        model = densenet201(num_classes=num_classes,
                            pretrained=pretrained,
                            **kwargs)
    elif model_name == 'inception_v3':
        model = inception_v3(num_classes=num_classes,
                             pretrained=pretrained,
                             transform_input=False,
                             **kwargs)
    else:
        assert False, "Unknown model architecture (%s)" % model_name
    return model