Ejemplo n.º 1
0
    def _build(self, batch_size):
        src_time_dim = 4
        vocab_size = 7

        emb = Embeddings(embedding_dim=self.emb_size,
                         vocab_size=vocab_size,
                         padding_idx=self.pad_index)

        decoder = TransformerDecoder(num_layers=self.num_layers,
                                     num_heads=self.num_heads,
                                     hidden_size=self.hidden_size,
                                     ff_size=self.ff_size,
                                     dropout=self.dropout,
                                     emb_dropout=self.dropout,
                                     vocab_size=vocab_size)

        encoder_output = torch.rand(size=(batch_size, src_time_dim,
                                          self.hidden_size))

        for p in decoder.parameters():
            torch.nn.init.uniform_(p, -0.5, 0.5)

        src_mask = torch.ones(size=(batch_size, 1, src_time_dim)) == 1

        encoder_hidden = None  # unused
        return src_mask, emb, decoder, encoder_output, encoder_hidden
Ejemplo n.º 2
0
    def test_transformer_decoder_layers(self):

        torch.manual_seed(self.seed)
        batch_size = 2
        src_time_dim = 4
        trg_time_dim = 5
        vocab_size = 7

        decoder = TransformerDecoder(num_layers=self.num_layers,
                                     num_heads=self.num_heads,
                                     hidden_size=self.hidden_size,
                                     ff_size=self.ff_size,
                                     dropout=self.dropout,
                                     vocab_size=vocab_size)

        self.assertEqual(len(decoder.layers), self.num_layers)

        for layer in decoder.layers:
            self.assertTrue(isinstance(layer, TransformerDecoderLayer))
            self.assertTrue(hasattr(layer, "src_attn"))
            self.assertTrue(hasattr(layer, "self_attn"))
            self.assertTrue(hasattr(layer, "feed_forward"))
            self.assertEqual(layer.size, self.hidden_size)
            self.assertEqual(layer.feed_forward.layer[0].in_features,
                             self.hidden_size)
            self.assertEqual(layer.feed_forward.layer[0].out_features,
                             self.ff_size)
Ejemplo n.º 3
0
    def test_transformer_decoder_output_size(self):

        vocab_size = 11
        decoder = TransformerDecoder(
            num_layers=self.num_layers, num_heads=self.num_heads,
            hidden_size=self.hidden_size, ff_size=self.ff_size,
            dropout=self.dropout, vocab_size=vocab_size)

        if not hasattr(decoder, "vocab_size"):
            self.fail("Missing vocab_size property.")

        self.assertEqual(decoder.vocab_size, vocab_size)
Ejemplo n.º 4
0
    def test_transformer_decoder_layers(self):

        vocab_size = 7

        decoder = TransformerDecoder(num_layers=self.num_layers,
                                     num_heads=self.num_heads,
                                     hidden_size=self.hidden_size,
                                     ff_size=self.ff_size,
                                     dropout=self.dropout,
                                     vocab_size=vocab_size)

        self.assertEqual(len(decoder.layers), self.num_layers)

        for layer in decoder.layers:
            self.assertTrue(isinstance(layer, TransformerDecoderLayer))
            self.assertTrue(hasattr(layer, "src_trg_att"))
            self.assertTrue(hasattr(layer, "trg_trg_att"))
            self.assertTrue(hasattr(layer, "feed_forward"))
            self.assertEqual(layer.size, self.hidden_size)
            self.assertEqual(layer.feed_forward.pwff_layer[0].in_features,
                             self.hidden_size)
            self.assertEqual(layer.feed_forward.pwff_layer[0].out_features,
                             self.ff_size)
Ejemplo n.º 5
0
    def test_transformer_decoder_forward(self):
        torch.manual_seed(self.seed)
        batch_size = 2
        src_time_dim = 4
        trg_time_dim = 5
        vocab_size = 7

        trg_embed = torch.rand(size=(batch_size, trg_time_dim, self.emb_size))

        decoder = TransformerDecoder(num_layers=self.num_layers,
                                     num_heads=self.num_heads,
                                     hidden_size=self.hidden_size,
                                     ff_size=self.ff_size,
                                     dropout=self.dropout,
                                     vocab_size=vocab_size)

        encoder_output = torch.rand(size=(batch_size, src_time_dim,
                                          self.hidden_size))

        src_mask = torch.ones(size=(batch_size, 1, src_time_dim)).byte()
        trg_mask = torch.ones(size=(batch_size, trg_time_dim, 1)).byte()

        encoder_hidden = None  # unused
        decoder_hidden = None  # unused
        unrol_steps = None  # unused

        output, states, _, _ = decoder(trg_embed, encoder_output,
                                       encoder_hidden, src_mask, unrol_steps,
                                       decoder_hidden, trg_mask)

        output_target = torch.Tensor(
            [[[-0.0805, 0.4592, 0.0718, 0.7900, 0.5230, 0.5067, -0.5715],
              [0.0711, 0.3738, -0.1151, 0.5634, 0.0394, 0.2720, -0.4201],
              [0.2361, 0.1380, -0.2817, 0.0559, 0.0591, 0.2231, -0.0882],
              [0.1779, 0.2605, -0.1604, -0.1684, 0.1802, 0.0476, -0.3675],
              [0.2059, 0.1267, -0.2322, -0.1361, 0.1820, -0.0788, -0.2393]],
             [[0.0538, 0.0175, -0.0042, 0.0384, 0.2151, 0.4149, -0.4311],
              [-0.0368, 0.1387, -0.3131, 0.3600, -0.1514, 0.4926, -0.2868],
              [0.1802, 0.0177, -0.4545, 0.2662, -0.3109, -0.0331, -0.0180],
              [0.3109, 0.2541, -0.3547, 0.0236, -0.3156, -0.0822, -0.0328],
              [0.3497, 0.2526, 0.1080, -0.5393, 0.2724, -0.4332, -0.3632]]])
        self.assertEqual(output_target.shape, output.shape)
        self.assertTensorAlmostEqual(output_target, output)

        greedy_predictions = output.argmax(-1)
        expect_predictions = output_target.argmax(-1)
        self.assertTensorEqual(expect_predictions, greedy_predictions)

        states_target = torch.Tensor(
            [[[
                -0.0755, 1.4055, -1.1602, 0.6213, 0.0544, 1.3840, -1.2356,
                -1.9077, -0.6345, 0.5314, 0.4973, 0.5196
            ],
              [
                  0.3196, 0.7497, -0.7922, -0.2416, 0.5386, 1.0843, -1.4864,
                  -1.8824, -0.7546, 1.2005, 0.0748, 1.1896
              ],
              [
                  -0.3941, -0.1136, -0.9666, -0.4205, 0.2330, 0.7739, -0.4792,
                  -2.0162, -0.4363, 1.6525, 0.5820, 1.5851
              ],
              [
                  -0.6153, -0.4550, -0.8141, -0.8289, 0.3393, 1.1795, -1.0093,
                  -1.0871, -0.8108, 1.4794, 1.1199, 1.5025
              ],
              [
                  -0.6611, -0.6822, -0.7189, -0.6791, -0.1858, 1.5746, -0.5461,
                  -1.0275, -0.9931, 1.5337, 1.3765, 1.0090
              ]],
             [[
                 -0.5529, 0.5892, -0.5661, -0.0163, -0.1006, 0.8997, -0.9661,
                 -1.7280, -1.2770, 1.3293, 1.0589, 1.3298
             ],
              [
                  0.5863, 0.2046, -0.9396, -0.5605, -0.4051, 1.3006, -0.9817,
                  -1.3750, -1.2850, 1.2806, 0.9258, 1.2487
              ],
              [
                  0.1955, -0.3549, -0.4581, -0.8584, 0.0424, 1.1371, -0.7769,
                  -1.8383, -0.6448, 1.8183, 0.4338, 1.3043
              ],
              [
                  -0.0227, -0.8035, -0.5716, -0.9380, 0.3337, 1.2892, -0.7494,
                  -1.5868, -0.5518, 1.5482, 0.5330, 1.5195
              ],
              [
                  -1.7046, -0.7190, 0.0613, -0.5847, 1.0075, 0.7987, -1.0774,
                  -1.0810, -0.1800, 1.2212, 0.8317, 1.4263
              ]]])

        self.assertEqual(states_target.shape, states.shape)
        self.assertTensorAlmostEqual(states_target, states)
