def test_lstm_export_with_constantofshape(tmpdir):
    np.random.seed(42)
    torch.manual_seed(43)

    class RNNNet(torch.nn.Module):
        def __init__(self):
            super(RNNNet, self).__init__()

            hidden_size = 8
            input_size = 18

            self.lstm = torch.nn.LSTM(input_size=input_size,
                                      hidden_size=hidden_size,
                                      batch_first=True)

        def forward(self, x):
            x, (h, c) = self.lstm(x)
            return x

    net = RNNNet()
    np_data = np.random.rand(1, 100, 18).astype(np.float32)
    torch_data = torch.from_numpy(np_data)
    torchOutput = net(torch_data).detach().numpy()

    export_name = str(tmpdir / "lstm_small_repro.onnx")

    torch.onnx.export(net,
                      torch_data,
                      export_name,
                      verbose=True,
                      input_names=['data'],
                      output_names=['tag'])

    # Verify this model contains a ConstantOfShape op.
    model = onnx.load(export_name)
    nodes = model.graph.node
    nodes = [i for i in nodes if i.op_type == 'ConstantOfShape']
    assert len(nodes) > 0

    inputShapeInfo = popart.InputShapeInfo()
    inputShapeInfo.add("data", popart.TensorInfo("FLOAT", [1, 100, 18]))

    anchors = {"tag": popart.AnchorReturnType("All")}
    dataFlow = popart.DataFlow(1, anchors)
    device = tu.create_test_device()

    session = popart.InferenceSession(export_name,
                                      dataFlow,
                                      device,
                                      inputShapeInfo=inputShapeInfo)

    session.prepareDevice()

    inferenceAnchors = session.initAnchorArrays()
    stepio = popart.PyStepIO({"data": np_data}, inferenceAnchors)
    session.run(stepio)
    popartOutput = inferenceAnchors['tag']

    assert torchOutput.shape == popartOutput.shape
    assert np.allclose(torchOutput, popartOutput, atol=1e-07)
Exemple #2
0
def test_model_with_specified_dim_params(tmpdir):
    proto, outId = get_model_with_dim_param(tmpdir)

    inputShapeInfo = popart.InputShapeInfo()
    inputShapeInfo.add("input_0", popart.TensorInfo("FLOAT", [1, 2, 32, 32]))

    session = popart.InferenceSession(
        fnModel=proto,
        dataFlow=popart.DataFlow(1, {outId: popart.AnchorReturnType("All")}),
        deviceInfo=tu.create_test_device(),
        inputShapeInfo=inputShapeInfo)
Exemple #3
0
    def __init__(
            self,
            fnModel: bytes,
            dataFlow: Dict[int, Dict],
            deviceInfo: popart.DeviceInfo,
            inputShapeInfo: popart.InputShapeInfo = popart.InputShapeInfo(),
            patterns: popart.Patterns = None,
            userOptions: popart.SessionOptions = popart.SessionOptions(),
            name: str = "inference") -> None:

        if patterns == None:
            patterns = popart.Patterns()

        super(InferenceSession,
              self).__init__(fnModel, dataFlow, deviceInfo, inputShapeInfo,
                             userOptions, patterns, name)
Exemple #4
0
    def __init__(
            self,
            fnModel: bytes,
            dataFlow: Dict[int, Dict],
            loss: "",
            optimizer: popart.Optimizer,
            deviceInfo: popart.DeviceInfo,
            inputShapeInfo: popart.InputShapeInfo = popart.InputShapeInfo(),
            patterns: popart.Patterns = None,
            userOptions: popart.SessionOptions = popart.SessionOptions(),
            name: str = "training") -> None:

        if patterns is None:
            patterns = popart.Patterns()

        super(TrainingSession,
              self).__init__(fnModel, dataFlow, loss, optimizer, deviceInfo,
                             inputShapeInfo, userOptions, patterns, name)
Exemple #5
0
    def __init__(
            self,
            fnModel: bytes,
            dataFlow: Dict[int, Dict],
            deviceInfo: popart.DeviceInfo,
            inputShapeInfo: popart.InputShapeInfo = popart.InputShapeInfo(),
            patterns: popart.Patterns = None,
            userOptions: popart.SessionOptions = popart.SessionOptions()
    ) -> None:

        if patterns == None:
            patterns = popart.Patterns()

        super(InferenceSession,
              self).__init__(fnModel, dataFlow, deviceInfo, inputShapeInfo,
                             userOptions, patterns)

        self.dataFlow = dataFlow
        self.replicationFactor = userOptions.replicatedGraphCount if \
            userOptions.enableReplicatedGraphs else 1
        self.accumulationFactor = userOptions.accumulationFactor if \
            userOptions.enableGradientAccumulation else 1
