def test_lstm_export_with_constantofshape(tmpdir): np.random.seed(42) torch.manual_seed(43) class RNNNet(torch.nn.Module): def __init__(self): super(RNNNet, self).__init__() hidden_size = 8 input_size = 18 self.lstm = torch.nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True) def forward(self, x): x, (h, c) = self.lstm(x) return x net = RNNNet() np_data = np.random.rand(1, 100, 18).astype(np.float32) torch_data = torch.from_numpy(np_data) torchOutput = net(torch_data).detach().numpy() export_name = str(tmpdir / "lstm_small_repro.onnx") torch.onnx.export(net, torch_data, export_name, verbose=True, input_names=['data'], output_names=['tag']) # Verify this model contains a ConstantOfShape op. model = onnx.load(export_name) nodes = model.graph.node nodes = [i for i in nodes if i.op_type == 'ConstantOfShape'] assert len(nodes) > 0 inputShapeInfo = popart.InputShapeInfo() inputShapeInfo.add("data", popart.TensorInfo("FLOAT", [1, 100, 18])) anchors = {"tag": popart.AnchorReturnType("All")} dataFlow = popart.DataFlow(1, anchors) device = tu.create_test_device() session = popart.InferenceSession(export_name, dataFlow, device, inputShapeInfo=inputShapeInfo) session.prepareDevice() inferenceAnchors = session.initAnchorArrays() stepio = popart.PyStepIO({"data": np_data}, inferenceAnchors) session.run(stepio) popartOutput = inferenceAnchors['tag'] assert torchOutput.shape == popartOutput.shape assert np.allclose(torchOutput, popartOutput, atol=1e-07)
def test_model_with_specified_dim_params(tmpdir): proto, outId = get_model_with_dim_param(tmpdir) inputShapeInfo = popart.InputShapeInfo() inputShapeInfo.add("input_0", popart.TensorInfo("FLOAT", [1, 2, 32, 32])) session = popart.InferenceSession( fnModel=proto, dataFlow=popart.DataFlow(1, {outId: popart.AnchorReturnType("All")}), deviceInfo=tu.create_test_device(), inputShapeInfo=inputShapeInfo)
def __init__( self, fnModel: bytes, dataFlow: Dict[int, Dict], deviceInfo: popart.DeviceInfo, inputShapeInfo: popart.InputShapeInfo = popart.InputShapeInfo(), patterns: popart.Patterns = None, userOptions: popart.SessionOptions = popart.SessionOptions(), name: str = "inference") -> None: if patterns == None: patterns = popart.Patterns() super(InferenceSession, self).__init__(fnModel, dataFlow, deviceInfo, inputShapeInfo, userOptions, patterns, name)
def __init__( self, fnModel: bytes, dataFlow: Dict[int, Dict], loss: "", optimizer: popart.Optimizer, deviceInfo: popart.DeviceInfo, inputShapeInfo: popart.InputShapeInfo = popart.InputShapeInfo(), patterns: popart.Patterns = None, userOptions: popart.SessionOptions = popart.SessionOptions(), name: str = "training") -> None: if patterns is None: patterns = popart.Patterns() super(TrainingSession, self).__init__(fnModel, dataFlow, loss, optimizer, deviceInfo, inputShapeInfo, userOptions, patterns, name)
def __init__( self, fnModel: bytes, dataFlow: Dict[int, Dict], deviceInfo: popart.DeviceInfo, inputShapeInfo: popart.InputShapeInfo = popart.InputShapeInfo(), patterns: popart.Patterns = None, userOptions: popart.SessionOptions = popart.SessionOptions() ) -> None: if patterns == None: patterns = popart.Patterns() super(InferenceSession, self).__init__(fnModel, dataFlow, deviceInfo, inputShapeInfo, userOptions, patterns) self.dataFlow = dataFlow self.replicationFactor = userOptions.replicatedGraphCount if \ userOptions.enableReplicatedGraphs else 1 self.accumulationFactor = userOptions.accumulationFactor if \ userOptions.enableGradientAccumulation else 1
from popart.torch import torchwriter #we require torch in this file to create the torch Module import torch args = cmdline.parse() nInChans = 3 nOutChans = 10 batchSize = 2 batchesPerStep = 4 anchors = { "loss": popart.AnchorReturnType("EveryN", 2), "image0": popart.AnchorReturnType("All") } dataFlow = popart.DataFlow(batchesPerStep, anchors) inputShapeInfo = popart.InputShapeInfo() inputShapeInfo.add("image0", popart.TensorInfo("FLOAT", [batchSize, nInChans, 32, 32])) inputShapeInfo.add("image1", popart.TensorInfo("FLOAT", [batchSize, nInChans, 32, 32])) inputShapeInfo.add("label", popart.TensorInfo("INT32", [batchSize])) inNames = ["image0", "image1"] cifarInIndices = {"image0": 0, "image1": 0, "label": 1} outNames = ["loss"] willowOptPatterns = popart.Patterns(popart.PatternsLevel.All) def nllloss(logprobs, targets): targets = targets.unsqueeze(1) loss = torch.gather(logprobs, 1, targets)
def __init__(self, torchModel, inputs, targets, losses, deviceInfo, batch_size=1, batches_per_step=1, inputShapeInfo=popart.InputShapeInfo(), patterns=popart.Patterns(), userOptions=popart.SessionOptions()): self.torchModel = torchModel self.batch_size = batch_size self.losses = losses self.deviceInfo = deviceInfo self.batches_per_step = batches_per_step self.inputShapeInfo = inputShapeInfo self.anchor_returns = {} self.inputs = tuple() self.targets = tuple() if isinstance(inputs, torch.Tensor): inputs = (inputs, ) if isinstance(targets, torch.Tensor): targets = (targets, ) for tensor in inputs: if (tensor.shape[0] != (self.batch_size * self.batches_per_step)) and (self.batch_size != 1): raise RuntimeError( f"Shape discrepancy in input tensor {tensor}, shape {tensor.shape}." + "Dim 0 should be equal to :" + f"batch size {self.batch_size} * bps {self.batches_per_step}" ) reshape = tensor.view(batches_per_step, batch_size, *list(tensor.shape[1:])) self.inputs = self.inputs + (reshape[0, :], ) for tensor in targets: reshape = tensor.view(batches_per_step, batch_size) self.targets = self.targets + (reshape[0, :], ) self.outputs = self.torchModel(*self.inputs) self.inputNames = [f"input_{i}" for i in range(len(self.inputs))] if isinstance(self.outputs, torch.Tensor): num_outputs = 1 else: num_outputs = len(self.outputs) self.outputNames = [f"output_{i}" for i in range(num_outputs)] proto = self.createProto() losses = [] for idx, (out, tgt) in enumerate(zip(self.outputNames, self.targets)): self.inputShapeInfo.add( f"target_{idx}", popart.TensorInfo(torch_to_popart_type(tgt.type()), list(tgt.shape))) losses.append( self.createLosses(self.losses, out, f"target_{idx}", f"loss_{idx}")) self.anchor_returns[out] = popart.AnchorReturnType("All") self.anchor_returns[f"loss_{idx}"] = popart.AnchorReturnType("All") if patterns is None: patterns = popart.Patterns() self.dataFlow = self.createdataFlow() super(InferenceSession, self).__init__(proto, self.dataFlow, self.deviceInfo, losses, self.inputShapeInfo, userOptions, patterns) self.replicationFactor = userOptions.replicatedGraphCount if \ userOptions.enableReplicatedGraphs else 1 self.accumulationFactor = userOptions.accumulationFactor if \ userOptions.enableGradientAccumulation else 1