Ejemplo n.º 6
0
 def test_transformer_decoder_freeze(self):
     torch.manual_seed(self.seed)
     encoder = TransformerDecoder(freeze=True)
     for n, p in encoder.named_parameters():
         self.assertFalse(p.requires_grad)
Ejemplo n.º 7
0
def build_model(cfg: dict = None,
                src_vocab: Vocabulary = None,
                trg_vocab: Vocabulary = None) -> Model:
    """
    Build and initialize the model according to the configuration.

    :param cfg: dictionary configuration containing model specifications
    :param src_vocab: source vocabulary
    :param trg_vocab: target vocabulary
    :return: built and initialized model
    """
    src_padding_idx = src_vocab.stoi[PAD_TOKEN]
    trg_padding_idx = trg_vocab.stoi[PAD_TOKEN]

    # TODO if continue-us
    src_embed = PretrainedEmbeddings(src_vocab,
                                     trg_vocab,
                                     **cfg["encoder"]["embeddings"],
                                     vocab_size=len(src_vocab),
                                     padding_idx=src_padding_idx)

    # this ties source and target embeddings
    # for softmax layer tying, see further below
    if cfg.get("tied_embeddings", False):
        if src_vocab.itos == trg_vocab.itos:
            # share embeddings for src and trg
            trg_embed = src_embed
        else:
            raise ConfigurationError(
                "Embedding cannot be tied since vocabularies differ.")
    else:
        src_embed = PretrainedEmbeddings(src_vocab,
                                         trg_vocab,
                                         **cfg["encoder"]["embeddings"],
                                         vocab_size=len(src_vocab),
                                         padding_idx=src_padding_idx)

    # build encoder
    enc_dropout = cfg["encoder"].get("dropout", 0.)
    enc_emb_dropout = cfg["encoder"]["embeddings"].get("dropout", enc_dropout)
    if cfg["encoder"].get("type", "recurrent") == "transformer":
        assert cfg["encoder"]["embeddings"]["embedding_dim"] == \
               cfg["encoder"]["hidden_size"], \
               "for transformer, emb_size must be hidden_size"

        encoder = TransformerEncoder(**cfg["encoder"],
                                     emb_size=src_embed.embedding_dim,
                                     emb_dropout=enc_emb_dropout)
    else:
        encoder = RecurrentEncoder(**cfg["encoder"],
                                   emb_size=src_embed.embedding_dim,
                                   emb_dropout=enc_emb_dropout)

    # build decoder
    dec_dropout = cfg["decoder"].get("dropout", 0.)
    dec_emb_dropout = cfg["decoder"]["embeddings"].get("dropout", dec_dropout)
    if cfg["decoder"].get("type", "recurrent") == "transformer":
        decoder = TransformerDecoder(**cfg["decoder"],
                                     encoder=encoder,
                                     vocab_size=len(trg_vocab),
                                     emb_size=trg_embed.embedding_dim,
                                     emb_dropout=dec_emb_dropout)
    else:
        decoder = RecurrentDecoder(**cfg["decoder"],
                                   encoder=encoder,
                                   vocab_size=len(trg_vocab),
                                   emb_size=trg_embed.embedding_dim,
                                   emb_dropout=dec_emb_dropout)

    model = Model(encoder=encoder,
                  decoder=decoder,
                  src_embed=src_embed,
                  trg_embed=trg_embed,
                  src_vocab=src_vocab,
                  trg_vocab=trg_vocab)

    # tie softmax layer with trg embeddings
    """
    if cfg.get("tied_softmax", False):
        if trg_embed.lut.weight.shape == \
                model.decoder.output_layer.weight.shape:
            # (also) share trg embeddings and softmax layer:
            model.decoder.output_layer.weight = trg_embed.lut.weight
        else:
            raise ConfigurationError(
                "For tied_softmax, the decoder embedding_dim and decoder "
                "hidden_size must be the same."
                "The decoder must be a Transformer."
                f"shapes: output_layer.weight: {model.decoder.output_layer.weight.shape}; target_embed.lut.weight:{trg_embed.lut.weight.shape}")
    """
    # custom initialization of model parameters
    initialize_model(model, cfg, src_padding_idx, trg_padding_idx)

    return model
