Esempio n. 1
0
    def __init__(self, l1_vocab_size, l2_vocab_size, embedding_size,
                 hidden_size, rnn_type, n_layers_encoder, n_layers_decoder,
                 dropout):
        super(Model, self).__init__()
        self.l1_vocab_size = l1_vocab_size
        self.l2_vocab_size = l2_vocab_size
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.rnn_type = rnn_type
        self.n_layers_encoder = n_layers_encoder
        self.n_layers_decoder = n_layers_decoder
        self.dropout = dropout

        # Embed
        self.l1_embeddings = FixedEmbedding(l1_vocab_size, embedding_size)
        self.l2_embeddings = FixedEmbedding(l2_vocab_size, embedding_size)
        # Encode
        self.encoder = RNNEncoder(rnn_type=rnn_type,
                                  bidirectional=True,
                                  num_layers=n_layers_encoder,
                                  hidden_size=hidden_size,
                                  dropout=dropout,
                                  l1_embeddings=self.l1_embeddings,
                                  l2_embeddings=self.l2_embeddings)
        # Decode
        self.l1_decoder = InputFeedRNNDecoder(
            rnn_type=rnn_type,
            bidirectional_encoder=True,
            num_layers=n_layers_decoder,
            hidden_size=hidden_size,
            attn_type='general',
            coverage_attn=None,  # Not supported.
            context_gate=None,
            copy_attn=None,  # Not supported.
            dropout=dropout,
            embeddings=self.l1_embeddings)
        self.l2_decoder = InputFeedRNNDecoder(
            rnn_type=rnn_type,
            bidirectional_encoder=True,
            num_layers=n_layers_decoder,
            hidden_size=hidden_size,
            attn_type='general',
            coverage_attn=None,  # Not supported.
            context_gate=None,
            copy_attn=None,  # Not supported.
            dropout=dropout,
            embeddings=self.l2_embeddings)
        # Project
        self.l1_to_vocab = nn.Linear(hidden_size, l1_vocab_size)
        self.l2_to_vocab = nn.Linear(hidden_size, l2_vocab_size)
Esempio n. 2
0
def make_decoder(opt, embeddings):
    """
    Various decoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this decoder.
    """
    if opt.decoder_type == "transformer":
        return TransformerDecoder(opt.dec_layers, opt.rnn_size,
                                  opt.global_attention, opt.copy_attn,
                                  opt.dropout, embeddings)
    elif opt.decoder_type == "cnn":
        return CNNDecoder(opt.dec_layers, opt.rnn_size, opt.global_attention,
                          opt.copy_attn, opt.cnn_kernel_width, opt.dropout,
                          embeddings)
    elif opt.input_feed:
        return InputFeedRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers,
                                   opt.rnn_size, opt.global_attention,
                                   opt.coverage_attn, opt.context_gate,
                                   opt.copy_attn, opt.dropout, embeddings)
    else:
        return StdRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers,
                             opt.rnn_size, opt.global_attention,
                             opt.coverage_attn, opt.context_gate,
                             opt.copy_attn, opt.dropout, embeddings)
Esempio n. 3
0
def make_decoder(opt, embeddings):
    """
    Various decoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this decoder.
    """
    if opt.decoder_type == "transformer":
        return TransformerDecoder(opt.dec_layers, opt.rnn_size,
                                  opt.global_attention, opt.copy_attn,
                                  opt.dropout, embeddings)
    elif opt.decoder_type == "cnn":
        return CNNDecoder(opt.dec_layers, opt.rnn_size, opt.global_attention,
                          opt.copy_attn, opt.cnn_kernel_width, opt.dropout,
                          embeddings)
    elif opt.input_feed:
        return InputFeedRNNDecoder(
            opt.rnn_type, opt.brnn, opt.dec_layers, opt.rnn_size,
            opt.global_attention, opt.coverage_attn, opt.context_gate,
            opt.copy_attn, opt.dropout, embeddings, opt.affective_attention,
            opt.affective_attn_strength)  # Add affective attention options px
    else:
        return StdRNNDecoder(
            opt.rnn_type, opt.brnn, opt.dec_layers, opt.rnn_size,
            opt.global_attention, opt.coverage_attn, opt.context_gate,
            opt.copy_attn, opt.dropout, embeddings, opt.affective_attention,
            opt.affective_attn_strength,
            opt.word_freq if "word_freq" in opt else "",
            opt.local_weights if "local_weights" in opt else False)
