Ejemplo n.º 1
0
def test_layer_norm():
    sl = 10
    bs = 2
    in_features = 32
    inputs = to_gpu(V(tr.randn([sl, bs, in_features])))
    layernorm = to_gpu(LayerNorm(in_features))
    outputs = layernorm(inputs)
    assert_dims(outputs, [sl, bs, in_features])
Ejemplo n.º 2
0
def attention_setup(request):
    sl, bs = 3, 2
    edq, edk = request.param

    # query would be the hidden state of the decoder
    keys = to_gpu(V(T(np.random.rand(sl, bs, edk))))
    query = to_gpu(V(T(np.random.rand(bs, edq))))
    return keys, query
Ejemplo n.º 3
0
def test_transfomer_layer():
    sl = 10
    bs = 2
    in_features = 32
    inputs = tr.randn([sl, bs, in_features])
    inputs = to_gpu(V(T(inputs)))
    transfomer = to_gpu(TransformerLayer(in_features=in_features, num_heads=8))
    outputs = transfomer(inputs)
    assert_dims(outputs, [sl, bs, in_features])
Ejemplo n.º 4
0
def test_transformer_encoder():
    sl = 10
    bs = 2
    in_features = 300
    num_layers = 5
    inputs = tr.randn([sl, bs, in_features])
    inputs = to_gpu(V(T(inputs)))
    transformer = to_gpu(
        TransformerEncoderLayers(input_size=in_features,
                                 num_heads=8,
                                 nhid=512,
                                 num_layers=num_layers))
    layer_outputs = transformer(inputs)
    assert_dims(layer_outputs, [num_layers, sl, bs, in_features])
Ejemplo n.º 5
0
 def reparameterize(self, mu, logvar):
     if self.training:
         std = torch.exp(0.5 * logvar)
         eps = to_gpu(V(torch.randn(self.latent_dim)))
         return mu + eps * std
     else:
         return mu
Ejemplo n.º 6
0
def test_transformer_decoder(num_beams, decoder_inputs_transformer):
    batch_size, emb_size, nlayers, sl, vin, ven = decoder_inputs_transformer
    ntokens, nhid, max_tokens = 10, 2, 20
    embedding = TransformerEmbeddings(ntokens=ntokens,
                                      emb_size=emb_size,
                                      dropout=0.0,
                                      pad_token=1)

    encoder = TransformerDecoderLayers(nlayers=nlayers,
                                       input_size=emb_size,
                                       num_heads=2,
                                       nhid=emb_size)
    projection_layer = Projection(output_size=ntokens,
                                  input_size=emb_size,
                                  tie_encoder=None,
                                  dropout=0.0)
    decoder = TransformerDecoder(decoder_layer=encoder,
                                 projection_layer=projection_layer,
                                 pad_token=1,
                                 eos_token=2,
                                 max_tokens=max_tokens,
                                 embedding_layer=embedding)
    decoder = to_gpu(decoder)
    outputs = decoder(vin, ven, num_beams=num_beams)
    if num_beams > 0:
        assert_dims(outputs,
                    [None, num_beams * batch_size, (emb_size, ntokens)])
        # actual beam outputs can be found in beam_outputs
        assert decoder.beam_outputs is not None
        assert_dims(decoder.beam_outputs, [None, batch_size, num_beams])
        # the sl can go up to max_tokens + 1(for the extra 0 token at the end)
        assert 0 < decoder.beam_outputs.shape[0] <= max_tokens + 1
    else:
        assert_dims(outputs, [None, batch_size, (emb_size, ntokens)])
        assert decoder.beam_outputs is None
Ejemplo n.º 7
0
    def _greedy_forward(self, inputs, hidden=None, constraints=None):
        inputs = inputs[:
                        1]  # inputs should be only first token initially [1,bs]
        sl, bs = inputs.size()
        finished = to_gpu(torch.zeros(bs).byte())
        iteration = 0
        self.beam_outputs = inputs.clone()
        layer_outputs = [[] for _ in range(self.nlayers)]
        while not finished.all() and iteration < self.max_iterations:
            # output should be List[[sl, bs, layer_dim], ...] sl should be one
            output = self.forward(inputs, hidden=hidden, num_beams=0)
            for layer_index in range(self.nlayers):
                layer_outputs[layer_index].append(output[layer_index])

            # step_inputs have shape [1,bs]
            _, step_inputs = output[-1][-1:].max(dim=-1)
            iteration += 1
            self.beam_outputs = assert_dims(
                torch.cat([self.beam_outputs, step_inputs], dim=0),
                [iteration + 1, bs])
            new_finished = step_inputs.data == self.eos_token
            inputs = torch.cat([inputs, step_inputs], dim=0)
            assert_dims(inputs, [iteration + 1, bs])
            finished = finished | new_finished

        self.beam_outputs = self.beam_outputs.view(-1, bs, 1)
        outputs = [torch.cat(i, dim=0) for i in layer_outputs]
        return outputs