Ejemplo n.º 8
0
def build_model(cfg: dict = None,
                src_vocab: Vocabulary = None,
                trg_vocab: Vocabulary = None,
                trv_vocab: Vocabulary = None,
                canonizer=None) -> Model:
    """
    Build and initialize the model according to the configuration.

    :param cfg: dictionary configuration containing model specifications
    :param src_vocab: source vocabulary
    :param trg_vocab: target vocabulary
    :param trv_vocab: kb true value lookup vocabulary
    :return: built and initialized model
    """
    src_padding_idx = src_vocab.stoi[PAD_TOKEN]
    trg_padding_idx = trg_vocab.stoi[PAD_TOKEN]

    if "embedding_files" in cfg.keys():  #init from pretrained
        assert not cfg.get(
            "tied_embeddings", False
        ), "TODO implement tied embeddings along with pretrained initialization"
        raise NotImplementedError(
            "TODO implement kbsrc embed loading for embedding files")
        weight_tensors = []
        for weight_file in cfg["embedding_files"]:
            with open(weight_file, "r") as f:
                weight = []
                for line in f.readlines():
                    line = line.split()
                    line = [float(x) for x in line]
                    weight.append(line)

            weight = FloatTensor(weight)
            weight_tensors.append(weight)
        # Set source Embeddings to Pretrained Embeddings
        src_embed = Embeddings(
            int(weight_tensors[0][0].shape[0]),
            False,  #TODO transformer: change to True
            len(weight_tensors[0]),
        )
        src_embed.lut.weight.data = weight_tensors[0]

        # Set target Embeddings to Pretrained Embeddings
        trg_embed = Embeddings(
            int(weight_tensors[1][0].shape[0]),
            False,  #TODO transformer: change to True
            len(weight_tensors[1]),
        )
        trg_embed.lut.weight.data = weight_tensors[1]
    else:
        src_embed = Embeddings(**cfg["encoder"]["embeddings"],
                               vocab_size=len(src_vocab),
                               padding_idx=src_padding_idx)
        if cfg.get("kb_embed_separate", False):
            kbsrc_embed = Embeddings(**cfg["encoder"]["embeddings"],
                                     vocab_size=len(src_vocab),
                                     padding_idx=src_padding_idx)
        else:
            kbsrc_embed = src_embed

        # this ties source and target embeddings
        # for softmax layer tying, see further below
        if cfg.get("tied_embeddings", False):
            if src_vocab.itos == trg_vocab.itos:
                # share embeddings for src and trg
                trg_embed = src_embed
            else:
                raise ConfigurationError(
                    "Embedding cannot be tied since vocabularies differ.")
        else:
            # Latest TODO: init embeddings with vocab_size = len(trg_vocab joined with kb_vocab)
            trg_embed = Embeddings(**cfg["decoder"]["embeddings"],
                                   vocab_size=len(trg_vocab),
                                   padding_idx=trg_padding_idx)
    # build encoder
    enc_dropout = cfg["encoder"].get("dropout", 0.)
    enc_emb_dropout = cfg["encoder"]["embeddings"].get("dropout", enc_dropout)
    if cfg["encoder"].get("type", "recurrent") == "transformer":
        assert cfg["encoder"]["embeddings"]["embedding_dim"] == \
               cfg["encoder"]["hidden_size"], \
               "for transformer, emb_size must be hidden_size"

        encoder = TransformerEncoder(**cfg["encoder"],
                                     emb_size=src_embed.embedding_dim,
                                     emb_dropout=enc_emb_dropout)
    else:
        encoder = RecurrentEncoder(**cfg["encoder"],
                                   emb_size=src_embed.embedding_dim,
                                   emb_dropout=enc_emb_dropout)

    # retrieve kb task info
    kb_task = bool(cfg.get("kb", False))
    k_hops = int(
        cfg.get("k_hops", 1)
    )  # k number of kvr attention layers in decoder (eric et al/default: 1)
    same_module_for_all_hops = bool(cfg.get("same_module_for_all_hops", False))
    do_postproc = bool(cfg.get("do_postproc", True))
    copy_from_source = bool(cfg.get("copy_from_source", True))
    canonization_func = None if canonizer is None else canonizer(
        copy_from_source=copy_from_source)
    kb_input_feeding = bool(cfg.get("kb_input_feeding", True))
    kb_feed_rnn = bool(cfg.get("kb_feed_rnn", True))
    kb_multihead_feed = bool(cfg.get("kb_multihead_feed", False))
    posEncKBkeys = cfg.get("posEncdKBkeys", False)
    tfstyletf = cfg.get("tfstyletf", True)
    infeedkb = bool(cfg.get("infeedkb", False))
    outfeedkb = bool(cfg.get("outfeedkb", False))
    add_kb_biases_to_output = bool(cfg.get("add_kb_biases_to_output", True))
    kb_max_dims = cfg.get("kb_max_dims", (16, 32))  # should be tuple
    double_decoder = cfg.get("double_decoder", False)
    tied_side_softmax = cfg.get(
        "tied_side_softmax",
        False)  # actually use separate linear layers, tying only the main one
    do_pad_kb_keys = cfg.get(
        "pad_kb_keys", True
    )  # doesnt need to be true for 1 hop (=>BIG PERFORMANCE SAVE), needs to be true for >= 2 hops

    if hasattr(kb_max_dims, "__iter__"):
        kb_max_dims = tuple(kb_max_dims)
    else:
        assert type(kb_max_dims) == int, kb_max_dims
        kb_max_dims = (kb_max_dims, )

    assert cfg["decoder"]["hidden_size"]
    dec_dropout = cfg["decoder"].get("dropout", 0.)
    dec_emb_dropout = cfg["decoder"]["embeddings"].get("dropout", dec_dropout)

    if cfg["decoder"].get("type", "recurrent") == "transformer":
        if tfstyletf:
            decoder = TransformerDecoder(
                **cfg["decoder"],
                encoder=encoder,
                vocab_size=len(trg_vocab),
                emb_size=trg_embed.embedding_dim,
                emb_dropout=dec_emb_dropout,
                kb_task=kb_task,
                kb_key_emb_size=kbsrc_embed.embedding_dim,
                feed_kb_hidden=kb_input_feeding,
                infeedkb=infeedkb,
                outfeedkb=outfeedkb,
                double_decoder=double_decoder)
        else:
            decoder = TransformerKBrnnDecoder(
                **cfg["decoder"],
                encoder=encoder,
                vocab_size=len(trg_vocab),
                emb_size=trg_embed.embedding_dim,
                emb_dropout=dec_emb_dropout,
                kb_task=kb_task,
                k_hops=k_hops,
                kb_max=kb_max_dims,
                same_module_for_all_hops=same_module_for_all_hops,
                kb_key_emb_size=kbsrc_embed.embedding_dim,
                kb_input_feeding=kb_input_feeding,
                kb_feed_rnn=kb_feed_rnn,
                kb_multihead_feed=kb_multihead_feed)
    else:
        if not kb_task:
            decoder = RecurrentDecoder(**cfg["decoder"],
                                       encoder=encoder,
                                       vocab_size=len(trg_vocab),
                                       emb_size=trg_embed.embedding_dim,
                                       emb_dropout=dec_emb_dropout)
        else:
            decoder = KeyValRetRNNDecoder(
                **cfg["decoder"],
                encoder=encoder,
                vocab_size=len(trg_vocab),
                emb_size=trg_embed.embedding_dim,
                emb_dropout=dec_emb_dropout,
                k_hops=k_hops,
                kb_max=kb_max_dims,
                same_module_for_all_hops=same_module_for_all_hops,
                kb_key_emb_size=kbsrc_embed.embedding_dim,
                kb_input_feeding=kb_input_feeding,
                kb_feed_rnn=kb_feed_rnn,
                kb_multihead_feed=kb_multihead_feed,
                do_pad_kb_keys=do_pad_kb_keys)

    # specify generator which is mostly just the output layer
    generator = Generator(dec_hidden_size=cfg["decoder"]["hidden_size"],
                          vocab_size=len(trg_vocab),
                          add_kb_biases_to_output=add_kb_biases_to_output,
                          double_decoder=double_decoder)

    model = Model(
                  encoder=encoder, decoder=decoder, generator=generator,
                  src_embed=src_embed, trg_embed=trg_embed,
                  src_vocab=src_vocab, trg_vocab=trg_vocab,\
                  kb_key_embed=kbsrc_embed,\
                  trv_vocab=trv_vocab,
                  k_hops=k_hops,
                  do_postproc=do_postproc,
                  canonize=canonization_func,
                  kb_att_dims=len(kb_max_dims),
                  posEncKBkeys=posEncKBkeys
                  )

    # tie softmax layer with trg embeddings
    if cfg.get("tied_softmax", False):
        if trg_embed.lut.weight.shape == \
                model.generator.output_layer.weight.shape:
            # (also) share trg embeddings and softmax layer:
            model.generator.output_layer.weight = trg_embed.lut.weight
            if model.generator.double_decoder:
                # (also also) share trg embeddings and side softmax layer
                assert hasattr(model.generator, "side_output_layer")
                if tied_side_softmax:
                    # because of distributivity this becomes O (x_1+x_2) instead of O_1 x_1 + O_2 x_2
                    model.generator.side_output_layer.weight = trg_embed.lut.weight
        else:
            raise ConfigurationError(
                "For tied_softmax, the decoder embedding_dim and decoder "
                "hidden_size must be the same."
                "The decoder must be a Transformer.")

    # custom initialization of model parameters
    initialize_model(model, cfg, src_padding_idx, trg_padding_idx)

    return model
