Esempio n. 1
0
def train():

    df = pd.read_csv(TRAIN_PATH).fillna('None')

    train, valid = model_selection.train_test_split(
        df, test_size=0.15, random_state=42, stratify=df['Class Index'].values)

    train = train.reset_index(drop=True)
    valid = valid.reset_index(drop=True)

    train = ohe(train, 'Class Index')
    valid = ohe(valid, 'Class Index')

    train_labels = train[train.columns[-4:]].values
    valid_labels = valid[valid.columns[-4:]].values

    train_data = prepare_dataset(text=train['Description'].values,
                                 label=train_labels)

    valid_data = prepare_dataset(text=valid['Description'].values,
                                 label=valid_labels)

    train_sampler = torch.utils.data.DistributedSampler(
        train_data,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)

    valid_sampler = torch.utils.data.DistributedSampler(
        valid_data,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False)

    train_dataloader = torch.utils.data.DataLoader(train_data,
                                                   batch_size=BATCH_SIZE,
                                                   num_workers=4,
                                                   sampler=train_sampler,
                                                   drop_last=True)
    valid_dataloader = torch.utils.data.DataLoader(valid_data,
                                                   batch_size=V_BATCH_SIZE,
                                                   num_workers=4,
                                                   sampler=valid_sampler,
                                                   drop_last=True)

    #     device= torch.device('cuda')

    device = xm.xla_device()

    model = BertBaseUncased()
    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {
            "params": [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.001,
        },
        {
            "params":
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            "weight_decay":
            0.0,
        },
    ]

    num_train_steps = int(
        len(train_data) / BATCH_SIZE / xm.xrt_world_size() * EPOCHS)
    xm.master_print(
        f'num_train_steps = {num_train_steps}, world_size={xm.xrt_world_size()}'
    )

    lr = 3e-4 * xm.xrt_world_size()

    optimizer = AdamW(optimizer_parameters, lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=0, num_training_steps=num_train_steps)

    best_acc = 0

    for epoch in range(EPOCHS):

        para_loader = pl.ParallelLoader(train_dataloader, [device])

        train_loss = train_loop(para_loader.per_device_loader(device),
                                model=model,
                                optimizer=optimizer,
                                scheduler=scheduler,
                                device=device)

        para_loader = pl.ParallelLoader(valid_dataloader, [device])

        val_acc, val_loss = eval_loop(para_loader.per_device_loader(device),
                                      model, device)

        #         print(f"EPOCH: {epoch} train_loss: {train_loss} val_loss: {val_loss} val_acc: {val_acc}")

        if val_acc > best_acc:
            torch.save(
                {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, 'best_model.bin')

            best_acc = val_acc

        xm.master_print(
            f'Epoch: {epoch+1} train_loss: {train_loss} val_loss: {val_loss} Accracy: {val_acc}'
        )
Esempio n. 2
0
    def fit(
        self,
        train_dataset,
        valid_dataset=None,
        train_sampler=None,
        valid_sampler=None,
        device="cuda",
        epochs=10,
        train_bs=16,
        valid_bs=16,
        n_jobs=8,
        callbacks=None,
        fp16=False,
        train_collate_fn=None,
        valid_collate_fn=None,
        train_shuffle=True,
        valid_shuffle=False,
        accumulation_steps=1,
        clip_grad_norm=None,
    ):
        """
        The model fit function. Heavily inspired by tf/keras, this function is the core of Tez and this is the only
        function you need to train your models.

        """
        if device == "tpu":
            if XLA_AVAILABLE is False:
                raise RuntimeError(
                    "XLA is not available. Please install pytorch_xla")
            else:
                self.using_tpu = True
                fp16 = False
                device = xm.xla_device()
        self._init_model(
            device=device,
            train_dataset=train_dataset,
            valid_dataset=valid_dataset,
            train_sampler=train_sampler,
            valid_sampler=valid_sampler,
            train_bs=train_bs,
            valid_bs=valid_bs,
            n_jobs=n_jobs,
            callbacks=callbacks,
            fp16=fp16,
            train_collate_fn=train_collate_fn,
            valid_collate_fn=valid_collate_fn,
            train_shuffle=train_shuffle,
            valid_shuffle=valid_shuffle,
            accumulation_steps=accumulation_steps,
            clip_grad_norm=clip_grad_norm,
        )

        for _ in range(epochs):
            self.train_state = enums.TrainingState.EPOCH_START
            self.train_state = enums.TrainingState.TRAIN_EPOCH_START
            train_loss = self.train_one_epoch(self.train_loader)
            self.train_state = enums.TrainingState.TRAIN_EPOCH_END
            if self.valid_loader:
                self.train_state = enums.TrainingState.VALID_EPOCH_START
                valid_loss = self.validate_one_epoch(self.valid_loader)
                self.train_state = enums.TrainingState.VALID_EPOCH_END
            if self.scheduler:
                if self.step_scheduler_after == "epoch":
                    if self.step_scheduler_metric is None:
                        self.scheduler.step()
                    else:
                        step_metric = self.name_to_metric(
                            self.step_scheduler_metric)
                        self.scheduler.step(step_metric)
            self.train_state = enums.TrainingState.EPOCH_END
            if self._model_state.value == "end":
                break
            self.current_epoch += 1
        self.train_state = enums.TrainingState.TRAIN_END
Esempio n. 3
0
def train_loop(folds, fold):

    if CFG.device == 'GPU':
        LOGGER.info(f"========== fold: {fold} training ==========")
    elif CFG.device == 'TPU':
        if CFG.nprocs == 1:
            LOGGER.info(f"========== fold: {fold} training ==========")
        elif CFG.nprocs == 8:
            xm.master_print(f"========== fold: {fold} training ==========")
            
    # ====================================================
    # loader
    # ====================================================
    trn_idx = folds[folds['fold'] != fold].index
    val_idx = folds[folds['fold'] == fold].index

    train_folds = folds.loc[trn_idx].reset_index(drop=True)
    valid_folds = folds.loc[val_idx].reset_index(drop=True)
    
    train_folds = train_folds[train_folds['StudyInstanceUID'].isin(train_annotations['StudyInstanceUID'].unique())].reset_index(drop=True)
    
    valid_labels = valid_folds[CFG.target_cols].values
    
    train_dataset = TrainDataset(train_folds, train_annotations, use_annot=True,
                                 transform=get_transforms(data='train'))
    valid_dataset = TrainDataset(valid_folds, train_annotations, use_annot=False,
                                 transform=get_transforms(data='valid'))
    
    if CFG.device == 'GPU':
        train_loader = DataLoader(train_dataset, 
                                  batch_size=CFG.batch_size, 
                                  shuffle=True, 
                                  num_workers=CFG.num_workers, pin_memory=True, drop_last=True)
        valid_loader = DataLoader(valid_dataset, 
                                  batch_size=CFG.batch_size * 2, 
                                  shuffle=False, 
                                  num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
        
    elif CFG.device == 'TPU':
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,
                                                                        num_replicas=xm.xrt_world_size(),
                                                                        rank=xm.get_ordinal(),
                                                                        shuffle=True)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=CFG.batch_size,
                                                   sampler=train_sampler,
                                                   drop_last=True,
                                                   num_workers=CFG.num_workers)
        
        valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset,
                                                                        num_replicas=xm.xrt_world_size(),
                                                                        rank=xm.get_ordinal(),
                                                                        shuffle=False)
        valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                                   batch_size=CFG.batch_size * 2,
                                                   sampler=valid_sampler,
                                                   drop_last=False,
                                                   num_workers=CFG.num_workers)
        
    # ====================================================
    # scheduler 
    # ====================================================
    def get_scheduler(optimizer):
        if CFG.scheduler=='ReduceLROnPlateau':
            scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
        elif CFG.scheduler=='CosineAnnealingLR':
            scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
        elif CFG.scheduler=='CosineAnnealingWarmRestarts':
            scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1)
        return scheduler
    
    # ====================================================
    # model & optimizer
    # ====================================================
    if CFG.device == 'TPU':
        device = xm.xla_device()
    elif CFG.device == 'GPU':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
    teacher_model = CustomSeResNet152D(CFG.model_name, pretrained=False)
    teacher_model.to(device)
    state = torch.load(CFG.teacher)
    teacher_model.load_state_dict(state['model'])
    for param in teacher_model.parameters():
        param.requires_grad = False
    teacher_model.eval()
