Example #1
0
    def testWrapModelLossFnStateDict(self):
        torch.manual_seed(1)
        device = torch.device("cuda")
        class LinearModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(2, 4)
            def forward(self, y=None, x=None):
                if y is not None:
                    return self.linear(x) + y
                else:
                    return self.linear(x) + torch.ones(2, 4)

        pt_model = LinearModel()
        data = torch.randn(2, 2)
        label = torch.tensor([0, 1], dtype=torch.int64)
        input_desc = IODescription('x', [2, 2], torch.float32)
        label_desc = IODescription('label', [2, ], torch.int64, num_classes=4)
        output_desc = IODescription('output', [2, 4], torch.float32)
        loss_desc = IODescription('loss', [], torch.float32)
        model_desc = ModelDescription([input_desc, label_desc], [loss_desc, output_desc])
        def loss_fn(x, label):
            return F.nll_loss(F.log_softmax(x, dim=1), label)

        def get_lr_this_step(global_step):
            learningRate = 0.02
            return torch.tensor([learningRate])

        ort_trainer = ORTTrainer(
            pt_model, loss_fn, model_desc, "SGDOptimizer", None,
            IODescription('Learning_Rate', [1, ], torch.float32), device,
            get_lr_this_step=get_lr_this_step)
        ort_trainer.train_step(x=data, label=label)
        state_dict = ort_trainer.state_dict()
        assert state_dict.keys() == {'linear.bias', 'linear.weight'}
Example #2
0
                     None,
                     model_desc,
                     "LambOptimizer",
                     None,
                     IODescription('Learning_Rate', [
                         1,
                     ], torch.float32),
                     device,
                     gradient_accumulation_steps=1,
                     world_rank=world_rank,
                     world_size=world_size,
                     use_mixed_precision=False,
                     allreduce_post_accumulation=True)
print('\nBuild ort model done.')

ort_sd = trainer.state_dict()
#print(ort_sd)

printSizes(torch_model, "PyTorch")
printSizes(trainer, "ORT")

compareModels(torch_model, trainer)

print(onnx.helper.printable_graph(model.graph))

print("one weights")
new_dict = setWeights(torch_model)
reset_model = Net()
reset_model.load_state_dict(new_dict)

compareModels(torch_model, reset_model)