Ejemplo n.º 9
0
    def test_transformer_decoder_forward(self):
        batch_size = 2
        src_time_dim = 4
        trg_time_dim = 5
        vocab_size = 7

        trg_embed = torch.rand(size=(batch_size, trg_time_dim, self.emb_size))

        decoder = TransformerDecoder(num_layers=self.num_layers,
                                     num_heads=self.num_heads,
                                     hidden_size=self.hidden_size,
                                     ff_size=self.ff_size,
                                     dropout=self.dropout,
                                     emb_dropout=self.dropout,
                                     vocab_size=vocab_size)

        encoder_output = torch.rand(size=(batch_size, src_time_dim,
                                          self.hidden_size))

        for p in decoder.parameters():
            torch.nn.init.uniform_(p, -0.5, 0.5)

        src_mask = torch.ones(size=(batch_size, 1, src_time_dim)) == 1
        trg_mask = torch.ones(size=(batch_size, trg_time_dim, 1)) == 1

        encoder_hidden = None  # unused
        decoder_hidden = None  # unused
        unrol_steps = None  # unused

        output, states, _, _ = decoder(trg_embed, encoder_output,
                                       encoder_hidden, src_mask, unrol_steps,
                                       decoder_hidden, trg_mask)

        output_target = torch.Tensor(
            [[[0.1946, 0.6144, -0.1925, -0.6967, 0.4466, -0.1085, 0.3400],
              [0.1857, 0.5558, -0.1314, -0.7783, 0.3980, -0.1736, 0.2347],
              [-0.0216, 0.3663, -0.2251, -0.5800, 0.2996, 0.0918, 0.2833],
              [0.0389, 0.4843, -0.1914, -0.6326, 0.3674, -0.0903, 0.2524],
              [0.0373, 0.3276, -0.2835, -0.6210, 0.2297, -0.0367, 0.1962]],
             [[0.0241, 0.4255, -0.2074, -0.6517, 0.3380, -0.0312, 0.2392],
              [0.1577, 0.4292, -0.1792, -0.7406, 0.2696, -0.1610, 0.2233],
              [0.0122, 0.4203, -0.2302, -0.6640, 0.2843, -0.0710, 0.2984],
              [0.0115, 0.3416, -0.2007, -0.6255, 0.2708, -0.0251, 0.2113],
              [0.0094, 0.4787, -0.1730, -0.6124, 0.4650, -0.0382, 0.1910]]])
        self.assertEqual(output_target.shape, output.shape)
        self.assertTensorAlmostEqual(output_target, output)

        greedy_predictions = output.argmax(-1)
        expect_predictions = output_target.argmax(-1)
        self.assertTensorEqual(expect_predictions, greedy_predictions)

        states_target = torch.Tensor([
            [[
                0.0491, 0.5322, 0.0327, -0.9208, -0.5646, -0.1138, 0.3416,
                -0.3235, 0.0350, -0.4339, 0.5837, 0.1022
            ],
             [
                 0.1838, 0.4832, -0.0498, -0.7803, -0.5348, -0.1162, 0.3667,
                 -0.3076, -0.0842, -0.4287, 0.6334, 0.1872
             ],
             [
                 0.0910, 0.3801, 0.0451, -0.7478, -0.4655, -0.1040, 0.6660,
                 -0.2871, 0.0544, -0.4561, 0.5823, 0.1653
             ],
             [
                 0.1064, 0.3970, -0.0691, -0.5924, -0.4410, -0.0984, 0.2759,
                 -0.3108, -0.0127, -0.4857, 0.6074, 0.0979
             ],
             [
                 0.0424, 0.3607, -0.0287, -0.5379, -0.4454, -0.0892, 0.4730,
                 -0.3021, -0.1303, -0.4889, 0.5257, 0.1394
             ]],
            [[
                0.1459, 0.4663, 0.0316, -0.7014, -0.4267, -0.0985, 0.5141,
                -0.2743, -0.0897, -0.4771, 0.5795, 0.1014
            ],
             [
                 0.2450, 0.4507, 0.0958, -0.6684, -0.4726, -0.0926, 0.4593,
                 -0.2969, -0.1612, -0.4224, 0.6054, 0.1698
             ],
             [
                 0.2137, 0.4132, 0.0327, -0.5304, -0.4519, -0.0934, 0.3898,
                 -0.2846, -0.0077, -0.4928, 0.6087, 0.1249
             ],
             [
                 0.1752, 0.3687, 0.0479, -0.5960, -0.4000, -0.0952, 0.5159,
                 -0.2926, -0.0668, -0.4628, 0.6031, 0.1711
             ],
             [
                 0.0396, 0.4577, -0.0789, -0.7109, -0.4049, -0.0989, 0.3596,
                 -0.2966, 0.0044, -0.4571, 0.6315, 0.1103
             ]]
        ])

        self.assertEqual(states_target.shape, states.shape)
        self.assertTensorAlmostEqual(states_target, states)
