Exemplo n.º 1
0
def train(args):
    if not exists(args.path):
        os.makedirs(args.path)

    # make net
    agent, agent_vocab, abstractor, net_args = configure_net(
        args.abs_dir, args.ext_dir, args.cuda)

    # configure training setting
    assert args.stop > 0
    train_params = configure_training(
        'adam', args.lr, args.clip, args.decay, args.batch,
        args.gamma, args.reward, args.stop, 'rouge-1'
    )
    train_batcher, val_batcher = build_batchers(args.batch)
    # TODO different reward
    reward_fn = compute_rouge_l
    stop_reward_fn = compute_rouge_n(n=1)

    # save abstractor binary
    if args.abs_dir is not None:
        abs_ckpt = {}
        abs_ckpt['state_dict'] = load_best_ckpt(args.abs_dir)
        abs_vocab = pkl.load(open(join(args.abs_dir, 'vocab.pkl'), 'rb'))
        abs_dir = join(args.path, 'abstractor')
        os.makedirs(join(abs_dir, 'ckpt'))
        with open(join(abs_dir, 'meta.json'), 'w') as f:
            json.dump(net_args['abstractor'], f, indent=4)
        torch.save(abs_ckpt, join(abs_dir, 'ckpt/ckpt-0-0'))
        with open(join(abs_dir, 'vocab.pkl'), 'wb') as f:
            pkl.dump(abs_vocab, f)
    # save configuration
    meta = {}
    meta['net']           = 'rnn-ext_abs_rl'
    meta['net_args']      = net_args
    meta['train_params']  = train_params
    with open(join(args.path, 'meta.json'), 'w') as f:
        json.dump(meta, f, indent=4)
    with open(join(args.path, 'agent_vocab.pkl'), 'wb') as f:
        pkl.dump(agent_vocab, f)

    # prepare trainer
    grad_fn = get_grad_fn(agent, args.clip)
    optimizer = optim.Adam(agent.parameters(), **train_params['optimizer'][1])
    scheduler = ReduceLROnPlateau(optimizer, 'max', verbose=True,
                                  factor=args.decay, min_lr=0,
                                  patience=args.lr_p)

    pipeline = A2CPipeline(meta['net'], agent, abstractor,
                           train_batcher, val_batcher,
                           optimizer, grad_fn,
                           reward_fn, args.gamma,
                           stop_reward_fn, args.stop)
    trainer = BasicTrainer(pipeline, args.path,
                           args.ckpt_freq, args.patience, scheduler,
                           val_mode='score')

    print('start training with the following hyper-parameters:')
    print(meta)
    trainer.train()
Exemplo n.º 2
0
def configure_net(abs_dir, ext_dir, cuda, sc, tv, rl_dir=''):
    """ load pretrained sub-modules and build the actor-critic network"""
    # load pretrained abstractor model
    if abs_dir is not None:
        abstractor = Abstractor(abs_dir, MAX_ABS_LEN, cuda)
    else:
        abstractor = identity

    # load ML trained extractor net and buiild RL agent
    extractor, agent_vocab = load_ext_net(ext_dir)
    if sc:
        agent = SelfCritic(extractor,
                           ArticleBatcher(agent_vocab, cuda),
                           time_variant=tv)
    else:
        agent = ActorCritic(extractor._sent_enc, extractor._art_enc,
                            extractor._extractor,
                            ArticleBatcher(agent_vocab, cuda))

    if rl_dir != '':
        ckpt = load_best_ckpt(rl_dir, reverse=True)
        agent.load_state_dict(ckpt)

    if cuda:
        agent = agent.cuda()

    net_args = {}
    net_args['abstractor'] = (None if abs_dir is None else json.load(
        open(join(abs_dir, 'meta.json'))))
    net_args['extractor'] = json.load(open(join(ext_dir, 'meta.json')))
    print('agent:', agent)

    return agent, agent_vocab, abstractor, net_args
