示例#1
0
def configure_net(abs_dir, ext_dir, cuda, net_type='ml_rnn_extractor'):
    """ load pretrained sub-modules and build the actor-critic network"""
    # load pretrained abstractor model
    if abs_dir is not None:
        abstractor = Abstractor(abs_dir, MAX_ABS_LEN, cuda)
    else:
        abstractor = identity

    # load ML trained extractor net and buiild RL agent
    extractor, agent_vocab = load_ext_net(ext_dir)
    agent = ActorCritic(extractor._sent_enc,
                        extractor._art_enc,
                        extractor._extractor,
                        ArticleBatcher(agent_vocab, cuda, net_type=net_type))
    # agent = ActorCriticPreSumm(extractor._sent_enc,
    #                     extractor._art_enc,
    #                     extractor._extractor,
    #                     ArticleBatcher(agent_vocab, cuda))
    if cuda:
        agent = agent.cuda()

    net_args = {}
    net_args['abstractor'] = (None if abs_dir is None
                              else json.load(open(join(abs_dir, 'meta.json'))))
    net_args['extractor'] = json.load(open(join(ext_dir, 'meta.json')))

    return agent, agent_vocab, abstractor, net_args
def configure_net(abs_dir, ext_dir, cuda):
    """ load pretrained sub-modules and build the actor-critic network"""
    # load pretrained abstractor model
    if abs_dir is not None:
        abstractor_sent = Abstractor(abs_dir, MAX_ABS_LEN, cuda)
        # abs_dir = abs_dir.split(';')
        # abstractor_sent = Abstractor(abs_dir[0], MAX_ABS_LEN, cuda)
        # abstractor_doc = Abstractor(abs_dir[1], MAX_ABS_LEN, cuda)
    else:
        abstractor = identity
    # load ML trained extractor net and buiild RL agent
    extractor, agent_vocab = load_ext_net(ext_dir)
    agent = ActorCritic(extractor._sent_enc,
                        extractor._art_enc, extractor._extractor,
                        ArticleBatcher(agent_vocab, cuda))

    target_agent = ActorCritic(extractor._sent_enc, extractor._art_enc,
                               extractor._extractor,
                               ArticleBatcher(agent_vocab, cuda))

    target_agent.load_state_dict(agent.state_dict())

    if cuda:
        agent = agent.cuda()
        target_agent = target_agent.cuda()
    net_args = {}
    net_args['abstractor'] = (None if abs_dir is None else json.load(
        open(join(abs_dir, 'meta.json'))))
    net_args['extractor'] = json.load(open(join(ext_dir, 'meta.json')))

    return agent, target_agent, agent_vocab, abstractor_sent, net_args
示例#3
0
 def __init__(self, ext_dir, cuda=True):
     ext_meta = json.load(open(join(ext_dir, 'meta.json')))
     assert ext_meta['net'] == 'rnn-ext_abs_rl'
     ext_args = ext_meta['net_args']['extractor']['net_args']
     word2id = pkl.load(open(join(ext_dir, 'agent_vocab.pkl'), 'rb'))
     extractor = PtrExtractSumm(**ext_args)
     agent = ActorCritic(extractor, ArticleBatcher(word2id, cuda))
     ext_ckpt = load_best_ckpt(ext_dir, reverse=True)
     agent.load_state_dict(ext_ckpt)
     self._device = torch.device('cuda' if cuda else 'cpu')
     self._net = agent.to(self._device)
     self._word2id = word2id
     self._id2word = {i: w for w, i in word2id.items()}
示例#4
0
def configure_net(abs_dir, ext_dir, cuda, sc, tv, rl_dir=''):
    """ load pretrained sub-modules and build the actor-critic network"""
    # load pretrained abstractor model
    if abs_dir is not None:
        abstractor = Abstractor(abs_dir, MAX_ABS_LEN, cuda)
    else:
        abstractor = identity

    # load ML trained extractor net and buiild RL agent
    extractor, agent_vocab = load_ext_net(ext_dir)
    if sc:
        agent = SelfCritic(extractor,
                           ArticleBatcher(agent_vocab, cuda),
                           time_variant=tv)
    else:
        agent = ActorCritic(extractor._sent_enc, extractor._art_enc,
                            extractor._extractor,
                            ArticleBatcher(agent_vocab, cuda))

    if rl_dir != '':
        ckpt = load_best_ckpt(rl_dir, reverse=True)
        agent.load_state_dict(ckpt)

    if cuda:
        agent = agent.cuda()

    net_args = {}
    net_args['abstractor'] = (None if abs_dir is None else json.load(
        open(join(abs_dir, 'meta.json'))))
    net_args['extractor'] = json.load(open(join(ext_dir, 'meta.json')))
    print('agent:', agent)

    return agent, agent_vocab, abstractor, net_args