Ejemplo n.º 8
0
 def to_model(self, m, opt_fn):
     model = CVAEModel(to_gpu(m))
     learner = EncoderDecoderLearner(self, model, opt_fn=opt_fn)
     learner.crit = partial(
         cvae_loss,
         pad_idx=learner.data.pad_idx)  # change loss to auxiliary loss
     return learner
Ejemplo n.º 9
0
    def cvae_loss_sigmoid(input, target, step=0, max_kld_step=None, **kwargs):
        predictions, recog_mu, recog_log_var, prior_mu, prior_log_var, bow_logits = input
        vocab = predictions.size(-1)
        # dims are sq-1 times bs times vocab
        dec_input = predictions[:target.size(0)].view(-1, vocab).contiguous()
        bow_targets = torch.zeros_like(bow_logits).scatter(
            1, target.transpose(1, 0), 1)
        # mask pad token
        weights = to_gpu(V(torch.ones(bow_logits.size(-1)).unsqueeze_(0)))
        weights[0, pad_idx] = 0
        bow_loss = F.binary_cross_entropy_with_logits(bow_logits,
                                                      bow_targets,
                                                      weight=weights)

        # targets are sq-1 times bs (one label for every word)
        kld_loss = gaussian_kld(recog_mu, recog_log_var, prior_mu,
                                prior_log_var)
        target = target.view(-1).contiguous()
        decoder_loss = F.cross_entropy(
            input=dec_input,
            target=target,
            ignore_index=pad_idx,
        )
        kld_weight = 1.0 if max_kld_step is None else min(
            (step + 1) / max_kld_step, 1)
        nonlocal STEP
        if step > STEP:
            if step == 0: STEP = 0
            print(
                f"losses: decoder {decoder_loss}, bow: {bow_loss}, kld x weight: {kld_loss} x {kld_weight}"
            )
            STEP += 1

        return decoder_loss + bow_loss + kld_loss * kld_weight
Ejemplo n.º 10
0
    def _greedy_forward(self, inputs, hidden=None, constraints=None):
        dec_inputs = inputs
        max_iterations = min(dec_inputs.size(0), self.MAX_STEPS_ALLOWED) if self.training else self.max_iterations
        inputs = V(inputs[:1].data)  # inputs should be only first token initially [1,bs]
        sl, bs = inputs.size()
        finished = to_gpu(torch.zeros(bs).byte())
        iteration = 0
        self.beam_outputs = inputs.clone()
        final_outputs = []
        while not finished.all() and iteration < max_iterations:
            # output should be List[[sl, bs, layer_dim], ...] sl should be one
            if 0 < iteration and self.training and 0. < self.random() < self.pr_force:
                inputs = dec_inputs[iteration].unsqueeze(0)
            output = self.forward(inputs, hidden=hidden, num_beams=0, constraints=constraints)
            hidden = self.decoder_layer.hidden
            final_outputs.append(output)  # dim should be [sl=1, bs, nt]
            #  inputs are the indices  dims [1,bs] # repackage the var to avoid grad backwards
            inputs = assert_dims(V(output.data.max(dim=-1)[1]), [1, bs])
            iteration += 1
            self.beam_outputs = assert_dims(torch.cat([self.beam_outputs, inputs], dim=0), [iteration + 1, bs])
            new_finished = inputs.data == self.eos_token
            finished = finished | new_finished
            # stop if the output is to big to fit in memory

        self.beam_outputs = self.beam_outputs.view(-1, bs, 1)
        # outputs should be [sl, bs, nt]
        outputs = torch.cat(final_outputs, dim=0)
        return outputs
