def load_ext_net(ext_dir):
    ext_meta = json.load(open(join(ext_dir, 'meta.json')))
    assert ext_meta['net'] == 'ml_rnn_extractor'
    ext_ckpt = load_best_ckpt(ext_dir)
    ext_args = ext_meta['net_args']
    vocab = pkl.load(open(join(ext_dir, 'vocab.pkl'), 'rb'))
    ext = PtrExtractSumm(**ext_args)
    ext.load_state_dict(ext_ckpt)
    return ext, vocab
Пример #2
0
def load_ext_net(ext_dir):
    ext_meta = json.load(open(join(ext_dir, 'meta.json')))
    assert ext_meta['net'] == 'ml_rnn_extractor' or ext_meta[
        'net'] == "ml_entity_extractor"
    ext_ckpt = load_best_ckpt(ext_dir)
    ext_args = ext_meta['net_args']
    vocab = pkl.load(open(join(ext_dir, 'vocab.pkl'), 'rb'))
    if ext_meta['net'] == 'ml_rnn_extractor':
        ext = PtrExtractSumm(**ext_args)
    elif ext_meta['net'] == "ml_entity_extractor":
        ext = PtrExtractSummEntity(**ext_args)
    else:
        raise Exception('not implemented')
    ext.load_state_dict(ext_ckpt)
    return ext, vocab
Пример #3
0
def configure_net(net_type, vocab_size, emb_dim, conv_hidden,
                  lstm_hidden, lstm_layer, bidirectional, use_bert,
                  bert_type, bert_cache, tokenizer_cache, cuda, aux_device, fix_bert):
    assert net_type in ['ff', 'rnn']
    net_args = {}
    net_args['conv_hidden']   = conv_hidden
    net_args['lstm_hidden']   = lstm_hidden
    net_args['lstm_layer']    = lstm_layer
    net_args['bidirectional'] = bidirectional

    if not use_bert:
        net_args['vocab_size']    = vocab_size
        net_args['emb_dim']       = emb_dim

        net = (ExtractSumm(**net_args) if net_type == 'ff'
           else PtrExtractSumm(**net_args))
        
        if cuda:
            net = net.cuda()
    else:
        # bert config
        net_args['bert_type'] = bert_type
        net_args['bert_cache'] = bert_cache
        net_args['tokenizer_cache'] = tokenizer_cache
        net_args['fix_bert'] = fix_bert

        # add aux cuda
        added_net_args = dict(net_args)
        added_net_args['aux_device'] = aux_device
        net = BertPtrExtractSumm(**added_net_args)

    return net, net_args
Пример #4
0
def configure_net(net_type,
                  vocab_size,
                  emb_dim,
                  conv_hidden,
                  lstm_hidden,
                  lstm_layer,
                  bidirectional,
                  prev_ckpt=None):
    assert net_type in ['ff', 'rnn', 'trans_rnn']
    net_args = {}
    net_args['vocab_size'] = vocab_size
    net_args['emb_dim'] = emb_dim
    net_args['conv_hidden'] = conv_hidden
    net_args['lstm_hidden'] = lstm_hidden
    net_args['lstm_layer'] = lstm_layer
    net_args['bidirectional'] = bidirectional

    if net_type == 'ff':
        net = ExtractSumm(**net_args)
    elif net_type == 'trans_rnn':
        net = TransExtractSumm(**net_args)
    else:
        net = PtrExtractSumm(**net_args)
    if prev_ckpt is not None:
        ext_ckpt = load_best_ckpt(prev_ckpt)
        net.load_state_dict(ext_ckpt)
    return net, net_args
Пример #5
0
def configure_net(net_type, vocab_size, emb_dim, lstm_hidden, lstm_layer,
                  bidirectional):
    assert net_type in ['ff', 'rnn']
    net_args = {}
    net_args['vocab_size'] = vocab_size
    net_args['emb_dim'] = emb_dim
    net_args['lstm_hidden'] = lstm_hidden
    net_args['lstm_layer'] = lstm_layer
    net_args['bidirectional'] = bidirectional

    net = PtrExtractSumm(**net_args)
    return net, net_args
