示例#1
0
    def __init__(self, opt, dicts):
        super(HierModel, self).__init__()

        self.set_embeddings(opt, dicts)

        mem = opt.mem.split('_')

        self.sen_encoder = nn.LSTM(opt.word_vec_size,
                                   opt.word_vec_size // 2,
                                   num_layers=2,
                                   dropout=opt.dropout,
                                   bidirectional=1)

        if mem[0] == 'lstm':
            opt.rnn_size = self.embed_in.weight.size(1)
            self.diag_encoder = lstms.LSTMseq(opt, dicts, 'encode')
        elif mem[0] == 'dnc':
            self.diag_encoder = dnc.DNC(opt, 'diag_encode')

        if mem[1] == 'lstm':
            opt.word_vec_size = self.embed_out.weight.size(1)
            self.decoder = lstms.LSTMseq(opt, dicts, 'decode')
        elif mem[1] == 'dnc':
            self.decoder = dnc.DNC(opt, 'decode')

        self.merge_hidden = opt.merge_hidden
        if self.merge_hidden:
            self.merge_h = nn.Linear(2 * opt.word_vec_size, opt.word_vec_size)
            self.merge_c = nn.Linear(2 * opt.word_vec_size, opt.word_vec_size)

        self.forward = eval('self.hier_' + opt.mem)
        self.generate = False
示例#2
0
    def __init__(self, opt, dicts):
        super(HierModel, self).__init__()

        self.set_embeddings(opt, dicts)

        mem = opt.mem.split('_')

        def bi_lstm(l):
            return nn.LSTM(opt.word_vec_size,
                           opt.word_vec_size // 2,
                           num_layers=l,
                           dropout=opt.dropout,
                           bidirectional=1)

        # nn hierarchical models
        if opt.mem == 'dnc_dnc':
            opt.dropout = .6
            opt.attn = 0
            self.diag_encoder = dnc.DNC(opt, 'encode')
            self.decoder = dnc.DNC(opt, 'decode')

        elif opt.mem == 'baseline':
            self.diag_encoder = bi_lstm(2)
            self.decoder = lstms.LSTMseq(opt, dicts, 'decode')

        elif opt.mem == 'reasoning_nse':
            self.utt_encoder = bi_lstm(1)
            self.utt_decoder = lstms.LSTMseq(opt, dicts, 'init_decode')
            self.context_mem = bi_lstm(2)
            self.decoder = reasoning_nse.Tweak(opt)

        # hierarchical models
        else:
            mem = opt.mem.split('_')

            if mem[0] == 'lstm':
                self.utt_encoder = bi_lstm(2)
            elif mem[0] == 'dnc':
                opt.dropout = .6
                self.utt_encoder = dnc.DNC(opt, 'encode')

            if mem[1] == 'lstm':
                self.diag_encoder = bi_lstm(2)
                self.decoder = lstms.LSTMseq(opt, dicts, 'decode')
            elif mem[1] == 'dnc':
                opt.dropout = .6
                self.diag_encoder = dnc.DNC(opt, 'encode')
                self.decoder = dnc.DNC(opt, 'decode')

        self.forward = eval('self.' + opt.mem)
        self.generate = False
示例#3
0
    def __init__(self, opt, dicts):
        super(MemModel, self).__init__()

        self.set_embeddings(opt, dicts)

        mem = opt.mem.split('_')

        if mem[0] == 'lstm':
            opt.rnn_size = self.embed_in.weight.size(1)
            self.encoder = lstms.LSTMseq(opt, dicts, 'encode')
        elif mem[0] == 'dnc':
            self.encoder = dnc.DNC(opt, 'encode')

        if mem[1] == 'lstm':
            opt.word_vec_size = self.embed_out.weight.size(1)
            self.decoder = lstms.LSTMseq(opt, dicts, 'decode')
        elif mem[1] == 'dnc':
            self.decoder = dnc.DNC(opt, 'decode')

        self.forward = eval('self.' + opt.mem)
        self.generate = False
示例#4
0
    def __init__(self, opt, dicts):
        super(HierDAModel, self).__init__()

        self.set_embeddings(opt, dicts)

        mem = opt.mem.split('_')

        def bi_lstm(l):
            return nn.LSTM(opt.word_vec_size,
                           opt.word_vec_size // 2,
                           num_layers=l,
                           dropout=opt.dropout,
                           bidirectional=1)

        if opt.mem == 'DAreasoning_nse':
            self.utt_encoder = bi_lstm(1)
            self.utt_decoder = lstms.LSTMseq(opt, dicts, 'init_decode')
            self.context_mem = bi_lstm(2)
            self.decoder = reasoning_nse.Tweak(opt)
        if mem[1] == 'baseline':
            self.merge = nn.Sequential(
                nn.Linear(2 * opt.word_vec_size, opt.word_vec_size), nn.Tanh())
            self.diag_encoder = bi_lstm(2)
            self.decoder = lstms.LSTMseq(opt, dicts, 'decode')
        else:
            if mem[0] == 'lstm':
                self.utt_encoder = bi_lstm(2)
                self.merge = nn.Sequential(
                    nn.Linear(2 * opt.word_vec_size, opt.word_vec_size),
                    nn.Tanh())

            if mem[1] == 'hierda':
                self.diag_encoder = bi_lstm(2)
                self.decoder = lstms.LSTMseq(opt, dicts, 'decode')

        self.forward = eval('self.' + opt.mem)
        self.generate = False
    def __init__(self, opt, dicts):
        super(KeyContModel, self).__init__()

        self.set_embeddings(opt, dicts)

        mem = opt.mem.split('_')

        self.act_embedding = nn.Embedding(opt.act_vec_size,
                                          opt.word_vec_size,
                                          padding_idx=Constants.PAD)

        def bi_lstm(l):
            return nn.LSTM(opt.word_vec_size,
                           opt.word_vec_size // 2,
                           num_layers=l,
                           dropout=opt.dropout,
                           bidirectional=1)

        self.sen_encoder = bi_lstm(2)
        self.sen_decoder = lstms.LSTMseq(opt, dicts, 'decode')

        if mem[0] == 'lstm':
            self.context_encoder = nn.LSTM(opt.word_vec_size,
                                           opt.word_vec_size,
                                           num_layers=1,
                                           dropout=opt.dropout,
                                           bidirectional=0)

        if mem[0] == 'dnc':
            opt.dropout = .6
            self.context_encoder = dnc.DNC(opt, 'encode')

        if mem[1] == 'lstm':
            self.context_attention = memories.attention.GlobalAttention(
                opt.word_vec_size)
        elif mem[1] == 'dnc':
            opt.dropout = .6
            opt.attn = 0
            self.context_attention = dnc.DNC(opt, 'act_decode')

        self.forward = eval('self.' + opt.mem)
        self.generate = False