Exemple #1
0
    class VisualTrans(nn.Module):
        def __init__(self, file_path):
            super(VisualTrans, self).__init__()

            self.file_path = file_path

            self.model = ViT_modified(
                n_classes=1,
                image_size=(1,
                            962),  # image size is a tuple of (height, width)
                patch_size=(1, 13),  # patch size is a tuple of (height, width)
                dim=16,
                depth=3,
                heads=16,
                mlp_dim=512,
                dropout=0.1,
                emb_dropout=0.1)

            state_dict = torch.load(self.file_path, map_location='cpu')
            new_state_dict = OrderedDict()

            try:
                self.model.load_state_dict(state_dict)
            except RuntimeError as e:
                print('Ignoring test_dataset_size "' + str(e) + '"')

        def forward(self, inpt):
            theta, x = inpt
            theta = theta.unsqueeze_(1).unsqueeze_(1)
            x = x.unsqueeze_(1).unsqueeze_(1)
            x = torch.nn.functional.pad(x, (0, 2))
            inp = torch.cat((theta, x), 3)

            out = self.model(inp)[0]  #another [0]- when the n=2
            return out
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_train', type=str, required=True)
    parser.add_argument('--model_path', type=str, required=True)
    parser.add_argument('--num_classes', type=int, required=True)
    parser.add_argument('--output_path', type=str, required=True)
    args = parser.parse_args()

    if args.model_train == 'efficientnet-b0':
      netD = EfficientNet.from_pretrained('efficientnet-b0', num_classes=args.num_classes)
    elif args.model_train == 'efficientnet-b1':
      netD = EfficientNet.from_pretrained('efficientnet-b1', num_classes=args.num_classes)
    elif args.model_train == 'efficientnet-b2':
      netD = EfficientNet.from_pretrained('efficientnet-b2', num_classes=args.num_classes)
    elif args.model_train == 'efficientnet-b3':
      netD = EfficientNet.from_pretrained('efficientnet-b3', num_classes=args.num_classes)
    elif args.model_train == 'efficientnet-b4':
      netD = EfficientNet.from_pretrained('efficientnet-b4', num_classes=args.num_classes)
    elif args.model_train == 'efficientnet-b5':
      netD = EfficientNet.from_pretrained('efficientnet-b5', num_classes=args.num_classes)
    elif args.model_train == 'efficientnet-b6':
      netD = EfficientNet.from_pretrained('efficientnet-b6', num_classes=args.num_classes)
    elif args.model_train == 'efficientnet-b7':
      netD = EfficientNet.from_pretrained('efficientnet-b7', num_classes=args.num_classes)



    elif args.model_train == 'mobilenetv3_small':
      from arch.mobilenetv3_arch import MobileNetV3
      netD = MobileNetV3(n_class=args.num_classes, mode='small', input_size=256)
    elif args.model_train == 'mobilenetv3_large':
      from arch.mobilenetv3_arch import MobileNetV3
      netD = MobileNetV3(n_class=args.num_classes, mode='large', input_size=256)



    elif args.model_train == 'resnet50':
      from arch.resnet_arch import resnet50
      netD = resnet50(num_classes=args.num_classes, pretrain=True)
    elif args.model_train == 'resnet101':
      from arch.resnet_arch import resnet101
      netD = resnet101(num_classes=args.num_classes, pretrain=True)
    elif args.model_train == 'resnet152':
      from arch.resnet_arch import resnet152
      netD = resnet152(num_classes=args.num_classes, pretrain=True)

    #############################################
    elif args.model_train == 'ViT':
      from vit_pytorch import ViT
      netD = ViT(
          image_size = 256,
          patch_size = 32,
          num_classes = args.num_classes,
          dim = 1024,
          depth = 6,
          heads = 16,
          mlp_dim = 2048,
          dropout = 0.1,
          emb_dropout = 0.1
      )

    elif args.model_train == 'DeepViT':
      from vit_pytorch.deepvit import DeepViT
      netD = DeepViT(
          image_size = 256,
          patch_size = 32,
          num_classes = args.num_classes,
          dim = 1024,
          depth = 6,
          heads = 16,
          mlp_dim = 2048,
          dropout = 0.1,
          emb_dropout = 0.1
      )


    #############################################

    elif model_train == 'RepVGG-A0':
      from arch.RepVGG_arch import create_RepVGG_A0
      self.netD = create_RepVGG_A0(deploy=False, num_classes=num_classes)

    elif model_train == 'RepVGG-A1':
      from arch.RepVGG_arch import create_RepVGG_A1
      self.netD = create_RepVGG_A1(deploy=False, num_classes=num_classes)

    elif model_train == 'RepVGG-A2':
      from arch.RepVGG_arch import create_RepVGG_A2
      self.netD = create_RepVGG_A2(deploy=False, num_classes=num_classes)

    elif model_train == 'RepVGG-B0':
      from arch.RepVGG_arch import create_RepVGG_B0
      self.netD = create_RepVGG_B0(deploy=False, num_classes=num_classes)

    elif model_train == 'RepVGG-B1':
      from arch.RepVGG_arch import create_RepVGG_B1
      self.netD = create_RepVGG_B1(deploy=False, num_classes=num_classes)

    elif model_train == 'RepVGG-B1g2':
      from arch.RepVGG_arch import create_RepVGG_B1g2
      self.netD = create_RepVGG_B1g2(deploy=False, num_classes=num_classes)

    elif model_train == 'RepVGG-B1g4':
      from arch.RepVGG_arch import create_RepVGG_B1g4
      self.netD = create_RepVGG_B1g4(deploy=False, num_classes=num_classes)

    elif model_train == 'RepVGG-B2':
      from arch.RepVGG_arch import create_RepVGG_B2
      self.netD = create_RepVGG_B2(deploy=False, num_classes=num_classes)

    elif model_train == 'RepVGG-B2g2':
      from arch.RepVGG_arch import create_RepVGG_B2g2
      self.netD = create_RepVGG_B2g2(deploy=False, num_classes=num_classes)

    elif model_train == 'RepVGG-B2g4':
      from arch.RepVGG_arch import create_RepVGG_B2g4
      self.netD = create_RepVGG_B2g4(deploy=False, num_classes=num_classes)

    elif model_train == 'RepVGG-B3':
      from arch.RepVGG_arch import create_RepVGG_B3
      self.netD = create_RepVGG_B3(deploy=False, num_classes=num_classes)

    elif model_train == 'RepVGG-B3g2':
      from arch.RepVGG_arch import create_RepVGG_B3g2
      self.netD = create_RepVGG_B3g2(deploy=False, num_classes=num_classes)

    elif model_train == 'RepVGG-B3g4':
      from arch.RepVGG_arch import create_RepVGG_B3g4
      self.netD = create_RepVGG_B3g4(deploy=False, num_classes=num_classes)

    #############################################

    elif args.model_train == 'squeezenet_1_0':
      from arch.squeezenet_arch import SqueezeNet
      netD = SqueezeNet(num_classes=args.num_classes, version='1_0')

    elif args.model_train == 'squeezenet_1_1':
      from arch.squeezenet_arch import SqueezeNet
      netD = SqueezeNet(num_classes=args.num_classes, version='1_1')
    #############################################
    elif args.model_train == 'vgg11':
      from arch.vgg_arch import create_vgg11
      netD = create_vgg11(num_classes, pretrained=True)
    elif args.model_train == 'vgg13':
      from arch.vgg_arch import create_vgg13
      netD = create_vgg13(num_classes, pretrained=True)
    elif args.model_train == 'vgg16':
      from arch.vgg_arch import create_vgg16
      netD = create_vgg16(num_classes, pretrained=True)
    elif args.model_train == 'vgg19':
      from arch.vgg_arch import create_vgg19
      netD = create_vgg19(num_classes, pretrained=True)

    #############################################
    elif args.model_train == 'SwinTransformer':
      from swin_transformer_pytorch import SwinTransformer

      netD = SwinTransformer(
          hidden_dim=96,
          layers=(2, 2, 6, 2),
          heads=(3, 6, 12, 24),
          channels=3,
          num_classes=args.num_classes,
          head_dim=32,
          window_size=8,
          downscaling_factors=(4, 2, 2, 2),
          relative_pos_embedding=True
      )








    from torch.autograd import Variable

    import torch.onnx
    import torchvision
    import torch

    dummy_input = Variable(torch.randn(1, 3, 256, 256)) # don't set it too high, will run out of RAM
    state_dict = torch.load(args.model_path)
    print("Loaded model from model path into state_dict.")

    netD.load_state_dict(state_dict)
    torch.onnx.export(netD, dummy_input, args.output_path, opset_version=11)
    print("Done.")
 def to_vit(self):
     v = ViT(*self.args, **self.kwargs)
     v.load_state_dict(self.state_dict())
     return v
        pred = torch.argmax(preds[i])
        if pred == labels[i]:
            if pred == 1:
                TP += 1
            else:
                TN += 1
        else:
            if pred == 1:
                FP += 1
            else:
                FN += 1

    return TP, FN, FP, TN


