Beispiel #1
0
def resnext_wsl(arch, pretrained, progress=True, **kwargs):
    """
    models trained in weakly-supervised fashion on 940 million public images with 1.5K hashtags matching with 1000 ImageNet1K synsets, followed by fine-tuning on ImageNet1K dataset.
    https://github.com/facebookresearch/WSL-Images/
    """
    from torch.hub import load_state_dict_from_url
    from torchvision.models.resnet import ResNet, Bottleneck

    model_args = {'resnext101_32x8d': dict(block=Bottleneck, layers=[3, 4, 23, 3], groups=32, width_per_group=8),
                  'resnext101_32x16d': dict(block=Bottleneck, layers=[3, 4, 23, 3], groups=32, width_per_group=16),
                  'resnext101_32x32d': dict(block=Bottleneck, layers=[3, 4, 23, 3], groups=32, width_per_group=32),
                  'resnext101_32x48d': dict(block=Bottleneck, layers=[3, 4, 23, 3], groups=32, width_per_group=48)}

    args = model_args[arch]
    args.update(kwargs)
    model = ResNet(**args)

    if pretrained:
        model_urls = {
            'resnext101_32x8d': 'https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth',
            'resnext101_32x16d': 'https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth',
            'resnext101_32x32d': 'https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth',
            'resnext101_32x48d': 'https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth',
        }
        state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
        model.load_state_dict(state_dict)

    return model
def fresnet50_v3(**kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    keep_dict = {
        'big': [
            64, 20, 64, 26, 64, 32, 64, 116, 128, 52, 128, 39, 128, 52, 128,
            231, 256, 77, 256, 231, 256, 205, 256, 103, 256, 180, 256, 128,
            256, 205, 256, 205, 256, 103, 256, 103, 256, 154, 256, 77, 256,
            103, 256, 410, 512, 359, 512, 154, 512
        ],
        'medium': [
            64, 13, 64, 13, 64, 13, 64, 52, 128, 26, 128, 26, 128, 26, 128,
            103, 256, 52, 256, 77, 256, 52, 256, 52, 256, 52, 256, 52, 256, 52,
            256, 52, 256, 52, 256, 52, 256, 52, 256, 52, 256, 52, 256, 154,
            512, 103, 512, 103, 512
        ],
        'small': [
            64, 7, 64, 7, 64, 7, 64, 26, 128, 13, 128, 13, 128, 13, 128, 52,
            256, 26, 256, 26, 256, 26, 256, 26, 256, 26, 256, 26, 256, 26, 256,
            26, 256, 26, 256, 26, 256, 26, 256, 26, 256, 26, 256, 103, 512, 52,
            512, 52, 512
        ]
    }

    model = ResNet(BasicBlock_v3, [3, 4, 14, 3], keep_dict['small'], **kwargs)
    model.load_state_dict(
        torch.load(
            '/media/user1/Ubuntu 16.0/resnet50/model_resnet50_66.7M_0.1.pt'))

    return model
Beispiel #3
0
class FaceModel(nn.Module):
    def __init__(self, num_classes, pretrained=False, **kwargs):
        super(FaceModel, self).__init__()
        self.model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
        if pretrained:
            parameters = model_zoo.load_url(model_urls['resnet18'])
            self.model.load_state_dict(parameters)
        self.model.avgpool = None
        self.model.fc1 = nn.Linear(512 * 3 * 4, 512)
        self.model.fc2 = nn.Linear(512, 512)
        self.model.classifier = nn.Linear(512, num_classes)
        self.register_buffer('centers', torch.zeros(num_classes, 512))
        self.num_classes = num_classes

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = x.view(x.size(0), -1)
        x = self.model.fc1(x)
        #feature for center loss
        x = self.model.fc2(x)
        self.features = x
        x = self.model.classifier(x)
        return F.log_softmax(x)
Beispiel #4
0
class MyNet(nn.Module):
    def __init__(self, pretrained=False, **kwargs):
        super(MyNet, self).__init__()
        netname = args.arch.split('_')[0]
        if netname == 'resnet152':
            self.model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
        else:
            self.model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
        if pretrained:
            parameters = model_zoo.load_url(model_urls[netname])
            self.model.load_state_dict(parameters)
        self.model.avgpool = nn.AvgPool2d(8)
        self.model.fc = nn.Linear(1024, 1)
        self.model.sig = nn.Sigmoid()

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        # x = self.model.layer4(x)
        x = self.model.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.model.fc(x)
        return self.model.sig(x)
Beispiel #5
0
def _resnext(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model
Beispiel #6
0
 def __init__(self, stride=1, pretrained=False, **kwargs):
     super().__init__()
     #encoder
     m = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4)
     if pretrained:
         stt = torch.hub.load(
             'facebookresearch/semi-supervised-ImageNet1K-models',
             'resnext50_32x4d_ssl').state_dict()
         m.load_state_dict(stt)
     self.enc0 = nn.Sequential(m.conv1, m.bn1, nn.ReLU(inplace=True))
     self.enc1 = nn.Sequential(
         nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1),
         m.layer1)  #256
     self.enc2 = m.layer2  #512
     self.enc3 = m.layer3  #1024
     self.enc4 = m.layer4  #2048
     #aspp with customized dilatations
     self.aspp = ASPP(
         2048,
         256,
         out_c=512,
         dilations=[stride * 1, stride * 2, stride * 3, stride * 4])
     self.drop_aspp = nn.Dropout2d(0.5)
     #decoder
     self.dec4 = UnetBlock(512, 1024, 256)
     self.dec3 = UnetBlock(256, 512, 128)
     self.dec2 = UnetBlock(128, 256, 64)
     self.dec1 = UnetBlock(64, 64, 32)
     self.fpn = FPN([512, 256, 128, 64], [16] * 4)
     self.drop = nn.Dropout2d(0.1)
     self.final_conv = ConvLayer(32 + 16 * 4,
                                 1,
                                 ks=1,
                                 norm_type=None,
                                 act_cls=None)
