def configure_net(abs_dir, ext_dir, cuda): """ load pretrained sub-modules and build the actor-critic network""" # load pretrained abstractor model if abs_dir is not None: abstractor_sent = Abstractor(abs_dir, MAX_ABS_LEN, cuda) # abs_dir = abs_dir.split(';') # abstractor_sent = Abstractor(abs_dir[0], MAX_ABS_LEN, cuda) # abstractor_doc = Abstractor(abs_dir[1], MAX_ABS_LEN, cuda) else: abstractor = identity # load ML trained extractor net and buiild RL agent extractor, agent_vocab = load_ext_net(ext_dir) agent = ActorCritic(extractor._sent_enc, extractor._art_enc, extractor._extractor, ArticleBatcher(agent_vocab, cuda)) target_agent = ActorCritic(extractor._sent_enc, extractor._art_enc, extractor._extractor, ArticleBatcher(agent_vocab, cuda)) target_agent.load_state_dict(agent.state_dict()) if cuda: agent = agent.cuda() target_agent = target_agent.cuda() net_args = {} net_args['abstractor'] = (None if abs_dir is None else json.load( open(join(abs_dir, 'meta.json')))) net_args['extractor'] = json.load(open(join(ext_dir, 'meta.json'))) return agent, target_agent, agent_vocab, abstractor_sent, net_args
def configure_net(abs_dir, ext_dir, cuda, sc, tv, rl_dir=''): """ load pretrained sub-modules and build the actor-critic network""" # load pretrained abstractor model if abs_dir is not None: abstractor = Abstractor(abs_dir, MAX_ABS_LEN, cuda) else: abstractor = identity # load ML trained extractor net and buiild RL agent extractor, agent_vocab = load_ext_net(ext_dir) if sc: agent = SelfCritic(extractor, ArticleBatcher(agent_vocab, cuda), time_variant=tv) else: agent = ActorCritic(extractor._sent_enc, extractor._art_enc, extractor._extractor, ArticleBatcher(agent_vocab, cuda)) if rl_dir != '': ckpt = load_best_ckpt(rl_dir, reverse=True) agent.load_state_dict(ckpt) if cuda: agent = agent.cuda() net_args = {} net_args['abstractor'] = (None if abs_dir is None else json.load( open(join(abs_dir, 'meta.json')))) net_args['extractor'] = json.load(open(join(ext_dir, 'meta.json'))) print('agent:', agent) return agent, agent_vocab, abstractor, net_args
def configure_net_entity(abs_dir, ext_dir, cuda, sc, tv): """ load pretrained sub-modules and build the actor-critic network""" # load pretrained abstractor model if abs_dir is not None: abstractor = Abstractor(abs_dir, MAX_ABS_LEN, cuda) else: abstractor = identity # load ML trained extractor net and buiild RL agent extractor, agent_vocab = load_ext_net(ext_dir) if sc: agent = SelfCriticEntity(extractor, ArticleBatcher(agent_vocab, cuda), time_variant=tv) else: raise Exception('actor critic entity model not implemented') if cuda: agent = agent.cuda() net_args = {} net_args['abstractor'] = (None if abs_dir is None else json.load( open(join(abs_dir, 'meta.json')))) net_args['extractor'] = json.load(open(join(ext_dir, 'meta.json'))) return agent, agent_vocab, abstractor, net_args
def configure_net(abs_dir, ext_dir, cuda, net_type='ml_rnn_extractor'): """ load pretrained sub-modules and build the actor-critic network""" # load pretrained abstractor model if abs_dir is not None: abstractor = Abstractor(abs_dir, MAX_ABS_LEN, cuda) else: abstractor = identity # load ML trained extractor net and buiild RL agent extractor, agent_vocab = load_ext_net(ext_dir) agent = ActorCritic(extractor._sent_enc, extractor._art_enc, extractor._extractor, ArticleBatcher(agent_vocab, cuda, net_type=net_type)) # agent = ActorCriticPreSumm(extractor._sent_enc, # extractor._art_enc, # extractor._extractor, # ArticleBatcher(agent_vocab, cuda)) if cuda: agent = agent.cuda() net_args = {} net_args['abstractor'] = (None if abs_dir is None else json.load(open(join(abs_dir, 'meta.json')))) net_args['extractor'] = json.load(open(join(ext_dir, 'meta.json'))) return agent, agent_vocab, abstractor, net_args
def load_rl_ckpt(abs_dir, ext_dir, cuda): ext_meta = json.load(open(join(ext_dir, 'meta.json'))) assert ext_meta['net'] == 'rnn-ext_abs_rl' ext_args = ext_meta['net_args']['extractor']['net_args'] vocab = pkl.load(open(join(ext_dir, 'agent_vocab.pkl'), 'rb')) extractor = PtrExtractSumm(**ext_args) abstractor_sent = Abstractor(abs_dir, MAX_ABS_LEN, cuda) agent = ActorCritic(extractor._sent_enc, extractor._art_enc, extractor._extractor, ArticleBatcher(vocab, cuda)) target_agent = ActorCritic(extractor._sent_enc, extractor._art_enc, extractor._extractor, ArticleBatcher(vocab, cuda)) ext_ckpt = load_best_ckpt(ext_dir, reverse=True) agent.load_state_dict(ext_ckpt) device = torch.device('cuda' if cuda else 'cpu') agent = agent.to(device) target_agent = target_agent.to(device) return agent, target_agent, vocab, abstractor_sent, ext_meta
def configure_net(abs_dir, ext_dir, cuda): """ load pretrained sub-modules and build the actor-critic network""" # load pretrained abstractor model if abs_dir is not None: abstractor = Abstractor(abs_dir, MAX_ABS_LEN, cuda) else: abstractor = lambda x, y: x #print(abstractor([[], [0,0,0,0,0,0]], [[],[1,1,1,1,1,1]])) #exit(0) # load ML trained extractor net and buiild RL agent extractor, agent_vocab = load_ext_net(ext_dir) agent = ActorCritic(extractor, ArticleBatcher(agent_vocab, cuda)) if cuda: agent = agent.cuda() net_args = {} net_args['abstractor'] = (None if abs_dir is None else json.load( open(join(abs_dir, 'meta.json')))) net_args['extractor'] = json.load(open(join(ext_dir, 'meta.json'))) return agent, agent_vocab, abstractor, net_args
def configure_net(abs_dir, ext_dir, cuda, aux_device, bert_max_len, args): """ load pretrained sub-modules and build the actor-critic network""" # load pretrained abstractor model if abs_dir is not None: abstractor = Abstractor(abs_dir, MAX_ABS_LEN, cuda) else: abstractor = identity # load ML trained extractor net and build RL agent extractor, agent_vocab, use_bert, bert_config = load_ext_net( ext_dir, aux_device, args) if use_bert: bert_type, tokenizer_cache = bert_config bert_tokenizer = BertTokenizer.from_pretrained( bert_type, cache_dir=tokenizer_cache) agent = BertActorCritic(extractor._sent_enc, extractor._art_enc, extractor._extractor, BertArticleBatcher(bert_tokenizer, bert_max_len, args.bert_max_sent, cuda), cuda=cuda, aux_device=aux_device) else: agent = ActorCritic(extractor._sent_enc, extractor._art_enc, extractor._extractor, ArticleBatcher(agent_vocab, cuda)) if cuda: agent = agent.cuda() net_args = {} net_args['abstractor'] = (None if abs_dir is None else json.load( open(join(abs_dir, 'meta.json')))) net_args['extractor'] = json.load(open(join(ext_dir, 'meta.json'))) return agent, agent_vocab, abstractor, net_args