예제 #1
0
def construct_datasets(args) -> Dict[str, tx.data.RecordData]:
    cache_prefix = f"length{args.max_seq_len}"

    tokenizer = tx.data.XLNetTokenizer(
        pretrained_model_name=args.pretrained_model_name)
    tokenizer.do_lower_case = args.uncased

    processor_class = get_processor_class(args.task)
    data_dir = args.data_dir or f"data/{processor_class.task_name}"
    cache_dir = args.cache_dir or f"processed_data/{processor_class.task_name}"
    task_processor = processor_class(data_dir)
    dataset.construct_dataset(task_processor,
                              cache_dir,
                              args.max_seq_len,
                              tokenizer,
                              file_prefix=cache_prefix)

    datasets = dataset.load_datasets(args.task,
                                     cache_dir,
                                     args.max_seq_len,
                                     args.batch_size,
                                     file_prefix=cache_prefix,
                                     eval_batch_size=args.eval_batch_size,
                                     shuffle_buffer=None)
    return datasets
예제 #2
0
def construct_datasets(args, device: Optional[torch.device] = None) \
        -> Dict[str, tx.data.RecordData]:
    sp_model = spm.SentencePieceProcessor()

    pretrained_model_dir = tx.modules.XLNetEncoder.download_checkpoint(
        pretrained_model_name=args.pretrained_model_name)

    spm_model_path = os.path.join(pretrained_model_dir, "spiece.model")
    sp_model.Load(spm_model_path)

    cache_prefix = f"length{args.max_seq_len}"
    tokenize_fn = data_utils.create_tokenize_fn(sp_model, args.uncased)
    processor_class = get_processor_class(args.task)
    data_dir = args.data_dir or f"data/{processor_class.task_name}"
    cache_dir = args.cache_dir or f"processed_data/{processor_class.task_name}"
    task_processor = processor_class(data_dir)
    dataset.construct_dataset(task_processor,
                              cache_dir,
                              args.max_seq_len,
                              tokenize_fn,
                              file_prefix=cache_prefix)

    datasets = dataset.load_datasets(args.task,
                                     cache_dir,
                                     args.max_seq_len,
                                     args.batch_size,
                                     file_prefix=cache_prefix,
                                     eval_batch_size=args.eval_batch_size,
                                     shuffle_buffer=None,
                                     device=device)
    return datasets
예제 #3
0
def load_datasets(task: str, input_dir: str, seq_length: int, batch_size: int,
                  drop_remainder: bool = False,
                  file_prefix: Optional[str] = None,
                  eval_batch_size: Optional[int] = None,
                  shuffle_buffer: Optional[int] = None,
                  device: Optional[torch.device] = None) \
        -> Dict[str, tx.data.RecordData]:
    r"""Creates an `input_fn` closure to be passed to TPUEstimator."""
    processor_class = get_processor_class(task)
    file_prefix = '' if file_prefix is None else file_prefix + '.'
    eval_batch_size = eval_batch_size or batch_size

    feature_types = get_record_feature_types(seq_length,
                                             processor_class.is_regression)

    logging.info("Loading records with prefix \"%s\" from %s", file_prefix,
                 input_dir)

    datasets = {}
    for split in ['train', 'dev', 'test']:
        is_training = (split == 'train')
        input_file = os.path.join(input_dir, f"{file_prefix}{split}.pkl")
        if not os.path.exists(input_file):
            logging.warning("%s set does not exist for task %s",
                            split.capitalize(), processor_class.task_name)
            continue
        datasets[split] = tx.data.RecordData(
            hparams={
                "dataset": {
                    "files": [input_file],
                    "feature_original_types": feature_types,
                },
                "batch_size": (batch_size if is_training else eval_batch_size),
                "allow_smaller_final_batch": not drop_remainder,
                # "shuffle": is_training,
                "shuffle": True,
                "shuffle_buffer_size": shuffle_buffer,
            }).to(device)

    return datasets
예제 #4
0
def main(args) -> None:
    if args.seed != -1:
        make_deterministic(args.seed)
        print(f"Random seed set to {args.seed}")

    datasets = construct_datasets(args)
    print("Dataset constructed")

    processor_class = get_processor_class(args.task)
    is_regression = processor_class.is_regression
    model: Union[RegressorWrapper, ClassifierWrapper]
    if is_regression:
        model = RegressorWrapper(
            pretrained_model_name=args.pretrained_model_name)
    else:
        model = ClassifierWrapper(
            pretrained_model_name=args.pretrained_model_name,
            hparams={"num_classes": len(processor_class.labels)})
    print("Model constructed")

    optim = torch.optim.Adam(model.param_groups(args.lr,
                                                args.lr_layer_decay_rate),
                             lr=args.lr,
                             eps=args.adam_eps,
                             weight_decay=args.weight_decay)
    lambda_lr = model_utils.warmup_lr_lambda(args.train_steps,
                                             args.warmup_steps,
                                             args.min_lr_ratio)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lambda_lr)

    bps = args.backwards_per_step

    def get_condition(steps: int) -> Optional[cond.Condition]:
        if steps == -1:
            return None
        return cond.iteration(steps * bps)

    if is_regression:
        valid_metric: metric.Metric = metric.PearsonR(pred_name="preds",
                                                      label_name="label_ids")
    else:
        valid_metric = metric.Accuracy(pred_name="preds",
                                       label_name="label_ids")
    executor = Executor(
        model=model,
        train_data=datasets["train"],
        valid_data=datasets["dev"],
        test_data=datasets.get("test", None),
        checkpoint_dir=args.save_dir or f"saved_models/{args.task}",
        save_every=get_condition(args.save_steps),
        max_to_keep=1,
        train_metrics=[("loss",
                        metric.RunningAverage(args.display_steps * bps)),
                       metric.LR(optim)],
        optimizer=optim,
        lr_scheduler=scheduler,
        grad_clip=args.grad_clip,
        num_iters_per_update=args.backwards_per_step,
        log_every=cond.iteration(args.display_steps * bps),
        validate_every=get_condition(args.eval_steps),
        valid_metrics=[valid_metric, ("loss", metric.Average())],
        stop_training_on=cond.iteration(args.train_steps * bps),
        log_format="{time} : Epoch {epoch} @ {iteration:5d}it "
        "({speed}), LR = {LR:.3e}, loss = {loss:.3f}",
        test_mode='eval',
        show_live_progress=True,
    )

    if args.checkpoint is not None:
        executor.load(args.checkpoint)

    if args.mode == 'train':
        executor.train()
        executor.save()
        executor.test(tx.utils.dict_fetch(datasets, ["dev", "test"]))
    else:
        if args.checkpoint is None:
            executor.load(
                load_training_state=False)  # load previous best model
        executor.test(tx.utils.dict_fetch(datasets, ["dev", "test"]))