#     teacher_model.to(device)
    
    model = CustomSeResNet152D_WLF(CFG.model_name, pretrained=True)
    model.to(device)
#     state = torch.load(CFG.student)
#     model.load_state_dict(state['model'])

    optimizer = Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay, amsgrad=False)
    scheduler = get_scheduler(optimizer)
    
    # ====================================================
    # loop
    # ====================================================
    train_criterion = CustomLoss(weights=CFG.weights)
    valid_criterion = nn.BCEWithLogitsLoss()

    best_score = 0.
    best_loss = np.inf
    
    for epoch in range(CFG.epochs):
        
        start_time = time.time()
        
        # train
        if CFG.device == 'TPU':
            if CFG.nprocs == 1:
                avg_loss = train_fn(train_loader, teacher_model, model, train_criterion, optimizer, epoch, scheduler, device)
            elif CFG.nprocs == 8:
                para_train_loader = pl.ParallelLoader(train_loader, [device])
                avg_loss = train_fn(para_train_loader.per_device_loader(device), teacher_model, model, train_criterion, optimizer, epoch, scheduler, device)
        elif CFG.device == 'GPU':
            avg_loss = train_fn(train_loader, teacher_model, model, train_criterion, optimizer, epoch, scheduler, device)
        
        # eval
        if CFG.device == 'TPU':
            if CFG.nprocs == 1:
                avg_val_loss, preds, _ = valid_fn(valid_loader, model, valid_criterion, device)
            elif CFG.nprocs == 8:
                para_valid_loader = pl.ParallelLoader(valid_loader, [device])
                avg_val_loss, preds, valid_labels = valid_fn(para_valid_loader.per_device_loader(device), model, valid_criterion, device)
                preds = idist.all_gather(torch.tensor(preds)).to('cpu').numpy()
                valid_labels = idist.all_gather(torch.tensor(valid_labels)).to('cpu').numpy()
        elif CFG.device == 'GPU':
            avg_val_loss, preds, _ = valid_fn(valid_loader, model, valid_criterion, device)
            
        if isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(avg_val_loss)
        elif isinstance(scheduler, CosineAnnealingLR):
            scheduler.step()
        elif isinstance(scheduler, CosineAnnealingWarmRestarts):
            scheduler.step()
            
        # scoring
        score, scores = get_score(valid_labels, preds)

        elapsed = time.time() - start_time
        
        if CFG.device == 'GPU':
            LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
            LOGGER.info(f'Epoch {epoch+1} - Score: {score:.4f}  Scores: {np.round(scores, decimals=4)}')
        elif CFG.device == 'TPU':
            if CFG.nprocs == 1:
                LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
                LOGGER.info(f'Epoch {epoch+1} - Score: {score:.4f}  Scores: {np.round(scores, decimals=4)}')
            elif CFG.nprocs == 8:
                xm.master_print(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
                xm.master_print(f'Epoch {epoch+1} - Score: {score:.4f}  Scores: {np.round(scores, decimals=4)}')
                
        if score > best_score:
            best_score = score
            if CFG.device == 'GPU':
                LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
                torch.save({'model': model.state_dict(), 
                            'preds': preds},
                           OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_score.pth')
            elif CFG.device == 'TPU':
                if CFG.nprocs == 1:
                    LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
                elif CFG.nprocs == 8:
                    xm.master_print(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
                xm.save({'model': model, 
                         'preds': preds}, 
                        OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_score.pth')
                
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            if CFG.device == 'GPU':
                LOGGER.info(f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model')
                torch.save({'model': model.state_dict(), 
                            'preds': preds},
                           OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_loss.pth')
            elif CFG.device == 'TPU':
                if CFG.nprocs == 1:
                    LOGGER.info(f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model')
                elif CFG.nprocs == 8:
                    xm.master_print(f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model')
                xm.save({'model': model, 
                         'preds': preds}, 
                        OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_loss.pth')
                
#         # inference用に全て保存しておく
#         if CFG.device == 'TPU':
#             xm.save({'model': model.state_dict()}, OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_epoch{epoch+1}.pth')
#         elif CFG.device == 'GPU':
#             torch.save({'model': model.state_dict()}, OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_epoch{epoch+1}.pth')
                
        if CFG.nprocs != 8:
            check_point = torch.load(OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_score.pth')
            for c in [f'pred_{c}' for c in CFG.target_cols]:
                valid_folds[c] = np.nan
            valid_folds[[f'pred_{c}' for c in CFG.target_cols]] = check_point['preds']

    return valid_folds
Esempio n. 4
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--model_type",
                        default=None,
                        type=str,
                        required=True,
                        help="Model type selected in the list: " +
                        ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(ALL_MODELS))
    parser.add_argument(
        "--task_name",
        default=None,
        type=str,
        required=True,
        help="The name of the task to train selected in the list: " +
        ", ".join(processors.keys()))
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_test",
                        action='store_true',
                        help="Whether to run eval on the test set.")

    parser.add_argument("--train_file",
                        default="train.tsv",
                        type=str,
                        help="Training file.")
    parser.add_argument("--dev_file",
                        default="dev.tsv",
                        type=str,
                        help="Validation file.")
    parser.add_argument("--test_file",
                        default="test.tsv",
                        type=str,
                        help="Test file.")
    parser.add_argument("--results_file",
                        default="eval_results.txt",
                        type=str,
                        help="File name to write results.")

    parser.add_argument(
        "--evaluate_during_training",
        action='store_true',
        help="Rul evaluation during training at each logging step.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--per_gpu_train_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help=
        "If > 0: set total number of training steps to perform. Override num_train_epochs."
    )
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")

    parser.add_argument('--logging_steps',
                        type=int,
                        default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps',
                        type=int,
                        default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--eval_all_checkpoints",
        action='store_true',
        help=
        "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number"
    )
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--overwrite_output_dir',
                        action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument(
        '--overwrite_cache',
        action='store_true',
        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument(
        '--tpu',
        action='store_true',
        help="Whether to run on the TPU defined in the environment variables")
    parser.add_argument(
        '--tpu_ip_address',
        type=str,
        default='',
        help="TPU IP address if none are set in the environment variables")
    parser.add_argument(
        '--tpu_name',
        type=str,
        default='',
        help="TPU name if none are set in the environment variables")
    parser.add_argument(
        '--xrt_tpu_config',
        type=str,
        default='',
        help="XRT TPU config if none are set in the environment variables")

    parser.add_argument(
        '--fp16',
        action='store_true',
        help=
        "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
    )
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="For distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="For distant debugging.")
    args = parser.parse_args()

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir
    ) and args.do_train and not args.overwrite_output_dir:
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
            .format(args.output_dir))

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device

    if args.tpu:
        if args.tpu_ip_address:
            os.environ["TPU_IP_ADDRESS"] = args.tpu_ip_address
        if args.tpu_name:
            os.environ["TPU_NAME"] = args.tpu_name
        if args.xrt_tpu_config:
            os.environ["XRT_TPU_CONFIG"] = args.xrt_tpu_config

        assert "TPU_IP_ADDRESS" in os.environ
        assert "TPU_NAME" in os.environ
        assert "XRT_TPU_CONFIG" in os.environ

        import torch_xla
        import torch_xla.core.xla_model as xm
        args.device = xm.xla_device()
        args.xla_model = xm

    # Setup logging
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank, device, args.n_gpu, bool(args.local_rank != -1),
        args.fp16)

    # Set seed
    set_seed(args)

    # Prepare GLUE task
    args.task_name = args.task_name.lower()
    #print(processors)
    #print(args.task_name )
    if args.task_name not in processors:
        raise ValueError("Task not found: %s" % (args.task_name))
    processor = processors[args.task_name]()

    processor.set_train_file(args.train_file)
    processor.set_dev_file(args.dev_file)
    processor.set_dev_file(args.test_file)

    args.output_mode = output_modes[args.task_name]
    binary_label_list = processor.get_binary_labels()
    multi_label_list = processor.get_multi_labels()
    num_binary_labels = len(binary_label_list)
    num_multi_labels = len(multi_label_list)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab

    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        num_binary_labels=num_binary_labels,
        num_multi_labels=num_multi_labels,
        finetuning_task=args.task_name,
        cache_dir=args.cache_dir if args.cache_dir else None)

    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name
        if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None)
    model = model_class.from_pretrained(
        args.model_name_or_path,
        from_tf=bool('.ckpt' in args.model_name_or_path),
        config=config,
        cache_dir=args.cache_dir if args.cache_dir else None)

    if args.local_rank == 0:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab

    model.to(args.device)

    logger.info("Training/evaluation parameters %s", args)

    # Training
    if args.do_train:
        train_dataset = load_and_cache_examples(args,
                                                args.task_name,
                                                tokenizer,
                                                evaluate=False)
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step,
                    tr_loss)

    # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
    if args.do_train and (args.local_rank == -1 or
                          torch.distributed.get_rank() == 0) and not args.tpu:
        # Create output directory if needed
        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        model_to_save = model.module if hasattr(
            model,
            'module') else model  # Take care of distributed/parallel training
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))

        # Load a trained model and vocabulary that you have fine-tuned
        model = model_class.from_pretrained(args.output_dir)
        tokenizer = tokenizer_class.from_pretrained(args.output_dir)
        model.to(args.device)

    # Evaluation
    results = {}
    if (args.do_eval or args.do_test) and args.local_rank in [-1, 0]:
        tokenizer = tokenizer_class.from_pretrained(
            args.output_dir, do_lower_case=args.do_lower_case)
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(
                    glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME,
                              recursive=True)))
            logging.getLogger("transformers.modeling_utils").setLevel(
                logging.WARN)  # Reduce logging
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split(
                '-')[-1] if len(checkpoints) > 1 else ""
            prefix = checkpoint.split(
                '/')[-1] if checkpoint.find('checkpoint') != -1 else ""

            model = model_class.from_pretrained(checkpoint)
            model.to(args.device)
            result = evaluate(args, model, tokenizer, prefix=prefix)
            result = dict(
                (k + '_{}'.format(global_step), v) for k, v in result.items())
            results.update(result)

    return results
