def __init__(self, opt, dicts): super(HierModel, self).__init__() self.set_embeddings(opt, dicts) mem = opt.mem.split('_') self.sen_encoder = nn.LSTM(opt.word_vec_size, opt.word_vec_size // 2, num_layers=2, dropout=opt.dropout, bidirectional=1) if mem[0] == 'lstm': opt.rnn_size = self.embed_in.weight.size(1) self.diag_encoder = lstms.LSTMseq(opt, dicts, 'encode') elif mem[0] == 'dnc': self.diag_encoder = dnc.DNC(opt, 'diag_encode') if mem[1] == 'lstm': opt.word_vec_size = self.embed_out.weight.size(1) self.decoder = lstms.LSTMseq(opt, dicts, 'decode') elif mem[1] == 'dnc': self.decoder = dnc.DNC(opt, 'decode') self.merge_hidden = opt.merge_hidden if self.merge_hidden: self.merge_h = nn.Linear(2 * opt.word_vec_size, opt.word_vec_size) self.merge_c = nn.Linear(2 * opt.word_vec_size, opt.word_vec_size) self.forward = eval('self.hier_' + opt.mem) self.generate = False
def __init__(self, opt, dicts): super(HierModel, self).__init__() self.set_embeddings(opt, dicts) mem = opt.mem.split('_') def bi_lstm(l): return nn.LSTM(opt.word_vec_size, opt.word_vec_size // 2, num_layers=l, dropout=opt.dropout, bidirectional=1) # nn hierarchical models if opt.mem == 'dnc_dnc': opt.dropout = .6 opt.attn = 0 self.diag_encoder = dnc.DNC(opt, 'encode') self.decoder = dnc.DNC(opt, 'decode') elif opt.mem == 'baseline': self.diag_encoder = bi_lstm(2) self.decoder = lstms.LSTMseq(opt, dicts, 'decode') elif opt.mem == 'reasoning_nse': self.utt_encoder = bi_lstm(1) self.utt_decoder = lstms.LSTMseq(opt, dicts, 'init_decode') self.context_mem = bi_lstm(2) self.decoder = reasoning_nse.Tweak(opt) # hierarchical models else: mem = opt.mem.split('_') if mem[0] == 'lstm': self.utt_encoder = bi_lstm(2) elif mem[0] == 'dnc': opt.dropout = .6 self.utt_encoder = dnc.DNC(opt, 'encode') if mem[1] == 'lstm': self.diag_encoder = bi_lstm(2) self.decoder = lstms.LSTMseq(opt, dicts, 'decode') elif mem[1] == 'dnc': opt.dropout = .6 self.diag_encoder = dnc.DNC(opt, 'encode') self.decoder = dnc.DNC(opt, 'decode') self.forward = eval('self.' + opt.mem) self.generate = False
def __init__(self, opt, dicts): super(MemModel, self).__init__() self.set_embeddings(opt, dicts) mem = opt.mem.split('_') if mem[0] == 'lstm': opt.rnn_size = self.embed_in.weight.size(1) self.encoder = lstms.LSTMseq(opt, dicts, 'encode') elif mem[0] == 'dnc': self.encoder = dnc.DNC(opt, 'encode') if mem[1] == 'lstm': opt.word_vec_size = self.embed_out.weight.size(1) self.decoder = lstms.LSTMseq(opt, dicts, 'decode') elif mem[1] == 'dnc': self.decoder = dnc.DNC(opt, 'decode') self.forward = eval('self.' + opt.mem) self.generate = False
def __init__(self, opt, dicts): super(HierDAModel, self).__init__() self.set_embeddings(opt, dicts) mem = opt.mem.split('_') def bi_lstm(l): return nn.LSTM(opt.word_vec_size, opt.word_vec_size // 2, num_layers=l, dropout=opt.dropout, bidirectional=1) if opt.mem == 'DAreasoning_nse': self.utt_encoder = bi_lstm(1) self.utt_decoder = lstms.LSTMseq(opt, dicts, 'init_decode') self.context_mem = bi_lstm(2) self.decoder = reasoning_nse.Tweak(opt) if mem[1] == 'baseline': self.merge = nn.Sequential( nn.Linear(2 * opt.word_vec_size, opt.word_vec_size), nn.Tanh()) self.diag_encoder = bi_lstm(2) self.decoder = lstms.LSTMseq(opt, dicts, 'decode') else: if mem[0] == 'lstm': self.utt_encoder = bi_lstm(2) self.merge = nn.Sequential( nn.Linear(2 * opt.word_vec_size, opt.word_vec_size), nn.Tanh()) if mem[1] == 'hierda': self.diag_encoder = bi_lstm(2) self.decoder = lstms.LSTMseq(opt, dicts, 'decode') self.forward = eval('self.' + opt.mem) self.generate = False
def __init__(self, opt, dicts): super(KeyContModel, self).__init__() self.set_embeddings(opt, dicts) mem = opt.mem.split('_') self.act_embedding = nn.Embedding(opt.act_vec_size, opt.word_vec_size, padding_idx=Constants.PAD) def bi_lstm(l): return nn.LSTM(opt.word_vec_size, opt.word_vec_size // 2, num_layers=l, dropout=opt.dropout, bidirectional=1) self.sen_encoder = bi_lstm(2) self.sen_decoder = lstms.LSTMseq(opt, dicts, 'decode') if mem[0] == 'lstm': self.context_encoder = nn.LSTM(opt.word_vec_size, opt.word_vec_size, num_layers=1, dropout=opt.dropout, bidirectional=0) if mem[0] == 'dnc': opt.dropout = .6 self.context_encoder = dnc.DNC(opt, 'encode') if mem[1] == 'lstm': self.context_attention = memories.attention.GlobalAttention( opt.word_vec_size) elif mem[1] == 'dnc': opt.dropout = .6 opt.attn = 0 self.context_attention = dnc.DNC(opt, 'act_decode') self.forward = eval('self.' + opt.mem) self.generate = False