Exemple #1
0
    parser.add_argument('--train_batch_size', type=int, default=88888)

    return parser.parse_args()

if __name__ == '__main__':
    from extract import LexicalMap
    import time
    args = parse_config()
    vocabs = dict()
    vocabs['concept'] = Vocab(args.concept_vocab, 5, [CLS])
    vocabs['token'] = Vocab(args.token_vocab, 5, [STR, END])
    vocabs['predictable_token'] = Vocab(args.predictable_token_vocab, 5, [END])
    vocabs['token_char'] = Vocab(args.token_char_vocab, 100, [STR, END])
    vocabs['concept_char'] = Vocab(args.concept_char_vocab, 100, [STR, END])
    vocabs['relation'] = Vocab(args.relation_vocab, 5, [CLS, rCLS, SEL, TL])
    lexical_mapping = LexicalMap()

    train_data = DataLoader(vocabs, lexical_mapping, args.train_data, args.train_batch_size, for_train=True)
    epoch_idx = 0
    batch_idx = 0
    last = 0
    while True:
        st = time.time()
        for d in train_data:
            d = move_to_device(d, torch.device('cpu'))
            batch_idx += 1
            #if d['concept'].size(0) > 5:
            #    continue
            print (epoch_idx, batch_idx, d['concept'].size(), d['token_in'].size())
            c_len, bsz = d['concept'].size()
            t_len, bsz = d['token_in'].size()