def train_mnist(flags, training_started=None, dynamic_graph=False, fetch_often=False):
    torch.manual_seed(1)

    if flags.fake_data:
        train_loader = xu.SampleGenerator(
            data=(
                torch.zeros(flags.batch_size, 1, 28, 28),
                torch.zeros(flags.batch_size, dtype=torch.int64),
            ),
            sample_count=600000 // flags.batch_size // xm.xrt_world_size(),
        )
        test_loader = xu.SampleGenerator(
            data=(
                torch.zeros(flags.batch_size, 1, 28, 28),
                torch.zeros(flags.batch_size, dtype=torch.int64),
            ),
            sample_count=100000 // flags.batch_size // xm.xrt_world_size(),
        )
    else:
        train_dataset = datasets.MNIST(
            os.path.join(flags.datadir, str(xm.get_ordinal())),
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        )
        test_dataset = datasets.MNIST(
            os.path.join(flags.datadir, str(xm.get_ordinal())),
            train=False,
            download=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        )
        train_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True
            )
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=flags.batch_size,
            sampler=train_sampler,
            drop_last=flags.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=flags.num_workers,
        )
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=flags.batch_size,
            drop_last=flags.drop_last,
            shuffle=False,
            num_workers=flags.num_workers,
        )

    # Scale learning rate to num cores
    lr = flags.lr * xm.xrt_world_size()

    device = xm.xla_device()
    model = MNIST().to(device)
    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(flags.logdir)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum)
    loss_fn = nn.NLLLoss()

    server = xp.start_server(flags.profiler_port)

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            if dynamic_graph:
                # testing purpose only: dynamic batch size and graph.
                index = max(-step, -flags.batch_size + 1)  # non-empty
                data, target = data[:-index, :, :, :], target[:-index]
            if step >= 15 and training_started:
                # testing purpose only: set event for synchronization.
                training_started.set()

            with xp.StepTrace("train_mnist", step_num=step):
                with xp.Trace("build_graph"):
                    optimizer.zero_grad()
                    output = model(data)
                    loss = loss_fn(output, target)
                    loss.backward()
                xm.optimizer_step(optimizer)
                if fetch_often:
                    # testing purpose only: fetch XLA tensors to CPU.
                    loss_i = loss.item()
                tracker.add(flags.batch_size)
                if step % flags.log_steps == 0:
                    xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer))

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            with xp.StepTrace("test_mnist"):
                output = model(data)
                pred = output.max(1, keepdim=True)[1]
                correct += pred.eq(target.view_as(pred)).sum()
                total_samples += data.size()[0]

        accuracy = 100.0 * correct.item() / total_samples
        accuracy = xm.mesh_reduce("test_accuracy", accuracy, np.mean)
        return accuracy

    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    test_device_loader = pl.MpDeviceLoader(test_loader, device)
    accuracy, max_accuracy = 0.0, 0.0
    for epoch in range(1, flags.num_epochs + 1):
        xm.master_print("Epoch {} train begin {}".format(epoch, test_utils.now()))
        train_loop_fn(train_device_loader)
        xm.master_print("Epoch {} train end {}".format(epoch, test_utils.now()))

        accuracy = test_loop_fn(test_device_loader)
        xm.master_print(
            "Epoch {} test end {}, Accuracy={:.2f}".format(epoch, test_utils.now(), accuracy)
        )
        max_accuracy = max(accuracy, max_accuracy)
        test_utils.write_to_summary(
            writer, epoch, dict_to_write={"Accuracy/test": accuracy}, write_xla_metrics=True
        )
        if flags.metrics_debug:
            xm.master_print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    xm.master_print("Max Accuracy: {:.2f}%".format(max_accuracy))
    return max_accuracy
    def run_training_epoch(self):

        # get model
        model = self.get_model()

        # Epoch start events
        with self.profiler.profile('on_epoch_start'):
            # callbacks
            self.on_epoch_start()

            # model hooks
            if self.is_function_implemented('on_epoch_start'):
                model.on_epoch_start()

        # track local dataloader so TPU can wrap each epoch
        train_dataloader = self.train_dataloader

        # on TPU we have to wrap it under the ParallelLoader
        if self.use_tpu:
            device = xm.xla_device()
            train_dataloader = xla_pl.ParallelLoader(train_dataloader,
                                                     [device])
            train_dataloader = train_dataloader.per_device_loader(device)

        # bookkeeping
        outputs = []

        # run epoch
        for batch_idx, (batch,
                        is_last_batch) in self.profiler.profile_iterable(
                            enumerate(_with_is_last(train_dataloader)),
                            "get_train_batch"):
            # stop epoch if we limited the number of training batches
            if batch_idx >= self.num_training_batches:
                break

            self.batch_idx = batch_idx

            model.global_step = self.global_step

            # ---------------
            # RUN TRAIN STEP
            # ---------------
            _outputs = self.run_training_batch(batch, batch_idx)
            batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs
            # detach tensors in batch_output before appending to outputs
            outputs.append(_recursive_detach(batch_output))

            # when returning -1 from train_step, we end epoch early
            early_stop_epoch = batch_result == -1

            # update lr
            self.update_learning_rates(interval='step')

            # ---------------
            # RUN VAL STEP
            # ---------------
            is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
            can_check_epoch = (self.current_epoch +
                               1) % self.check_val_every_n_epoch == 0
            can_check_val = not self.disable_validation and can_check_epoch
            should_check_val = is_val_check_batch or early_stop_epoch
            should_check_val = should_check_val or (
                is_last_batch and self.val_check_batch == float('inf'))
            should_check_val = can_check_val and should_check_val

            # fast_dev_run always forces val checking after train batch
            if self.fast_dev_run or should_check_val:
                self.run_evaluation(test_mode=self.testing)

            # when logs should be saved
            should_save_log = (
                batch_idx +
                1) % self.log_save_interval == 0 or early_stop_epoch
            if should_save_log or self.fast_dev_run:
                if self.proc_rank == 0 and self.logger is not None:
                    self.logger.save()

            # when metrics should be logged
            should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
            if should_log_metrics or self.fast_dev_run:
                # logs user requested information to logger
                self.log_metrics(batch_step_metrics, grad_norm_dic)

            # ---------------
            # CHECKPOINTING, EARLY STOPPING
            # ---------------
            # save checkpoint even when no test or val step are defined
            if self.fast_dev_run or should_check_val:
                self.call_checkpoint_callback()

                if self.enable_early_stop:
                    self.early_stop_callback.check_metrics(
                        self.callback_metrics)

            # progress global step according to grads progress
            if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
                self.global_step += 1
            self.total_batch_idx += 1

            # max steps reached, end training
            if self.max_steps is not None and self.max_steps == self.global_step:
                break

            # end epoch early
            # stop when the flag is changed or we've gone past the amount
            # requested in the batches
            if early_stop_epoch or self.fast_dev_run:
                break

        # process epoch outputs
        if isinstance(
                model,
            (LightningDistributedDataParallel, LightningDataParallel)):
            model = model.module

        if self.is_overriden('training_epoch_end', model=model):
            epoch_output = model.training_epoch_end(outputs)
            _processed_outputs = self.process_output(epoch_output)
            log_epoch_metrics = _processed_outputs[2]
            callback_epoch_metrics = _processed_outputs[3]
            self.log_metrics(log_epoch_metrics, {})
            self.callback_metrics.update(callback_epoch_metrics)

        # in case validation step is missing and you are not running fast-dev to duplicate last batch
        if not self.is_overriden('validation_step') and not (
                self.fast_dev_run or should_check_val):
            self.call_checkpoint_callback()

            if self.enable_early_stop:
                self.early_stop_callback.check_metrics(self.callback_metrics)

        # Epoch end events
        with self.profiler.profile('on_epoch_end'):
            # callbacks
            self.on_epoch_end()
            # model hooks
            if self.is_function_implemented('on_epoch_end'):
                model.on_epoch_end()