Exemple #6
0
from popart.torch import torchwriter
#we require torch in this file to create the torch Module
import torch

args = cmdline.parse()

nInChans = 3
nOutChans = 10
batchSize = 2
batchesPerStep = 4
anchors = {
    "loss": popart.AnchorReturnType("EveryN", 2),
    "image0": popart.AnchorReturnType("All")
}
dataFlow = popart.DataFlow(batchesPerStep, anchors)
inputShapeInfo = popart.InputShapeInfo()
inputShapeInfo.add("image0",
                   popart.TensorInfo("FLOAT", [batchSize, nInChans, 32, 32]))
inputShapeInfo.add("image1",
                   popart.TensorInfo("FLOAT", [batchSize, nInChans, 32, 32]))
inputShapeInfo.add("label", popart.TensorInfo("INT32", [batchSize]))
inNames = ["image0", "image1"]
cifarInIndices = {"image0": 0, "image1": 0, "label": 1}
outNames = ["loss"]

willowOptPatterns = popart.Patterns(popart.PatternsLevel.All)


def nllloss(logprobs, targets):
    targets = targets.unsqueeze(1)
    loss = torch.gather(logprobs, 1, targets)
Exemple #7
0
    def __init__(self,
                 torchModel,
                 inputs,
                 targets,
                 losses,
                 deviceInfo,
                 batch_size=1,
                 batches_per_step=1,
                 inputShapeInfo=popart.InputShapeInfo(),
                 patterns=popart.Patterns(),
                 userOptions=popart.SessionOptions()):

        self.torchModel = torchModel
        self.batch_size = batch_size
        self.losses = losses
        self.deviceInfo = deviceInfo
        self.batches_per_step = batches_per_step
        self.inputShapeInfo = inputShapeInfo
        self.anchor_returns = {}

        self.inputs = tuple()
        self.targets = tuple()

        if isinstance(inputs, torch.Tensor):
            inputs = (inputs, )
        if isinstance(targets, torch.Tensor):
            targets = (targets, )

        for tensor in inputs:
            if (tensor.shape[0] !=
                (self.batch_size * self.batches_per_step)) and (self.batch_size
                                                                != 1):
                raise RuntimeError(
                    f"Shape discrepancy in input tensor {tensor}, shape {tensor.shape}."
                    + "Dim 0 should be equal to :" +
                    f"batch size {self.batch_size} * bps {self.batches_per_step}"
                )
            reshape = tensor.view(batches_per_step, batch_size,
                                  *list(tensor.shape[1:]))
            self.inputs = self.inputs + (reshape[0, :], )
        for tensor in targets:
            reshape = tensor.view(batches_per_step, batch_size)
            self.targets = self.targets + (reshape[0, :], )

        self.outputs = self.torchModel(*self.inputs)

        self.inputNames = [f"input_{i}" for i in range(len(self.inputs))]
        if isinstance(self.outputs, torch.Tensor):
            num_outputs = 1
        else:
            num_outputs = len(self.outputs)
        self.outputNames = [f"output_{i}" for i in range(num_outputs)]

        proto = self.createProto()

        losses = []
        for idx, (out, tgt) in enumerate(zip(self.outputNames, self.targets)):
            self.inputShapeInfo.add(
                f"target_{idx}",
                popart.TensorInfo(torch_to_popart_type(tgt.type()),
                                  list(tgt.shape)))

            losses.append(
                self.createLosses(self.losses, out, f"target_{idx}",
                                  f"loss_{idx}"))
            self.anchor_returns[out] = popart.AnchorReturnType("All")
            self.anchor_returns[f"loss_{idx}"] = popart.AnchorReturnType("All")

        if patterns is None:
            patterns = popart.Patterns()

        self.dataFlow = self.createdataFlow()

        super(InferenceSession,
              self).__init__(proto, self.dataFlow, self.deviceInfo, losses,
                             self.inputShapeInfo, userOptions, patterns)

        self.replicationFactor = userOptions.replicatedGraphCount if \
            userOptions.enableReplicatedGraphs else 1
        self.accumulationFactor = userOptions.accumulationFactor if \
            userOptions.enableGradientAccumulation else 1