Ejemplo n.º 10
0
 def test_transformer_decoder_freeze(self):
     decoder = TransformerDecoder(freeze=True)
     for n, p in decoder.named_parameters():
         self.assertFalse(p.requires_grad)
Ejemplo n.º 11
0
    def test_transformer_decoder_forward(self):
        batch_size = 2
        src_time_dim = 4
        trg_time_dim = 5
        vocab_size = 7

        trg_embed = torch.rand(size=(batch_size, trg_time_dim, self.emb_size))

        decoder = TransformerDecoder(num_layers=self.num_layers,
                                     num_heads=self.num_heads,
                                     hidden_size=self.hidden_size,
                                     ff_size=self.ff_size,
                                     dropout=self.dropout,
                                     emb_dropout=self.dropout,
                                     vocab_size=vocab_size)

        encoder_output = torch.rand(size=(batch_size, src_time_dim,
                                          self.hidden_size))

        for p in decoder.parameters():
            torch.nn.init.uniform_(p, -0.5, 0.5)

        src_mask = torch.ones(size=(batch_size, 1, src_time_dim)) == 1
        trg_mask = torch.ones(size=(batch_size, trg_time_dim, 1)) == 1

        encoder_hidden = None  # unused
        decoder_hidden = None  # unused
        unrol_steps = None  # unused

        output, states, _, _ = decoder(trg_embed, encoder_output,
                                       encoder_hidden, src_mask, unrol_steps,
                                       decoder_hidden, trg_mask)

        output_target = torch.Tensor(
            [[[0.1718, 0.5595, -0.1996, -0.6924, 0.4351, -0.0850, 0.2805],
              [0.0666, 0.4923, -0.1724, -0.6804, 0.3983, -0.1111, 0.2194],
              [-0.0315, 0.3673, -0.2320, -0.6100, 0.3019, 0.0422, 0.2514],
              [-0.0026, 0.3807, -0.2195, -0.6010, 0.3081, -0.0101, 0.2099],
              [-0.0172, 0.3384, -0.2853, -0.5799, 0.2470, 0.0312, 0.2518]],
             [[0.0284, 0.3918, -0.2010, -0.6472, 0.3646, -0.0296, 0.1791],
              [0.1017, 0.4387, -0.2031, -0.7084, 0.3051, -0.1354, 0.2511],
              [0.0155, 0.4274, -0.2061, -0.6702, 0.3085, -0.0617, 0.2830],
              [0.0227, 0.4067, -0.1697, -0.6463, 0.3277, -0.0423, 0.2333],
              [0.0133, 0.4409, -0.1186, -0.5694, 0.4450, 0.0290, 0.1643]]])
        self.assertEqual(output_target.shape, output.shape)
        self.assertTensorAlmostEqual(output_target, output)

        greedy_predictions = output.argmax(-1)
        expect_predictions = output_target.argmax(-1)
        self.assertTensorEqual(expect_predictions, greedy_predictions)

        states_target = torch.Tensor(
            [[[
                3.7535e-02, 5.3508e-01, 4.9478e-02, -9.1961e-01, -5.3966e-01,
                -1.0065e-01, 4.3053e-01, -3.0671e-01, -1.2724e-02, -4.1879e-01,
                5.9625e-01, 1.1887e-01
            ],
              [
                  1.3837e-01, 4.6963e-01, -3.7059e-02, -6.8479e-01,
                  -4.6042e-01, -1.0072e-01, 3.9374e-01, -3.0429e-01,
                  -5.4203e-02, -4.3680e-01, 6.4257e-01, 1.1424e-01
              ],
              [
                  1.0263e-01, 3.8331e-01, -2.5586e-02, -6.4478e-01,
                  -4.5860e-01, -1.0590e-01, 5.8806e-01, -2.8856e-01,
                  1.1084e-02, -4.7479e-01, 5.9094e-01, 1.6089e-01
              ],
              [
                  7.3408e-02, 3.7701e-01, -5.8783e-02, -6.2368e-01,
                  -4.4201e-01, -1.0237e-01, 5.2556e-01, -3.0821e-01,
                  -5.3345e-02, -4.5606e-01, 5.8259e-01, 1.2531e-01
              ],
              [
                  4.1206e-02, 3.6129e-01, -1.2955e-02, -5.8638e-01,
                  -4.6023e-01, -9.4267e-02, 5.5464e-01, -3.0029e-01,
                  -3.3974e-02, -4.8347e-01, 5.4088e-01, 1.2015e-01
              ]],
             [[
                 1.1017e-01, 4.7179e-01, 2.6402e-02, -7.2170e-01, -3.9778e-01,
                 -1.0226e-01, 5.3498e-01, -2.8369e-01, -1.1081e-01,
                 -4.6096e-01, 5.9517e-01, 1.3531e-01
             ],
              [
                  2.1947e-01, 4.6407e-01, 8.4276e-02, -6.3263e-01, -4.4953e-01,
                  -9.7334e-02, 4.0321e-01, -2.9893e-01, -1.0368e-01,
                  -4.5760e-01, 6.1378e-01, 1.3509e-01
              ],
              [
                  2.1437e-01, 4.1372e-01, 1.9859e-02, -5.7415e-01, -4.5025e-01,
                  -9.8621e-02, 4.1182e-01, -2.8410e-01, -1.2729e-03,
                  -4.8586e-01, 6.2318e-01, 1.4731e-01
              ],
              [
                  1.9153e-01, 3.8401e-01, 2.6096e-02, -6.2339e-01, -4.0685e-01,
                  -9.7387e-02, 4.1836e-01, -2.8648e-01, -1.7857e-02,
                  -4.7678e-01, 6.2907e-01, 1.7617e-01
              ],
              [
                  3.1713e-02, 3.7548e-01, -6.3005e-02, -7.9804e-01,
                  -3.6541e-01, -1.0398e-01, 4.2991e-01, -2.9607e-01,
                  2.1376e-04, -4.5897e-01, 6.1062e-01, 1.6142e-01
              ]]])

        self.assertEqual(states_target.shape, states.shape)
        self.assertTensorAlmostEqual(states_target, states)