Esempio n. 7
0
def main():
    parser = argparse.ArgumentParser()

    # Required params
    parser.add_argument("--input_dir", default=None, type=str, required=True,
                        help="Comma seperated list of data dir paths (Use prepare-*.py scripts for preprocessing.)")
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--model", default=None, type=str, required=True,
                        help="Path to pretrained model or shortcut name (bert-base-multilingual-cased)")

    # Optional arguments
    parser.add_argument("--model_file", default='model', type=str, help="Filename of task model")
    parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
    parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
    parser.add_argument('--force', action='store_true', help="Overwrite the content of the output directory")
    parser.add_argument("--tasks", default='all', type=str,
                        help="Comma seperated list of tasks. Names are subdirs of data dirs. Default: all")
    parser.add_argument("--cache_dir", default='cache', type=str, help="Path to cache directory")

    # TPU training
    parser.add_argument("--use_tpu", action='store_true',
                        help="Whether to use a TPU. (Make sure that the environement variables TPU_NAME,\
                              TPU_IP_ADDRESS and XRT_TPU_CONFIG are set)")

    args = parser.parse_args()

    if not args.do_train and not args.do_eval:
        print('Specify --do_train and/or --do_eval')
        exit(-1)

    if args.do_train and not args.force and os.path.exists(args.output_dir):
        print('Output path already exists')
        exit(-1)

    tasks = list_tasks(args.input_dir, args.output_dir, args.tasks)
    if len(tasks) == 0:
        print('No (whitelisted) tasks found')
        exit(-1)

    print('Starting benchmark tasks!')

    print('\n▶ Arguments:')
    for key in vars(args):
        print(' ➤ {:15}: {}'.format(key, getattr(args, key)))

    print('\n▶ Device:')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    xla_model = None

    if args.use_tpu:
        import torch_xla.core.xla_model as xm
        device = xm.xla_device()
        xla_model = xm

    print(' ➤ Using device: {}'.format(device.type))

    print('\n▶ Scheduling to run {} tasks:'.format(len(tasks)))
    for task_in, task_out in tasks:
        print(' ➤ {} [Output: {}]'.format(task_in, task_out))

    print('\n' + '#' * 80)

    for i, (task_in, task_out) in enumerate(tasks, 1):
        print('\n▶ Task {}/{} [{}]:'.format(i, len(tasks), task_in))
        run_task(args.model, task_in, task_out, device, args.model_file, args.do_train,
                 args.do_eval, cache_dir=args.cache_dir, xla_model=xla_model)
        print(' ➤ Finished!')
Esempio n. 8
0
 def setUpClass(cls):
     # Sets the primary test device to the xla_device (CPU or TPU)
     cls.primary_device = str(xm.xla_device())
     torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
         use_full_mat_mul_precision=True)