Esempio n. 4
0
def make_decoder(opt, embeddings):
    """
    Various decoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this decoder.
    """
    if opt.decoder_type == "transformer":
        return TransformerDecoder(opt.dec_layers, opt.rnn_size,
                                  opt.global_attention, opt.copy_attn,
                                  opt.dropout, embeddings)
    elif opt.decoder_type == "cnn":
        return CNNDecoder(opt.dec_layers, opt.rnn_size, opt.global_attention,
                          opt.copy_attn, opt.cnn_kernel_width, opt.dropout,
                          embeddings)
    elif opt.input_feed and opt.inference_network_type == "none":
        print("input feed")
        return InputFeedRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers,
                                   opt.memory_size, opt.decoder_rnn_size,
                                   opt.attention_size, opt.global_attention,
                                   opt.coverage_attn, opt.context_gate,
                                   opt.copy_attn, opt.dropout, embeddings,
                                   opt.reuse_copy_attn)
    elif opt.input_feed and opt.inference_network_type != "none":
        print("VARIATIONAL DECODER")
        scoresFstring = opt.alpha_transformation
        scoresF = scoresF_dict[scoresFstring]

        return ViRNNDecoder(
            opt.rnn_type,
            opt.brnn,
            opt.dec_layers,
            memory_size=opt.memory_size,
            hidden_size=opt.decoder_rnn_size,
            attn_size=opt.attention_size,
            attn_type=opt.global_attention,
            coverage_attn=opt.coverage_attn,
            context_gate=opt.context_gate,
            copy_attn=opt.copy_attn,
            dropout=opt.dropout,
            embeddings=embeddings,
            reuse_copy_attn=opt.reuse_copy_attn,
            p_dist_type=opt.p_dist_type,
            q_dist_type=opt.q_dist_type,
            use_prior=opt.use_generative_model > 0,
            scoresF=scoresF,
            n_samples=opt.n_samples,
            mode=opt.mode,
        )
    else:
        return StdRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers,
                             opt.rnn_size, opt.global_attention,
                             opt.coverage_attn, opt.context_gate,
                             opt.copy_attn, opt.dropout, embeddings,
                             opt.reuse_copy_attn)
Esempio n. 5
0
def make_decoder(opt, embeddings):
    """
    Various decoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this decoder.
    """
    if opt.decoder_type == "transformer":
        return TransformerDecoder(opt.dec_layers, opt.rnn_size,
                                  opt.global_attention, opt.copy_attn,
                                  opt.dropout, embeddings)
    elif opt.decoder_type == "cnn":
        return CNNDecoder(opt.dec_layers, opt.rnn_size,
                          opt.global_attention, opt.copy_attn,
                          opt.cnn_kernel_width, opt.dropout,
                          embeddings)
    elif opt.model_type == "hierarchical_text":
        print("ModelConstructor line:96 hierarchical text decoder")
        return HierarchicalInputFeedRNNDecoder(opt.rnn_type, opt.brnn,
                                   opt.dec_layers, opt.rnn_size,
                                   opt.global_attention,
                                   opt.coverage_attn,
                                   opt.context_gate,
                                   opt.copy_attn,
                                   opt.dropout,
                                   embeddings,
                                   opt.reuse_copy_attn,
                                   opt.model_type,
                                   opt.hier_add_word_enc_input )
    elif opt.input_feed:
        return InputFeedRNNDecoder(opt.rnn_type, opt.brnn,
                                   opt.dec_layers, opt.rnn_size,
                                   opt.global_attention,
                                   opt.coverage_attn,
                                   opt.context_gate,
                                   opt.copy_attn,
                                   opt.dropout,
                                   embeddings,
                                   opt.reuse_copy_attn)
    else:
        return StdRNNDecoder(opt.rnn_type, opt.brnn,
                             opt.dec_layers, opt.rnn_size,
                             opt.global_attention,
                             opt.coverage_attn,
                             opt.context_gate,
                             opt.copy_attn,
                             opt.dropout,
                             embeddings,
                             opt.reuse_copy_attn)