Beispiel #7
0
def get_resnext(layers, pretrained, progress, **kwargs):
    from torchvision.models.resnet import ResNet, Bottleneck
    model = ResNet(Bottleneck, layers, **kwargs)
    model.load_state_dict(
        torch.load(
            '../input/resnext-50-ssl/semi_supervised_resnext50_32x4-ddb3e555.pth'
        ))
    return model
Beispiel #8
0
def resnet50_se():
    model = ResNet(SEBottleneck, [3, 4, 6, 3])

    state_dict = load_state_dict_from_url(model_urls["resnet50"],
                                          progress=True)
    model.load_state_dict(state_dict, strict=False)

    return model
Beispiel #9
0
def resnet34_se():
    model = ResNet(SEBasicBlock, [3, 4, 6, 3])

    state_dict = load_state_dict_from_url(model_urls["resnet34"],
                                          progress=True)
    model.load_state_dict(state_dict, strict=False)

    return model
Beispiel #10
0
def resnext50_32x4d_se(**kwargs):
    kwargs["groups"] = 32
    kwargs["width_per_group"] = 4

    model = ResNet(SEBottleneck, [3, 4, 6, 3], **kwargs)
    state_dict = load_state_dict_from_url(model_urls["resnext50_32x4d"],
                                          progress=True)
    model.load_state_dict(state_dict, strict=False)

    return model
