Beispiel #1
0
def main():
    opt = options.set(training=True)
    pubg_data = data.PUBGDataset(opt)
    TRAIN_LEN = len(pubg_data)
    # exec(util.TEST_EMBEDDING)

    # pubg_data_small = Subset(pubg_data, range(0, 100))

    train_indices = range(0, TRAIN_LEN * 9 // 10)
    valid_indices = range(TRAIN_LEN * 9 // 10, TRAIN_LEN)
    pubg_train_loader = DataLoader(pubg_data,
                                   batch_size=opt.batch_size,
                                   num_workers=0,
                                   sampler=SubsetRandomSampler(train_indices))
    pubg_valid_loader = DataLoader(pubg_data,
                                   batch_size=opt.batch_size,
                                   num_workers=0,
                                   sampler=SubsetRandomSampler(valid_indices))

    #model
    pubg_reg = models.PUBGSimpleRegressor(opt)
    # pubg_reg = models.PUBGSimpleAE(opt)

    #optimizer
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  pubg_reg.parameters()),
                           lr=opt.lr_classifier)

    my_trainer = trainer.Trainer(opt, pubg_reg, optimizer, pubg_train_loader,
                                 pubg_valid_loader)
    my_trainer.train()
    exec(util.TEST_EMBEDDING)
Beispiel #2
0
#-*- coding: utf-8 -*-
import sys
from trainer import trainer
import os.path

# If you add a new function, use the follwing module to update the feature value which stored in model directory.
if __name__ == '__main__':
    if len(sys.argv) < 3:
        print(os.path.dirname(__file__))
        print '$python trainer.py model_number iteration_number'
        exit()

    t = trainer.Trainer()
    t.training(sys.argv[1], int(sys.argv[2]))