Esempio n. 6
0
def make_decoder(opt, embeddings, mmod_dcap=False):
    """
    Various decoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this decoder.
    """
    if opt.decoder_type == "transformer":
        if mmod_dcap:
            return multimodal.CapsuleTransformerDecoder(
                opt.dec_layers, opt.rnn_size,
                opt.global_attention, opt.copy_attn,
                opt.dropout, embeddings,
                opt.num_iterations, opt.num_capsules, opt.num_regions
            )

        else:
            return TransformerDecoder(opt.dec_layers, opt.rnn_size,
                                      opt.global_attention, opt.copy_attn,
                                      opt.dropout, embeddings)
    elif opt.decoder_type == "cnn":
        return CNNDecoder(opt.dec_layers, opt.rnn_size,
                          opt.global_attention, opt.copy_attn,
                          opt.cnn_kernel_width, opt.dropout,
                          embeddings)
    elif opt.input_feed:
        return InputFeedRNNDecoder(opt.rnn_type, opt.brnn,
                                   opt.dec_layers, opt.rnn_size,
                                   opt.global_attention,
                                   opt.coverage_attn,
                                   opt.context_gate,
                                   opt.copy_attn,
                                   opt.dropout,
                                   embeddings,
                                   opt.reuse_copy_attn)
    else:
        return StdRNNDecoder(opt.rnn_type, opt.brnn,
                             opt.dec_layers, opt.rnn_size,
                             opt.global_attention,
                             opt.coverage_attn,
                             opt.context_gate,
                             opt.copy_attn,
                             opt.dropout,
                             embeddings,
                             opt.reuse_copy_attn)
Esempio n. 7
0
def make_decoder(opt, embeddings):
    """
    Various decoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this decoder.
    """
    if opt.decoder_type == "transformer":
        return TransformerDecoder(opt.dec_layers, opt.rnn_size,
                                  opt.global_attention, opt.copy_attn,
                                  opt.dropout, embeddings)
    elif opt.decoder_type == "cnn":
        return CNNDecoder(opt.dec_layers, opt.rnn_size, opt.global_attention,
                          opt.copy_attn, opt.cnn_kernel_width, opt.dropout,
                          embeddings)
    elif opt.input_feed:
        print("input feed")
        if opt.inference_network_type == 'none':
            return InputFeedRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers,
                                       opt.rnn_size, opt.global_attention,
                                       opt.coverage_attn, opt.context_gate,
                                       opt.copy_attn, opt.dropout, embeddings,
                                       opt.reuse_copy_attn, opt.dist_type)
        else:
            return ViInputFeedRNNDecoder(opt.rnn_type,
                                         opt.brnn,
                                         opt.dec_layers,
                                         opt.rnn_size,
                                         opt.global_attention,
                                         opt.coverage_attn,
                                         opt.context_gate,
                                         opt.copy_attn,
                                         opt.dropout,
                                         embeddings,
                                         opt.reuse_copy_attn,
                                         dist_type=opt.dist_type,
                                         normalization=opt.prior_normalization)
    else:
        return StdRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers,
                             opt.rnn_size, opt.global_attention,
                             opt.coverage_attn, opt.context_gate,
                             opt.copy_attn, opt.dropout, embeddings,
                             opt.reuse_copy_attn, opt.dist_type)