def main(args):
    blue = lambda x: '\033[94m' + x + '\033[0m'

    seeding(args.seed)

    if args.hfta:
        B = consolidate_hyperparams_and_determine_B(
            args,
            ['lr', 'beta1', 'beta2', 'weight_decay', 'gamma', 'step_size'],
        )
    else:
        B = 0
        (args.lr, args.beta1, args.beta2, args.weight_decay, args.gamma,
         args.step_size) = (args.lr[0], args.beta1[0], args.beta2[0],
                            args.weight_decay[0], args.gamma[0],
                            args.step_size[0])

    if args.device == 'cuda':
        assert torch.cuda.is_available()
        torch.backends.cudnn.benchmark = True
        print('Enable cuDNN heuristics!')
    device = (xm.xla_device()
              if args.device == 'xla' else torch.device(args.device))

    dataset, test_dataset = build_dataset(args)
    dataloader, testdataloader = build_dataloader(args, dataset, test_dataset)

    print('len(dataset)={}'.format(len(dataset)),
          'len(test_dataset)={}'.format(len(test_dataset)))
    num_classes = len(dataset.classes)
    print('classes', num_classes)

    if args.outf is not None:
        try:
            os.makedirs(args.outf)
        except OSError:
            pass

    classifier = PointNetCls(
        k=num_classes,
        feature_transform=args.feature_transform,
        B=B,
        track_running_stats=(args.device != 'xla'),
    )

    if args.model != '':
        classifier.load_state_dict(torch.load(args.model))

    optimizer = get_hfta_optim_for(optim.Adam, B=B)(
        classifier.parameters(),
        lr=args.lr,
        betas=(args.beta1, args.beta2),
        weight_decay=args.weight_decay,
    )
    scheduler = get_hfta_lr_scheduler_for(optim.lr_scheduler.StepLR, B=B)(
        optimizer,
        step_size=args.step_size,
        gamma=args.gamma,
    )

    scaler = amp.GradScaler(enabled=(args.device == 'cuda' and args.amp))

    classifier.to(device)

    num_batch = len(dataloader)

    def loss_fn(output, label, batch_size, trans_feat):
        if B > 0:
            loss = B * F.nll_loss(output.view(B * batch_size, -1), label)
        else:
            loss = F.nll_loss(output, label)
        if args.feature_transform:
            loss += feature_transform_regularizer(trans_feat) * 0.001
        return loss

    classifier = classifier.train()
    epoch_timer = EpochTimer()

    # Training loop
    for epoch in range(args.epochs):
        num_samples_per_epoch = 0
        epoch_timer.epoch_start(epoch)
        for i, data in enumerate(dataloader, 0):
            if i > args.iters_per_epoch:
                break
            if args.warmup_data_loading:
                continue

            points, target = data
            target = target[:, 0]
            points, target = points.to(device), target.to(device)
            N = points.size(0)
            if B > 0:
                points = points.unsqueeze(0).expand(B, -1, -1, -1).contiguous()
                target = target.repeat(B)
            optimizer.zero_grad(set_to_none=True)
            if args.device == 'cuda':
                with amp.autocast(enabled=args.amp):
                    pred, trans, trans_feat = classifier(points)
                    loss = loss_fn(pred, target, N, trans_feat)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
            else:
                pred, trans, trans_feat = classifier(points)
                loss = loss_fn(pred, target, N, trans_feat)
                loss.backward()
                if args.device == 'xla':
                    xm.optimizer_step(optimizer, barrier=True)
                else:
                    optimizer.step()

            print('[{}: {}/{}] train loss: {}'.format(epoch, i, num_batch,
                                                      loss.item()))
            num_samples_per_epoch += N * max(B, 1)
            scaler.update()
        scheduler.step()
        epoch_timer.epoch_stop(num_samples_per_epoch)
        print('Epoch {} took {} s!'.format(epoch,
                                           epoch_timer.epoch_latency(epoch)))

    if args.device == 'xla' and not args.eval:
        print(met.metrics_report())
    if args.outf is not None:
        epoch_timer.to_csv(args.outf)

    if args.eval:
        # Run validation loop.
        print("Running validation loop ...")
        classifier = classifier.eval()
        with torch.no_grad():
            total_correct = torch.zeros(max(B, 1), device=device)
            total_testset = 0
            for data in testdataloader:
                if args.warmup_data_loading:
                    continue
                points, target = data
                target = target[:, 0]
                points, target = points.to(device), target.to(device)
                N = points.size(0)
                if B > 0:
                    points = points.unsqueeze(0).expand(B, -1, -1,
                                                        -1).contiguous()
                    target = target.repeat(B)
                pred, _, _ = classifier(points)
                pred_choice = pred.argmax(-1)

                correct = pred_choice.eq(
                    target.view(B, N) if B > 0 else target).sum(-1)

                total_correct.add_(correct)
                total_testset += N

            final_accuracy = total_correct / total_testset
            final_accuracy = final_accuracy.cpu().tolist()
            if args.outf is not None:
                pd.DataFrame({
                    'acc': final_accuracy
                }).to_csv(os.path.join(args.outf, 'eval.csv'))

            # Return test_accuracy
            return final_accuracy
 def get_device(
     self,
 ):
     return xm.xla_device()
 def get_hw_device(
     self,
 ):
     return xm._xla_real_device(xm.xla_device())
Esempio n. 12
0
def measure_tpu(warmups, steps, h_emb, h_indices, h_offsets, args):

    import torch_xla
    import torch_xla.core.xla_model as xm
    import os

    tsize = int(os.environ.get("MODEL_PARTITION_SIZE", 3000000))

    def syncTPU(tensor):
        torch_xla._XLAC._xla_sync_multi([tensor],
                                        devices=[],
                                        wait=True,
                                        sync_xla_data=True)

    alldev = xm.get_xla_supported_devices()
    allrealdev = xm.xla_real_devices(alldev)
    print("Found {0} devices: {1}".format(len(allrealdev), allrealdev))

    dev = xm.xla_device()
    if (args.features > tsize):
        if args.usexlabag:
            tsplit = torch.split(h_emb.embtable.weight, tsize, dim=0)
        else:
            tsplit = torch.split(h_emb.weight, tsize, dim=0)
        tsplit = list(tsplit)
        for i, chunk in enumerate(tsplit):
            tsplit[i] = chunk.to(dev)

        t = nn.Parameter(torch.ones(10, 10))
        if args.usexlabag:
            h_emb.embtable.weight = t
            t_emb = h_emb.to(dev)
            tsplit = torch.cat(tsplit)
            t_emb.embtable.weight = nn.Parameter(tsplit)
            print("Xla EMB weight shape: ", t_emb.embtable.weight.shape,
                  " on device: ", str(dev))
        else:
            h_emb.weight = t
            t_emb = h_emb.to(dev)
            tsplit = torch.cat(tsplit)
            t_emb.weight = nn.Parameter(tsplit)
            print("EMB weight shape: ", t_emb.weight.shape, " on device: ",
                  str(dev))
    else:
        t_emb = h_emb.to(dev)

    t_indices = h_indices.to(dev)
    t_offsets = h_offsets.to(dev)

    emb_times = 0.0
    start1 = time.perf_counter()
    for i in range(warmups + steps):
        start = time.perf_counter()
        results = t_emb(t_indices, t_offsets)
        syncTPU(results)
        end = time.perf_counter()
        print("Time: {0:.6f} ".format(end - start))
        if (i >= warmups):
            emb_times += end - start

    end1 = time.perf_counter()

    return end1 - start1, emb_times, results
 def process_dataloader(self, dataloader):
     device = xm.xla_device(self.trainer.tpu_id)
     dataloader = xla_pl.ParallelLoader(dataloader, [device])
     dataloader = dataloader.per_device_loader(device)
     return dataloader
def test_xla_device_is_a_tpu():
    """Check that the XLA device is a TPU"""
    device = xm.xla_device()
    device_type = xm.xla_device_hw(device)
    return device_type == "TPU"
Esempio n. 15
0
def get_tpu_device():
    return xm.xla_device()
Esempio n. 16
0
 def model_to_device(self) -> None:
     self.device = xm.xla_device()
     self.model = self.wrapped_model.to(self.device)
Esempio n. 17
0
def build_dataloader_and_sampler(
    dataset_instance: torch.utils.data.Dataset, datamodule_config: DictConfig
) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]:
    """Builds and returns a dataloader along with its sample

    Args:
        dataset_instance (torch.utils.data.Dataset): Instance of dataset for which
            dataloader has to be created
        datamodule_config (omegaconf.DictConfig): Datamodule configuration; required
            for infering params for dataloader

    Returns:
        Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]:
            Tuple of Dataloader and Sampler instance
    """
    from mmf.common.batch_collator import BatchCollator

    training_config = get_global_config("training")
    # Support params coming in from dataloader params
    other_args = {
        "num_workers":
        datamodule_config.get("num_workers",
                              training_config.get("num_workers", 4)),
        "pin_memory":
        datamodule_config.get("pin_memory",
                              training_config.get("pin_memory", False)),
        "shuffle":
        datamodule_config.get("shuffle", None),
        "batch_size":
        datamodule_config.get("batch_size", None),
    }

    # IterableDataset returns batches directly, so no need to add Sampler
    # or batch size as user is expected to control those. This is a fine
    # assumption for now to not support single item based IterableDataset
    # as it will add unnecessary complexity and config parameters
    # to the codebase
    if not isinstance(dataset_instance, torch.utils.data.IterableDataset):
        other_args = _add_extra_args_for_dataloader(dataset_instance,
                                                    other_args)
    else:
        other_args.pop("shuffle")

    loader = torch.utils.data.DataLoader(
        dataset=dataset_instance,
        collate_fn=BatchCollator(dataset_instance.dataset_name,
                                 dataset_instance.dataset_type),
        drop_last=False,  # see also MultiDatasetLoader.__len__
        **other_args,
    )

    if is_xla():
        device = xm.xla_device()
        loader = xla_pl.MpDeviceLoader(loader, device)

    if other_args["num_workers"] >= 0:
        # Suppress leaking semaphore warning
        os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"

    loader.dataset_type = dataset_instance.dataset_type

    return loader, other_args.get("sampler", None)