Beispiel #11
0
def resnet18(pretrained=False, model_dir=None, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18'], model_dir=model_dir))
    return model
Beispiel #12
0
def _resnext(arch, num_classes, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features=in_features,
                         out_features=num_classes,
                         bias=True)
    return model
Beispiel #13
0
 def __init__(self, pretrained=False, **kwargs):
     super(Mymodel, self).__init__(BasicBlock, [3, 4, 6, 3])
     pre_model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
     if pretrained:
         pre_model.load_state_dict(
             model_zoo.load_url(model_urls['resnet34']))  # resnet34
     fc = t.nn.Linear(in_features=1000, out_features=120)
     for param in pre_model.parameters():  # 预训练模型不需要学习参数
         param.require_grad = False
     self.model = t.nn.Sequential(pre_model, fc)
     self.model_name = str(type(self))  # 默认名字
def _resnext(url, block, layers, pretrained, progress, **kwargs):
    # make pretrained compatible although it is not used in resnext.
    model = ResNet(block, layers, **kwargs)
    if pretrained == "imagenet" or pretrained is True:
        state_dict = load_state_dict_from_url(url, progress=progress)
        model.load_state_dict(state_dict)
    model.input_space = input_space
    model.input_range = input_range
    model.input_size = input_sizes
    model.mean = means
    model.std = stds
    return model
Beispiel #15
0
def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    # change your path
    model_path = '/home/jiayunpei/SSDG_github/pretrained_model/resnet18-5c106cde.pth'
    if pretrained:
        model.load_state_dict(torch.load(model_path))
        print("loading model: ", model_path)
    # print(model)
    return model
Beispiel #16
0
def resnet152(pretrained=False, model_path=None, **kwargs):
    """Constructs a ResNet-152 model. top1-acc-%   parameter-60.20M

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        if model_path is not None:
            model.load_state_dict(torch.load(model_path))
        else:
            model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
    return model
Beispiel #17
0
 def __init__(self, pretrained=False, **kwargs):
     super(Mymodel, self).__init__(BasicBlock, [3, 4, 6, 3])
     pre_model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
     if pretrained:
         pre_model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))  # resnet34
     fc = t.nn.Linear(in_features=1000, out_features=120)
     for param in pre_model.parameters(): # 预训练模型不需要学习参数
         param.require_grad = False
     self.model = t.nn.Sequential(
         pre_model,
         fc
     )
     self.model_name = str(type(self))  # 默认名字
Beispiel #18
0
def resnet18(pretrained=False, model_path=None, **kwargs):
    """Constructs a ResNet-18 model. top1-acc-69.758%  parameter-11.69M

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        if model_path is not None:
            model.load_state_dict(torch.load(model_path))
        else:
            model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model
Beispiel #19
0
def resnet34(pretrained=False, model_path=None, **kwargs):
    """Constructs a ResNet-34 model. top1-acc-73.314%  parameter-21.80M

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    if pretrained:
        if model_path is not None:
            model.load_state_dict(torch.load(model_path))
        else:
            model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
    return model
Beispiel #20
0
def _resnet(arch, block, layers, pretrained=False, progress=True,
            imagenet_pretrained=False, num_classes=1, lin_features=512,
            dropout_prob=0.5, bn_final=False, concat_pool=True, **kwargs):

    # Model creation
    base_model = ResNet(block, layers, num_classes=num_classes, **kwargs)
    # Imagenet pretraining
    if imagenet_pretrained:
        if pretrained:
            raise ValueError('imagenet_pretrained cannot be set to True if pretrained=True')
        state_dict = load_state_dict_from_url(imagenet_urls[arch],
                                              progress=progress)
        # Remove FC params from dict
        for key in ('fc.weight', 'fc.bias'):
            state_dict.pop(key, None)
        missing, unexpected = base_model.load_state_dict(state_dict, strict=False)
        if any(unexpected) or any(not elt.startswith('fc.') for elt in missing):
            raise KeyError(f"Missing parameters: {missing}\nUnexpected parameters: {unexpected}")

    # Cut at last conv layers
    model = cnn_model(base_model, model_cut, base_model.fc.in_features, num_classes,
                      lin_features, dropout_prob, bn_final=bn_final, concat_pool=concat_pool)

    # Parameter loading
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)

    return model
def _fixmodel(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    pretrained_dict = load_state_dict_from_url(model_urls[arch],
                                               progress=progress,
                                               map_location='cpu')['model']
    model_dict = model.state_dict()
    count = 0
    count2 = 0
    for k in model_dict.keys():
        count = count + 1.0
        if (('module.' + k) in pretrained_dict.keys()):
            count2 = count2 + 1.0
            model_dict[k] = pretrained_dict.get(('module.' + k))

    assert int(count2 * 100 / count) == 100, "model loading error"

    model.load_state_dict(model_dict)
    return model
Beispiel #22
0
    def __init__(self, raw_model_dir, use_flow, logger):
        super(BackboneModel, self).__init__()
        self.use_flow = use_flow
        model = ResNet(Bottleneck, [3, 4, 6, 3])

        model.load_state_dict(
            model_zoo.load_url(model_urls['resnet50'],
                               model_dir=raw_model_dir))
        logger.info('Model restored from pretrained resnet50')

        self.feature = nn.Sequential(*list(model.children())[:-2])
        self.base = list(self.feature.parameters())

        if self.use_flow:
            self.flow_branch = self.get_flow_branch(model)
            self.rgb_branch = nn.Sequential(model.conv1, model.bn1, model.relu,
                                            model.maxpool)
            self.fuse_branch = nn.Sequential(*list(model.children())[4:-2])
        self.fea_dim = model.fc.in_features
Beispiel #23
0
def get_efficientnet_encoder(in_channels,
                             out_channels=1024,
                             layers=None,
                             pretrained=None,
                             norm_layer=nn.Identity):
    # TODO
    if layers is None:
        layers = [3, 4, 6, 3]
    encoder = ResNet(Bottleneck, layers, norm_layer=norm_layer)

    if pretrained:
        if pretrained not in model_urls:
            raise RuntimeError('No pretrained weights for this model')
        state_dict = load_state_dict_from_url(model_urls[pretrained])
        encoder.load_state_dict(state_dict, strict=False)

    # replace first conv for different number of input channels
    if in_channels != 3:
        encoder.conv1 = nn.Conv2d(in_channels,
                                  64,
                                  kernel_size=7,
                                  stride=2,
                                  padding=3,
                                  bias=False)
        nn.init.kaiming_normal_(encoder.conv1.weight,
                                mode='fan_out',
                                nonlinearity='relu')

    if out_channels == 1024:
        end_layer = -3
    elif out_channels == 2048:
        end_layer = -2
    else:
        raise RuntimeError('Invalid out_channels value')

    return StitchedModel((encoder, 0, end_layer), nn.AdaptiveAvgPool2d(1),
                         nn.Flatten(1))
Beispiel #24
0
def custom_resnet(layers,
                  pretrained=False,
                  progress=True,
                  arch='resnet',
                  **kwargs):
    """ Builds custom ResNet backbone
    Arguments:
        layers (list): configuration of layer-blocks (Bottlenecks)
        pretrained (bool): If True, returns a model pre-trained on ImageNet dataset
        progress (bool): If True, shows progress bar while downloading model
        arch (str): give architecture name if pretrained=True to fetch model params
    """
    model = ResNet(Bottleneck, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    model.conv1 = nn.Conv2d(200,
                            64,
                            kernel_size=7,
                            stride=2,
                            padding=3,
                            bias=False)  # adjust for 200 layers
    return model
Beispiel #25
0
def fresnet100_v3(**kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    keep_dict = {'74.2M': [32, 32, 26, 20, 26, 32, 26, 77, 64, 52, 64, 39, 64, 39, 64, 64, 64, 52, 64, 52, 64, 52, 64, 64, 64, 52, 64, 52, 64, 52, 64, 39, 64, 205, 128, 103, 128, 103, 128, 154, 128, 77, 128, 128, 128, 103, 128, 103, 128, 103, 128, 103, 128, 103, 128, 103, 128, 128, 128, 77, 128, 103, 128, 77, 128, 77, 128, 103, 128, 103, 128, 103, 128, 103, 128, 77, 128, 77, 128, 77, 128, 103, 128, 103, 128, 103, 128, 103, 128, 103, 128, 103, 128, 410, 256, 205, 256, 154, 256],
                 '101M': [64, 7, 64, 7, 64, 7, 64, 13, 128, 13, 128, 13, 128, 13,
                          128, 13, 128, 13, 128, 13, 128, 13, 128, 13, 128, 13, 128,
                          13, 128, 13, 128, 13, 128, 52, 256, 26, 256, 26, 256, 26,
                          256, 26, 256, 26, 256, 26, 256, 26, 256, 26, 256, 26, 256,
                          26, 256, 26, 256, 26, 256, 26, 256, 26, 256, 26, 256, 26,
                          256, 26, 256, 26, 256, 26, 256, 26, 256, 26, 256, 26, 256,
                          26, 256, 26, 256, 26, 256, 26, 256, 26, 256, 26, 256, 26,
                          256, 52, 512, 52, 512, 52, 512],
                 '131.6M': [64, 26, 64, 20, 64, 26, 64, 77, 128, 39, 128, 39, 128,
                            39, 128, 52, 128, 64, 128, 52, 128, 52, 128, 52, 128,
                            52, 128, 39, 128, 39, 128, 39, 128, 205, 256, 103, 256,
                            103, 256, 103, 256, 77, 256, 103, 256, 77, 256, 77, 256,
                            77, 256, 103, 256, 103, 256, 103, 256, 103, 256, 77, 256,
                            77, 256, 77, 256, 77, 256, 77, 256, 77, 256, 103, 256, 77,
                            256, 77, 256, 77, 256, 77, 256, 103, 256, 103, 256, 77, 256,
                            77, 256, 77, 256, 103, 256, 359, 512, 205, 512, 154, 512],
                 '202.7M': [64, 58, 64, 45, 64, 45, 64, 103, 128, 90, 128, 77, 128, 77,
                         128, 103, 128, 77, 128, 103, 128, 116, 128, 103, 128, 116,
                         128, 90, 128, 116, 128, 77, 128, 231, 256, 154, 256, 180, 256,
                         231, 256, 103, 256, 205, 256, 231, 256, 154, 256, 103, 256, 205,
                         256, 180, 256, 205, 256, 231, 256, 231, 256, 154, 256, 231, 256,
                         103, 256, 205, 256, 231, 256, 231, 256, 180, 256, 154, 256, 103,
                         256, 103, 256, 205, 256, 231, 256, 231, 256, 231, 256, 180, 256,
                         180, 256, 410, 512, 461, 512, 205, 512]}

    model = ResNet(BasicBlock_v3, [3, 13, 30, 3], keep_dict['74.2M'], **kwargs)
    model.load_state_dict(torch.load('/home/user1/linx/program/pruning_tool/work_space/pruned_model/model_resnet100.pt'))

    return model
Beispiel #26
0
def resnet50(pretrained=False, **kwargs):
    model = ResNet(MyBottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(load(pretrained))
        model.eval()
    return model
class UNetRes34(nn.Module):
    """ https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/65933
    """
    def __init__(self, n_classes=1, pretrained_resnet=True):
        super().__init__()
        self.n_classes = n_classes

        self.resnet = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=n_classes)
        del self.resnet.fc
        if pretrained_resnet:
            self.resnet.load_state_dict(model_zoo.load_url(
                model_urls['resnet34']),
                                        strict=False)
            print('Loaded pretrained resnet weights')

        self.encoder1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.encoder2 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            self.resnet.layer1,
        )  # out: 64
        self.encoder3 = self.resnet.layer2  # out: 128
        self.encoder4 = self.resnet.layer3  # out: 256
        self.encoder5 = self.resnet.layer4  # out: 512

        self.center = nn.Sequential(
            ConvBn2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            ConvBn2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)  # Add
        )

        self.decoder5 = Decoder(256, 512, 512, 64)
        self.decoder4 = Decoder(64, 256, 256, 64)
        self.decoder3 = Decoder(64, 128, 128, 64)
        self.decoder2 = Decoder(64, 64, 64, 64)
        self.decoder1 = Decoder(64, 64, 32, 64)

        self.conv = nn.Conv2d(64, 1, kernel_size=1, padding=0)

    def forward(self, x):
        batch_size, C, H, W = x.shape

        e1 = self.encoder1(x)  # shape(B, 64, 128, 128)
        e2 = self.encoder2(e1)  # shape(B, 64, 64, 64)
        e3 = self.encoder3(e2)  # shape(B, 128, 32, 32)
        e4 = self.encoder4(e3)  # shape(B, 256, 16, 16)
        e5 = self.encoder5(e4)  # shape(B, 512, 8, 8)

        f = self.center(e5)  # shape(B, 256, 4, 4)

        d5 = self.decoder5(f, e5)
        d4 = self.decoder4(d5, e4)
        d3 = self.decoder3(d4, e3)
        d2 = self.decoder2(d3, e2)
        d1 = self.decoder1(d2, e1)

        return self.conv(d1)
def main(opt):
    start_epoch = 0
    acc_best = 0.
    glob_step = 0
    lr_now = opt.lr

    # save options
    log.save_options(opt, opt.ckpt)
    tb_logdir = f'./exp/{opt.name}'
    if os.path.exists(tb_logdir):
        shutil.rmtree(tb_logdir)
    writer = SummaryWriter(log_dir=f'./exp/{opt.name}')
    exp_dir_ = dirname(opt.load)

    # create model
    print(">>> creating model")
    # TODO: This is how to avoid weird data reshaping for non-3-channel inputs.
    # Have ResNet model take in grayscale rather than RGB
    #    model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    if opt.arch == 'cnn':
        model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=opt.num_classes)
    else:
        model = LinearModel()
    model = model.cuda()
    model.apply(weight_init)
    print(">>> total params: {:.2f}M".format(
        sum(p.numel() for p in model.parameters()) / 1000000.0))
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)

    # load ckpt
    if opt.load:
        print(">>> loading ckpt from '{}'".format(opt.load))
        ckpt = torch.load(opt.load)
        start_epoch = ckpt['epoch']
        acc_best = ckpt['acc']
        glob_step = ckpt['step']
        lr_now = ckpt['lr']
        model.load_state_dict(ckpt['state_dict'])
        optimizer.load_state_dict(ckpt['optimizer'])
        print(">>> ckpt loaded (epoch: {} | acc: {})".format(
            start_epoch, acc_best))
    if opt.resume:
        logger = log.Logger(os.path.join(opt.ckpt, 'log.txt'), resume=True)
    else:
        logger = log.Logger(os.path.join(opt.ckpt, 'log.txt'))
        logger.set_names([
            'epoch', 'lr', 'loss_train', 'err_train', 'acc_train', 'loss_test',
            'err_test', 'acc_test'
        ])

    transforms = [
        ToTensor(),
    ]

    train_datasets = []
    for dataset_name in opt.train_datasets:
        train_datasets.append(
            ClassificationDataset(name=dataset_name,
                                  num_kpts=opt.num_kpts,
                                  transforms=transforms,
                                  split='train',
                                  arch=opt.arch,
                                  gt=opt.gt))
    train_dataset = ConcatDataset(train_datasets)
    train_loader = DataLoader(train_dataset,
                              batch_size=opt.train_batch,
                              shuffle=True,
                              num_workers=opt.job)

    split = 'test' if opt.test else 'valid'

    test_dataset = ClassificationDataset(name=opt.test_dataset,
                                         num_kpts=opt.num_kpts,
                                         transforms=transforms,
                                         split=split,
                                         arch=opt.arch,
                                         gt=opt.gt)

    test_loader = DataLoader(test_dataset,
                             batch_size=opt.test_batch,
                             shuffle=False,
                             num_workers=opt.job)

    subset_loaders = {}
    for subset in test_dataset.create_subsets():
        subset_loaders[subset.split] = DataLoader(subset,
                                                  batch_size=opt.test_batch,
                                                  shuffle=False,
                                                  num_workers=opt.job)

    cudnn.benchmark = True

    for epoch in range(start_epoch, opt.epochs):
        torch.cuda.empty_cache()
        print('==========================')
        print('>>> epoch: {} | lr: {:.5f}'.format(epoch + 1, lr_now))

        if not opt.test:
            glob_step, lr_now, loss_train, err_train, acc_train = \
                    train(train_loader, model, criterion, optimizer,
                            num_kpts=opt.num_kpts, num_classes=opt.num_classes,
                            lr_init=opt.lr, lr_now=lr_now, glob_step=glob_step,
                            lr_decay=opt.lr_decay, gamma=opt.lr_gamma,
                            max_norm=opt.max_norm)

        loss_test, err_test, acc_test, auc_test, prec_test = \
                test(test_loader, model, criterion, num_kpts=opt.num_kpts,
                        num_classes=opt.num_classes, batch_size=opt.test_batch)

        ## Test subsets ##
        subset_losses = {}
        subset_errs = {}
        subset_accs = {}
        subset_aucs = {}
        subset_precs = {}
        subset_openpose = {}
        subset_missing = {}
        subset_grids = {}

        if len(subset_loaders) > 0:
            bar = Bar('>>>', fill='>', max=len(subset_loaders))

        for key_idx, key in enumerate(subset_loaders):
            loss_sub, err_sub, acc_sub, auc_sub, prec_sub = test(
                subset_loaders[key],
                model,
                criterion,
                num_kpts=opt.num_kpts,
                num_classes=opt.num_classes,
                batch_size=4,
                log=False)

            subset_losses[key] = loss_sub
            subset_errs[key] = err_sub
            subset_accs[key] = acc_sub
            subset_aucs[key] = auc_sub
            subset_precs[key] = prec_sub

            sub_dataset = subset_loaders[key].dataset
            if sub_dataset.gt_paths is not None:
                gt_X = load_gt(sub_dataset.gt_paths)
                subset_openpose[key] = mpjpe_2d_openpose(sub_dataset.X, gt_X)
                subset_missing[key] = mean_missing_parts(sub_dataset.X)
            else:
                subset_openpose[key] = 0.
                subset_missing[key] = 0.

            sample_idxs = extract_tb_sample(subset_loaders[key],
                                            model,
                                            batch_size=opt.test_batch)
            sample_X = sub_dataset.X[sample_idxs]
            sample_img_paths = [sub_dataset.img_paths[x] for x in sample_idxs]
            if opt.arch == 'cnn':
                subset_grids[key] = create_grid(sample_X, sample_img_paths)

            bar.suffix = f'({key_idx+1}/{len(subset_loaders)}) | {key}'
            bar.next()

        if len(subset_loaders) > 0:
            bar.finish()
        ###################

        if opt.test:
            subset_accs['all'] = acc_test
            subset_aucs['all'] = auc_test
            subset_precs['all'] = prec_test
            report_dict = {
                'acc': subset_accs,
                'auc': subset_aucs,
                'prec': subset_precs
            }

            report_idx = 0
            report_path = f'report/{opt.name}-{report_idx}.json'
            while os.path.exists(f'report/{opt.name}-{report_idx}.json'):
                report_idx += 1
            report_path = f'report/{opt.name}-{report_idx}.json'

            print(f'>>> Saving report to {report_path}...')
            with open(report_path, 'w') as acc_f:
                json.dump(report_dict, acc_f, indent=4)

            print('>>> Exiting (test mode)...')
            break

        # update log file
        logger.append([
            epoch + 1, lr_now, loss_train, err_train, acc_train, loss_test,
            err_test, acc_test
        ], [
            'int', 'float', 'float', 'float', 'float', 'float', 'float',
            'float'
        ])

        # save ckpt
        is_best = acc_test > acc_best
        acc_best = max(acc_test, acc_best)
        if is_best:
            log.save_ckpt(
                {
                    'epoch': epoch + 1,
                    'lr': lr_now,
                    'step': glob_step,
                    'acc': acc_best,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                },
                ckpt_path=opt.ckpt,
                is_best=True)
        else:
            log.save_ckpt(
                {
                    'epoch': epoch + 1,
                    'lr': lr_now,
                    'step': glob_step,
                    'acc': acc_best,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                },
                ckpt_path=opt.ckpt,
                is_best=False)

        writer.add_scalar('Loss/train', loss_train, epoch)
        writer.add_scalar('Loss/test', loss_test, epoch)
        writer.add_scalar('Error/train', err_train, epoch)
        writer.add_scalar('Error/test', err_test, epoch)
        writer.add_scalar('Accuracy/train', acc_train, epoch)
        writer.add_scalar('Accuracy/test', acc_test, epoch)
        for key in subset_losses:
            writer.add_scalar(f'Loss/Subsets/{key}', subset_losses[key], epoch)
            writer.add_scalar(f'Error/Subsets/{key}', subset_errs[key], epoch)
            writer.add_scalar(f'Accuracy/Subsets/{key}', subset_accs[key],
                              epoch)
            writer.add_scalar(f'OpenPose/Subsets/{key}', subset_openpose[key],
                              epoch)
            writer.add_scalar(f'Missing/Subsets/{key}', subset_missing[key],
                              epoch)
            if opt.arch == 'cnn':
                writer.add_images(f'Subsets/{key}',
                                  subset_grids[key],
                                  epoch,
                                  dataformats='NHWC')

    logger.close()
    writer.close()
def _resnext(url, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    state_dict = load_state_dict_from_url(url, progress=progress)
    model.load_state_dict(state_dict)
    return model
def _resnext(path, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    model.load_state_dict(torch.load(path))
    return model
def resnet50_from_checkpoint(file_like_obj):
    model = ResNet(Bottleneck, [3, 4, 6, 3])
    model.load_state_dict(torch.load(file_like_obj))
    return model