Example #1
0
def train(args):
    if not exists(args.path):
        os.makedirs(args.path)

    # make net
    agent, agent_vocab, abstractor, net_args = configure_net(
        args.abs_dir, args.ext_dir, args.cuda)

    # configure training setting
    assert args.stop > 0
    train_params = configure_training(
        'adam', args.lr, args.clip, args.decay, args.batch,
        args.gamma, args.reward, args.stop, 'rouge-1'
    )
    train_batcher, val_batcher = build_batchers(args.batch)
    # TODO different reward
    reward_fn = compute_rouge_l
    stop_reward_fn = compute_rouge_n(n=1)

    # save abstractor binary
    if args.abs_dir is not None:
        abs_ckpt = {}
        abs_ckpt['state_dict'] = load_best_ckpt(args.abs_dir)
        abs_vocab = pkl.load(open(join(args.abs_dir, 'vocab.pkl'), 'rb'))
        abs_dir = join(args.path, 'abstractor')
        os.makedirs(join(abs_dir, 'ckpt'))
        with open(join(abs_dir, 'meta.json'), 'w') as f:
            json.dump(net_args['abstractor'], f, indent=4)
        torch.save(abs_ckpt, join(abs_dir, 'ckpt/ckpt-0-0'))
        with open(join(abs_dir, 'vocab.pkl'), 'wb') as f:
            pkl.dump(abs_vocab, f)
    # save configuration
    meta = {}
    meta['net']           = 'rnn-ext_abs_rl'
    meta['net_args']      = net_args
    meta['train_params']  = train_params
    with open(join(args.path, 'meta.json'), 'w') as f:
        json.dump(meta, f, indent=4)
    with open(join(args.path, 'agent_vocab.pkl'), 'wb') as f:
        pkl.dump(agent_vocab, f)

    # prepare trainer
    grad_fn = get_grad_fn(agent, args.clip)
    optimizer = optim.Adam(agent.parameters(), **train_params['optimizer'][1])
    scheduler = ReduceLROnPlateau(optimizer, 'max', verbose=True,
                                  factor=args.decay, min_lr=0,
                                  patience=args.lr_p)

    pipeline = A2CPipeline(meta['net'], agent, abstractor,
                           train_batcher, val_batcher,
                           optimizer, grad_fn,
                           reward_fn, args.gamma,
                           stop_reward_fn, args.stop)
    trainer = BasicTrainer(pipeline, args.path,
                           args.ckpt_freq, args.patience, scheduler,
                           val_mode='score')

    print('start training with the following hyper-parameters:')
    print(meta)
    trainer.train()
Example #2
0
def train(args):
    if not exists(args.path):
        os.makedirs(args.path)

    # make net
    if args.docgraph or args.paragraph:
        agent, agent_vocab, abstractor, net_args = configure_net_graph(
            args.abs_dir, args.ext_dir, args.cuda, args.docgraph,
            args.paragraph)
    else:
        agent, agent_vocab, abstractor, net_args = configure_net(
            args.abs_dir, args.ext_dir, args.cuda, True, False, args.rl_dir)

    if args.bert_stride > 0:
        assert args.bert_stride == agent._bert_stride
    # configure training setting
    assert args.stop > 0
    train_params = configure_training('adam', args.lr, args.clip, args.decay,
                                      args.batch, args.gamma, args.reward,
                                      args.stop, 'rouge-1')

    if args.docgraph or args.paragraph:
        if args.bert:
            train_batcher, val_batcher = build_batchers_graph_bert(
                args.batch, args.key, args.adj_type, args.max_bert_word,
                args.docgraph, args.paragraph)
        else:
            train_batcher, val_batcher = build_batchers_graph(
                args.batch, args.key, args.adj_type, args.gold_key,
                args.docgraph, args.paragraph)
    elif args.bert:
        train_batcher, val_batcher = build_batchers_bert(
            args.batch, args.bert_sent, args.bert_stride, args.max_bert_word)
    else:
        train_batcher, val_batcher = build_batchers(args.batch)
    # TODO different reward
    if args.reward == 'rouge-l':
        reward_fn = compute_rouge_l
    elif args.reward == 'rouge-1':
        reward_fn = compute_rouge_n(n=1)
    elif args.reward == 'rouge-2':
        reward_fn = compute_rouge_n(n=2)
    elif args.reward == 'rouge-l-s':
        reward_fn = compute_rouge_l_summ
    else:
        raise Exception('Not prepared reward')
    stop_reward_fn = compute_rouge_n(n=1)

    # save abstractor binary
    if args.abs_dir is not None:
        abs_ckpt = {}
        abs_ckpt['state_dict'] = load_best_ckpt(args.abs_dir, reverse=True)
        abs_vocab = pkl.load(open(join(args.abs_dir, 'vocab.pkl'), 'rb'))
        abs_dir = join(args.path, 'abstractor')
        os.makedirs(join(abs_dir, 'ckpt'))
        with open(join(abs_dir, 'meta.json'), 'w') as f:
            json.dump(net_args['abstractor'], f, indent=4)
        torch.save(abs_ckpt, join(abs_dir, 'ckpt/ckpt-0-0'))
        with open(join(abs_dir, 'vocab.pkl'), 'wb') as f:
            pkl.dump(abs_vocab, f)
        # save configuration
    meta = {}
    meta['net'] = 'rnn-ext_abs_rl'
    meta['net_args'] = net_args
    meta['train_params'] = train_params
    with open(join(args.path, 'meta.json'), 'w') as f:
        json.dump(meta, f, indent=4)
    with open(join(args.path, 'agent_vocab.pkl'), 'wb') as f:
        pkl.dump(agent_vocab, f)

    # prepare trainer
    grad_fn = get_grad_fn(agent, args.clip)
    optimizer = optim.Adam(agent.parameters(), **train_params['optimizer'][1])
    scheduler = ReduceLROnPlateau(optimizer,
                                  'max',
                                  verbose=True,
                                  factor=args.decay,
                                  min_lr=1e-5,
                                  patience=args.lr_p)

    if args.docgraph or args.paragraph:
        entity = True
    else:
        entity = False
    pipeline = SCPipeline(meta['net'], agent, abstractor, train_batcher,
                          val_batcher, optimizer, grad_fn, reward_fn, entity,
                          args.bert)

    trainer = BasicTrainer(pipeline,
                           args.path,
                           args.ckpt_freq,
                           args.patience,
                           scheduler,
                           val_mode='score')

    print('start training with the following hyper-parameters:')
    print(meta)
    trainer.train()