def __init__(self): torch.nn.Module.__init__(self) self.model = WideResNet().cuda() self.model = torch.nn.DataParallel(self.model) self._mean_torch = torch.tensor( (0.4914, 0.4822, 0.4465)).view(3, 1, 1).cuda() self._std_torch = torch.tensor( (0.2471, 0.2435, 0.2616)).view(3, 1, 1).cuda()
class WideResNet_TRADES(torch.nn.Module): def __init__(self): torch.nn.Module.__init__(self) self.model = WideResNet().cuda() def forward(self, x): x = x.transpose(1, 2).transpose(1, 3).contiguous() labels = self.model(x.cuda()) return labels.cpu() def load(self): checkpoint = torch.load(MODEL_PATH) self.model.load_state_dict(checkpoint) self.eval()
class Label_Smoothing(torch.nn.Module): def __init__(self): torch.nn.Module.__init__(self) self.model = WideResNet().cuda() self.model = torch.nn.DataParallel(self.model) self._mean_torch = torch.tensor( (0.4914, 0.4822, 0.4465)).view(3, 1, 1).cuda() self._std_torch = torch.tensor( (0.2471, 0.2435, 0.2616)).view(3, 1, 1).cuda() def forward(self, x): x = x.transpose(1, 2).transpose(1, 3).contiguous() input_var = (x.cuda() - self._mean_torch) / self._std_torch labels = self.model(input_var) return labels.cpu() def load(self): checkpoint = torch.load(MODEL_PATH) print(checkpoint["test_robust_acc"], checkpoint["test_acc"]) self.model.load_state_dict(checkpoint['state_dict']) self.eval()
def gen_pb(): from trades.models.wideresnet import WideResNet import torch import onnx from onnx_tf.backend import prepare device = torch.device("cuda") model = WideResNet().to(device) model.load_state_dict(torch.load('./model_cifar_wrn.pt')) model.eval() dummy_input = torch.from_numpy(np.zeros( (64, 3, 32, 32), )).float().to(device) dummy_output = model(dummy_input) torch.onnx.export(model, dummy_input, './model_cifar_wrn.onnx', input_names=['input'], output_names=['output']) model_onnx = onnx.load('./model_cifar_wrn.onnx') tf_rep = prepare(model_onnx) # Print out tensors and placeholders in model (helpful during inference in TensorFlow) print(tf_rep.tensor_dict) # Export model as .pb file tf_rep.export_graph('./model_cifar_wrn.pb')
def __init__(self): torch.nn.Module.__init__(self) self.model = WideResNet().cuda()