예제 #1
0
    def _interaction(self, bottom_mlp_output, embedding_outputs, batch_size):
        """Interaction

        "dot" interaction is a bit tricky to implement and test. Break it out from forward so that it can be tested
        independently.

        Args:
            bottom_mlp_output (Tensor):
            embedding_outputs (list): Sequence of tensors
            batch_size (int):
        """
        concat = torch.cat([bottom_mlp_output] + embedding_outputs, dim=1)
        if self._interaction_op == "dot" and not self._self_interaction:
            concat = concat.view((batch_size, -1, self._embedding_dim))
            if concat.dtype == torch.half:
                interaction_output = dotBasedInteract(concat, bottom_mlp_output)
            else:  # Legacy path
                interaction = torch.bmm(concat, torch.transpose(concat, 1, 2))
                tril_indices_row, tril_indices_col = torch.tril_indices(
                    interaction.shape[1], interaction.shape[2], offset=-1)
                interaction_flat = interaction[:, tril_indices_row, tril_indices_col]

                # concatenate dense features and interactions
                zero_padding = torch.zeros(
                    concat.shape[0], 1, dtype=concat.dtype, device=concat.device)
                interaction_output = torch.cat((bottom_mlp_output, interaction_flat, zero_padding), dim=1)

        elif self._interaction_op == "cat":
            interaction_output = concat
        else:
            raise NotImplementedError

        return interaction_output
 def interact(self, bottom_output, bottom_mlp_output):
     """
     :param bottom_output: [batch_size, 1 + #embeddings, embedding_dim]
     :param bottom_mlp_output
     :return:
     """
     return dotBasedInteract(bottom_output, bottom_mlp_output)
예제 #3
0
def dot_interaction(bottom_mlp_output, embedding_outputs, batch_size):
    concat = torch.cat([bottom_mlp_output] + embedding_outputs, dim=1)
    concat = concat.view((batch_size, -1, EMBED_DIM))
    if FLAGS.fp16:
        interaction_output = dotBasedInteract(concat, bottom_mlp_output)
    else:
        interaction = torch.bmm(concat, torch.transpose(concat, 1, 2))
        tril_indices_row, tril_indices_col = torch.tril_indices(
            interaction.shape[1], interaction.shape[2], offset=-1)
        interaction_flat = interaction[:, tril_indices_row, tril_indices_col]

        padding = torch.empty(FLAGS.batch_size, 1, device="cuda")
        # concatenate dense features and interactions
        interaction_output = torch.cat([bottom_mlp_output] +
                                       [interaction_flat, padding],
                                       dim=1)

    return interaction_output