Ejemplo n.º 12
0
    def test_transformer_decoder_forward(self):
        torch.manual_seed(self.seed)
        batch_size = 2
        src_time_dim = 4
        trg_time_dim = 5
        vocab_size = 7

        trg_embed = torch.rand(size=(batch_size, trg_time_dim, self.emb_size))

        decoder = TransformerDecoder(
            num_layers=self.num_layers, num_heads=self.num_heads,
            hidden_size=self.hidden_size, ff_size=self.ff_size,
            dropout=self.dropout, emb_dropout=self.dropout,
            vocab_size=vocab_size)

        encoder_output = torch.rand(
            size=(batch_size, src_time_dim, self.hidden_size))

        for p in decoder.parameters():
            torch.nn.init.uniform_(p, -0.5, 0.5)

        src_mask = torch.ones(size=(batch_size, 1, src_time_dim)) == 1
        trg_mask = torch.ones(size=(batch_size, trg_time_dim, 1)) == 1

        output, states, _, _ = decoder(
            trg_embed, encoder_output, src_mask, trg_mask)

        output_target = torch.Tensor(
            [[[ 0.1765,  0.4578,  0.2345, -0.5303,  0.3862,  0.0964,  0.6882],
            [ 0.3363,  0.3907,  0.2210, -0.5414,  0.3770,  0.0748,  0.7344],
            [ 0.3275,  0.3729,  0.2797, -0.3519,  0.3341,  0.1605,  0.5403],
            [ 0.3081,  0.4513,  0.1900, -0.3443,  0.3072,  0.0570,  0.6652],
            [ 0.3253,  0.4315,  0.1227, -0.3371,  0.3339,  0.1129,  0.6331]],

            [[ 0.3235,  0.4836,  0.2337, -0.4019,  0.2831, -0.0260,  0.7013],
            [ 0.2800,  0.5662,  0.0469, -0.4156,  0.4246, -0.1121,  0.8110],
            [ 0.2968,  0.4777,  0.0652, -0.2706,  0.3146,  0.0732,  0.5362],
            [ 0.3108,  0.4910,  0.0774, -0.2341,  0.2873,  0.0404,  0.5909],
            [ 0.2338,  0.4371,  0.1350, -0.1292,  0.0673,  0.1034,  0.5356]]]
        )
        self.assertEqual(output_target.shape, output.shape)
        self.assertTensorAlmostEqual(output_target, output)

        greedy_predictions = output.argmax(-1)
        expect_predictions = output_target.argmax(-1)
        self.assertTensorEqual(expect_predictions, greedy_predictions)

        states_target = torch.Tensor(
            [[[ 8.3742e-01, -1.3161e-01,  2.1876e-01, -1.3920e-01, -9.1572e-01,
            2.3006e-01,  3.8328e-01, -1.6271e-01,  3.7370e-01, -1.2110e-01,
            -4.7549e-01, -4.0622e-01],
            [ 8.3609e-01, -2.9161e-02,  2.0583e-01, -1.3571e-01, -8.0510e-01,
            2.7630e-01,  4.8219e-01, -1.8863e-01,  1.1977e-01, -2.0179e-01,
            -4.4314e-01, -4.1228e-01],
            [ 8.5478e-01,  1.1368e-01,  2.0400e-01, -1.3059e-01, -8.1042e-01,
            1.6369e-01,  5.4244e-01, -2.9103e-01,  3.9919e-01, -3.3826e-01,
            -4.5423e-01, -4.2516e-01],
            [ 9.0388e-01,  1.1853e-01,  1.9927e-01, -1.1675e-01, -7.7208e-01,
            2.0686e-01,  4.6024e-01, -9.1610e-02,  3.9778e-01, -2.6214e-01,
            -4.7688e-01, -4.0807e-01],
            [ 8.9476e-01,  1.3646e-01,  2.0298e-01, -1.0910e-01, -8.2137e-01,
            2.8025e-01,  4.2538e-01, -1.1852e-01,  4.1497e-01, -3.7422e-01,
            -4.9212e-01, -3.9790e-01]],

            [[ 8.8745e-01, -2.5798e-02,  2.1483e-01, -1.8219e-01, -6.4821e-01,
            2.6279e-01,  3.9598e-01, -1.0423e-01,  3.0726e-01, -1.1315e-01,
            -4.7201e-01, -3.6979e-01],
            [ 7.5528e-01,  6.8919e-02,  2.2486e-01, -1.6395e-01, -7.9692e-01,
            3.7830e-01,  4.9367e-01,  2.4355e-02,  2.6674e-01, -1.1740e-01,
            -4.4945e-01, -3.6367e-01],
            [ 8.3467e-01,  1.7779e-01,  1.9504e-01, -1.6034e-01, -8.2783e-01,
            3.2627e-01,  5.0045e-01, -1.0181e-01,  4.4797e-01, -4.8046e-01,
            -3.7264e-01, -3.7392e-01],
            [ 8.4359e-01,  2.2699e-01,  1.9721e-01, -1.5768e-01, -7.5897e-01,
            3.3738e-01,  4.5559e-01, -1.0258e-01,  4.5782e-01, -3.8058e-01,
            -3.9275e-01, -3.8412e-01],
            [ 9.6349e-01,  1.6264e-01,  1.8207e-01, -1.6910e-01, -5.9304e-01,
            1.4468e-01,  2.4968e-01,  6.4794e-04,  5.4930e-01, -3.8420e-01,
            -4.2137e-01, -3.8016e-01]]]
        )

        self.assertEqual(states_target.shape, states.shape)
        self.assertTensorAlmostEqual(states_target, states)