def make_decoder(opt, embeddings, stage1, basic_enc_dec):
    """
    Various decoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this decoder.
        stage1: stage1 decoder
    """
    return InputFeedRNNDecoder(opt.rnn_type,
                               opt.brnn2,
                               opt.dec_layers2,
                               opt.rnn_size,
                               opt.global_attention,
                               opt.coverage_attn,
                               opt.context_gate,
                               True,
                               opt.dropout,
                               embeddings,
                               opt.reuse_copy_attn,
                               hier_attn=True)
Esempio n. 9
0
def make_decoder(opt, embeddings):

    if opt.decoder_type == "transformer":
        return TransformerDecoder(opt.dec_layers, opt.rnn_size,
                                  opt.global_attention, opt.copy_attn,
                                  opt.dropout, embeddings)
    elif opt.decoder_type == "cnn":
        return CNNDecoder(opt.dec_layers, opt.rnn_size, opt.global_attention,
                          opt.copy_attn, opt.cnn_kernel_width, opt.dropout,
                          embeddings)
    elif opt.input_feed:
        return InputFeedRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers,
                                   opt.rnn_size, opt.global_attention,
                                   opt.coverage_attn, opt.context_gate,
                                   opt.copy_attn, opt.dropout, embeddings,
                                   opt.reuse_copy_attn)
    else:
        return StdRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers,
                             opt.rnn_size, opt.global_attention,
                             opt.coverage_attn, opt.context_gate,
                             opt.copy_attn, opt.dropout, embeddings,
                             opt.reuse_copy_attn)
Esempio n. 10
0
def make_decoder(opt, embeddings):
    """
    Various decoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this decoder.
    """
    if opt.decoder_type == "transformer":
        return TransformerDecoder(opt.dec_layers, opt.rnn_size,
                                  opt.global_attention, opt.copy_attn,
                                  opt.dropout, embeddings)
    elif opt.decoder_type == "cnn":
        return CNNDecoder(opt.dec_layers, opt.rnn_size, opt.global_attention,
                          opt.copy_attn, opt.cnn_kernel_width, opt.dropout,
                          embeddings)

    # NOTE: THIS IS WHAT GETS TRIGGERED IN DEFAULT EXPERIMENTS
    elif opt.decoder_type == "charrnn":

        print(f"opt.rnn_type={opt.rnn_type}")
        print(f"opt.brnn={opt.brnn}")
        print(f"opt.dec_lay={opt.dec_layers}")
        print(f"opt.rnn_size={opt.rnn_size}")
        print(f"opt.global_attention={opt.global_attention}")
        print(f"opt.coverage_attn={opt.coverage_attn}")
        print(f"opt.context_gate={opt.context_gate}")
        print(f"opt.copy_attn={opt.copy_attn}")
        print(f"opt.dropout={opt.dropout}")
        print(f"embeddings={embeddings}")
        print(f"opt.reuse_copy_attn={opt.reuse_copy_attn}")
        """
        rnn_type, 
        bidirectional_encoder, 
        num_layers,
        hidden_size, 
        attn_type="general",
        coverage_attn=False, 
        context_gate=None,
        copy_attn=False,
        dropout=0.0, 
        embeddings=None,
        reuse_copy_attn=False):
        """

        return [
            StdWordRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers,
                              opt.rnn_size, opt.global_attention,
                              opt.coverage_attn, opt.context_gate,
                              opt.copy_attn, opt.dropout, embeddings,
                              opt.reuse_copy_attn),
            StdCharRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers,
                              opt.rnn_size, opt.global_attention,
                              opt.coverage_attn, opt.context_gate,
                              opt.copy_attn, opt.dropout, embeddings,
                              opt.reuse_copy_attn)
        ]

    elif opt.input_feed:
        return InputFeedRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers,
                                   opt.rnn_size, opt.global_attention,
                                   opt.coverage_attn, opt.context_gate,
                                   opt.copy_attn, opt.dropout, embeddings,
                                   opt.reuse_copy_attn)
    else:
        return StdRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers,
                             opt.rnn_size, opt.global_attention,
                             opt.coverage_attn, opt.context_gate,
                             opt.copy_attn, opt.dropout, embeddings,
                             opt.reuse_copy_attn)