Esempio n. 1
0
def train(model, tokenizer, train_data, valid_data, args, eos=False):
    model.train()

    train_dataset = TextDataset(train_data)
    train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset),
                                  batch_size=args.train_batch_size, num_workers=args.num_workers,
                                  collate_fn=lambda x: collate_fn(x, tokenizer, args.max_seq_length, eos=eos, tokenizer_type=args.tokenizer))

    valid_dataset = TextDataset(valid_data)
    valid_dataloader = DataLoader(valid_dataset, sampler=SequentialSampler(valid_dataset),
                                  batch_size=args.eval_batch_size, num_workers=args.num_workers,
                                  collate_fn=lambda x: collate_fn(x, tokenizer, args.max_seq_length, eos=eos, tokenizer_type=args.tokenizer))

    valid_noisy = [x['noisy'] for x in valid_data]
    valid_clean = [x['clean'] for x in valid_data]

    epochs = (args.max_steps - 1) // len(train_dataloader) + 1
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
                                 betas=eval(args.adam_betas), eps=args.eps,
                                 weight_decay=args.weight_decay)
    lr_lambda = lambda x: x / args.num_warmup_steps if x <= args.num_warmup_steps else (x / args.num_warmup_steps) ** -0.5
    scheduler = LambdaLR(optimizer, lr_lambda)

    step = 0
    best_val_gleu = -float("inf")
    meter = Meter()
    for epoch in range(1, epochs + 1):
        print("===EPOCH: ", epoch)
        for batch in train_dataloader:
            step += 1
            batch = tuple(t.to(args.device) for t in batch)
            loss, items = calc_loss(model, batch)
            meter.add(*items)

            loss.backward()
            if args.max_grad_norm > 0:
                nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()
            model.zero_grad()
            scheduler.step()

            if step % args.log_interval == 0:
                lr = scheduler.get_lr()[0]
                loss_sent, loss_token = meter.average()

                logger.info(f' [{step:5d}] lr {lr:.6f} | {meter.print_str(True)}')
                nsml.report(step=step, scope=locals(), summary=True,
                            train__lr=lr, train__loss_sent=loss_sent, train__token_ppl=math.exp(loss_token))
                meter.init()

            if step % args.eval_interval == 0:
                start_eval = time.time()
                (val_loss, val_loss_token), valid_str = evaluate(model, valid_dataloader, args)
                prediction = correct(model, tokenizer, valid_noisy, args, eos=eos, length_limit=0.1)
                val_em = em(prediction, valid_clean)
                cnt = 0
                for noisy, pred, clean in zip(valid_noisy, prediction, valid_clean):
                    print(f'[{noisy}], [{pred}], [{clean}]')
                    # 10개만 출력하기
                    cnt += 1
                    if cnt == 20:
                        break
                val_gleu = gleu(prediction, valid_clean)

                logger.info('-' * 89)
                logger.info(f' [{step:6d}] valid | {valid_str} | em {val_em:5.2f} | gleu {val_gleu:5.2f}')
                logger.info('-' * 89)
                nsml.report(step=step, scope=locals(), summary=True,
                            valid__loss_sent=val_loss, valid__token_ppl=math.exp(val_loss_token),
                            valid__em=val_em, valid__gleu=val_gleu)

                if val_gleu > best_val_gleu:
                    best_val_gleu = val_gleu
                    nsml.save("best")
                meter.start += time.time() - start_eval

            if step >= args.max_steps:
                break
        #nsml.save(epoch)
        if step >= args.max_steps:
            break