예제 #5
0
def main(args):
    if args.seed != -1:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(args.seed)
        print(f"Random seed set to {args.seed}")

    if torch.cuda.is_available():
        device = torch.cuda.current_device()
        print(f"Using CUDA device {device}")
    else:
        device = 'cpu'
        print("Using CPU")
    device = torch.device(device)

    datasets = construct_datasets(args, device)
    iterator = tx.data.DataIterator(datasets)
    print("Dataset constructed")

    processor_class = get_processor_class(args.task)
    is_regression = processor_class.is_regression
    if is_regression:
        model = tx.modules.XLNetRegressor(
            pretrained_model_name=args.pretrained_model_name)
    else:
        model = tx.modules.XLNetClassifier(
            pretrained_model_name=args.pretrained_model_name,
            hparams={"num_classes": len(processor_class.labels)})
    print("Weights initialized")

    if args.checkpoint is not None:
        model.load_state_dict(torch.load(args.checkpoint, map_location=device))
        print(f"Loaded checkpoint from {args.checkpoint}")

    model = model.to(device)
    print("Model constructed")

    def eval_all_splits():
        model.eval()
        for split in datasets:
            if split != 'train':
                print(f"Evaluating on {split}")
                evaluate(model, iterator.get_iterator(split), is_regression)

    def save_model(step: int, model):
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)
        save_name = (f"{args.task}_step{step}_"
                     f"{time.strftime('%Y%m%d_%H%M%S')}")
        save_path = os.path.join(args.save_dir, save_name)
        torch.save(model.state_dict(), save_path)
        progress.write(f"Model at {step} steps saved to {save_path}")

    if args.mode == 'eval':
        eval_all_splits()
        return

    optim = torch.optim.Adam(model.param_groups(args.lr,
                                                args.lr_layer_decay_rate),
                             lr=args.lr,
                             eps=args.adam_eps,
                             weight_decay=args.weight_decay)
    lambda_lr = model_utils.warmup_lr_lambda(args.train_steps,
                                             args.warmup_steps,
                                             args.min_lr_ratio)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lambda_lr)

    avg_loss = tx.utils.AverageRecorder()
    train_steps = 0
    grad_steps = 0
    total_batch_size = args.batch_size * args.backwards_per_step
    progress = tqdm.tqdm(
        data_utils.repeat(lambda: iterator.get_iterator('train')), ncols=80)
    for batch in progress:
        model.train()
        labels = batch.label_ids
        if is_regression:
            preds = model(token_ids=batch.input_ids,
                          segment_ids=batch.segment_ids,
                          input_mask=batch.input_mask)
            loss = (preds - labels.view(-1))**2
        else:
            logits, _ = model(token_ids=batch.input_ids,
                              segment_ids=batch.segment_ids,
                              input_mask=batch.input_mask)
            loss = F.cross_entropy(logits, labels.view(-1), reduction='none')
        loss = loss.sum() / total_batch_size
        avg_loss.add(loss.item() * args.backwards_per_step)
        loss.backward()
        grad_steps += 1
        if grad_steps == args.backwards_per_step:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            optim.step()
            optim.zero_grad()
            train_steps += 1
            grad_steps = 0

            if train_steps % args.display_steps == 0:
                progress.write(f"Step: {train_steps}, "
                               f"LR = {optim.param_groups[0]['lr']:.3e}, "
                               f"loss = {avg_loss.avg():.4f}")
                avg_loss.reset()

            scheduler.step()

            if train_steps >= args.train_steps:
                # Break before save & eval since we're doing them anyway.
                break

            if args.save_steps != -1 and train_steps % args.save_steps == 0:
                save_model(train_steps, model)

            if args.eval_steps != -1 and train_steps % args.eval_steps == 0:
                model.eval()
                evaluate(model,
                         iterator.get_iterator('dev'),
                         is_regression,
                         print_fn=progress.write,
                         tqdm_kwargs={"leave": False})
    progress.close()

    save_model(args.train_steps, model)
    eval_all_splits()