Esempio n. 18
0
    def to_tpu(self) -> None:
        """Moves the model to the TPU."""

        self.model.to(xm.xla_device())
Esempio n. 19
0
 def root_device(self) -> torch.device:
     return xm.xla_device()
Esempio n. 20
0
    def run_training_epoch(self):

        # get model
        model = self.get_model()

        # Epoch start events
        with self.profiler.profile('on_epoch_start'):
            # callbacks
            self.on_epoch_start()

            # model hooks
            if self.is_function_implemented('on_epoch_start'):
                model.on_epoch_start()

        # track local dataloader so TPU can wrap each epoch
        train_dataloader = self.train_dataloader

        # on TPU we have to wrap it under the ParallelLoader
        if self.use_tpu:
            device = xm.xla_device(self.tpu_id)
            train_dataloader = xla_pl.ParallelLoader(train_dataloader,
                                                     [device])
            train_dataloader = train_dataloader.per_device_loader(device)

        # bookkeeping
        outputs = []

        # run epoch
        for batch_idx, (batch,
                        is_last_batch) in self.profiler.profile_iterable(
                            enumerate(_with_is_last(train_dataloader)),
                            "get_train_batch"):
            # stop epoch if we limited the number of training batches
            if batch_idx >= self.num_training_batches:
                break

            self.batch_idx = batch_idx

            model.global_step = self.global_step

            # ---------------
            # RUN TRAIN STEP
            # ---------------
            _outputs = self.run_training_batch(batch, batch_idx)
            batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs

            # only track outputs when user implements training_epoch_end
            # otherwise we will build up unnecessary memory
            if self.is_overridden('training_epoch_end',
                                  model=self.get_model()):
                outputs.append(batch_output)

            # when returning -1 from train_step, we end epoch early
            early_stop_epoch = batch_result == -1

            # TODO: consolidate all actions that need to take place only after
            # self.accumulate_grad_batches steps (optimizer step, lr update, global step increment)
            if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
                # update lr
                self.update_learning_rates(interval='step')

            # ---------------
            # RUN VAL STEP
            # ---------------
            is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
            can_check_epoch = self.current_epoch % self.check_val_every_n_epoch == 0
            can_check_val = not self.disable_validation and can_check_epoch
            should_check_val = is_val_check_batch or early_stop_epoch
            should_check_val = should_check_val or (
                is_last_batch and self.val_check_batch == float('inf'))
            should_check_val = can_check_val and should_check_val

            # ---------------
            # CHECKPOINTING, EARLY STOPPING
            # ---------------
            # fast_dev_run always forces val checking after train batch
            if self.fast_dev_run or should_check_val:
                self.run_evaluation(test_mode=self.testing)
                self.call_checkpoint_callback()

            # when logs should be saved
            should_save_log = (
                batch_idx +
                1) % self.log_save_interval == 0 or early_stop_epoch
            if should_save_log or self.fast_dev_run:
                if self.is_global_zero and self.logger is not None:
                    self.logger.save()

            # when metrics should be logged
            should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
            if should_log_metrics or self.fast_dev_run:
                # logs user requested information to logger
                self.log_metrics(batch_step_metrics, grad_norm_dic)

            # progress global step according to grads progress
            if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
                self.global_step += 1
            self.total_batch_idx += 1

            # max steps reached, end training
            if self.max_steps is not None and self.max_steps == self.global_step:
                break

            # end epoch early
            # stop when the flag is changed or we've gone past the amount
            # requested in the batches
            if early_stop_epoch or self.fast_dev_run:
                break

        if self.use_horovod:
            hvd.join(hvd.local_rank() if self.on_gpu else -1)

        # process epoch outputs
        model = self.get_model()
        if self.is_overridden('training_epoch_end', model=model):
            epoch_output = model.training_epoch_end(outputs)
            _processed_outputs = self.process_output(epoch_output)
            log_epoch_metrics = _processed_outputs[2]
            callback_epoch_metrics = _processed_outputs[3]
            self.log_metrics(log_epoch_metrics, {})
            self.callback_metrics.update(callback_epoch_metrics)
            self.add_progress_bar_metrics(_processed_outputs[1])

        # when no val loop is present or fast-dev-run still need to call checkpoints
        if not self.is_overridden('validation_step') and not (
                self.fast_dev_run or should_check_val):
            self.call_checkpoint_callback()

        # Epoch end events
        with self.profiler.profile('on_epoch_end'):
            # callbacks
            self.on_epoch_end()
            # model hooks
            if self.is_function_implemented('on_epoch_end'):
                model.on_epoch_end()
Esempio n. 21
0
"""#### Install PyTorch/XLA"""

if use_tpu:
  VERSION = "20200220" #@param ["20200220","nightly", "xrt==1.15.0"]
  !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
  !python pytorch-xla-env-setup.py --version $VERSION

import os
import torch

if use_tpu:
  # imports the torch_xla package for TPU support
  import torch_xla
  import torch_xla.core.xla_model as xm
  dev = xm.xla_device()
  print(dev)
  
import torchvision
import argparse

from torch.utils.tensorboard import SummaryWriter

apex = False
try:
    from apex import amp
    apex = True
except ImportError:
    print(
        "Install the apex package from https://www.github.com/nvidia/apex to use fp16 for training"
    )
Esempio n. 22
0
 def __init__(self,
              fp16: bool = None,
              cpu: bool = False,
              _from_accelerator: bool = False):
     self.__dict__ = self._shared_state
     if not getattr(self, "initialized", False):
         self.backend = None
         if not _from_accelerator:
             raise ValueError(
                 "Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` "
                 "before using any functionality from the `accelerate` library."
             )
         elif is_tpu_available() and not cpu:
             self.distributed_type = DistributedType.TPU
             self.num_processes = xm.xrt_world_size()
             self.process_index = xm.get_ordinal()
             self.local_process_index = xm.get_local_ordinal()
             self.device = xm.xla_device()
             self.use_fp16 = False
         elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
             self.distributed_type = DistributedType.MULTI_GPU
             if not torch.distributed.is_initialized():
                 torch.distributed.init_process_group(backend="nccl")
                 self.backend = "nccl"
             self.num_processes = torch.distributed.get_world_size()
             self.process_index = torch.distributed.get_rank()
             self.local_process_index = int(os.environ.get(
                 "LOCAL_RANK", -1))
             self.device = torch.device("cuda", self.local_process_index)
             torch.cuda.set_device(self.device)
             self.use_fp16 = parse_flag_from_env(
                 "USE_FP16", False) if fp16 is None else fp16
         elif get_int_from_env([
                 "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE",
                 "WORLD_SIZE"
         ], 1) > 1:
             self.distributed_type = DistributedType.MULTI_CPU
             if is_ccl_available() and get_int_from_env(
                 ["CCL_WORKER_COUNT"], 0) > 0:
                 backend = "ccl"
             elif torch.distributed.is_mpi_available():
                 backend = "mpi"
             else:
                 backend = "gloo"
             # Try to get launch configuration from environment variables set by MPI launcher - works for Intel MPI, OpenMPI and MVAPICH
             rank = get_int_from_env([
                 "RANK", "PMI_RANK", "OMPI_COMM_WORLD_RANK",
                 "MV2_COMM_WORLD_RANK"
             ], 0)
             size = get_int_from_env([
                 "WORLD_SIZE", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE",
                 "MV2_COMM_WORLD_SIZE"
             ], 1)
             local_rank = get_int_from_env([
                 "LOCAL_RANK", "MPI_LOCALRANKID",
                 "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"
             ], 0)
             local_size = get_int_from_env([
                 "MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE",
                 "MV2_COMM_WORLD_LOCAL_SIZE"
             ], 1)
             self.local_process_index = local_rank
             os.environ["RANK"] = str(rank)
             os.environ["WORLD_SIZE"] = str(size)
             os.environ["LOCAL_RANK"] = str(local_rank)
             if not os.environ.get("MASTER_PORT", None):
                 os.environ["MASTER_PORT"] = "29500"
             if not os.environ.get("MASTER_ADDR", None):
                 if local_size != size and backend != "mpi":
                     raise ValueError(
                         "Looks like distributed multinode run but MASTER_ADDR env not set, "
                         "please try exporting rank 0's hostname as MASTER_ADDR"
                     )
             if not torch.distributed.is_initialized():
                 torch.distributed.init_process_group(backend,
                                                      rank=rank,
                                                      world_size=size)
                 self.backend = backend
             self.num_processes = torch.distributed.get_world_size()
             self.process_index = torch.distributed.get_rank()
             self.local_process_index = local_rank
             self.device = torch.device("cpu")
             self.use_fp16 = False
         else:
             self.distributed_type = DistributedType.NO
             self.num_processes = 1
             self.process_index = self.local_process_index = 0
             self.device = torch.device(
                 "cuda" if torch.cuda.is_available() and not cpu else "cpu")
             self.use_fp16 = parse_flag_from_env(
                 "USE_FP16", False) if fp16 is None else fp16
         self.initialized = True
