def train(args): if not exists(args.path): os.makedirs(args.path) # make net agent, agent_vocab, abstractor, net_args = configure_net( args.abs_dir, args.ext_dir, args.cuda) # configure training setting assert args.stop > 0 train_params = configure_training( 'adam', args.lr, args.clip, args.decay, args.batch, args.gamma, args.reward, args.stop, 'rouge-1' ) train_batcher, val_batcher = build_batchers(args.batch) # TODO different reward reward_fn = compute_rouge_l stop_reward_fn = compute_rouge_n(n=1) # save abstractor binary if args.abs_dir is not None: abs_ckpt = {} abs_ckpt['state_dict'] = load_best_ckpt(args.abs_dir) abs_vocab = pkl.load(open(join(args.abs_dir, 'vocab.pkl'), 'rb')) abs_dir = join(args.path, 'abstractor') os.makedirs(join(abs_dir, 'ckpt')) with open(join(abs_dir, 'meta.json'), 'w') as f: json.dump(net_args['abstractor'], f, indent=4) torch.save(abs_ckpt, join(abs_dir, 'ckpt/ckpt-0-0')) with open(join(abs_dir, 'vocab.pkl'), 'wb') as f: pkl.dump(abs_vocab, f) # save configuration meta = {} meta['net'] = 'rnn-ext_abs_rl' meta['net_args'] = net_args meta['train_params'] = train_params with open(join(args.path, 'meta.json'), 'w') as f: json.dump(meta, f, indent=4) with open(join(args.path, 'agent_vocab.pkl'), 'wb') as f: pkl.dump(agent_vocab, f) # prepare trainer grad_fn = get_grad_fn(agent, args.clip) optimizer = optim.Adam(agent.parameters(), **train_params['optimizer'][1]) scheduler = ReduceLROnPlateau(optimizer, 'max', verbose=True, factor=args.decay, min_lr=0, patience=args.lr_p) pipeline = A2CPipeline(meta['net'], agent, abstractor, train_batcher, val_batcher, optimizer, grad_fn, reward_fn, args.gamma, stop_reward_fn, args.stop) trainer = BasicTrainer(pipeline, args.path, args.ckpt_freq, args.patience, scheduler, val_mode='score') print('start training with the following hyper-parameters:') print(meta) trainer.train()
def train(args): if not exists(args.path): os.makedirs(args.path) # make net if args.docgraph or args.paragraph: agent, agent_vocab, abstractor, net_args = configure_net_graph( args.abs_dir, args.ext_dir, args.cuda, args.docgraph, args.paragraph) else: agent, agent_vocab, abstractor, net_args = configure_net( args.abs_dir, args.ext_dir, args.cuda, True, False, args.rl_dir) if args.bert_stride > 0: assert args.bert_stride == agent._bert_stride # configure training setting assert args.stop > 0 train_params = configure_training('adam', args.lr, args.clip, args.decay, args.batch, args.gamma, args.reward, args.stop, 'rouge-1') if args.docgraph or args.paragraph: if args.bert: train_batcher, val_batcher = build_batchers_graph_bert( args.batch, args.key, args.adj_type, args.max_bert_word, args.docgraph, args.paragraph) else: train_batcher, val_batcher = build_batchers_graph( args.batch, args.key, args.adj_type, args.gold_key, args.docgraph, args.paragraph) elif args.bert: train_batcher, val_batcher = build_batchers_bert( args.batch, args.bert_sent, args.bert_stride, args.max_bert_word) else: train_batcher, val_batcher = build_batchers(args.batch) # TODO different reward if args.reward == 'rouge-l': reward_fn = compute_rouge_l elif args.reward == 'rouge-1': reward_fn = compute_rouge_n(n=1) elif args.reward == 'rouge-2': reward_fn = compute_rouge_n(n=2) elif args.reward == 'rouge-l-s': reward_fn = compute_rouge_l_summ else: raise Exception('Not prepared reward') stop_reward_fn = compute_rouge_n(n=1) # save abstractor binary if args.abs_dir is not None: abs_ckpt = {} abs_ckpt['state_dict'] = load_best_ckpt(args.abs_dir, reverse=True) abs_vocab = pkl.load(open(join(args.abs_dir, 'vocab.pkl'), 'rb')) abs_dir = join(args.path, 'abstractor') os.makedirs(join(abs_dir, 'ckpt')) with open(join(abs_dir, 'meta.json'), 'w') as f: json.dump(net_args['abstractor'], f, indent=4) torch.save(abs_ckpt, join(abs_dir, 'ckpt/ckpt-0-0')) with open(join(abs_dir, 'vocab.pkl'), 'wb') as f: pkl.dump(abs_vocab, f) # save configuration meta = {} meta['net'] = 'rnn-ext_abs_rl' meta['net_args'] = net_args meta['train_params'] = train_params with open(join(args.path, 'meta.json'), 'w') as f: json.dump(meta, f, indent=4) with open(join(args.path, 'agent_vocab.pkl'), 'wb') as f: pkl.dump(agent_vocab, f) # prepare trainer grad_fn = get_grad_fn(agent, args.clip) optimizer = optim.Adam(agent.parameters(), **train_params['optimizer'][1]) scheduler = ReduceLROnPlateau(optimizer, 'max', verbose=True, factor=args.decay, min_lr=1e-5, patience=args.lr_p) if args.docgraph or args.paragraph: entity = True else: entity = False pipeline = SCPipeline(meta['net'], agent, abstractor, train_batcher, val_batcher, optimizer, grad_fn, reward_fn, entity, args.bert) trainer = BasicTrainer(pipeline, args.path, args.ckpt_freq, args.patience, scheduler, val_mode='score') print('start training with the following hyper-parameters:') print(meta) trainer.train()