Exemple #1
0
    def __init__(self,
                 Vi,
                 Ei,
                 Hi,
                 init_orth=False,
                 use_bn_length=0,
                 cell_type=rnn_cells.LSTMCell):
        gru_f = cell_type(Ei, Hi)
        gru_b = cell_type(Ei, Hi)

        log.info("constructing encoder [%s]" % (cell_type, ))
        super(Encoder, self).__init__(
            emb=L.EmbedID(Vi, Ei),
            #             gru_f = L.GRU(Hi, Ei),
            #             gru_b = L.GRU(Hi, Ei)
            gru_f=gru_f,
            gru_b=gru_b)
        self.Hi = Hi

        if use_bn_length > 0:
            self.add_link("bn_f", BNList(Hi, use_bn_length))
#             self.add_link("bn_b", BNList(Hi, use_bn_length)) #TODO
        self.use_bn_length = use_bn_length

        if init_orth:
            ortho_init(self.gru_f)
            ortho_init(self.gru_b)
Exemple #2
0
    def __init__(self,
                 Hi,
                 Ha,
                 Ho,
                 init_orth=False,
                 prev_word_embedding_size=None):
        super(AttentionModule, self).__init__(al_lin_h=L.Linear(Hi,
                                                                Ha,
                                                                nobias=False),
                                              al_lin_s=L.Linear(Ho,
                                                                Ha,
                                                                nobias=True),
                                              al_lin_o=L.Linear(Ha,
                                                                1,
                                                                nobias=True))
        self.Hi = Hi
        self.Ha = Ha

        if prev_word_embedding_size is not None:
            self.add_link("al_lin_y", L.Linear(prev_word_embedding_size, Ha))

        if init_orth:
            ortho_init(self.al_lin_h)
            ortho_init(self.al_lin_s)
            ortho_init(self.al_lin_o)
Exemple #3
0
    def __init__(self,
                 Vo,
                 Eo,
                 Ho,
                 Ha,
                 Hi,
                 Hl,
                 attn_cls=AttentionModule,
                 init_orth=False,
                 cell_type=rnn_cells.LSTMCell,
                 use_goto_attention=False):
        #         assert cell_type in "gru dgru lstm slow_gru".split()
        #         self.cell_type = cell_type
        #         if cell_type == "gru":
        #             gru = faster_gru.GRU(Ho, Eo + Hi)
        #         elif cell_type == "dgru":
        #             gru = DoubleGRU(Ho, Eo + Hi)
        #         elif cell_type == "lstm":
        #             gru = L.StatelessLSTM(Eo + Hi, Ho) #, forget_bias_init = 3)
        #         elif cell_type == "slow_gru":
        #             gru = L.GRU(Ho, Eo + Hi)

        if isinstance(cell_type, (str, unicode)):
            cell_type = rnn_cells.create_cell_model_from_string(cell_type)

        gru = cell_type(Eo + Hi, Ho)

        log.info("constructing decoder [%r]" % (cell_type, ))
        if use_goto_attention:
            log.info("using 'Goto' attention")
        super(Decoder, self).__init__(
            emb=L.EmbedID(Vo, Eo),
            #             gru = L.GRU(Ho, Eo + Hi),
            gru=gru,
            maxo=L.Maxout(Eo + Hi + Ho, Hl, 2),
            lin_o=L.Linear(Hl, Vo, nobias=False),
            attn_module=attn_cls(
                Hi,
                Ha,
                Ho,
                init_orth=init_orth,
                prev_word_embedding_size=Eo if use_goto_attention else None))
        #         self.add_param("initial_state", (1, Ho))
        self.add_param("bos_embeding", (1, Eo))

        self.use_goto_attention = use_goto_attention
        self.Hi = Hi
        self.Ho = Ho
        self.Eo = Eo
        #         self.initial_state.data[...] = np.random.randn(Ho)
        self.bos_embeding.data[...] = np.random.randn(Eo)

        if init_orth:
            ortho_init(self.gru)
            ortho_init(self.lin_o)
            ortho_init(self.maxo)