v.load_state_dict(torch.load('./best_model_551480.pt'))
for i in range(100, epoch):  # continue training
    ##### TRAIN #####
    for j, (img, label) in enumerate(train_loader):
        #train
        #TP, FN, FP, TN = 0, 0, 0, 0
        loss = torch.tensor(0.0).cuda()
        v.zero_grad()
        train_img = Variable(img).cuda()
        #video_label = label
        labels = Variable(label).cuda()
        preds = v(train_img)  # input:(10,3,256,256) #(1, 1000)
        loss += criterion(preds, labels)

        loss.backward()
        opt.step()
Exemple #5
0
    # out
    weight_shape = pretain_tf_model['Transformer'][tf_key][
        'MultiHeadDotProductAttention_1']['out']['kernel'].shape
    weight = pretain_tf_model['Transformer'][tf_key][
        'MultiHeadDotProductAttention_1']['out']['kernel'].reshape(
            weight_shape[0] * weight_shape[1], weight_shape[2])
    weight = np.array(jnp.transpose(weight))
    tf_dict[torch_key_prefix +
            '.attention.to_out.0.weight'] = torch.from_numpy(weight)
    tf_dict[torch_key_prefix + '.attention.to_out.0.bias'] = torch.from_numpy(
        pretain_tf_model['Transformer'][tf_key]
        ['MultiHeadDotProductAttention_1']['out']['bias'])

img = torch.randn(1, 3, input_size, input_size)
mask = torch.ones(
    1, input_size // patch_size, input_size //
    patch_size).bool()  # optional mask, designating which patch to attend to
preds = v(img, mask=mask)  # (1, 1000)

print(preds.flatten()[0:10])

v.load_state_dict(tf_dict)

preds = v(img, mask=mask)  # (1, 1000)

print(preds.flatten()[0:10])

# print(pretain_tf_model(img))

# torch.save
torch.save(v.state_dict(), "imagenet21k+imagenet2012_ViT-B_16-224.pth")