Пример #6
0
 def __init__(self, ext_dir, cuda=True):
     ext_meta = json.load(open(join(ext_dir, 'meta.json')))
     assert ext_meta['net'] == 'rnn-ext_abs_rl'
     ext_args = ext_meta['net_args']['extractor']['net_args']
     word2id = pkl.load(open(join(ext_dir, 'agent_vocab.pkl'), 'rb'))
     extractor = PtrExtractSumm(**ext_args)
     agent = ActorCritic(extractor, ArticleBatcher(word2id, cuda))
     ext_ckpt = load_best_ckpt(ext_dir, reverse=True)
     agent.load_state_dict(ext_ckpt)
     self._device = torch.device('cuda' if cuda else 'cpu')
     self._net = agent.to(self._device)
     self._word2id = word2id
     self._id2word = {i: w for w, i in word2id.items()}
Пример #7
0
def configure_net(net_type, vocab_size, emb_dim, conv_hidden,
                  lstm_hidden, lstm_layer, bidirectional):
    assert net_type in ['ff', 'rnn', 'nnse']
    net_args = {}
    net_args['vocab_size']    = vocab_size
    net_args['emb_dim']       = emb_dim
    net_args['conv_hidden']   = conv_hidden
    net_args['lstm_hidden']   = lstm_hidden
    net_args['lstm_layer']    = lstm_layer
    net_args['bidirectional'] = bidirectional


    if net_type in ['ff', 'rnn']:
        net = (ExtractSumm(**net_args) if net_type == 'ff'
           else PtrExtractSumm(**net_args))
    elif net_type == 'nnse':
        net = NNSESumm(**net_args)
    return net, net_args
Пример #8
0
def configure_net(net_type, vocab_size, emb_dim, conv_hidden,
                  lstm_hidden, lstm_layer, bidirectional):
    assert net_type in ['ff', 'rnn']
    net_args = {}
    net_args['vocab_size'] = vocab_size
    net_args['emb_dim'] = emb_dim
    net_args['conv_hidden'] = conv_hidden
    net_args['lstm_hidden'] = lstm_hidden
    net_args['lstm_layer'] = lstm_layer
    net_args['bidirectional'] = bidirectional

    net_args['dropoute'] = 0.0  # dropout to remove words from embedding layer (0 = no dropout)
    net_args['dropout'] = 0.2   # dropout applied to other layers (0 = no dropout)
    net_args['wdrop'] = 0.2     # amount of weight dropout to apply to the RNN hidden to hidden matrix
    net_args['dropouth'] = 0.2  # dropout for rnn layers (0 = no dropout)

    net = (ExtractSumm(**net_args) if net_type == 'ff'
           else PtrExtractSumm(**net_args))
    return net, net_args
def configure_net(net_type, vocab_size, emb_dim, conv_hidden,
                  lstm_hidden, lstm_layer, bidirectional, pe, petrainable, stop):
    assert net_type in ['ff', 'rnn', 'nnse']
    net_args = {}
    net_args['vocab_size']    = vocab_size
    net_args['emb_dim']       = emb_dim
    net_args['conv_hidden']   = conv_hidden
    net_args['lstm_hidden']   = lstm_hidden
    net_args['lstm_layer']    = lstm_layer
    net_args['bidirectional'] = bidirectional
    net_args['pe'] = pe # positional encoding
    net_args['petrainable'] = petrainable
    net_args['stop'] = stop

    if net_type in ['ff', 'rnn']:
        net = (ExtractSumm(**net_args) if net_type == 'ff'
           else PtrExtractSumm(**net_args))
    elif net_type == 'nnse':
        net = NNSESumm(**net_args)
    return net, net_args
