示例#1
0
def main():
    global args
    args = parser.parse_args()

    model_selection = 'resnet'
    model = define_model(encoder=model_selection)
    original_model2 = net_mask.drn_d_22(pretrained=True)
    model2 = net_mask.AutoED(original_model2)

    model = torch.nn.DataParallel(model).cuda()
    model2 = torch.nn.DataParallel(model2).cuda()

    model.load_state_dict(
        torch.load('./pretrained_model/model_' + model_selection))
    model2.load_state_dict(torch.load('./net_mask/mask_' + model_selection))

    test_loader = loaddata.getTestingData(1)
    test(test_loader, model, model2, 'mask_' + model_selection)
示例#2
0
def main():
    Encoder = modules.E_resnet(resnet.resnet50(pretrained=True))
    N = net.model(Encoder,
                  num_features=2048,
                  block_channel=[256, 512, 1024, 2048])
    G_adv = net_mask.G(net_mask.drn_d_22(pretrained=True))

    N = torch.nn.DataParallel(N).cuda()
    G_adv = torch.nn.DataParallel(G_adv).cuda()

    N.load_state_dict(torch.load('./models/N'))
    G_adv.load_state_dict(torch.load('./models/G_adv'))

    cudnn.benchmark = True

    test_loader = loaddata.getTestingData(8)

    #test for N(x*G_adv(x*))
    test_G_adv(test_loader, N, G_adv, epsilon=0.05, iteration=10)
    test_G_adv(test_loader, N, G_adv, epsilon=0.1, iteration=10)
    test_G_adv(test_loader, N, G_adv, epsilon=0.15, iteration=10)
    test_G_adv(test_loader, N, G_adv, epsilon=0.2, iteration=10)
示例#3
0
def main():
    global args
    args = parser.parse_args()

    model_selection = 'resnet'
    model = define_model(encoder=model_selection)

    original_model2 = net_mask.drn_d_22(pretrained=True)
    model2 = net_mask.AutoED(original_model2)

    if torch.cuda.device_count() == 8:
        model = torch.nn.DataParallel(model,
                                      device_ids=[0, 1, 2, 3, 4, 5, 6,
                                                  7]).cuda()
        model2 = torch.nn.DataParallel(model2,
                                       device_ids=[0, 1, 2, 3, 4, 5, 6,
                                                   7]).cuda()
        batch_size = 64
    elif torch.cuda.device_count() == 4:
        model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3]).cuda()
        model2 = torch.nn.DataParallel(model2, device_ids=[0, 1, 2, 3]).cuda()
        batch_size = 32
    else:
        model = torch.nn.DataParallel(model).cuda()
        model2 = torch.nn.DataParallel(model2).cuda()
        batch_size = 8
    model.load_state_dict(
        torch.load('./pretrained_model/model_' + model_selection))

    cudnn.benchmark = True
    optimizer = torch.optim.Adam(model2.parameters(),
                                 args.lr,
                                 weight_decay=args.weight_decay)

    train_loader = loaddata.getTrainingData(batch_size)
    for epoch in range(args.start_epoch, args.epochs):
        train(train_loader, model, model2, optimizer, epoch)

    torch.save(model.state_dict(), '/net_mask/mask_' + model_selection)
        original_model = senet.senet154(pretrained='imagenet')
        Encoder = modules.E_senet(original_model)
        model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])

    return model
   
def convert_to_onnx(net, output_name):
    input = torch.randn(1, 3, 224, 224)
    input_names = ['data']
    output_names = ['output']
    net.eval()
    torch.onnx.export(net, input, output_name, verbose=True, input_names=input_names, output_names=output_names,opset_version=9)

model_selection = 'senet'
model = define_model(encoder = model_selection)
original_model2 = net_mask.drn_d_22(pretrained=True)
model2 = net_mask.AutoED(original_model2)

#model = model.cuda()
#model2 = model2.cuda()
#model = torch.nn.DataParallel(model).cuda()
#model2 = torch.nn.DataParallel(model2).cuda()
dic = torch.load('./pretrained_model/model_' + model_selection)
print(dic.keys())
D2 = {}
for k in dic.keys():
   D2[k.replace("module.","",1)] = dic[k]
print(D2.keys())

model.load_state_dict(D2)
convert_to_onnx(model,"model.onnx")