Ejemplo n.º 11
0
    def _greedy_forward(self, inputs):
        inputs = inputs[:
                        1]  # inputs should be only first token initially [1,bs]
        sl, bs = inputs.size()
        finished = to_gpu(torch.zeros(bs).byte())
        iteration = 0
        self.beam_outputs = inputs.clone()
        layer_outputs = [[] for _ in range(self.nlayers)]
        raw_layer_outputs = [[] for _ in range(self.nlayers)]
        while not finished.all() and iteration < self.max_iterations:
            # output should be List[[sl, bs, layer_dim], ...] sl should be one
            raw_output, output = self.forward(inputs, 0)
            for layer_index in range(self.nlayers):
                layer_outputs[layer_index].append(output[layer_index])
                raw_layer_outputs[layer_index].append(raw_output[layer_index])

            #  inputs are the indices  dims [1,bs]
            _, inputs = output[-1].max(dim=-1)
            assert_dims(inputs, [1, bs])
            iteration += 1
            self.beam_outputs = assert_dims(
                torch.cat([self.beam_outputs, inputs], dim=0),
                [iteration + 1, bs])
            new_finished = inputs.data == self.eos_token
            finished = finished | new_finished

        self.beam_outputs = self.beam_outputs.view(-1, bs, 1)
        # ensure the outputs are a list of layers where each layer is [sl,bs,layerdim]
        raw_outputs = [torch.cat(i, dim=0) for i in raw_layer_outputs]
        outputs = [torch.cat(i, dim=0) for i in layer_outputs]
        return raw_outputs, outputs
Ejemplo n.º 12
0
def test_MPLPAttention(attention_setup):
    keys, query = attention_setup
    ed = keys.size(2)
    bs = query.size(0)
    in_features = keys.size(2) + query.size(1)
    attention = to_gpu(MLPAttention(in_features=in_features, nhid=200))
    result = attention(query=V(query), keys=V(keys), values=V(keys))
    assert (bs, ed) == result.shape
Ejemplo n.º 13
0
    def _topk_forward(self, inputs, hidden, num_beams, constraints=None):
        sl, bs = inputs.size()
        # initial logprobs should be zero (pr of <sos> token in the start is 1)
        logprobs = torch.zeros_like(inputs[:1]).view(
            1, bs, 1).float()  # shape will be [sl, bs, 1]
        inputs = inputs[:1].repeat(
            1, num_beams
        )  # inputs should be only first token initially [1,bs x num_beams]
        finished = to_gpu(torch.zeros(bs * num_beams).byte())
        iteration = 0
        layer_outputs = [[] for _ in range(self.nlayers)]
        self.beam_outputs = inputs.clone()
        hidden = repeat_cell_state(hidden, num_beams)
        while not finished.all() and iteration < self.max_iterations:
            # output should be List[[sl, bs * num_beams, layer_dim], ...] sl should be one
            output = self.forward(inputs, hidden=hidden, num_beams=0)
            for layer_index in range(self.nlayers):
                layer_outputs[layer_index].append(output[layer_index])

            # we take the output of the last layer with dims [1, bs, output_dim]
            # and get the indices of th top k for every bs
            new_logprobs = F.log_softmax(output[-1][-1:],
                                         dim=-1)  # [1, bs x num_beams, nt]
            num_tokens = new_logprobs.size(2)
            new_logprobs = new_logprobs.view(1, bs, num_beams,
                                             num_tokens) + logprobs.unsqueeze(
                                                 -1)  # [1, bs, nb, nt]
            # mask logprobs if they are finished or it's the first iteration
            new_logprobs = self.mask_logprobs(bs, finished, iteration,
                                              logprobs, new_logprobs,
                                              num_beams, num_tokens)

            # TODO take into account sequence_length for getting the top logprobs and their indices
            logprobs, beams = torch.topk(new_logprobs, k=num_beams,
                                         dim=-1)  # [1, bs, num_beams]
            parents = beams / num_tokens
            step_inputs = beams % num_tokens
            parent_indices = reshape_parent_indices(parents.view(-1),
                                                    bs=bs,
                                                    num_beams=num_beams)
            finished = torch.index_select(finished, 0, parent_indices.data)
            step_inputs = step_inputs.view(1, -1).contiguous()

            self.beam_outputs = torch.index_select(self.beam_outputs,
                                                   dim=1,
                                                   index=parent_indices)
            self.beam_outputs = torch.cat([self.beam_outputs, step_inputs],
                                          dim=0)
            new_finished = (step_inputs.data == self.eos_token).view(-1)
            inputs = torch.index_select(inputs, dim=1, index=parent_indices)
            inputs = torch.cat([inputs, step_inputs], dim=0)
            finished = finished | new_finished
            iteration += 1

        # ensure the outputs are a list of layers where each layer is [sl,bs,layerdim]
        outputs = [torch.cat(i, dim=0) for i in layer_outputs]
        self.beam_outputs = self.beam_outputs.view(-1, bs, num_beams)
        return outputs