def load_rl_ckpt(abs_dir, ext_dir, cuda):
    ext_meta = json.load(open(join(ext_dir, 'meta.json')))
    assert ext_meta['net'] == 'rnn-ext_abs_rl'
    ext_args = ext_meta['net_args']['extractor']['net_args']
    vocab = pkl.load(open(join(ext_dir, 'agent_vocab.pkl'), 'rb'))
    extractor = PtrExtractSumm(**ext_args)
    abstractor_sent = Abstractor(abs_dir, MAX_ABS_LEN, cuda)

    agent = ActorCritic(extractor._sent_enc, extractor._art_enc,
                        extractor._extractor, ArticleBatcher(vocab, cuda))

    target_agent = ActorCritic(extractor._sent_enc,
                               extractor._art_enc, extractor._extractor,
                               ArticleBatcher(vocab, cuda))

    ext_ckpt = load_best_ckpt(ext_dir, reverse=True)
    agent.load_state_dict(ext_ckpt)

    device = torch.device('cuda' if cuda else 'cpu')
    agent = agent.to(device)
    target_agent = target_agent.to(device)

    return agent, target_agent, vocab, abstractor_sent, ext_meta
Пример #11
0
def load_ext_net(ext_dir, aux_device, args):
    ext_meta = json.load(open(join(ext_dir, 'meta.json')))
    assert ext_meta['net'] == 'ml_rnn_extractor'
    ext_ckpt = load_best_ckpt(ext_dir)
    ext_args = ext_meta['net_args']
    vocab = pkl.load(open(join(ext_dir, 'vocab.pkl'), 'rb'))

    use_bert = 'bert_type' in ext_args.keys()
    fix_bert = getattr(ext_args, 'fix_bert', args.fix_bert)
    ext_args['fix_bert'] = fix_bert
    ext_args['aux_device'] = aux_device
    if use_bert:
        print('Use Bert based Extractor ...')
        ext = BertPtrExtractSumm(**ext_args)
        bert_type = ext_args['bert_type']
        tokenizer_cache = ext_args['tokenizer_cache']
        bert_config = (bert_type, tokenizer_cache)
    else:
        ext = PtrExtractSumm(**ext_args)
        bert_config = None

    ext.load_state_dict(ext_ckpt)
    return ext, vocab, use_bert, bert_config
Пример #12
0
def load_ext_net(ext_dir):
    ext_meta = json.load(open(join(ext_dir, 'meta.json')))
    assert ext_meta['net'] in [
        'ml_rnn_extractor', "ml_gat_extractor", "ml_subgraph_gat_extractor"
    ]
    net_name = ext_meta['net']
    ext_ckpt = load_best_ckpt(ext_dir)
    ext_args = ext_meta['net_args']
    vocab = pkl.load(open(join(ext_dir, 'vocab.pkl'), 'rb'))
    if ext_meta['net'] == 'ml_rnn_extractor':
        ext = PtrExtractSumm(**ext_args)
    elif ext_meta['net'] == "ml_gat_extractor":
        ext = PtrExtractSummGAT(**ext_args)
    elif ext_meta['net'] == "ml_subgraph_gat_extractor":
        ext = PtrExtractSummSubgraph(**ext_args)
    else:
        raise Exception('not implemented')
    ext.load_state_dict(ext_ckpt)
    return ext, vocab
Пример #13
0
def load_dis_net(emb_dim,
                 lstm_hidden,
                 lstm_layer,
                 bert_config,
                 dis_pretrain_file,
                 load=True,
                 cuda=True):
    dis = PtrExtractSumm(emb_dim=emb_dim,
                         lstm_hidden=lstm_hidden,
                         lstm_layer=lstm_layer,
                         bert_config=bert_config)
    dis = PolicyGradient(dis.transformer, dis._extractor)
    if load:
        print("Restoring all non-adagrad variables from {}...".format(
            dis_pretrain_file))
        state_dict = torch.load(dis_pretrain_file)['state_dict']
        dis.load_state_dict(state_dict)
    if cuda:
        dis = dis.cuda()
    return dis