Ejemplo n.º 13
0
def build_model(cfg: dict = None,
                src_vocab: Vocabulary = None,
                trg_vocab: Vocabulary = None) -> Model:
    """
    Build and initialize the model according to the configuration.

    :param cfg: dictionary configuration containing model specifications
    :param src_vocab: source vocabulary
    :param trg_vocab: target vocabulary
    :return: built and initialized model
    """
    logger.info("Building an encoder-decoder model...")
    src_padding_idx = src_vocab.stoi[PAD_TOKEN]
    trg_padding_idx = trg_vocab.stoi[PAD_TOKEN]

    src_embed = Embeddings(**cfg["encoder"]["embeddings"],
                           vocab_size=len(src_vocab),
                           padding_idx=src_padding_idx)

    # this ties source and target embeddings
    # for softmax layer tying, see further below
    if cfg.get("tied_embeddings", False):
        if src_vocab.itos == trg_vocab.itos:
            # share embeddings for src and trg
            trg_embed = src_embed
        else:
            raise ConfigurationError(
                "Embedding cannot be tied since vocabularies differ.")
    else:
        trg_embed = Embeddings(**cfg["decoder"]["embeddings"],
                               vocab_size=len(trg_vocab),
                               padding_idx=trg_padding_idx)

    # build encoder
    enc_dropout = cfg["encoder"].get("dropout", 0.)
    enc_emb_dropout = cfg["encoder"]["embeddings"].get("dropout", enc_dropout)
    if cfg["encoder"].get("type", "recurrent") == "transformer":
        assert cfg["encoder"]["embeddings"]["embedding_dim"] == \
               cfg["encoder"]["hidden_size"], \
               "for transformer, emb_size must be hidden_size"

        encoder = TransformerEncoder(**cfg["encoder"],
                                     emb_size=src_embed.embedding_dim,
                                     emb_dropout=enc_emb_dropout)
    else:
        encoder = RecurrentEncoder(**cfg["encoder"],
                                   emb_size=src_embed.embedding_dim,
                                   emb_dropout=enc_emb_dropout)

    # build decoder
    dec_dropout = cfg["decoder"].get("dropout", 0.)
    dec_emb_dropout = cfg["decoder"]["embeddings"].get("dropout", dec_dropout)
    if cfg["decoder"].get("type", "recurrent") == "transformer":
        decoder = TransformerDecoder(**cfg["decoder"],
                                     encoder=encoder,
                                     vocab_size=len(trg_vocab),
                                     emb_size=trg_embed.embedding_dim,
                                     emb_dropout=dec_emb_dropout)
    else:
        decoder = RecurrentDecoder(**cfg["decoder"],
                                   encoder=encoder,
                                   vocab_size=len(trg_vocab),
                                   emb_size=trg_embed.embedding_dim,
                                   emb_dropout=dec_emb_dropout)

    model = Model(encoder=encoder,
                  decoder=decoder,
                  src_embed=src_embed,
                  trg_embed=trg_embed,
                  src_vocab=src_vocab,
                  trg_vocab=trg_vocab)

    # tie softmax layer with trg embeddings
    if cfg.get("tied_softmax", False):
        if trg_embed.lut.weight.shape == \
                model.decoder.output_layer.weight.shape:
            # (also) share trg embeddings and softmax layer:
            model.decoder.output_layer.weight = trg_embed.lut.weight
        else:
            raise ConfigurationError(
                "For tied_softmax, the decoder embedding_dim and decoder "
                "hidden_size must be the same."
                "The decoder must be a Transformer.")

    # custom initialization of model parameters
    initialize_model(model, cfg, src_padding_idx, trg_padding_idx)

    # initialize embeddings from file
    pretrained_enc_embed_path = cfg["encoder"]["embeddings"].get(
        "load_pretrained", None)
    pretrained_dec_embed_path = cfg["decoder"]["embeddings"].get(
        "load_pretrained", None)
    if pretrained_enc_embed_path:
        logger.info("Loading pretraind src embeddings...")
        model.src_embed.load_from_file(pretrained_enc_embed_path, src_vocab)
    if pretrained_dec_embed_path and not cfg.get("tied_embeddings", False):
        logger.info("Loading pretraind trg embeddings...")
        model.trg_embed.load_from_file(pretrained_dec_embed_path, trg_vocab)

    logger.info("Enc-dec model built.")
    return model
Ejemplo n.º 14
0
def build_pretrained_model(cfg: dict = None,
                           pretrained_model: Model = None,
                           pretrained_src_vocab: Vocabulary = None,
                           src_vocab: Vocabulary = None,
                           trg_vocab: Vocabulary = None) -> Model:
    """
    Build and initialize the model according to the configuration.

    :param cfg: dictionary configuration containing model specifications
    :param src_vocab: source vocabulary
    :param trg_vocab: target vocabulary
    :return: built and initialized model
    """
    src_padding_idx = src_vocab.stoi[PAD_TOKEN]
    trg_padding_idx = trg_vocab.stoi[PAD_TOKEN]

    src_embed = Embeddings(**cfg["encoder"]["embeddings"],
                           vocab_size=len(src_vocab),
                           padding_idx=src_padding_idx)

    embedding_matrix = np.zeros((len(src_vocab), src_embed.embedding_dim))
    unknown_words = []
    for w in pretrained_src_vocab.itos:
        try:
            pre_ix = pretrained_src_vocab.stoi[w]
            ix = src_vocab.stoi[w]
            embedding_matrix[ix] = pretrained_model.src_embed.lut.weight[
                pre_ix].cpu().detach().numpy()
        except KeyError:
            unknown_words.append(w)

    src_embed.lut.weight = torch.nn.Parameter(
        torch.tensor(embedding_matrix, dtype=torch.float32))

    trg_embed = Embeddings(**cfg["decoder"]["embeddings"],
                           vocab_size=len(trg_vocab),
                           padding_idx=trg_padding_idx)

    # build decoder
    dec_dropout = cfg["decoder"].get("dropout", 0.)
    dec_emb_dropout = cfg["decoder"]["embeddings"].get("dropout", dec_dropout)

    encoder = pretrained_model.encoder
    encoder.train()
    set_requires_grad(encoder, True)

    # build encoder
    #enc_dropout = cfg["encoder"].get("dropout", 0.)
    #enc_emb_dropout = cfg["encoder"]["embeddings"].get("dropout", enc_dropout)
    #if cfg["encoder"].get("type", "recurrent") == "transformer":
    #    assert cfg["encoder"]["embeddings"]["embedding_dim"] == \
    #           cfg["encoder"]["hidden_size"], \
    #           "for transformer, emb_size must be hidden_size"

    #    encoder = TransformerEncoder(**cfg["encoder"],
    #                                 emb_size=src_embed.embedding_dim,
    #                                 emb_dropout=enc_emb_dropout)
    #else:
    #    encoder = RecurrentEncoder(**cfg["encoder"],
    #                               emb_size=src_embed.embedding_dim,
    #                               emb_dropout=enc_emb_dropout)

    if cfg["decoder"].get("type", "recurrent") == "transformer":
        decoder = TransformerDecoder(**cfg["decoder"],
                                     encoder=encoder,
                                     vocab_size=len(trg_vocab),
                                     emb_size=trg_embed.embedding_dim,
                                     emb_dropout=dec_emb_dropout)
    else:
        decoder = RecurrentDecoder(**cfg["decoder"],
                                   encoder=encoder,
                                   vocab_size=len(trg_vocab),
                                   emb_size=trg_embed.embedding_dim,
                                   emb_dropout=dec_emb_dropout)

    model = Model(encoder=encoder,
                  decoder=decoder,
                  src_embed=src_embed,
                  trg_embed=trg_embed,
                  src_vocab=pretrained_model.src_vocab,
                  trg_vocab=trg_vocab)

    # tie softmax layer with trg embeddings
    if cfg.get("tied_softmax", False):
        if trg_embed.lut.weight.shape == \
                model.decoder.output_layer.weight.shape:
            # (also) share trg embeddings and softmax layer:
            model.decoder.output_layer.weight = trg_embed.lut.weight
        else:
            raise ConfigurationError(
                "For tied_softmax, the decoder embedding_dim and decoder "
                "hidden_size must be the same."
                "The decoder must be a Transformer.")

    # custom initialization of model parameters
    #initialize_model(model, cfg, src_padding_idx, trg_padding_idx)

    return model