Exemple #2
0
def main(args, local_rank):
    vocabs = dict()
    vocabs['tok'] = Vocab(args.tok_vocab, 5, [CLS])
    vocabs['lem'] = Vocab(args.lem_vocab, 5, [CLS])
    vocabs['pos'] = Vocab(args.pos_vocab, 5, [CLS])
    vocabs['ner'] = Vocab(args.ner_vocab, 5, [CLS])
    vocabs['predictable_concept'] = Vocab(args.predictable_concept_vocab, 10,
                                          [DUM, END])
    vocabs['concept'] = Vocab(args.concept_vocab, 5, [DUM, END])
    vocabs['rel'] = Vocab(args.rel_vocab, 50, [NIL])
    vocabs['word_char'] = Vocab(args.word_char_vocab, 100, [CLS, END])
    vocabs['concept_char'] = Vocab(args.concept_char_vocab, 100, [CLS, END])
    lexical_mapping = LexicalMap(args.lexical_mapping)
    if args.pretrained_word_embed is not None:
        vocab, pretrained_embs = load_pretrained_word_embed(
            args.pretrained_word_embed)
        vocabs['glove'] = vocab
    else:
        pretrained_embs = None

    for name in vocabs:
        print((name, vocabs[name].size))

    torch.manual_seed(19940117)
    torch.cuda.manual_seed_all(19940117)
    random.seed(19940117)
    device = torch.device('cuda', local_rank)
    #print(device)
    #exit()
    model = Parser(vocabs,
                   args.word_char_dim,
                   args.word_dim,
                   args.pos_dim,
                   args.ner_dim,
                   args.concept_char_dim,
                   args.concept_dim,
                   args.cnn_filters,
                   args.char2word_dim,
                   args.char2concept_dim,
                   args.embed_dim,
                   args.ff_embed_dim,
                   args.num_heads,
                   args.dropout,
                   args.snt_layers,
                   args.graph_layers,
                   args.inference_layers,
                   args.rel_dim,
                   pretrained_embs,
                   device=device)

    if args.world_size > 1:
        torch.manual_seed(19940117 + dist.get_rank())
        torch.cuda.manual_seed_all(19940117 + dist.get_rank())
        random.seed(19940117 + dist.get_rank())

    model = model.cuda(local_rank)
    train_data = DataLoader(vocabs,
                            lexical_mapping,
                            args.train_data,
                            args.train_batch_size,
                            for_train=True)
    dev_data = DataLoader(vocabs,
                          lexical_mapping,
                          args.dev_data,
                          args.dev_batch_size,
                          for_train=True)
    train_data.set_unk_rate(args.unk_rate)

    weight_decay_params = []
    no_weight_decay_params = []
    for name, param in model.named_parameters():
        if name.endswith('bias') or 'layer_norm' in name:
            no_weight_decay_params.append(param)
        else:
            weight_decay_params.append(param)
    grouped_params = [{
        'params': weight_decay_params,
        'weight_decay': 1e-4
    }, {
        'params': no_weight_decay_params,
        'weight_decay': 0.
    }]
    optimizer = AdamWeightDecayOptimizer(grouped_params,
                                         lr=args.lr,
                                         betas=(0.9, 0.999),
                                         eps=1e-6)

    batches_acm, loss_acm, concept_loss_acm, arc_loss_acm, rel_loss_acm = 0, 0, 0, 0, 0
    #model.load_state_dict(torch.load('./ckpt/epoch297_batch49999')['model'])
    discarded_batches_acm = 0
    queue = mp.Queue(10)
    train_data_generator = mp.Process(target=data_proc,
                                      args=(train_data, queue))
    train_data_generator.start()

    used_batches = 0
    if args.resume_ckpt:
        ckpt = torch.load(args.resume_ckpt)
        model.load_state_dict(ckpt['model'])
        optimizer.load_state_dict(ckpt['optimizer'])
        batches_acm = ckpt['batches_acm']
        del ckpt

    model.train()
    epoch = 0
    while True:
        batch = queue.get()
        #print("epoch",epoch)
        #print("batches_acm",batches_acm)
        #print("used_batches",used_batches)
        if isinstance(batch, str):
            epoch += 1
            print('epoch', epoch, 'done', 'batches', batches_acm)
        else:
            batch = move_to_device(batch, model.device)
            concept_loss, arc_loss, rel_loss = model(batch)
            loss = (concept_loss + arc_loss +
                    rel_loss) / args.batches_per_update
            loss_value = loss.item()
            concept_loss_value = concept_loss.item()
            arc_loss_value = arc_loss.item()
            rel_loss_value = rel_loss.item()
            if batches_acm > args.warmup_steps and arc_loss_value > 5. * (
                    arc_loss_acm / batches_acm):
                discarded_batches_acm += 1
                print('abnormal', concept_loss.item(), arc_loss.item(),
                      rel_loss.item())
                continue
            loss_acm += loss_value
            concept_loss_acm += concept_loss_value
            arc_loss_acm += arc_loss_value
            rel_loss_acm += rel_loss_value
            loss.backward()

            used_batches += 1
            if not (used_batches % args.batches_per_update
                    == -1 % args.batches_per_update):
                continue
            batches_acm += 1

            if args.world_size > 1:
                average_gradients(model)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            update_lr(optimizer, args.embed_dim, batches_acm,
                      args.warmup_steps)
            optimizer.step()
            optimizer.zero_grad()
            if args.world_size == 1 or (dist.get_rank() == 0):
                if batches_acm % args.print_every == -1 % args.print_every:
                    print(
                        'Train Epoch %d, Batch %d, Discarded Batch %d, conc_loss %.3f, arc_loss %.3f, rel_loss %.3f'
                        % (epoch, batches_acm, discarded_batches_acm,
                           concept_loss_acm / batches_acm, arc_loss_acm /
                           batches_acm, rel_loss_acm / batches_acm))
                    model.train()

                if batches_acm % args.eval_every == -1 % args.eval_every:
                    model.eval()
                    torch.save(
                        {
                            'args': args,
                            'model': model.state_dict(),
                            'batches_acm': batches_acm,
                            'optimizer': optimizer.state_dict()
                        },
                        '%s/epoch%d_batch%d' % (args.ckpt, epoch, batches_acm))
                    model.train()
