def gen_pb():
    from trades.models.wideresnet import WideResNet
    import torch
    import onnx
    from onnx_tf.backend import prepare

    device = torch.device("cuda")
    model = WideResNet().to(device)
    model.load_state_dict(torch.load('./model_cifar_wrn.pt'))
    model.eval()

    dummy_input = torch.from_numpy(np.zeros(
        (64, 3, 32, 32), )).float().to(device)
    dummy_output = model(dummy_input)

    torch.onnx.export(model,
                      dummy_input,
                      './model_cifar_wrn.onnx',
                      input_names=['input'],
                      output_names=['output'])

    model_onnx = onnx.load('./model_cifar_wrn.onnx')

    tf_rep = prepare(model_onnx)

    # Print out tensors and placeholders in model (helpful during inference in TensorFlow)
    print(tf_rep.tensor_dict)

    # Export model as .pb file
    tf_rep.export_graph('./model_cifar_wrn.pb')
Exemple #2
0
 def __init__(self):
     torch.nn.Module.__init__(self)
     self.model = WideResNet().cuda()
     self.model = torch.nn.DataParallel(self.model)
     self._mean_torch = torch.tensor(
         (0.4914, 0.4822, 0.4465)).view(3, 1, 1).cuda()
     self._std_torch = torch.tensor(
         (0.2471, 0.2435, 0.2616)).view(3, 1, 1).cuda()
Exemple #3
0
 def __init__(self):
     torch.nn.Module.__init__(self)
     self.model = WideResNet().cuda()