def train(args): if not exists(args.path): os.makedirs(args.path) # make net agent, agent_vocab, abstractor, net_args = configure_net( args.abs_dir, args.ext_dir, args.cuda) # configure training setting assert args.stop > 0 train_params = configure_training( 'adam', args.lr, args.clip, args.decay, args.batch, args.gamma, args.reward, args.stop, 'rouge-1' ) train_batcher, val_batcher = build_batchers(args.batch) # TODO different reward reward_fn = compute_rouge_l stop_reward_fn = compute_rouge_n(n=1) # save abstractor binary if args.abs_dir is not None: abs_ckpt = {} abs_ckpt['state_dict'] = load_best_ckpt(args.abs_dir) abs_vocab = pkl.load(open(join(args.abs_dir, 'vocab.pkl'), 'rb')) abs_dir = join(args.path, 'abstractor') os.makedirs(join(abs_dir, 'ckpt')) with open(join(abs_dir, 'meta.json'), 'w') as f: json.dump(net_args['abstractor'], f, indent=4) torch.save(abs_ckpt, join(abs_dir, 'ckpt/ckpt-0-0')) with open(join(abs_dir, 'vocab.pkl'), 'wb') as f: pkl.dump(abs_vocab, f) # save configuration meta = {} meta['net'] = 'rnn-ext_abs_rl' meta['net_args'] = net_args meta['train_params'] = train_params with open(join(args.path, 'meta.json'), 'w') as f: json.dump(meta, f, indent=4) with open(join(args.path, 'agent_vocab.pkl'), 'wb') as f: pkl.dump(agent_vocab, f) # prepare trainer grad_fn = get_grad_fn(agent, args.clip) optimizer = optim.Adam(agent.parameters(), **train_params['optimizer'][1]) scheduler = ReduceLROnPlateau(optimizer, 'max', verbose=True, factor=args.decay, min_lr=0, patience=args.lr_p) pipeline = A2CPipeline(meta['net'], agent, abstractor, train_batcher, val_batcher, optimizer, grad_fn, reward_fn, args.gamma, stop_reward_fn, args.stop) trainer = BasicTrainer(pipeline, args.path, args.ckpt_freq, args.patience, scheduler, val_mode='score') print('start training with the following hyper-parameters:') print(meta) trainer.train()
def configure_net(abs_dir, ext_dir, cuda, sc, tv, rl_dir=''): """ load pretrained sub-modules and build the actor-critic network""" # load pretrained abstractor model if abs_dir is not None: abstractor = Abstractor(abs_dir, MAX_ABS_LEN, cuda) else: abstractor = identity # load ML trained extractor net and buiild RL agent extractor, agent_vocab = load_ext_net(ext_dir) if sc: agent = SelfCritic(extractor, ArticleBatcher(agent_vocab, cuda), time_variant=tv) else: agent = ActorCritic(extractor._sent_enc, extractor._art_enc, extractor._extractor, ArticleBatcher(agent_vocab, cuda)) if rl_dir != '': ckpt = load_best_ckpt(rl_dir, reverse=True) agent.load_state_dict(ckpt) if cuda: agent = agent.cuda() net_args = {} net_args['abstractor'] = (None if abs_dir is None else json.load( open(join(abs_dir, 'meta.json')))) net_args['extractor'] = json.load(open(join(ext_dir, 'meta.json'))) print('agent:', agent) return agent, agent_vocab, abstractor, net_args
def configure_net(net_type, vocab_size, emb_dim, conv_hidden, lstm_hidden, lstm_layer, bidirectional, prev_ckpt=None): assert net_type in ['ff', 'rnn', 'trans_rnn'] net_args = {} net_args['vocab_size'] = vocab_size net_args['emb_dim'] = emb_dim net_args['conv_hidden'] = conv_hidden net_args['lstm_hidden'] = lstm_hidden net_args['lstm_layer'] = lstm_layer net_args['bidirectional'] = bidirectional if net_type == 'ff': net = ExtractSumm(**net_args) elif net_type == 'trans_rnn': net = TransExtractSumm(**net_args) else: net = PtrExtractSumm(**net_args) if prev_ckpt is not None: ext_ckpt = load_best_ckpt(prev_ckpt) net.load_state_dict(ext_ckpt) return net, net_args
def main(args): # create data batcher, vocabulary # batcher with open(join(DATA_DIR, 'vocab_cnt.pkl'), 'rb') as f: wc = pkl.load(f) word2id = make_vocab(wc, args.vsize) train_batcher, val_batcher = build_batchers(word2id, args.cuda, args.debug, cross_rev_bucket=args.cross_rev_bucket) # make net net, net_args = configure_net(len(word2id), args.emb_dim, args.n_hidden, args.bi, args.n_layer) if args.prev_trained is True: abs_ckpt = load_best_ckpt(args.prev_trained) net.load_state_dict(abs_ckpt) if args.w2v: # NOTE: the pretrained embedding having the same dimension # as args.emb_dim should already be trained embedding, _ = make_embedding( {i: w for w, i in word2id.items()}, args.w2v) net.set_embedding(embedding) # configure training setting criterion, train_params = configure_training( 'adam', args.lr, args.clip, args.decay, args.batch ) # save experiment setting if not exists(args.path): os.makedirs(args.path) with open(join(args.path, 'vocab.pkl'), 'wb') as f: pkl.dump(word2id, f, pkl.HIGHEST_PROTOCOL) meta = {} meta['net'] = 'base_abstractor' meta['net_args'] = net_args meta['traing_params'] = train_params with open(join(args.path, 'meta.json'), 'w') as f: json.dump(meta, f, indent=4) # prepare trainer val_fn = basic_validate(net, criterion) grad_fn = get_basic_grad_fn(net, args.clip) optimizer = optim.Adam(net.parameters(), **train_params['optimizer'][1]) scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=args.decay, min_lr=0, patience=args.lr_p) if args.cuda: net = net.cuda() pipeline = BasicPipeline(meta['net'], net, train_batcher, val_batcher, args.batch, val_fn, criterion, optimizer, grad_fn) trainer = BasicTrainer(pipeline, args.path, args.ckpt_freq, args.patience, scheduler) print('start training with the following hyper-parameters:') print(meta) trainer.train()
def load_ext_net(ext_dir): ext_meta = json.load(open(join(ext_dir, 'meta.json'))) assert ext_meta['net'] == 'ml_rnn_extractor' ext_ckpt = load_best_ckpt(ext_dir) ext_args = ext_meta['net_args'] vocab = pkl.load(open(join(ext_dir, 'vocab.pkl'), 'rb')) ext = PtrExtractSumm(**ext_args) ext.load_state_dict(ext_ckpt) return ext, vocab
def load_abs_net(abs_dir): abs_meta = json.load(open(join(abs_dir, 'meta.json'))) assert abs_meta['net'] == 'base_abstractor' abs_args = abs_meta['net_args'] abs_ckpt = load_best_ckpt(abs_dir) word2id = pkl.load(open(join(abs_dir, 'vocab.pkl'), 'rb')) abstractor = CopySumm(**abs_args) abstractor.load_state_dict(abs_ckpt) return abstractor, word2id
def load_ext_net(ext_dir): ext_meta = json.load(open(join(ext_dir, 'meta.json'))) assert ext_meta['net'] == 'ml_rnn_extractor' or ext_meta[ 'net'] == "ml_entity_extractor" ext_ckpt = load_best_ckpt(ext_dir) ext_args = ext_meta['net_args'] vocab = pkl.load(open(join(ext_dir, 'vocab.pkl'), 'rb')) if ext_meta['net'] == 'ml_rnn_extractor': ext = PtrExtractSumm(**ext_args) elif ext_meta['net'] == "ml_entity_extractor": ext = PtrExtractSummEntity(**ext_args) else: raise Exception('not implemented') ext.load_state_dict(ext_ckpt) return ext, vocab
def load_ext_net(ext_dir): ext_meta = json.load(open(join(ext_dir, 'meta.json'))) assert ext_meta['net'] in [ 'ml_rnn_extractor', "ml_gat_extractor", "ml_subgraph_gat_extractor" ] net_name = ext_meta['net'] ext_ckpt = load_best_ckpt(ext_dir) ext_args = ext_meta['net_args'] vocab = pkl.load(open(join(ext_dir, 'vocab.pkl'), 'rb')) if ext_meta['net'] == 'ml_rnn_extractor': ext = PtrExtractSumm(**ext_args) elif ext_meta['net'] == "ml_gat_extractor": ext = PtrExtractSummGAT(**ext_args) elif ext_meta['net'] == "ml_subgraph_gat_extractor": ext = PtrExtractSummSubgraph(**ext_args) else: raise Exception('not implemented') ext.load_state_dict(ext_ckpt) return ext, vocab
def load_rl_ckpt(abs_dir, ext_dir, cuda): ext_meta = json.load(open(join(ext_dir, 'meta.json'))) assert ext_meta['net'] == 'rnn-ext_abs_rl' ext_args = ext_meta['net_args']['extractor']['net_args'] vocab = pkl.load(open(join(ext_dir, 'agent_vocab.pkl'), 'rb')) extractor = PtrExtractSumm(**ext_args) abstractor_sent = Abstractor(abs_dir, MAX_ABS_LEN, cuda) agent = ActorCritic(extractor._sent_enc, extractor._art_enc, extractor._extractor, ArticleBatcher(vocab, cuda)) target_agent = ActorCritic(extractor._sent_enc, extractor._art_enc, extractor._extractor, ArticleBatcher(vocab, cuda)) ext_ckpt = load_best_ckpt(ext_dir, reverse=True) agent.load_state_dict(ext_ckpt) device = torch.device('cuda' if cuda else 'cpu') agent = agent.to(device) target_agent = target_agent.to(device) return agent, target_agent, vocab, abstractor_sent, ext_meta
def load_ext_net(ext_dir, aux_device, args): ext_meta = json.load(open(join(ext_dir, 'meta.json'))) assert ext_meta['net'] == 'ml_rnn_extractor' ext_ckpt = load_best_ckpt(ext_dir) ext_args = ext_meta['net_args'] vocab = pkl.load(open(join(ext_dir, 'vocab.pkl'), 'rb')) use_bert = 'bert_type' in ext_args.keys() fix_bert = getattr(ext_args, 'fix_bert', args.fix_bert) ext_args['fix_bert'] = fix_bert ext_args['aux_device'] = aux_device if use_bert: print('Use Bert based Extractor ...') ext = BertPtrExtractSumm(**ext_args) bert_type = ext_args['bert_type'] tokenizer_cache = ext_args['tokenizer_cache'] bert_config = (bert_type, tokenizer_cache) else: ext = PtrExtractSumm(**ext_args) bert_config = None ext.load_state_dict(ext_ckpt) return ext, vocab, use_bert, bert_config
def train(args): if not exists(args.path): os.makedirs(args.path) # make net if args.docgraph or args.paragraph: agent, agent_vocab, abstractor, net_args = configure_net_graph( args.abs_dir, args.ext_dir, args.cuda, args.docgraph, args.paragraph) else: agent, agent_vocab, abstractor, net_args = configure_net( args.abs_dir, args.ext_dir, args.cuda, True, False, args.rl_dir) if args.bert_stride > 0: assert args.bert_stride == agent._bert_stride # configure training setting assert args.stop > 0 train_params = configure_training('adam', args.lr, args.clip, args.decay, args.batch, args.gamma, args.reward, args.stop, 'rouge-1') if args.docgraph or args.paragraph: if args.bert: train_batcher, val_batcher = build_batchers_graph_bert( args.batch, args.key, args.adj_type, args.max_bert_word, args.docgraph, args.paragraph) else: train_batcher, val_batcher = build_batchers_graph( args.batch, args.key, args.adj_type, args.gold_key, args.docgraph, args.paragraph) elif args.bert: train_batcher, val_batcher = build_batchers_bert( args.batch, args.bert_sent, args.bert_stride, args.max_bert_word) else: train_batcher, val_batcher = build_batchers(args.batch) # TODO different reward if args.reward == 'rouge-l': reward_fn = compute_rouge_l elif args.reward == 'rouge-1': reward_fn = compute_rouge_n(n=1) elif args.reward == 'rouge-2': reward_fn = compute_rouge_n(n=2) elif args.reward == 'rouge-l-s': reward_fn = compute_rouge_l_summ else: raise Exception('Not prepared reward') stop_reward_fn = compute_rouge_n(n=1) # save abstractor binary if args.abs_dir is not None: abs_ckpt = {} abs_ckpt['state_dict'] = load_best_ckpt(args.abs_dir, reverse=True) abs_vocab = pkl.load(open(join(args.abs_dir, 'vocab.pkl'), 'rb')) abs_dir = join(args.path, 'abstractor') os.makedirs(join(abs_dir, 'ckpt')) with open(join(abs_dir, 'meta.json'), 'w') as f: json.dump(net_args['abstractor'], f, indent=4) torch.save(abs_ckpt, join(abs_dir, 'ckpt/ckpt-0-0')) with open(join(abs_dir, 'vocab.pkl'), 'wb') as f: pkl.dump(abs_vocab, f) # save configuration meta = {} meta['net'] = 'rnn-ext_abs_rl' meta['net_args'] = net_args meta['train_params'] = train_params with open(join(args.path, 'meta.json'), 'w') as f: json.dump(meta, f, indent=4) with open(join(args.path, 'agent_vocab.pkl'), 'wb') as f: pkl.dump(agent_vocab, f) # prepare trainer grad_fn = get_grad_fn(agent, args.clip) optimizer = optim.Adam(agent.parameters(), **train_params['optimizer'][1]) scheduler = ReduceLROnPlateau(optimizer, 'max', verbose=True, factor=args.decay, min_lr=1e-5, patience=args.lr_p) if args.docgraph or args.paragraph: entity = True else: entity = False pipeline = SCPipeline(meta['net'], agent, abstractor, train_batcher, val_batcher, optimizer, grad_fn, reward_fn, entity, args.bert) trainer = BasicTrainer(pipeline, args.path, args.ckpt_freq, args.patience, scheduler, val_mode='score') print('start training with the following hyper-parameters:') print(meta) trainer.train()