Esempio n. 23
0
    def evaluate(self,
                 model,
                 dataloaders,
                 max_batches,
                 test_mode: bool = False):
        """Run evaluation code.

        :param model: PT model
        :param dataloaders: list of PT dataloaders
        :param max_batches: Scalar
        :param test_mode
        :return:
        """
        # enable eval mode
        model.zero_grad()
        model.eval()

        # copy properties for forward overrides
        self.copy_trainer_model_properties(model)

        # disable gradients to save memory
        torch.set_grad_enabled(False)

        # bookkeeping
        outputs = []

        # run validation
        for dataloader_idx, dataloader in enumerate(dataloaders):
            dl_outputs = []

            # on TPU we have to wrap it under the ParallelLoader
            if self.use_tpu:
                device = xm.xla_device()
                dataloader = xla_pl.ParallelLoader(dataloader, [device])
                dataloader = dataloader.per_device_loader(device)

            for batch_idx, batch in enumerate(dataloader):
                if batch is None:  # pragma: no cover
                    continue

                # stop short when on fast_dev_run (sets max_batch=1)
                if batch_idx >= max_batches:
                    break

                # -----------------
                # RUN EVALUATION STEP
                # -----------------
                output = self.evaluation_forward(model, batch, batch_idx,
                                                 dataloader_idx, test_mode)

                # on dp / ddp2 might still want to do something with the batch parts
                if test_mode:
                    if self.is_overriden('test_step_end'):
                        model_ref = self.get_model()
                        with self.profiler.profile('test_step_end'):
                            output = model_ref.test_step_end(output)
                else:
                    if self.is_overriden('validation_step_end'):
                        model_ref = self.get_model()
                        with self.profiler.profile('validation_step_end'):
                            output = model_ref.validation_step_end(output)

                # track outputs for collation
                dl_outputs.append(output)

                # batch done
                if batch_idx % self.progress_bar_refresh_rate == 0:
                    if test_mode:
                        self.test_progress_bar.update(
                            self.progress_bar_refresh_rate)
                    else:
                        self.val_progress_bar.update(
                            self.progress_bar_refresh_rate)
                        self.main_progress_bar.update(
                            self.progress_bar_refresh_rate)
            outputs.append(dl_outputs)

        eval_results = {}

        # with a single dataloader don't pass an array
        if len(dataloaders) == 1:
            outputs = outputs[0]

        # give model a chance to do something with the outputs (and method defined)
        model = self.get_model()

        if test_mode and self.is_overriden('test_epoch_end'):
            eval_results = model.test_epoch_end(outputs)
        elif self.is_overriden('validation_epoch_end'):
            eval_results = model.validation_epoch_end(outputs)

        # TODO: remove in v 1.0.0
        if test_mode and self.is_overriden('test_end'):
            eval_results = model.test_end(outputs)
            m = 'test_end was deprecated in 0.7.0 and will be removed 1.0.0. ' \
                'Use test_epoch_end instead.'
            warnings.warn(m, DeprecationWarning)
        elif self.is_overriden('validation_end'):
            eval_results = model.validation_end(outputs)
            m = 'validation_end was deprecated in 0.7.0 and will be removed 1.0.0. ' \
                'Use validation_epoch_end instead.'
            warnings.warn(m, DeprecationWarning)

        # enable train mode again
        model.train()

        # enable gradients to save memory
        torch.set_grad_enabled(True)

        return eval_results
Esempio n. 24
0
def _setup_replication():
  # At this point xla_model.py APIs are allowed as the setup is already
  # completed.
  if xm.xrt_world_size() > 1:
    device = xm.xla_device()
    xm.set_replication(device, [device])
Esempio n. 25
0
def train_model(tpu=False):
    train_ids = load_pickle_file(cfg.train_ids_224_pkl)
    train_class = load_pickle_file(cfg.train_class_224_pkl)
    train_images = load_pickle_file(cfg.train_image_224_pkl)

    val_ids = load_pickle_file(cfg.val_ids_224_pkl)
    val_class = load_pickle_file(cfg.val_class_224_pkl)
    val_images = load_pickle_file(cfg.val_image_224_pkl)

    if tpu == True:
        device = xm.xla_device()
    else:
        device = 'cuda'

    model = Model()
    model = model.to(device)
    # writer = SummaryWriter('runs/gpu_experiment_1')
    # writer.add_graph(model)

    # train_dataset = ClassificationDataset(id=train_ids, classes = train_class, images = train_images)
    # val_dataset = ClassificationDataset(id=val_ids, classes=val_class, images = val_images, is_valid=True)

    train_loader = ClassificationDataLoader(id=train_ids,
                                            classes=train_class,
                                            images=train_images).fetch(
                                                batch_size=cfg.train_bs,
                                                drop_last=True,
                                                num_workers=0,
                                                shuffle=True,
                                                tpu=tpu)

    valid_loader = ClassificationDataLoader(id=val_ids,
                                            classes=val_class,
                                            images=val_images).fetch(
                                                batch_size=cfg.val_bs,
                                                drop_last=False,
                                                num_workers=0,
                                                shuffle=True,
                                                tpu=tpu)

    if tpu:
        xm.master_print(f"Training for {len(train_loader)} steps per epoch")

        # Scale learning rate to num core
        learning_rate = 0.0001 * xm.xrt_world_size()
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3,
                                                           threshold=0.001,
                                                           mode="min")

    eng = Engine(model, optimizer, device=device, use_tpu=tpu, tpu_print=10)

    for epoch in range(cfg.epochs):
        train_loss = eng.train(train_loader)
        valid_loss, final_preds = eng.evaluate(valid_loader)
        # neptune.log_metric(f"train_loss", train_loss)
        # neptune.log_metric(f"valid_loss", valid_loss)
        xm.master_print(f"Epoch = {epoch}, LOSS = {valid_loss}")
        scheduler.step(valid_loss)
    gc.collect()