Ejemplo n.º 14
0
 def get_transformer_model(self, opt_fn, emb_sz, max_seq_len, **kwargs):
     m = get_transformer_language_model(self.n_tok,
                                        max_seq_len,
                                        self.trn_dl.target_length,
                                        emb_sz,
                                        pad_token=self.pad_idx,
                                        **kwargs)
     model = TransformerLanguageModel(to_gpu(m))
     return TransformerLearner(self, model, opt_fn=opt_fn)
Ejemplo n.º 15
0
def test_cell(cell_type, hidden_type):
    sl, bs, input_size, output_size = 8, 10, 12, 14
    cell = Cell(cell_type, input_size, output_size, dropout=0.0, wdrop=0.0)
    cell = to_gpu(cell)
    inputs = V(tr.rand(sl, bs, input_size))
    hidden = cell.hidden_state(bs)
    outputs, hidden = cell(inputs, hidden)
    assert (sl, bs, output_size) == outputs.shape
    assert isinstance(hidden, hidden_type)
Ejemplo n.º 16
0
def get_predictions(model: SequentialRNN, input_field: Field, prepared_input: Union[List[str], List[List[str]]],
                    max_n_predictions: int) -> (Variable, Variable):
    t = to_gpu(input_field.numericalize(prepared_input, -1))
    res, *_ = model(t)
    last_res = res[-1]
    n_predictions = min(max_n_predictions, last_res.size()[0])
    outputs, labels = torch.topk(last_res, n_predictions)
    probs = F.softmax(outputs)
    return probs, labels
Ejemplo n.º 17
0
def test_attention_layer():
    sl = 2
    bs = 2
    in_features = 32
    tr.random.manual_seed(0)
    inputs = to_gpu(V(tr.randn([sl, bs, in_features])))
    layer = to_gpu(AttentionLayer(input_size=32, num_heads=4, dropout=0.0))
    outputs1 = layer(inputs, inputs, inputs, mask=True)
    assert_dims(outputs1, [sl, bs, in_features])

    outputs2 = layer(inputs[:1], inputs[:1], inputs[:1])
    assert_dims(outputs2, [1, bs, in_features])
    assert ((outputs1[0] - outputs2[0]).abs() < 1E-6).all()

    outputs = layer(inputs, inputs, inputs, mask=False)
    assert_dims(outputs, [sl, bs, in_features])
    assert (outputs[0] != outputs1[0]).all()
    assert (outputs[0] != outputs2[0]).all()
Ejemplo n.º 18
0
def test_transfomer_layer_decoder():
    sl = 10
    bs = 2
    in_features = 32
    tr.random.manual_seed(0)
    encoder_inputs = tr.randn([sl, bs, in_features])
    decoder_inputs = tr.randn([sl, bs, in_features])
    encoder_inputs = to_gpu(V(T(encoder_inputs)))
    decoder_inputs = to_gpu(V(T(decoder_inputs)))
    transformer = to_gpu(
        TransformerLayerDecoder(input_size=in_features,
                                num_heads=8,
                                nhid=64,
                                dropout=0))
    outputs = transformer(encoder_inputs, decoder_inputs)
    assert_dims(outputs, [sl, bs, in_features])
    outputs1 = transformer(encoder_inputs, decoder_inputs[:1])
    assert_dims(outputs1, [1, bs, in_features])
    assert ((outputs[0] - outputs1[0]).abs() < 1E-6).all()
Ejemplo n.º 19
0
def gen_text(learner: RNN_Learner, starting_words_list: List[str], how_many_to_gen: int) -> List[str]:
    text = []
    t = to_gpu(learner.text_field.numericalize([starting_words_list], -1))
    res, *_ = learner.model(t)
    for i in range(how_many_to_gen):
        n = torch.multinomial(res[-1].exp(), 1)
        # n = n[1] if n.data[0] == 0 else n[0]
        text.append(learner.text_field.vocab.itos[n.data[0]])
        res, *_ = learner.model(n[0].unsqueeze(0))
    return text
