コード例 #1
0
    def test_against_base_model(self):
        model_config = copy(_DUMMY_BOTTOM_CONFIG)
        model_config.update(_DUMMY_TOP_CONFIG)
        model_config.pop('num_interaction_inputs')
        ref_model = model.DlrmJointEmbedding(**model_config)
        ref_model.to("cuda")

        test_model = dist_model.DistDlrm(**model_config)
        test_model.to("cuda")
        print(test_model)

        # Copy weight to make to models identical
        test_model.bottom_model.joint_embedding.embedding.weight.data.copy_(
            ref_model.embeddings[0].embedding.weight)
        for i in range(len(test_model.bottom_model.bottom_mlp)):
            if isinstance(ref_model.bottom_mlp[i], nn.Linear):
                test_model.bottom_model.bottom_mlp[i].weight.data.copy_(
                    ref_model.bottom_mlp[i].weight)
                test_model.bottom_model.bottom_mlp[i].bias.data.copy_(
                    ref_model.bottom_mlp[i].bias)
        for i in range(len(test_model.top_model.top_mlp)):
            if isinstance(ref_model.bottom_mlp[i], nn.Linear):
                test_model.top_model.top_mlp[i].weight.data.copy_(
                    ref_model.top_mlp[i].weight)
                test_model.top_model.top_mlp[i].bias.data.copy_(
                    ref_model.top_mlp[i].bias)

        test_numerical_input = torch.randn(2, 13, device="cuda")
        test_sparse_inputs = torch.tensor([[1, 1], [2, 2]], device="cuda")  # pylint:disable=not-callable
        test_top_out = test_model(test_numerical_input, test_sparse_inputs)
        ref_top_out = ref_model(test_numerical_input, test_sparse_inputs.t())
        assert (test_top_out == ref_top_out).all()
コード例 #2
0
    def test_hash(self):
        # test creation
        test_model = model.DlrmJointEmbedding(**_DUMMY_CONFIG,
                                              hash_indices=True)
        test_model.to("cuda")

        # Test forward
        ref_numerical_input = torch.randn(2, 13, device="cuda")
        ref_sparse_inputs = torch.tensor([[1, 2], [2, 3]], device="cuda")  # pylint:disable=not-callable
        ref = test_model(ref_numerical_input, ref_sparse_inputs)

        # Test indices that will be hashed to the same value as ref
        test_sparse_inputs = torch.tensor([[1, 7], [9, 3]], device="cuda")  # pylint:disable=not-callable
        test_result = test_model(ref_numerical_input, test_sparse_inputs)

        assert (ref == test_result).all()
コード例 #3
0
    def test_against_base(self):
        torch.set_printoptions(precision=4, sci_mode=False)
        ref_model = model.Dlrm(**_DUMMY_CONFIG)
        test_model = model.DlrmJointEmbedding(**_DUMMY_CONFIG)
        ref_model.set_devices("cuda")
        test_model.to("cuda")

        # Copy model weight from ref_model
        test_model.embeddings[0].embedding.weight.data = torch.cat(
            [embedding.weight for embedding in ref_model.embeddings]).clone()
        test_module_dict = dict(test_model.named_modules())
        for name, module in ref_model.named_modules():
            if isinstance(module, torch.nn.Linear):
                test_module_dict[name].weight.data.copy_(module.weight)
                test_module_dict[name].bias.data.copy_(module.bias)

        test_numerical_input = torch.randn(3, 13, device="cuda")
        test_sparse_inputs = torch.randint(0, 3, (2, 3), device="cuda")  # pylint:disable=not-callable

        ref_out = ref_model(test_numerical_input, test_sparse_inputs)
        test_out = test_model(test_numerical_input, test_sparse_inputs)
        assert (ref_out == test_out).all()