Esempio n. 26
0
def train(index, flags):

    torch.manual_seed(flags['seed'])
    device = xm.xla_device()

    if not xm.is_master_ordinal():
        xm.rendezvous('download_only_once')

    iteration = 0
    learning_rate = flags['hparams'].learning_rate

    train_dataset = TextMelLoader(flags['hparams'].training_files,
                                  flags['hparams'])

    val_dataset = TextMelLoader(flags['hparams'].validation_files,
                                flags['hparams'],
                                speaker_ids=train_dataset.speaker_ids)

    collate_fn = TextMelCollate(flags['hparams'].n_frames_per_step)

    if xm.is_master_ordinal():
        xm.rendezvous('download_only_once')

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)

    val_sampler = torch.utils.data.distributed.DistributedSampler(
        val_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=flags['batch_size'],
        sampler=train_sampler,
        num_workers=flags['num_workers'],
        collate_fn=collate_fn,
        drop_last=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=flags['batch_size'],
                                             sampler=val_sampler,
                                             shuffle=False,
                                             num_workers=flags['num_workers'],
                                             collate_fn=collate_fn,
                                             drop_last=True)

    model = load_model(flags['hparams']).to(device).train()
    criterion = Tacotron2Loss()

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=flags['hparams'].learning_rate,
                                 weight_decay=flags['hparams'].weight_decay)

    for epoch in range(flags['hparams'].epochs):
        train_start = time.time()
        para_train_loader = pl.ParallelLoader(
            train_loader, [device]).per_device_loader(device)

        # (text, mel, speaker_id, f0)
        for batch_num, batch in enumerate(para_train_loader):

            model.zero_grad()
            x, y = model.parse_batch(batch)
            y_pred = model(x)

            loss = criterion(y_pred, y)

            if flags['num_workers'] > 1:
                reduced_loss = reduce_tensor(loss.data,
                                             flags['num_workers']).item()
            else:
                reduced_loss = loss.item()

            optimizer.zero_grad()
            loss.backward()

            xm.optimizer_step(optimizer)

            elapsed_train_time = time.time() - train_start
            print("Batch Process", index, "finished training. Train time was:",
                  elapsed_train_time)

        if (iteration % flags['hparams'].iters_per_checkpoint == 0):
            # validate(model, criterion, val_dataset, iteration,
            #          flags['hparams'].batch_size, flags['n_gpus'], collate_fn, logger,
            #          flags['hparams'].distributed_run, flags['rank'])

            model.eval()
            eval_start = time.time()

            with torch.no_grad():

                para_train_loader = pl.ParallelLoader(
                    val_loader, [device]).per_device_loader(device)
                for i, batch in enumerate(para_train_loader):
                    val_loss = 0.0
                    x, y = model.parse_batch(batch)
                    y_pred = model(x)
                    loss = criterion(y_pred, y)

                    if flags['num_workers'] > 1:
                        reduced_val_loss = reduce_tensor(
                            loss.data, flags['num_workers']).item()
                    else:
                        reduced_val_loss = loss.item()

                    val_loss += reduced_val_loss

                val_loss = val_loss / (i + 1)

        elapsed_eval_time = time.time() - eval_start
        print("Process", index, "finished evaluation. Evaluation time was:",
              elapsed_eval_time)
        print("Validation loss {}: {:9f}  ".format(iteration,
                                                   reduced_val_loss))

        iteration += 1
Esempio n. 27
0
def get_tpu_device(args):
    import torch_xla.core.xla_model as xm
    return xm.xla_device()
Esempio n. 28
0
    def pre_dispatch(self) -> None:
        if isinstance(self.device, int):
            self.device = xm.xla_device(self.device)

        self.tpu_local_core_rank = xm.get_local_ordinal()
        self.tpu_global_core_rank = xm.get_ordinal()
Esempio n. 29
0
 def test_get_xla_tensor(self):
     x = _gen_tensor(14, 24, 8, device=xm.xla_device())
     t = x.data.cpu()
     sx = x.select(1, 12)
     tx = t.select(1, 12)
     self.assertEqual(tx, sx.data.cpu())
Esempio n. 30
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--train_path',
                        default='./data/wikitext-2/wiki.train.tokens',
                        type=str,
                        required=False)
    parser.add_argument('--val_path',
                        default='./data/wikitext-2/wiki.valid.tokens',
                        type=str,
                        required=False)
    parser.add_argument('--save_dir', default=None, type=str, required=False)

    parser.add_argument('--use_control_codes',
                        default=False,
                        action="store_true",
                        required=False)
    parser.add_argument('--control_codes',
                        nargs='+',
                        default=['<|endoftext|>'])

    parser.add_argument('--seq_len', default=256, type=int, required=False)
    parser.add_argument('--n_tokens', default=-1, type=int, required=False)
    parser.add_argument('--n_batches', default=-1, type=int, required=False)
    parser.add_argument('--min_seq_len', default=False, action='store_true')
    # Uses fast tokenization
    parser.add_argument('--fast',
                        default=False,
                        action="store_true",
                        required=False)
    # Efficient for large datasets
    parser.add_argument('--efficient',
                        default=False,
                        action="store_true",
                        required=False)
    parser.add_argument('--detokenizer',
                        default=False,
                        action="store_true",
                        required=False)

    # if from scratch
    parser.add_argument('--from_scratch', default=False, action="store_true")
    parser.add_argument('--config', default='gpt2', type=str)

    # if from a pretrained model
    parser.add_argument('--checkpoint', default='distilgpt2', type=str)

    parser.add_argument('--tokenizer', default='gpt2', type=str)
    parser.add_argument('--from_tf', default=False, action="store_true")

    parser.add_argument('--optimizer', default='AdamW', type=str)
    parser.add_argument('--lr', default=5e-5, type=float)
    parser.add_argument('--lr_schedule', default=True, type=bool)

    parser.add_argument('--batch_size', default=4, type=int)
    parser.add_argument('--grad_steps', default=1, type=int)
    parser.add_argument('--epochs', default=1, type=int)

    # check
    # add multiple tpu cores
    parser.add_argument('--accelerator', default='GPU', type=str)
    parser.add_argument('--fp16', default=False, action="store_true")
    parser.add_argument('--apex_mode', default='O1', type=str)

    parser.add_argument('--logging_steps', default=10, type=int)
    parser.add_argument('--hist_steps', default=100, type=int)
    parser.add_argument('--save_steps', default=100, type=int)

    parser.add_argument('--do_sample', default=False, action="store_true")
    parser.add_argument('--prompt', default=False, type=str)
    parser.add_argument('--n_samples', default=1, type=int)
    parser.add_argument('--max_length', default=256, type=int)
    parser.add_argument('--temperature', default=None, type=any)
    parser.add_argument('--top_k', default=None, type=any)
    parser.add_argument('--top_p', default=None, type=any)
    parser.add_argument('--repetition_penalty', default=None, type=any)

    parser.add_argument('--use_sliding_windows',
                        default=False,
                        action="store_true")
    parser.add_argument('--n_sliding_windows', default=5, type=int)
    parser.add_argument('--sliding_window_size', default=128, type=int)

    parser.add_argument('--eval_only', default=False, action="store_true")
    parser.add_argument('--sample_only', default=False, action="store_true")
    parser.add_argument('--debug', default=False, action="store_true")
    parser.add_argument('--tags', nargs='+')
    parser.add_argument('--seed', default=42, type=int)

    args = parser.parse_args()

    torch.manual_seed(args.seed)

    if args.debug:
        import ptvsd
        ptvsd.enable_attach(address=('localhost', 5678), redirect_output=True)
        ptvsd.wait_for_attach()
        breakpoint()

    if args.accelerator == 'TPU':
        import torch_xla.core.xla_model as xm

        args.device = xm.xla_device()
    else:
        args.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")

    if args.sample_only:
        run_sample(args)
    elif args.eval_only:
        run_eval(args)
    else:
        train(args)