Esempio n. 2
0
def train(model, tokenizer, train_data, valid_data, args):
    model.train()

    train_dataset = TextDataset(train_data)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=RandomSampler(train_dataset),
                                  batch_size=args.train_batch_size,
                                  num_workers=args.num_workers,
                                  collate_fn=lambda x: collate_fn_bert(
                                      x, tokenizer, args.max_seq_length))

    valid_dataset = TextDataset(valid_data)
    valid_dataloader = DataLoader(valid_dataset,
                                  sampler=SequentialSampler(valid_dataset),
                                  batch_size=args.eval_batch_size,
                                  num_workers=args.num_workers,
                                  collate_fn=lambda x: collate_fn_bert(
                                      x, tokenizer, args.max_seq_length))

    valid_noisy = [x['noisy'] for x in valid_data]
    valid_clean = [x['clean'] for x in valid_data]

    epochs = (args.max_steps - 1) // len(train_dataloader) + 1
    # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
    #                              betas=eval(args.adam_betas), eps=args.eps,
    #                              weight_decay=args.weight_decay)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    lr_lambda = lambda x: x / args.num_warmup_steps if x <= args.num_warmup_steps else (
        x / args.num_warmup_steps)**-0.5
    scheduler = LambdaLR(optimizer, lr_lambda)

    step = 0
    best_val_gleu = -float("inf")
    meter = Meter()
    for epoch in range(1, epochs + 1):
        for batch in train_dataloader:

            step += 1
            batch = tuple(t.to(args.device) for t in batch)
            noise_input_ids, clean_input_ids, noise_mask, clean_mask = batch
            #print("noise shape: {}, clean shape: {}".format(noise_input_ids.shape, clean_input_ids.shape))

            outputs = model(noise_input_ids,
                            labels=clean_input_ids,
                            attention_mask=noise_mask)
            loss = outputs[0]
            predict_score = outputs[1]

            bsz = clean_input_ids.size(0)
            items = [loss.data.item(), bsz, clean_mask.sum().item()]
            #print("items: ", items)
            meter.add(*items)

            loss.backward()
            if args.max_grad_norm > 0:
                nn.utils.clip_grad_norm_(model.parameters(),
                                         args.max_grad_norm)
            optimizer.step()
            model.zero_grad()
            scheduler.step()

            if step % args.log_interval == 0:
                lr = scheduler.get_lr()[0]
                loss_sent, loss_token = meter.average()

                logger.info(
                    f' [{step:5d}] lr {lr:.6f} | {meter.print_str(True)}')
                nsml.report(step=step,
                            scope=locals(),
                            summary=True,
                            train__lr=lr,
                            train__loss_sent=loss_sent,
                            train__token_ppl=math.exp(loss_token))
                meter.init()

            if step % args.eval_interval == 0:
                start_eval = time.time()
                (val_loss, val_loss_token), valid_str = evaluate_kcBert(
                    model, valid_dataloader, args)
                prediction = correct_kcBert(model,
                                            tokenizer,
                                            valid_noisy,
                                            args,
                                            length_limit=0.1)
                val_em = em(prediction, valid_clean)
                cnt = 0
                for noisy, pred, clean in zip(valid_noisy, prediction,
                                              valid_clean):
                    print(f'[{noisy}], [{pred}], [{clean}]')
                    # 10개만 출력하기
                    cnt += 1
                    if cnt == 20:
                        break
                # print("len of prediction: {}, len of valid_clean: {}", len(prediction), len(valid_clean))
                val_gleu = gleu(prediction, valid_clean)

                logger.info('-' * 89)
                logger.info(
                    f' [{step:6d}] valid | {valid_str} | em {val_em:5.2f} | gleu {val_gleu:5.2f}'
                )
                logger.info('-' * 89)
                nsml.report(step=step,
                            scope=locals(),
                            summary=True,
                            valid__loss_sent=val_loss,
                            valid__token_ppl=math.exp(val_loss_token),
                            valid__em=val_em,
                            valid__gleu=val_gleu)

                if val_gleu > best_val_gleu:
                    best_val_gleu = val_gleu
                    nsml.save("best")
                meter.start += time.time() - start_eval

            if step >= args.max_steps:
                break
        if step >= args.max_steps:
            break