Ejemplo n.º 15
0
def build_unsupervised_nmt_model(
        cfg: dict = None,
        src_vocab: Vocabulary = None,
        trg_vocab: Vocabulary = None) -> UnsupervisedNMTModel:
    """
    Build an UnsupervisedNMTModel.

    :param cfg: model configuration
    :param src_vocab: Vocabulary for the src language
    :param trg_vocab: Vocabulary for the trg language
    :return: Unsupervised NMT model as specified in cfg
    """
    src_padding_idx = src_vocab.stoi[PAD_TOKEN]
    trg_padding_idx = trg_vocab.stoi[PAD_TOKEN]

    # build source and target embedding layers
    # embeddings in the encoder are pretrained and stay fixed
    loaded_src_embed = PretrainedEmbeddings(**cfg["encoder"]["embeddings"],
                                            vocab_size=len(src_vocab),
                                            padding_idx=src_padding_idx,
                                            vocab=src_vocab,
                                            freeze=True)

    loaded_trg_embed = PretrainedEmbeddings(**cfg["decoder"]["embeddings"],
                                            vocab_size=len(trg_vocab),
                                            padding_idx=trg_padding_idx,
                                            vocab=trg_vocab,
                                            freeze=True)

    # embeddings in the decoder are randomly initialised and will be learned
    src_embed = Embeddings(**cfg["encoder"]["embeddings"],
                           vocab_size=len(src_vocab),
                           padding_idx=src_padding_idx,
                           freeze=False)

    trg_embed = Embeddings(**cfg["decoder"]["embeddings"],
                           vocab_size=len(trg_vocab),
                           padding_idx=trg_padding_idx,
                           freeze=False)

    # build shared encoder
    enc_dropout = cfg["encoder"].get("dropout", 0.)
    enc_emb_dropout = cfg["encoder"]["embeddings"].get("dropout", enc_dropout)
    if cfg["encoder"].get("type", "recurrent") == "transformer":
        assert cfg["encoder"]["embeddings"]["embedding_dim"] == \
               cfg["encoder"]["hidden_size"], \
               "for transformer, emb_size must be hidden_size"

        shared_encoder = TransformerEncoder(**cfg["encoder"],
                                            emb_size=src_embed.embedding_dim,
                                            emb_dropout=enc_emb_dropout)
    else:
        shared_encoder = RecurrentEncoder(**cfg["encoder"],
                                          emb_size=src_embed.embedding_dim,
                                          emb_dropout=enc_emb_dropout)

    # build src and trg language decoder
    dec_dropout = cfg["decoder"].get("dropout", 0.)
    dec_emb_dropout = cfg["decoder"]["embeddings"].get("dropout", dec_dropout)
    if cfg["decoder"].get("type", "recurrent") == "transformer":
        src_decoder = TransformerDecoder(**cfg["decoder"],
                                         encoder=shared_encoder,
                                         vocab_size=len(src_vocab),
                                         emb_size=src_embed.embedding_dim,
                                         emb_dropout=dec_emb_dropout)
        trg_decoder = TransformerDecoder(**cfg["decoder"],
                                         encoder=shared_encoder,
                                         vocab_size=len(trg_vocab),
                                         emb_size=trg_embed.embedding_dim,
                                         emb_dropout=dec_emb_dropout)
    else:
        src_decoder = RecurrentDecoder(**cfg["decoder"],
                                       encoder=shared_encoder,
                                       vocab_size=len(src_vocab),
                                       emb_size=src_embed.embedding_dim,
                                       emb_dropout=dec_emb_dropout)
        trg_decoder = RecurrentDecoder(**cfg["decoder"],
                                       encoder=shared_encoder,
                                       vocab_size=len(trg_vocab),
                                       emb_size=trg_embed.embedding_dim,
                                       emb_dropout=dec_emb_dropout)

    # build unsupervised NMT model
    model = UnsupervisedNMTModel(loaded_src_embed, loaded_trg_embed, src_embed,
                                 trg_embed, shared_encoder, src_decoder,
                                 trg_decoder, src_vocab, trg_vocab)

    # initialise model
    # embed_initializer should be none so loaded encoder embeddings won't be overwritten
    initialize_model(model.src2src_translator, cfg, src_padding_idx,
                     src_padding_idx)
    initialize_model(model.src2trg_translator, cfg, src_padding_idx,
                     trg_padding_idx)
    initialize_model(model.trg2src_translator, cfg, trg_padding_idx,
                     src_padding_idx)
    initialize_model(model.trg2src_translator, cfg, trg_padding_idx,
                     trg_padding_idx)

    return model