def _interaction(self, bottom_mlp_output, embedding_outputs, batch_size): """Interaction "dot" interaction is a bit tricky to implement and test. Break it out from forward so that it can be tested independently. Args: bottom_mlp_output (Tensor): embedding_outputs (list): Sequence of tensors batch_size (int): """ concat = torch.cat([bottom_mlp_output] + embedding_outputs, dim=1) if self._interaction_op == "dot" and not self._self_interaction: concat = concat.view((batch_size, -1, self._embedding_dim)) if concat.dtype == torch.half: interaction_output = dotBasedInteract(concat, bottom_mlp_output) else: # Legacy path interaction = torch.bmm(concat, torch.transpose(concat, 1, 2)) tril_indices_row, tril_indices_col = torch.tril_indices( interaction.shape[1], interaction.shape[2], offset=-1) interaction_flat = interaction[:, tril_indices_row, tril_indices_col] # concatenate dense features and interactions zero_padding = torch.zeros( concat.shape[0], 1, dtype=concat.dtype, device=concat.device) interaction_output = torch.cat((bottom_mlp_output, interaction_flat, zero_padding), dim=1) elif self._interaction_op == "cat": interaction_output = concat else: raise NotImplementedError return interaction_output
def interact(self, bottom_output, bottom_mlp_output): """ :param bottom_output: [batch_size, 1 + #embeddings, embedding_dim] :param bottom_mlp_output :return: """ return dotBasedInteract(bottom_output, bottom_mlp_output)
def dot_interaction(bottom_mlp_output, embedding_outputs, batch_size): concat = torch.cat([bottom_mlp_output] + embedding_outputs, dim=1) concat = concat.view((batch_size, -1, EMBED_DIM)) if FLAGS.fp16: interaction_output = dotBasedInteract(concat, bottom_mlp_output) else: interaction = torch.bmm(concat, torch.transpose(concat, 1, 2)) tril_indices_row, tril_indices_col = torch.tril_indices( interaction.shape[1], interaction.shape[2], offset=-1) interaction_flat = interaction[:, tril_indices_row, tril_indices_col] padding = torch.empty(FLAGS.batch_size, 1, device="cuda") # concatenate dense features and interactions interaction_output = torch.cat([bottom_mlp_output] + [interaction_flat, padding], dim=1) return interaction_output
def dot_based_interact_test(num_rows, num_cols, batch_size, synthesize_mode, upstream_grad_synthesize_mode, direction, linear_output, decimal, max_value=MAX_INT_VALUE, verbose=VERBOSE): """Computes the forward and backward for custom dot and checks the result.""" # Input tensor configuration and initialization if synthesize_mode == 'seq': bottom_mlp_output_np = np.arange(batch_size * num_cols).reshape( batch_size, num_cols) bottom_mlp_output_np = bottom_mlp_output_np % max_value embedding_outputs_np = [] for i in range(num_rows - 1): # `num_rows` embedding and one MLP tmp = np.arange(batch_size * num_cols).reshape( batch_size, num_cols) tmp = tmp % max_value embedding_outputs_np.append(tmp) elif synthesize_mode == 'rand': bottom_mlp_output_np = np.random.randn(batch_size, num_cols) bottom_mlp_output_np = bottom_mlp_output_np * SCALE embedding_outputs_np = [] for i in range(num_rows - 1): tmp = np.random.randn(batch_size, num_cols) tmp = tmp * SCALE embedding_outputs_np.append(tmp) elif synthesize_mode == 'ones': bottom_mlp_output_np = np.ones((batch_size, num_cols)) embedding_outputs_np = [] for i in range(num_rows - 1): tmp = np.ones((batch_size, num_cols)) embedding_outputs_np.append(tmp) else: print('Invalid synthesize_mode {}'.format(synthesize_mode)) raise NotImplementedError # Identical inputs for reference and test ref_bottom_mlp_output = torch.Tensor( bottom_mlp_output_np).half().cuda().requires_grad_() test_bottom_mlp_output = torch.Tensor( bottom_mlp_output_np).half().cuda().requires_grad_() ref_embedding_outputs = [] test_embedding_outputs = [] for elem in embedding_outputs_np: ref_embedding_outputs.append( torch.Tensor(elem).half().cuda().requires_grad_()) test_embedding_outputs.append( torch.Tensor(elem).half().cuda().requires_grad_()) assert ref_bottom_mlp_output.shape == test_bottom_mlp_output.shape assert ref_bottom_mlp_output.shape[0] == batch_size assert ref_bottom_mlp_output.shape[1] == num_cols assert ref_embedding_outputs[0].shape == test_embedding_outputs[0].shape assert len(ref_embedding_outputs) == len(test_embedding_outputs) assert len(ref_embedding_outputs) == num_rows - 1 assert ref_embedding_outputs[0].shape[0] == batch_size assert ref_embedding_outputs[0].shape[1] == num_cols reference_input = torch.cat([ref_bottom_mlp_output] + ref_embedding_outputs, dim=1) test_input = torch.cat([test_bottom_mlp_output] + test_embedding_outputs, dim=1) reference_input = reference_input.view((batch_size, -1, num_cols)) test_input = test_input.view((batch_size, -1, num_cols)) assert reference_input.shape == test_input.shape assert reference_input.shape[0] == batch_size assert reference_input.shape[1] == num_rows assert reference_input.shape[2] == num_cols ref_pad = torch.zeros(batch_size, 1, dtype=ref_bottom_mlp_output.dtype, device=ref_bottom_mlp_output.device) # FWD path in reference interaction = torch.bmm(reference_input, torch.transpose(reference_input, 1, 2)) tril_indices_row = [ i for i in range(interaction.shape[1]) for j in range(i) ] tril_indices_col = [ j for i in range(interaction.shape[2]) for j in range(i) ] interaction_flat = interaction[:, tril_indices_row, tril_indices_col] reference_output = torch.cat( (ref_bottom_mlp_output, interaction_flat, ref_pad), dim=1) num_output_elems = (num_rows * (num_rows - 1) >> 1) + num_cols + PADDING_SIZE assert reference_output.shape[0] == batch_size assert reference_output.shape[1] == num_output_elems if linear_output: reference_output = torch.sum(reference_output, dim=1) reference_output = torch.sum(reference_output, dim=0) # New FWD path test_output = dotBasedInteract(test_input, test_bottom_mlp_output) if linear_output: test_output = torch.sum(test_output, dim=1) test_output = torch.sum(test_output, dim=0) assert test_output.shape == reference_output.shape # FWD path test if direction in ['fwd', "both"]: log(verbose, 'Starting FWD Test ...') print_differences(test_output.detach().cpu().numpy(), reference_output.detach().cpu().numpy(), decimal) np.testing.assert_almost_equal( test_output.detach().cpu().numpy(), desired=reference_output.detach().cpu().numpy(), decimal=decimal) log(verbose, 'FWD test ended successfully.') if direction == 'fwd': return # BWD path test_input.retain_grad() reference_input.retain_grad() if linear_output: reference_output.backward() test_output.backward() else: # Synthesize upstream gradient if upstream_grad_synthesize_mode == 'ones': upstream_grad = np.ones(reference_output.shape) elif upstream_grad_synthesize_mode == 'seq': upstream_grad = np.arange(reference_output.numel()).reshape( reference_output.shape) upstream_grad = upstream_grad % max_value elif upstream_grad_synthesize_mode == 'rand': upstream_grad = np.random.randn(reference_output.numel()).reshape( reference_output.shape) upstream_grad = upstream_grad * SCALE else: print('Invalid upstream_grad_synthesize_mode {}'.format( synthesize_mode)) raise NotImplementedError reference_upstream_grad = torch.Tensor(upstream_grad).half().cuda() test_upstream_grad = torch.Tensor(upstream_grad).half().cuda() reference_output.backward(reference_upstream_grad) test_output.backward(test_upstream_grad) log(verbose, 'Starting BWD Test ...') print_differences(test_input.grad.detach().cpu().numpy(), reference_input.grad.detach().cpu().numpy(), decimal) print_differences(test_bottom_mlp_output.grad.detach().cpu().numpy(), ref_bottom_mlp_output.grad.detach().cpu().numpy(), decimal) np.testing.assert_almost_equal( test_input.grad.detach().cpu().numpy(), desired=reference_input.grad.detach().cpu().numpy(), decimal=decimal) np.testing.assert_almost_equal( test_bottom_mlp_output.grad.detach().cpu().numpy(), desired=ref_bottom_mlp_output.grad.detach().cpu().numpy(), decimal=decimal) log(verbose, 'BWD test ended successfully.')