Ejemplo n.º 20
0
def test_MultiHeadAttention_with_mask(self_attention_setup):
    keys, query = self_attention_setup
    slk, bs, ek = keys.size()
    slq, bs, eq = query.size()
    num_heads = 4
    nhid = 10
    attention = to_gpu(
        MultiHeadAttention(num_heads=num_heads, nhid=nhid, keys_dim=ek, query_dim=eq, values_dim=ek, dropout=0.3))
    mask = T(np.tril(np.ones((bs, num_heads, slq, slk)))).float()
    result = attention(query=V(query), keys=V(keys), values=V(keys), mask=mask)
    assert_dims(result, [slq, bs, num_heads * nhid])
Ejemplo n.º 21
0
def test_MultiHeadAttention(self_attention_setup):
    keys, query = self_attention_setup
    slk, bs, ek = keys.size()
    slq, bs, eq = query.size()
    num_heads = 4
    nhid = 10
    attention = to_gpu(
        MultiHeadAttention(num_heads=num_heads, nhid=nhid, keys_dim=ek, query_dim=eq, values_dim=ek, dropout=0.3))

    result = attention(query=V(query), keys=V(keys), values=V(keys))
    assert_dims(result, [slq, bs, num_heads * nhid])
Ejemplo n.º 22
0
def test_SDPAttention(attention_setup):
    keys, query = attention_setup
    bs = query.size(0)
    ed = keys.size(2)
    eq = query.size(1)
    attention = to_gpu(SDPAttention(in_features=ed))
    if ed != eq:
        with pytest.raises(RuntimeError):
            result = attention(query=V(query), keys=V(keys), values=V(keys))
    else:
        result = attention(query=V(query), keys=V(keys), values=V(keys))
        assert (bs, ed) == result.shape
Ejemplo n.º 23
0
def model(hredmodel, request):
    emb_size = 300
    nh = 1024
    ntoken = hredmodel.nt
    model = HRED(ntoken=ntoken,
                 nhid=nh,
                 nlayers=2,
                 emb_sz=emb_size,
                 pad_token=hredmodel.pad_idx,
                 eos_token=hredmodel.eos_idx,
                 bidir=request.param)
    model = to_gpu(model)
    return model
Ejemplo n.º 24
0
def test_transformer_decoder_layers():
    sl = 10
    bs = 2
    in_features = 32
    num_layers = 5
    inputs = tr.randn([sl, bs, in_features])
    encoder_inputs = to_gpu(V(T(tr.randn([num_layers, sl, bs, in_features]))))
    inputs = to_gpu(V(T(inputs)))
    transformer = to_gpu(
        TransformerDecoderLayers(input_size=in_features,
                                 num_heads=8,
                                 nhid=512,
                                 nlayers=num_layers,
                                 dropout=0.0))
    assert transformer.hidden is None
    layer_outputs = transformer(inputs, encoder_inputs)
    assert_dims(layer_outputs, [num_layers, sl, bs, in_features])
    assert transformer.hidden is None
    # Passing through tht decoderlayers only one output I should be getting the same output
    layer_outputs2 = transformer(inputs[:1], encoder_inputs)
    assert_dims(layer_outputs2, [num_layers, 1, bs, in_features])
    for layer1, layer2 in zip(layer_outputs, layer_outputs2):
        assert ((layer1[0] - layer2[0]).abs() < 1E-6).all()
Ejemplo n.º 25
0
 def get_model(self, opt_fn, emb_sz, n_hid, n_layers, **kwargs):
     """ Method returns a RNN_Learner object, that wraps an instance of the RNN_Encoder module.
     Args:
         opt_fn (Optimizer): the torch optimizer function to use
         emb_sz (int): embedding size
         n_hid (int): number of hidden inputs
         n_layers (int): number of hidden layers
         kwargs: other arguments
     Returns:
         An instance of the RNN_Learner class.
     """
     m = get_model(emb_sz, n_hid, n_layers, **kwargs)
     model = SingleModel(to_gpu(m))
     return RNN_Learner(self, model, opt_fn=opt_fn, crit=F.mse_loss)