Exemplo n.º 3
0
def configure_net(net_type,
                  vocab_size,
                  emb_dim,
                  conv_hidden,
                  lstm_hidden,
                  lstm_layer,
                  bidirectional,
                  prev_ckpt=None):
    assert net_type in ['ff', 'rnn', 'trans_rnn']
    net_args = {}
    net_args['vocab_size'] = vocab_size
    net_args['emb_dim'] = emb_dim
    net_args['conv_hidden'] = conv_hidden
    net_args['lstm_hidden'] = lstm_hidden
    net_args['lstm_layer'] = lstm_layer
    net_args['bidirectional'] = bidirectional

    if net_type == 'ff':
        net = ExtractSumm(**net_args)
    elif net_type == 'trans_rnn':
        net = TransExtractSumm(**net_args)
    else:
        net = PtrExtractSumm(**net_args)
    if prev_ckpt is not None:
        ext_ckpt = load_best_ckpt(prev_ckpt)
        net.load_state_dict(ext_ckpt)
    return net, net_args
Exemplo n.º 4
0
def main(args):
    # create data batcher, vocabulary
    # batcher
    with open(join(DATA_DIR, 'vocab_cnt.pkl'), 'rb') as f:
        wc = pkl.load(f)
    word2id = make_vocab(wc, args.vsize)
    train_batcher, val_batcher = build_batchers(word2id,
                                                args.cuda, args.debug,  cross_rev_bucket=args.cross_rev_bucket)

    # make net
    net, net_args = configure_net(len(word2id), args.emb_dim,
                                  args.n_hidden, args.bi, args.n_layer)
    if args.prev_trained is True:
        abs_ckpt = load_best_ckpt(args.prev_trained)
        net.load_state_dict(abs_ckpt)

    if args.w2v:
        # NOTE: the pretrained embedding having the same dimension
        #       as args.emb_dim should already be trained
        embedding, _ = make_embedding(
            {i: w for w, i in word2id.items()}, args.w2v)
        net.set_embedding(embedding)

    # configure training setting
    criterion, train_params = configure_training(
        'adam', args.lr, args.clip, args.decay, args.batch
    )

    # save experiment setting
    if not exists(args.path):
        os.makedirs(args.path)
    with open(join(args.path, 'vocab.pkl'), 'wb') as f:
        pkl.dump(word2id, f, pkl.HIGHEST_PROTOCOL)
    meta = {}
    meta['net']           = 'base_abstractor'
    meta['net_args']      = net_args
    meta['traing_params'] = train_params
    with open(join(args.path, 'meta.json'), 'w') as f:
        json.dump(meta, f, indent=4)

    # prepare trainer
    val_fn = basic_validate(net, criterion)
    grad_fn = get_basic_grad_fn(net, args.clip)
    optimizer = optim.Adam(net.parameters(), **train_params['optimizer'][1])
    scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True,
                                  factor=args.decay, min_lr=0,
                                  patience=args.lr_p)

    if args.cuda:
        net = net.cuda()
    pipeline = BasicPipeline(meta['net'], net,
                             train_batcher, val_batcher, args.batch, val_fn,
                             criterion, optimizer, grad_fn)
    trainer = BasicTrainer(pipeline, args.path,
                           args.ckpt_freq, args.patience, scheduler)

    print('start training with the following hyper-parameters:')
    print(meta)
    trainer.train()
def load_ext_net(ext_dir):
    ext_meta = json.load(open(join(ext_dir, 'meta.json')))
    assert ext_meta['net'] == 'ml_rnn_extractor'
    ext_ckpt = load_best_ckpt(ext_dir)
    ext_args = ext_meta['net_args']
    vocab = pkl.load(open(join(ext_dir, 'vocab.pkl'), 'rb'))
    ext = PtrExtractSumm(**ext_args)
    ext.load_state_dict(ext_ckpt)
    return ext, vocab
Exemplo n.º 6
0
def load_abs_net(abs_dir):
    abs_meta = json.load(open(join(abs_dir, 'meta.json')))
    assert abs_meta['net'] == 'base_abstractor'
    abs_args = abs_meta['net_args']
    abs_ckpt = load_best_ckpt(abs_dir)
    word2id = pkl.load(open(join(abs_dir, 'vocab.pkl'), 'rb'))
    abstractor = CopySumm(**abs_args)
    abstractor.load_state_dict(abs_ckpt)
    return abstractor, word2id