def dot_based_interact_test(num_rows,
                            num_cols,
                            batch_size,
                            synthesize_mode,
                            upstream_grad_synthesize_mode,
                            direction,
                            linear_output,
                            decimal,
                            max_value=MAX_INT_VALUE,
                            verbose=VERBOSE):
    """Computes the forward and backward for custom dot and checks the result."""
    # Input tensor configuration and initialization
    if synthesize_mode == 'seq':
        bottom_mlp_output_np = np.arange(batch_size * num_cols).reshape(
            batch_size, num_cols)
        bottom_mlp_output_np = bottom_mlp_output_np % max_value
        embedding_outputs_np = []
        for i in range(num_rows - 1):  # `num_rows` embedding and one MLP
            tmp = np.arange(batch_size * num_cols).reshape(
                batch_size, num_cols)
            tmp = tmp % max_value
            embedding_outputs_np.append(tmp)
    elif synthesize_mode == 'rand':
        bottom_mlp_output_np = np.random.randn(batch_size, num_cols)
        bottom_mlp_output_np = bottom_mlp_output_np * SCALE
        embedding_outputs_np = []
        for i in range(num_rows - 1):
            tmp = np.random.randn(batch_size, num_cols)
            tmp = tmp * SCALE
            embedding_outputs_np.append(tmp)
    elif synthesize_mode == 'ones':
        bottom_mlp_output_np = np.ones((batch_size, num_cols))
        embedding_outputs_np = []
        for i in range(num_rows - 1):
            tmp = np.ones((batch_size, num_cols))
            embedding_outputs_np.append(tmp)
    else:
        print('Invalid synthesize_mode {}'.format(synthesize_mode))
        raise NotImplementedError

    # Identical inputs for reference and test
    ref_bottom_mlp_output = torch.Tensor(
        bottom_mlp_output_np).half().cuda().requires_grad_()
    test_bottom_mlp_output = torch.Tensor(
        bottom_mlp_output_np).half().cuda().requires_grad_()

    ref_embedding_outputs = []
    test_embedding_outputs = []
    for elem in embedding_outputs_np:
        ref_embedding_outputs.append(
            torch.Tensor(elem).half().cuda().requires_grad_())
        test_embedding_outputs.append(
            torch.Tensor(elem).half().cuda().requires_grad_())

    assert ref_bottom_mlp_output.shape == test_bottom_mlp_output.shape
    assert ref_bottom_mlp_output.shape[0] == batch_size
    assert ref_bottom_mlp_output.shape[1] == num_cols

    assert ref_embedding_outputs[0].shape == test_embedding_outputs[0].shape
    assert len(ref_embedding_outputs) == len(test_embedding_outputs)
    assert len(ref_embedding_outputs) == num_rows - 1
    assert ref_embedding_outputs[0].shape[0] == batch_size
    assert ref_embedding_outputs[0].shape[1] == num_cols

    reference_input = torch.cat([ref_bottom_mlp_output] +
                                ref_embedding_outputs,
                                dim=1)
    test_input = torch.cat([test_bottom_mlp_output] + test_embedding_outputs,
                           dim=1)

    reference_input = reference_input.view((batch_size, -1, num_cols))
    test_input = test_input.view((batch_size, -1, num_cols))

    assert reference_input.shape == test_input.shape
    assert reference_input.shape[0] == batch_size
    assert reference_input.shape[1] == num_rows
    assert reference_input.shape[2] == num_cols

    ref_pad = torch.zeros(batch_size,
                          1,
                          dtype=ref_bottom_mlp_output.dtype,
                          device=ref_bottom_mlp_output.device)

    # FWD path in reference
    interaction = torch.bmm(reference_input,
                            torch.transpose(reference_input, 1, 2))
    tril_indices_row = [
        i for i in range(interaction.shape[1]) for j in range(i)
    ]
    tril_indices_col = [
        j for i in range(interaction.shape[2]) for j in range(i)
    ]
    interaction_flat = interaction[:, tril_indices_row, tril_indices_col]
    reference_output = torch.cat(
        (ref_bottom_mlp_output, interaction_flat, ref_pad), dim=1)

    num_output_elems = (num_rows *
                        (num_rows - 1) >> 1) + num_cols + PADDING_SIZE
    assert reference_output.shape[0] == batch_size
    assert reference_output.shape[1] == num_output_elems

    if linear_output:
        reference_output = torch.sum(reference_output, dim=1)
        reference_output = torch.sum(reference_output, dim=0)

    # New FWD path
    test_output = dotBasedInteract(test_input, test_bottom_mlp_output)

    if linear_output:
        test_output = torch.sum(test_output, dim=1)
        test_output = torch.sum(test_output, dim=0)

    assert test_output.shape == reference_output.shape
    # FWD path test
    if direction in ['fwd', "both"]:
        log(verbose, 'Starting FWD Test ...')
        print_differences(test_output.detach().cpu().numpy(),
                          reference_output.detach().cpu().numpy(), decimal)
        np.testing.assert_almost_equal(
            test_output.detach().cpu().numpy(),
            desired=reference_output.detach().cpu().numpy(),
            decimal=decimal)
        log(verbose, 'FWD test ended successfully.')
    if direction == 'fwd':
        return

    # BWD path
    test_input.retain_grad()
    reference_input.retain_grad()
    if linear_output:
        reference_output.backward()
        test_output.backward()
    else:
        # Synthesize upstream gradient
        if upstream_grad_synthesize_mode == 'ones':
            upstream_grad = np.ones(reference_output.shape)
        elif upstream_grad_synthesize_mode == 'seq':
            upstream_grad = np.arange(reference_output.numel()).reshape(
                reference_output.shape)
            upstream_grad = upstream_grad % max_value
        elif upstream_grad_synthesize_mode == 'rand':
            upstream_grad = np.random.randn(reference_output.numel()).reshape(
                reference_output.shape)
            upstream_grad = upstream_grad * SCALE
        else:
            print('Invalid upstream_grad_synthesize_mode {}'.format(
                synthesize_mode))
            raise NotImplementedError

        reference_upstream_grad = torch.Tensor(upstream_grad).half().cuda()
        test_upstream_grad = torch.Tensor(upstream_grad).half().cuda()
        reference_output.backward(reference_upstream_grad)
        test_output.backward(test_upstream_grad)

        log(verbose, 'Starting BWD Test ...')
        print_differences(test_input.grad.detach().cpu().numpy(),
                          reference_input.grad.detach().cpu().numpy(), decimal)
        print_differences(test_bottom_mlp_output.grad.detach().cpu().numpy(),
                          ref_bottom_mlp_output.grad.detach().cpu().numpy(),
                          decimal)
        np.testing.assert_almost_equal(
            test_input.grad.detach().cpu().numpy(),
            desired=reference_input.grad.detach().cpu().numpy(),
            decimal=decimal)
        np.testing.assert_almost_equal(
            test_bottom_mlp_output.grad.detach().cpu().numpy(),
            desired=ref_bottom_mlp_output.grad.detach().cpu().numpy(),
            decimal=decimal)
        log(verbose, 'BWD test ended successfully.')