示例#5
0
 def __init__(self, ext_dir, cuda=True):
     ext_meta = json.load(open(join(ext_dir, 'meta.json')))
     assert ext_meta['net'] == 'rnn-ext_abs_rl'
     ext_args = ext_meta['net_args']['extractor']['net_args']
     word2id = pkl.load(open(join(ext_dir, 'agent_vocab.pkl'), 'rb'))
     extractor = PtrExtractSumm(**ext_args)
     agent = ActorCritic(extractor._sent_enc,
                         extractor._art_enc,
                         extractor._extractor,
                         ArticleBatcher(word2id, cuda))
     ext_ckpt = load_best_ckpt(ext_dir, reverse=True)
     agent.load_state_dict(ext_ckpt)
     self._device = torch.device('cuda' if cuda else 'cpu')
     self._net = agent.to(self._device)
     self._word2id = word2id
     self._id2word = {i: w for w, i in word2id.items()}
示例#6
0
def configure_net(abs_dir, ext_dir, cuda):
    """ load pretrained sub-modules and build the actor-critic network"""
    # load pretrained abstractor model
    if abs_dir is not None:
        abstractor = Abstractor(abs_dir, MAX_ABS_LEN, cuda)
    else:
        abstractor = lambda x, y: x
    #print(abstractor([[], [0,0,0,0,0,0]], [[],[1,1,1,1,1,1]]))
    #exit(0)
    # load ML trained extractor net and buiild RL agent
    extractor, agent_vocab = load_ext_net(ext_dir)
    agent = ActorCritic(extractor, ArticleBatcher(agent_vocab, cuda))
    if cuda:
        agent = agent.cuda()

    net_args = {}
    net_args['abstractor'] = (None if abs_dir is None else json.load(
        open(join(abs_dir, 'meta.json'))))
    net_args['extractor'] = json.load(open(join(ext_dir, 'meta.json')))

    return agent, agent_vocab, abstractor, net_args
def load_rl_ckpt(abs_dir, ext_dir, cuda):
    ext_meta = json.load(open(join(ext_dir, 'meta.json')))
    assert ext_meta['net'] == 'rnn-ext_abs_rl'
    ext_args = ext_meta['net_args']['extractor']['net_args']
    vocab = pkl.load(open(join(ext_dir, 'agent_vocab.pkl'), 'rb'))
    extractor = PtrExtractSumm(**ext_args)
    abstractor_sent = Abstractor(abs_dir, MAX_ABS_LEN, cuda)

    agent = ActorCritic(extractor._sent_enc, extractor._art_enc,
                        extractor._extractor, ArticleBatcher(vocab, cuda))

    target_agent = ActorCritic(extractor._sent_enc,
                               extractor._art_enc, extractor._extractor,
                               ArticleBatcher(vocab, cuda))

    ext_ckpt = load_best_ckpt(ext_dir, reverse=True)
    agent.load_state_dict(ext_ckpt)

    device = torch.device('cuda' if cuda else 'cpu')
    agent = agent.to(device)
    target_agent = target_agent.to(device)

    return agent, target_agent, vocab, abstractor_sent, ext_meta
示例#8
0
def configure_net(abs_dir, ext_dir, cuda, aux_device, bert_max_len, args):
    """ load pretrained sub-modules and build the actor-critic network"""
    # load pretrained abstractor model
    if abs_dir is not None:
        abstractor = Abstractor(abs_dir, MAX_ABS_LEN, cuda)
    else:
        abstractor = identity

    # load ML trained extractor net and build RL agent
    extractor, agent_vocab, use_bert, bert_config = load_ext_net(
        ext_dir, aux_device, args)

    if use_bert:
        bert_type, tokenizer_cache = bert_config
        bert_tokenizer = BertTokenizer.from_pretrained(
            bert_type, cache_dir=tokenizer_cache)
        agent = BertActorCritic(extractor._sent_enc,
                                extractor._art_enc,
                                extractor._extractor,
                                BertArticleBatcher(bert_tokenizer,
                                                   bert_max_len,
                                                   args.bert_max_sent, cuda),
                                cuda=cuda,
                                aux_device=aux_device)

    else:
        agent = ActorCritic(extractor._sent_enc, extractor._art_enc,
                            extractor._extractor,
                            ArticleBatcher(agent_vocab, cuda))
        if cuda:
            agent = agent.cuda()

    net_args = {}
    net_args['abstractor'] = (None if abs_dir is None else json.load(
        open(join(abs_dir, 'meta.json'))))
    net_args['extractor'] = json.load(open(join(ext_dir, 'meta.json')))

    return agent, agent_vocab, abstractor, net_args