Beispiel #3
0
def train(args):
    trainer = base_trainer.Trainer()
    args, device = get_args(args)
    args, log, tbx = trainer.setup(args)

    # Get BPE
    log.info("Loading BPE...")
    bpe = get_bpe(args)
    log.info("Loaded {} BPE tokens".format(len(bpe)))

    # Get data loader
    log.info("Building dataset...")
    train_dataset, train_loader = get_dataset(args,
                                              args.train_record_file,
                                              bpe,
                                              shuffle=True)
    dev_dataset, dev_loader = get_dataset(args,
                                          args.train_record_file,
                                          bpe,
                                          shuffle=False)
    args.epoch_size = len(train_dataset)
    log.info("Train has {} examples".format(args.epoch_size))

    # Get model
    log.info("Building model...")
    model = get_model(args, bpe)
    model = trainer.setup_model(model, device)

    # Get optimizer, scheduler, and scaler
    optimizer = optim.AdamW(
        model.parameters(),
        args.lr,
        betas=(args.beta_1, args.beta_2),
        eps=args.eps,
        weight_decay=args.l2_wd,
    )

    get_num_steps(args)
    log.info("Scheduler will decay over {} steps".format(args.num_steps))
    scheduler = sched.get_linear_warmup_power_decay_scheduler(
        optimizer, args.warmup_steps, args.num_steps, power=args.power_decay)

    scaler = amp.GradScaler()
    optimizer, scheduler, scaler = trainer.setup_optimizer(
        optimizer, scheduler, scaler)

    # Train
    log.info("Training...")
    model.train()
    sample_num = 0
    samples_till_eval = args.eval_per_n_samples
    epoch = 0
    step = 0
    trainer.setup_saver()
    trainer.setup_random()
    sample_num, samples_till_eval, epoch, step = trainer.setup_step(
        step_vars=(sample_num, samples_till_eval, epoch, step))
    trainer.setup_close()

    while epoch != args.num_epochs:
        trainer.save_checkpoint(step_vars=(sample_num, samples_till_eval,
                                           epoch, step))
        epoch += 1
        log.info(f"Starting epoch {epoch}...")
        # Print histogram of weights every epoch
        for tags, params in model.named_parameters():
            tbx.add_histogram(tags, params.data, epoch)
        with torch.enable_grad(), tqdm(
                total=len(train_loader.dataset)) as progress_bar:
            for x, y, _, _ in train_loader:
                batch_size = x.size(0)
                loss, loss_val, _ = forward(x, y, args, device, model)
                loss = loss / args.gradient_accumulation

                # Backward
                scaler.scale(loss).backward()
                if (step + 1) % args.gradient_accumulation == 0:
                    scaler.unscale_(optimizer)
                    nn.utils.clip_grad_norm_(model.parameters(),
                                             args.max_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()
                    scheduler.step()
                    optimizer.zero_grad()

                # Log info
                step += 1
                sample_num += batch_size
                progress_bar.update(batch_size)
                progress_bar.set_postfix(epoch=epoch, NLL=loss_val)
                tbx.add_scalar("train/NLL", loss_val, sample_num)
                tbx.add_scalar("train/LR", optimizer.param_groups[0]["lr"],
                               sample_num)
                tbx.add_scalar("train/steps",
                               step // args.gradient_accumulation, sample_num)

    results, augs = augment(model, dev_loader, device, bpe, args)
    for k, v in results.items():
        tbx.add_scalar(f"dev/{k}", v, sample_num)
    save(args.train_aug_file, augs, "train aug")
trainer = trainer.Trainer(cuda=cuda,
                          student_model=student,
                          teacher_model=teacher,
                          style_net=style_net,
                          student_optimizer=stu_optim,
                          teacher_optimizer=tea_optim,
                          train_loader=train_loader,
                          val_loader=val_loader,
                          out=out,
                          source_record=source_record,
                          src_ave_method=src_ave_method,
                          tgt_ave_method=tgt_ave_method,
                          src_temperture=src_temperture,
                          tgt_temperture=tgt_temperture,
                          train_generator=train_generator,
                          style_weight=style_weight,
                          content_weight=content_weight,
                          style_net_optim=style_net_optim,
                          max_iter=max_iteration,
                          unsup_weight=unsup_weight,
                          clustering_weight=clustering_weight,
                          loss_type=loss_type,
                          pseudo_labeling=pseudo_labeling,
                          balance_function=balance_function,
                          confidence_thresh=confidence_thresh,
                          internal_weight=internal_weight,
                          size_average=size_average,
                          class_dist_reg=class_dist_reg,
                          interval_validate=interval_validate,
                          rampup_function=rampup_function,
                          src_style_alpha=src_style_alpha,
                          tgt_style_alpha=tgt_style_alpha,
                          src_transfer_rate=src_transfer_rate,
                          tgt_transfer_rate=tgt_transfer_rate,
                          tgt_style_method=tgt_style_method,
                          pad=pad)
Beispiel #5
0
def train(args):
    trainer = base_trainer.Trainer()
    args, device = get_args(args)
    args, log, tbx = trainer.setup(args)

    # Get BPE
    log.info("Loading BPE...")
    bpe = get_bpe(args)
    log.info("Loaded {} BPE tokens".format(len(bpe)))

    # Get data loader
    log.info("Building dataset...")
    train_dataset, train_loader = get_dataset(args, args.epoch_size,
                                              args.train_record_file, bpe)
    dev_dataset, dev_loader = get_dataset(args, args.dev_epoch_size,
                                          args.dev_record_file, bpe)

    # Get model
    log.info("Building model...")
    model = get_model(args, bpe)
    model = trainer.setup_model(model, device)

    # Get optimizer, scheduler, and scaler
    optimizer = optim.AdamW(
        model.parameters(),
        args.lr,
        betas=(args.beta_1, args.beta_2),
        eps=args.eps,
        weight_decay=args.l2_wd,
    )

    scheduler = sched.get_linear_warmup_power_decay_scheduler(
        optimizer, args.warmup_steps, args.num_steps, power=args.power_decay)

    scaler = amp.GradScaler()
    optimizer, scheduler, scaler = trainer.setup_optimizer(
        optimizer, scheduler, scaler)

    # Train
    log.info("Training...")
    model.train()
    sample_num = 0
    samples_till_eval = args.eval_per_n_samples
    epoch = 0
    step = 0
    trainer.setup_saver()
    trainer.setup_random()
    sample_num, samples_till_eval, epoch, step = trainer.setup_step(
        step_vars=(sample_num, samples_till_eval, epoch, step))
    trainer.setup_close()

    while epoch != args.num_epochs:
        trainer.save_checkpoint(step_vars=(sample_num, samples_till_eval,
                                           epoch, step))
        epoch += 1
        log.info(f"Starting epoch {epoch}...")
        # Print histogram of weights every epoch
        for tags, params in model.named_parameters():
            tbx.add_histogram(tags, params.data, epoch)
        with torch.enable_grad(), tqdm(
                total=len(train_loader.dataset)) as progress_bar:
            for x, y in train_loader:
                loss_weight = args.gradient_accumulation * (
                    args.mlm_samples * args.lambda_weight + 1)
                batch_size = x.size(0)
                loss, loss_val, scores = forward(x, y, args, device, model)
                loss = loss / loss_weight
                scaler.scale(loss).backward()

                # ELECTRA MLM
                x, y = x.to(device), y.to(device)
                x, y = sample_mlm_pred(model.module, x, y, scores,
                                       args.mlm_samples, args)
                # Try to free up scores
                del scores

                loss_val_e = 0
                for i in range(args.mlm_samples):
                    loss, loss_val_item, _ = forward(x[i], y, args, device,
                                                     model)
                    loss = loss / loss_weight * args.lambda_weight
                    loss_val_e += loss_val_item / x.size(0)
                    scaler.scale(loss).backward()

                # Optimizer step
                if (step + 1) % args.gradient_accumulation == 0:
                    scaler.unscale_(optimizer)
                    nn.utils.clip_grad_norm_(model.parameters(),
                                             args.max_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()
                    scheduler.step()
                    optimizer.zero_grad()

                # Log info
                step += 1
                sample_num += batch_size
                progress_bar.update(batch_size)
                progress_bar.set_postfix(epoch=epoch,
                                         NLL=loss_val,
                                         NLL_E=loss_val_e)
                tbx.add_scalar("train/NLL", loss_val, sample_num)
                tbx.add_scalar("train/NLL_E", loss_val_e, sample_num)
                tbx.add_scalar("train/LR", optimizer.param_groups[0]["lr"],
                               sample_num)
                tbx.add_scalar("train/steps",
                               step // args.gradient_accumulation, sample_num)

                samples_till_eval -= batch_size
                if samples_till_eval <= 0:
                    samples_till_eval = args.eval_per_n_samples

                    # Evaluate and save checkpoint
                    log.info(f"Evaluating at sample step {sample_num}...")
                    results, preds, preds_e = evaluate(model, dev_loader,
                                                       device, args)
                    trainer.save_best(sample_num, results[args.metric_name])

                    # Log to console
                    results_str = ", ".join(f"{k}: {v:05.2f}"
                                            for k, v in results.items())
                    log.info(f"Dev {results_str}")

                    # Log to TensorBoard
                    log.info("Visualizing in TensorBoard...")
                    for k, v in results.items():
                        tbx.add_scalar(f"dev/{k}", v, sample_num)
                    visualize(
                        tbx,
                        preds=preds,
                        bpe=bpe,
                        sample_num=sample_num,
                        split="dev",
                        num_visuals=args.num_visuals,
                    )
                    visualize(
                        tbx,
                        preds=preds_e,
                        bpe=bpe,
                        sample_num=sample_num,
                        split="dev_e",
                        num_visuals=args.num_visuals,
                    )
Beispiel #6
0
def test(args):
    trainer = base_trainer.Trainer(is_train=False)
    args, device = get_args(args)
    args, log, tbx = trainer.setup(args)

    # Get BPE
    log.info("Loading BPE...")
    bpe = get_bpe(args)
    log.info("Loaded {} BPE tokens".format(len(bpe)))

    # Get data loader
    log.info("Building dataset...")
    record_file = vars(args)[f"{args.split}_record_file"]
    dataset, data_loader = get_dataset(args,
                                       record_file,
                                       shuffle=False,
                                       randomize=False)

    # Get model
    log.info("Building model...")
    model = get_model(args, bpe)
    model = trainer.setup_model(model, device)
    model.eval()

    trainer.setup_close()

    # Evaluate
    log.info(f"Evaluating on {args.split} split...")
    nll_meter = stats.AverageMeter()
    pred_dict = {}  # Predictions for TensorBoard
    sub_dict = {}  # Predictions for submission
    eval_file = vars(args)[f"{args.split}_eval_file"]
    with open(eval_file, "r") as fh:
        gold_dict = json_load(fh)
    with torch.no_grad(), tqdm(total=len(dataset)) as progress_bar:
        for x, y, c_padding_mask, c_starts, ids in data_loader:
            batch_size = x.size(0)
            _, loss_val, scores = forward(x, y, c_padding_mask, args, device,
                                          model)
            nll_meter.update(loss_val, batch_size)

            # Get F1 and EM scores
            p1, p2 = model.module.get_prob(scores).split(1, dim=-1)
            p1, p2 = p1.squeeze(-1), p2.squeeze(-1)
            starts, ends = util.discretize(p1, p2, args.max_ans_len,
                                           args.use_squad_v2)

            # Log info
            progress_bar.update(batch_size)
            if args.split != "test":
                # No labels for the test set, so NLL would be invalid
                progress_bar.set_postfix(NLL=nll_meter.avg)

            idx2pred, uuid2pred = util.convert_tokens(
                gold_dict,
                ids.tolist(),
                starts.tolist(),
                ends.tolist(),
                args.use_squad_v2,
                c_starts.tolist(),
            )
            pred_dict.update(idx2pred)
            sub_dict.update(uuid2pred)

    # Log results (except for test set, since it does not come with labels)
    if args.split != "test":

        results = {"NLL": nll_meter.avg}
        results.update(eval.eval_dicts(gold_dict, pred_dict,
                                       args.use_squad_v2))

        # Log to console
        results_str = ", ".join(f"{k}: {v:05.2f}" for k, v in results.items())
        log.info(f"{args.split.title()} {results_str}")

        # Log to TensorBoard
        tbx = SummaryWriter(args.save_dir)
        util.visualize(
            tbx,
            pred_dict=pred_dict,
            eval_path=eval_file,
            step=0,
            split=args.split,
            num_visuals=args.num_visuals,
        )

    # Write submission file
    if args.split == "dev":
        sub_path = join(args.save_dir, "val" + "_" + args.sub_file)
    else:
        sub_path = join(args.save_dir, args.split + "_" + args.sub_file)
    log.info(f"Writing submission file to {sub_path}...")
    eval.write_submission(sub_path, sub_dict)