def test_against_base_model(self): model_config = copy(_DUMMY_BOTTOM_CONFIG) model_config.update(_DUMMY_TOP_CONFIG) model_config.pop('num_interaction_inputs') ref_model = model.DlrmJointEmbedding(**model_config) ref_model.to("cuda") test_model = dist_model.DistDlrm(**model_config) test_model.to("cuda") print(test_model) # Copy weight to make to models identical test_model.bottom_model.joint_embedding.embedding.weight.data.copy_( ref_model.embeddings[0].embedding.weight) for i in range(len(test_model.bottom_model.bottom_mlp)): if isinstance(ref_model.bottom_mlp[i], nn.Linear): test_model.bottom_model.bottom_mlp[i].weight.data.copy_( ref_model.bottom_mlp[i].weight) test_model.bottom_model.bottom_mlp[i].bias.data.copy_( ref_model.bottom_mlp[i].bias) for i in range(len(test_model.top_model.top_mlp)): if isinstance(ref_model.bottom_mlp[i], nn.Linear): test_model.top_model.top_mlp[i].weight.data.copy_( ref_model.top_mlp[i].weight) test_model.top_model.top_mlp[i].bias.data.copy_( ref_model.top_mlp[i].bias) test_numerical_input = torch.randn(2, 13, device="cuda") test_sparse_inputs = torch.tensor([[1, 1], [2, 2]], device="cuda") # pylint:disable=not-callable test_top_out = test_model(test_numerical_input, test_sparse_inputs) ref_top_out = ref_model(test_numerical_input, test_sparse_inputs.t()) assert (test_top_out == ref_top_out).all()
def test_hash(self): # test creation test_model = model.DlrmJointEmbedding(**_DUMMY_CONFIG, hash_indices=True) test_model.to("cuda") # Test forward ref_numerical_input = torch.randn(2, 13, device="cuda") ref_sparse_inputs = torch.tensor([[1, 2], [2, 3]], device="cuda") # pylint:disable=not-callable ref = test_model(ref_numerical_input, ref_sparse_inputs) # Test indices that will be hashed to the same value as ref test_sparse_inputs = torch.tensor([[1, 7], [9, 3]], device="cuda") # pylint:disable=not-callable test_result = test_model(ref_numerical_input, test_sparse_inputs) assert (ref == test_result).all()
def test_against_base(self): torch.set_printoptions(precision=4, sci_mode=False) ref_model = model.Dlrm(**_DUMMY_CONFIG) test_model = model.DlrmJointEmbedding(**_DUMMY_CONFIG) ref_model.set_devices("cuda") test_model.to("cuda") # Copy model weight from ref_model test_model.embeddings[0].embedding.weight.data = torch.cat( [embedding.weight for embedding in ref_model.embeddings]).clone() test_module_dict = dict(test_model.named_modules()) for name, module in ref_model.named_modules(): if isinstance(module, torch.nn.Linear): test_module_dict[name].weight.data.copy_(module.weight) test_module_dict[name].bias.data.copy_(module.bias) test_numerical_input = torch.randn(3, 13, device="cuda") test_sparse_inputs = torch.randint(0, 3, (2, 3), device="cuda") # pylint:disable=not-callable ref_out = ref_model(test_numerical_input, test_sparse_inputs) test_out = test_model(test_numerical_input, test_sparse_inputs) assert (ref_out == test_out).all()