Exemple #3
0
def main(args, local_rank):
    vocabs = dict()
    vocabs['concept'] = Vocab(args.concept_vocab, 5, [CLS])
    vocabs['token'] = Vocab(args.token_vocab, 5, [STR, END])
    vocabs['predictable_token'] = Vocab(args.predictable_token_vocab, 5, [END])
    vocabs['token_char'] = Vocab(args.token_char_vocab, 100, [STR, END])
    vocabs['concept_char'] = Vocab(args.concept_char_vocab, 100, [STR, END])
    vocabs['relation'] = Vocab(args.relation_vocab, 5, [CLS, rCLS, SEL, TL])
    lexical_mapping = LexicalMap()

    for name in vocabs:
        print((name, vocabs[name].size, vocabs[name].coverage))

    torch.manual_seed(19940117)
    torch.cuda.manual_seed_all(19940117)
    random.seed(19940117)

    #device = torch.device('cpu')
    device = torch.device('cuda', local_rank)
    model = Generator(vocabs, args.token_char_dim, args.token_dim,
                      args.concept_char_dim, args.concept_dim,
                      args.cnn_filters, args.char2word_dim,
                      args.char2concept_dim, args.rel_dim,
                      args.rnn_hidden_size, args.rnn_num_layers,
                      args.embed_dim, args.ff_embed_dim, args.num_heads,
                      args.dropout, args.snt_layers, args.graph_layers,
                      args.inference_layers, args.pretrained_file, device)

    if args.world_size > 1:
        torch.manual_seed(19940117 + dist.get_rank())
        torch.cuda.manual_seed_all(19940117 + dist.get_rank())
        random.seed(19940117 + dist.get_rank())

    model = model.to(device)
    train_data = DataLoader(vocabs,
                            lexical_mapping,
                            args.train_data,
                            args.train_batch_size,
                            for_train=True)
    #dev_data = DataLoader(vocabs, lexical_mapping, args.dev_data, args.dev_batch_size, for_train=False)
    train_data.set_unk_rate(args.unk_rate)

    weight_decay_params = []
    no_weight_decay_params = []
    for name, param in model.named_parameters():
        if name.endswith('bias') or 'layer_norm' in name:
            no_weight_decay_params.append(param)
        else:
            weight_decay_params.append(param)
    grouped_params = [{
        'params': weight_decay_params,
        'weight_decay': 1e-4
    }, {
        'params': no_weight_decay_params,
        'weight_decay': 0.
    }]
    optimizer = AdamWeightDecayOptimizer(grouped_params,
                                         lr=args.lr,
                                         betas=(0.9, 0.999),
                                         eps=1e-6)

    batches_acm, loss_acm = 0, 0
    discarded_batches_acm = 0

    queue = mp.Queue(10)
    train_data_generator = mp.Process(target=data_proc,
                                      args=(train_data, queue))
    train_data_generator.start()

    model.train()
    epoch = 0
    while batches_acm < args.total_train_steps:
        batch = queue.get()
        if isinstance(batch, str):
            epoch += 1
            print('epoch', epoch, 'done', 'batches', batches_acm)
            continue
        batch = move_to_device(batch, device)
        loss = model(batch)
        exit(0)

        loss_value = loss.item()
        if batches_acm > args.warmup_steps and loss_value > 5. * (loss_acm /
                                                                  batches_acm):
            discarded_batches_acm += 1
            print('abnormal', loss_value)
            continue
        loss_acm += loss_value
        batches_acm += 1
        loss.backward()
        if args.world_size > 1:
            average_gradients(model)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        update_lr(optimizer, args.embed_dim, batches_acm, args.warmup_steps)
        optimizer.step()
        optimizer.zero_grad()
        #------------
        if args.world_size == 1 or (dist.get_rank() == 0):
            if batches_acm % args.print_every == -1 % args.print_every:
                print(
                    'Train Epoch %d, Batch %d, Discarded Batch %d, loss %.3f' %
                    (epoch, batches_acm, discarded_batches_acm,
                     loss_acm / batches_acm))
                model.train()
            if batches_acm > args.warmup_steps and batches_acm % args.eval_every == -1 % args.eval_every:
                #model.eval()
                #bleu, chrf = validate(model, dev_data)
                #print ("epoch", "batch", "bleu", "chrf")
                #print (epoch, batches_acm, bleu, chrf)
                torch.save({
                    'args': args,
                    'model': model.state_dict()
                }, '%s/epoch%d_batch%d' % (args.ckpt, epoch, batches_acm))
                model.train()