Exemplo n.º 7
0
def load_ext_net(ext_dir):
    ext_meta = json.load(open(join(ext_dir, 'meta.json')))
    assert ext_meta['net'] == 'ml_rnn_extractor' or ext_meta[
        'net'] == "ml_entity_extractor"
    ext_ckpt = load_best_ckpt(ext_dir)
    ext_args = ext_meta['net_args']
    vocab = pkl.load(open(join(ext_dir, 'vocab.pkl'), 'rb'))
    if ext_meta['net'] == 'ml_rnn_extractor':
        ext = PtrExtractSumm(**ext_args)
    elif ext_meta['net'] == "ml_entity_extractor":
        ext = PtrExtractSummEntity(**ext_args)
    else:
        raise Exception('not implemented')
    ext.load_state_dict(ext_ckpt)
    return ext, vocab
Exemplo n.º 8
0
def load_ext_net(ext_dir):
    ext_meta = json.load(open(join(ext_dir, 'meta.json')))
    assert ext_meta['net'] in [
        'ml_rnn_extractor', "ml_gat_extractor", "ml_subgraph_gat_extractor"
    ]
    net_name = ext_meta['net']
    ext_ckpt = load_best_ckpt(ext_dir)
    ext_args = ext_meta['net_args']
    vocab = pkl.load(open(join(ext_dir, 'vocab.pkl'), 'rb'))
    if ext_meta['net'] == 'ml_rnn_extractor':
        ext = PtrExtractSumm(**ext_args)
    elif ext_meta['net'] == "ml_gat_extractor":
        ext = PtrExtractSummGAT(**ext_args)
    elif ext_meta['net'] == "ml_subgraph_gat_extractor":
        ext = PtrExtractSummSubgraph(**ext_args)
    else:
        raise Exception('not implemented')
    ext.load_state_dict(ext_ckpt)
    return ext, vocab
def load_rl_ckpt(abs_dir, ext_dir, cuda):
    ext_meta = json.load(open(join(ext_dir, 'meta.json')))
    assert ext_meta['net'] == 'rnn-ext_abs_rl'
    ext_args = ext_meta['net_args']['extractor']['net_args']
    vocab = pkl.load(open(join(ext_dir, 'agent_vocab.pkl'), 'rb'))
    extractor = PtrExtractSumm(**ext_args)
    abstractor_sent = Abstractor(abs_dir, MAX_ABS_LEN, cuda)

    agent = ActorCritic(extractor._sent_enc, extractor._art_enc,
                        extractor._extractor, ArticleBatcher(vocab, cuda))

    target_agent = ActorCritic(extractor._sent_enc,
                               extractor._art_enc, extractor._extractor,
                               ArticleBatcher(vocab, cuda))

    ext_ckpt = load_best_ckpt(ext_dir, reverse=True)
    agent.load_state_dict(ext_ckpt)

    device = torch.device('cuda' if cuda else 'cpu')
    agent = agent.to(device)
    target_agent = target_agent.to(device)

    return agent, target_agent, vocab, abstractor_sent, ext_meta
Exemplo n.º 10
0
def load_ext_net(ext_dir, aux_device, args):
    ext_meta = json.load(open(join(ext_dir, 'meta.json')))
    assert ext_meta['net'] == 'ml_rnn_extractor'
    ext_ckpt = load_best_ckpt(ext_dir)
    ext_args = ext_meta['net_args']
    vocab = pkl.load(open(join(ext_dir, 'vocab.pkl'), 'rb'))

    use_bert = 'bert_type' in ext_args.keys()
    fix_bert = getattr(ext_args, 'fix_bert', args.fix_bert)
    ext_args['fix_bert'] = fix_bert
    ext_args['aux_device'] = aux_device
    if use_bert:
        print('Use Bert based Extractor ...')
        ext = BertPtrExtractSumm(**ext_args)
        bert_type = ext_args['bert_type']
        tokenizer_cache = ext_args['tokenizer_cache']
        bert_config = (bert_type, tokenizer_cache)
    else:
        ext = PtrExtractSumm(**ext_args)
        bert_config = None

    ext.load_state_dict(ext_ckpt)
    return ext, vocab, use_bert, bert_config