Ejemplo n.º 26
0
def rnn_decoder(decoder_params):
    decoder_embedding_layer = DropoutEmbeddings(
        ntokens=decoder_params.ntokens,
        emb_size=decoder_params.emb_size,
    )

    if decoder_params.attention:
        # attention decoder must have double the input_size to accommodate for the attention concat
        decoder_rnn = RNNLayers(input_size=decoder_params.emb_size * 2,
                                output_size=decoder_params.emb_size,
                                nhid=decoder_params.nhid,
                                bidir=False,
                                nlayers=decoder_params.nlayers,
                                cell_type="gru")
        projection_layer = AttentionProjection(
            output_size=decoder_params.ntokens,
            input_size=decoder_params.emb_size,
            att_nhid=decoder_params.att_hid,
            tie_encoder=None,
            dropout=0.0)
        decoder = AttentionDecoder(decoder_layer=decoder_rnn,
                                   embedding_layer=decoder_embedding_layer,
                                   projection_layer=projection_layer,
                                   pad_token=1,
                                   eos_token=2,
                                   max_tokens=decoder_params.max_tokens)

    else:

        decoder_rnn = RNNLayers(input_size=decoder_params.emb_size,
                                output_size=decoder_params.emb_size,
                                nhid=decoder_params.nhid,
                                bidir=False,
                                nlayers=decoder_params.nlayers,
                                cell_type="gru")
        projection_layer = Projection(output_size=decoder_params.ntokens,
                                      input_size=decoder_params.emb_size,
                                      dropout=0.0,
                                      tie_encoder=None)
        decoder = Decoder(
            decoder_layer=decoder_rnn,
            projection_layer=projection_layer,
            embedding_layer=decoder_embedding_layer,
            pad_token=0,
            eos_token=1,
            max_tokens=decoder_params.max_tokens,
        )
    decoder = to_gpu(decoder)
    decoder.reset(decoder_params.batch_size)
    return decoder, decoder_params
Ejemplo n.º 27
0
def test_attention_projection(attention_projection_setup):
    encoder_outputs, decoder_output, params = attention_projection_setup
    module = to_gpu(AttentionProjection(**params))
    # When I reset the module
    module.reset(keys=encoder_outputs)
    # the attention output will be a zeros array with shape equal to the input
    assert to_np(module.get_attention_output(decoder_output)).sum() == 0
    assert module.get_attention_output(decoder_output) is not module._attention_output
    # when when I pass an input for the the decoder output
    results = module(decoder_output)
    assert_dims(results, [1, 2, params['n_out']])
    # the new attention_output is calculated from he attention module and is no longer zero
    assert to_np(module.get_attention_output(decoder_output)).sum() != 0
    assert module.get_attention_output(decoder_output) is module._attention_output
    assert_dims(module._attention_output, [2, params['n_in']])
Ejemplo n.º 28
0
 def one_hidden(self, bs=1, cell_state=False):
     ndir = 2 if self.bidir else 1
     if not self.train_init:
         init_state = to_gpu(torch.zeros(ndir, bs, self.output_size))
     elif cell_state:
         init_state = F.dropout(self.init_cell_state,
                                p=self.dropoutinit,
                                training=self.training)
         init_state.repeat(1, bs, 1)
     else:
         init_state = F.dropout(self.init_state,
                                p=self.dropoutinit,
                                training=self.training)
         return init_state.repeat(1, bs, 1)
     return init_state
Ejemplo n.º 29
0
def model(hredmodel, request):
    emb_size = 300
    nh = 1024
    ntoken = hredmodel.nt
    model = CVAE(ntoken=ntoken,
                 nhid=nh,
                 nlayers=2,
                 emb_sz=emb_size,
                 pad_token=hredmodel.pad_idx,
                 eos_token=hredmodel.eos_idx,
                 latent_dim=100,
                 bow_nhid=400,
                 bidir=request.param)
    model = to_gpu(model)
    return model
Ejemplo n.º 30
0
def test_MultiHeadAttention(attention_setup):
    keys, query = attention_setup
    bs = query.size(0)
    ed = keys.size(2)
    eq = query.size(1)
    num_heads = 4
    nhid = 10
    attention = to_gpu(
        MultiHeadAttention(num_heads=num_heads,
                           nhid=nhid,
                           keys_dim=ed,
                           query_dim=eq,
                           values_dim=eq))

    result = attention(query=V(query), keys=V(keys), values=V(keys))
    assert_dims(result, [bs, num_heads * nhid])