Exemplo n.º 11
0
def train(args):
    if not exists(args.path):
        os.makedirs(args.path)

    # make net
    if args.docgraph or args.paragraph:
        agent, agent_vocab, abstractor, net_args = configure_net_graph(
            args.abs_dir, args.ext_dir, args.cuda, args.docgraph,
            args.paragraph)
    else:
        agent, agent_vocab, abstractor, net_args = configure_net(
            args.abs_dir, args.ext_dir, args.cuda, True, False, args.rl_dir)

    if args.bert_stride > 0:
        assert args.bert_stride == agent._bert_stride
    # configure training setting
    assert args.stop > 0
    train_params = configure_training('adam', args.lr, args.clip, args.decay,
                                      args.batch, args.gamma, args.reward,
                                      args.stop, 'rouge-1')

    if args.docgraph or args.paragraph:
        if args.bert:
            train_batcher, val_batcher = build_batchers_graph_bert(
                args.batch, args.key, args.adj_type, args.max_bert_word,
                args.docgraph, args.paragraph)
        else:
            train_batcher, val_batcher = build_batchers_graph(
                args.batch, args.key, args.adj_type, args.gold_key,
                args.docgraph, args.paragraph)
    elif args.bert:
        train_batcher, val_batcher = build_batchers_bert(
            args.batch, args.bert_sent, args.bert_stride, args.max_bert_word)
    else:
        train_batcher, val_batcher = build_batchers(args.batch)
    # TODO different reward
    if args.reward == 'rouge-l':
        reward_fn = compute_rouge_l
    elif args.reward == 'rouge-1':
        reward_fn = compute_rouge_n(n=1)
    elif args.reward == 'rouge-2':
        reward_fn = compute_rouge_n(n=2)
    elif args.reward == 'rouge-l-s':
        reward_fn = compute_rouge_l_summ
    else:
        raise Exception('Not prepared reward')
    stop_reward_fn = compute_rouge_n(n=1)

    # save abstractor binary
    if args.abs_dir is not None:
        abs_ckpt = {}
        abs_ckpt['state_dict'] = load_best_ckpt(args.abs_dir, reverse=True)
        abs_vocab = pkl.load(open(join(args.abs_dir, 'vocab.pkl'), 'rb'))
        abs_dir = join(args.path, 'abstractor')
        os.makedirs(join(abs_dir, 'ckpt'))
        with open(join(abs_dir, 'meta.json'), 'w') as f:
            json.dump(net_args['abstractor'], f, indent=4)
        torch.save(abs_ckpt, join(abs_dir, 'ckpt/ckpt-0-0'))
        with open(join(abs_dir, 'vocab.pkl'), 'wb') as f:
            pkl.dump(abs_vocab, f)
        # save configuration
    meta = {}
    meta['net'] = 'rnn-ext_abs_rl'
    meta['net_args'] = net_args
    meta['train_params'] = train_params
    with open(join(args.path, 'meta.json'), 'w') as f:
        json.dump(meta, f, indent=4)
    with open(join(args.path, 'agent_vocab.pkl'), 'wb') as f:
        pkl.dump(agent_vocab, f)

    # prepare trainer
    grad_fn = get_grad_fn(agent, args.clip)
    optimizer = optim.Adam(agent.parameters(), **train_params['optimizer'][1])
    scheduler = ReduceLROnPlateau(optimizer,
                                  'max',
                                  verbose=True,
                                  factor=args.decay,
                                  min_lr=1e-5,
                                  patience=args.lr_p)

    if args.docgraph or args.paragraph:
        entity = True
    else:
        entity = False
    pipeline = SCPipeline(meta['net'], agent, abstractor, train_batcher,
                          val_batcher, optimizer, grad_fn, reward_fn, entity,
                          args.bert)

    trainer = BasicTrainer(pipeline,
                           args.path,
                           args.ckpt_freq,
                           args.patience,
                           scheduler,
                           val_mode='score')

    print('start training with the following hyper-parameters:')
    print(meta)
    trainer.train()