def train(): df = pd.read_csv(TRAIN_PATH).fillna('None') train, valid = model_selection.train_test_split( df, test_size=0.15, random_state=42, stratify=df['Class Index'].values) train = train.reset_index(drop=True) valid = valid.reset_index(drop=True) train = ohe(train, 'Class Index') valid = ohe(valid, 'Class Index') train_labels = train[train.columns[-4:]].values valid_labels = valid[valid.columns[-4:]].values train_data = prepare_dataset(text=train['Description'].values, label=train_labels) valid_data = prepare_dataset(text=valid['Description'].values, label=valid_labels) train_sampler = torch.utils.data.DistributedSampler( train_data, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) valid_sampler = torch.utils.data.DistributedSampler( valid_data, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, num_workers=4, sampler=train_sampler, drop_last=True) valid_dataloader = torch.utils.data.DataLoader(valid_data, batch_size=V_BATCH_SIZE, num_workers=4, sampler=valid_sampler, drop_last=True) # device= torch.device('cuda') device = xm.xla_device() model = BertBaseUncased() model.to(device) param_optimizer = list(model.named_parameters()) no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] optimizer_parameters = [ { "params": [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], "weight_decay": 0.001, }, { "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] num_train_steps = int( len(train_data) / BATCH_SIZE / xm.xrt_world_size() * EPOCHS) xm.master_print( f'num_train_steps = {num_train_steps}, world_size={xm.xrt_world_size()}' ) lr = 3e-4 * xm.xrt_world_size() optimizer = AdamW(optimizer_parameters, lr=lr) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, num_training_steps=num_train_steps) best_acc = 0 for epoch in range(EPOCHS): para_loader = pl.ParallelLoader(train_dataloader, [device]) train_loss = train_loop(para_loader.per_device_loader(device), model=model, optimizer=optimizer, scheduler=scheduler, device=device) para_loader = pl.ParallelLoader(valid_dataloader, [device]) val_acc, val_loss = eval_loop(para_loader.per_device_loader(device), model, device) # print(f"EPOCH: {epoch} train_loss: {train_loss} val_loss: {val_loss} val_acc: {val_acc}") if val_acc > best_acc: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, 'best_model.bin') best_acc = val_acc xm.master_print( f'Epoch: {epoch+1} train_loss: {train_loss} val_loss: {val_loss} Accracy: {val_acc}' )
def fit( self, train_dataset, valid_dataset=None, train_sampler=None, valid_sampler=None, device="cuda", epochs=10, train_bs=16, valid_bs=16, n_jobs=8, callbacks=None, fp16=False, train_collate_fn=None, valid_collate_fn=None, train_shuffle=True, valid_shuffle=False, accumulation_steps=1, clip_grad_norm=None, ): """ The model fit function. Heavily inspired by tf/keras, this function is the core of Tez and this is the only function you need to train your models. """ if device == "tpu": if XLA_AVAILABLE is False: raise RuntimeError( "XLA is not available. Please install pytorch_xla") else: self.using_tpu = True fp16 = False device = xm.xla_device() self._init_model( device=device, train_dataset=train_dataset, valid_dataset=valid_dataset, train_sampler=train_sampler, valid_sampler=valid_sampler, train_bs=train_bs, valid_bs=valid_bs, n_jobs=n_jobs, callbacks=callbacks, fp16=fp16, train_collate_fn=train_collate_fn, valid_collate_fn=valid_collate_fn, train_shuffle=train_shuffle, valid_shuffle=valid_shuffle, accumulation_steps=accumulation_steps, clip_grad_norm=clip_grad_norm, ) for _ in range(epochs): self.train_state = enums.TrainingState.EPOCH_START self.train_state = enums.TrainingState.TRAIN_EPOCH_START train_loss = self.train_one_epoch(self.train_loader) self.train_state = enums.TrainingState.TRAIN_EPOCH_END if self.valid_loader: self.train_state = enums.TrainingState.VALID_EPOCH_START valid_loss = self.validate_one_epoch(self.valid_loader) self.train_state = enums.TrainingState.VALID_EPOCH_END if self.scheduler: if self.step_scheduler_after == "epoch": if self.step_scheduler_metric is None: self.scheduler.step() else: step_metric = self.name_to_metric( self.step_scheduler_metric) self.scheduler.step(step_metric) self.train_state = enums.TrainingState.EPOCH_END if self._model_state.value == "end": break self.current_epoch += 1 self.train_state = enums.TrainingState.TRAIN_END
def train_loop(folds, fold): if CFG.device == 'GPU': LOGGER.info(f"========== fold: {fold} training ==========") elif CFG.device == 'TPU': if CFG.nprocs == 1: LOGGER.info(f"========== fold: {fold} training ==========") elif CFG.nprocs == 8: xm.master_print(f"========== fold: {fold} training ==========") # ==================================================== # loader # ==================================================== trn_idx = folds[folds['fold'] != fold].index val_idx = folds[folds['fold'] == fold].index train_folds = folds.loc[trn_idx].reset_index(drop=True) valid_folds = folds.loc[val_idx].reset_index(drop=True) train_folds = train_folds[train_folds['StudyInstanceUID'].isin(train_annotations['StudyInstanceUID'].unique())].reset_index(drop=True) valid_labels = valid_folds[CFG.target_cols].values train_dataset = TrainDataset(train_folds, train_annotations, use_annot=True, transform=get_transforms(data='train')) valid_dataset = TrainDataset(valid_folds, train_annotations, use_annot=False, transform=get_transforms(data='valid')) if CFG.device == 'GPU': train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers, pin_memory=True, drop_last=True) valid_loader = DataLoader(valid_dataset, batch_size=CFG.batch_size * 2, shuffle=False, num_workers=CFG.num_workers, pin_memory=True, drop_last=False) elif CFG.device == 'TPU': train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=CFG.batch_size, sampler=train_sampler, drop_last=True, num_workers=CFG.num_workers) valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=CFG.batch_size * 2, sampler=valid_sampler, drop_last=False, num_workers=CFG.num_workers) # ==================================================== # scheduler # ==================================================== def get_scheduler(optimizer): if CFG.scheduler=='ReduceLROnPlateau': scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps) elif CFG.scheduler=='CosineAnnealingLR': scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1) elif CFG.scheduler=='CosineAnnealingWarmRestarts': scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1) return scheduler # ==================================================== # model & optimizer # ==================================================== if CFG.device == 'TPU': device = xm.xla_device() elif CFG.device == 'GPU': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') teacher_model = CustomSeResNet152D(CFG.model_name, pretrained=False) teacher_model.to(device) state = torch.load(CFG.teacher) teacher_model.load_state_dict(state['model']) for param in teacher_model.parameters(): param.requires_grad = False teacher_model.eval() # teacher_model.to(device) model = CustomSeResNet152D_WLF(CFG.model_name, pretrained=True) model.to(device) # state = torch.load(CFG.student) # model.load_state_dict(state['model']) optimizer = Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay, amsgrad=False) scheduler = get_scheduler(optimizer) # ==================================================== # loop # ==================================================== train_criterion = CustomLoss(weights=CFG.weights) valid_criterion = nn.BCEWithLogitsLoss() best_score = 0. best_loss = np.inf for epoch in range(CFG.epochs): start_time = time.time() # train if CFG.device == 'TPU': if CFG.nprocs == 1: avg_loss = train_fn(train_loader, teacher_model, model, train_criterion, optimizer, epoch, scheduler, device) elif CFG.nprocs == 8: para_train_loader = pl.ParallelLoader(train_loader, [device]) avg_loss = train_fn(para_train_loader.per_device_loader(device), teacher_model, model, train_criterion, optimizer, epoch, scheduler, device) elif CFG.device == 'GPU': avg_loss = train_fn(train_loader, teacher_model, model, train_criterion, optimizer, epoch, scheduler, device) # eval if CFG.device == 'TPU': if CFG.nprocs == 1: avg_val_loss, preds, _ = valid_fn(valid_loader, model, valid_criterion, device) elif CFG.nprocs == 8: para_valid_loader = pl.ParallelLoader(valid_loader, [device]) avg_val_loss, preds, valid_labels = valid_fn(para_valid_loader.per_device_loader(device), model, valid_criterion, device) preds = idist.all_gather(torch.tensor(preds)).to('cpu').numpy() valid_labels = idist.all_gather(torch.tensor(valid_labels)).to('cpu').numpy() elif CFG.device == 'GPU': avg_val_loss, preds, _ = valid_fn(valid_loader, model, valid_criterion, device) if isinstance(scheduler, ReduceLROnPlateau): scheduler.step(avg_val_loss) elif isinstance(scheduler, CosineAnnealingLR): scheduler.step() elif isinstance(scheduler, CosineAnnealingWarmRestarts): scheduler.step() # scoring score, scores = get_score(valid_labels, preds) elapsed = time.time() - start_time if CFG.device == 'GPU': LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f} avg_val_loss: {avg_val_loss:.4f} time: {elapsed:.0f}s') LOGGER.info(f'Epoch {epoch+1} - Score: {score:.4f} Scores: {np.round(scores, decimals=4)}') elif CFG.device == 'TPU': if CFG.nprocs == 1: LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f} avg_val_loss: {avg_val_loss:.4f} time: {elapsed:.0f}s') LOGGER.info(f'Epoch {epoch+1} - Score: {score:.4f} Scores: {np.round(scores, decimals=4)}') elif CFG.nprocs == 8: xm.master_print(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f} avg_val_loss: {avg_val_loss:.4f} time: {elapsed:.0f}s') xm.master_print(f'Epoch {epoch+1} - Score: {score:.4f} Scores: {np.round(scores, decimals=4)}') if score > best_score: best_score = score if CFG.device == 'GPU': LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model') torch.save({'model': model.state_dict(), 'preds': preds}, OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_score.pth') elif CFG.device == 'TPU': if CFG.nprocs == 1: LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model') elif CFG.nprocs == 8: xm.master_print(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model') xm.save({'model': model, 'preds': preds}, OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_score.pth') if avg_val_loss < best_loss: best_loss = avg_val_loss if CFG.device == 'GPU': LOGGER.info(f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model') torch.save({'model': model.state_dict(), 'preds': preds}, OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_loss.pth') elif CFG.device == 'TPU': if CFG.nprocs == 1: LOGGER.info(f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model') elif CFG.nprocs == 8: xm.master_print(f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model') xm.save({'model': model, 'preds': preds}, OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_loss.pth') # # inference用に全て保存しておく # if CFG.device == 'TPU': # xm.save({'model': model.state_dict()}, OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_epoch{epoch+1}.pth') # elif CFG.device == 'GPU': # torch.save({'model': model.state_dict()}, OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_epoch{epoch+1}.pth') if CFG.nprocs != 8: check_point = torch.load(OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_score.pth') for c in [f'pred_{c}' for c in CFG.target_cols]: valid_folds[c] = np.nan valid_folds[[f'pred_{c}' for c in CFG.target_cols]] = check_point['preds'] return valid_folds
def main(): parser = argparse.ArgumentParser() ## Required parameters parser.add_argument( "--data_dir", default=None, type=str, required=True, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument("--model_type", default=None, type=str, required=True, help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) parser.add_argument( "--model_name_or_path", default=None, type=str, required=True, help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) parser.add_argument( "--task_name", default=None, type=str, required=True, help="The name of the task to train selected in the list: " + ", ".join(processors.keys())) parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) ## Other parameters parser.add_argument( "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name") parser.add_argument( "--tokenizer_name", default="", type=str, help="Pretrained tokenizer name or path if not the same as model_name") parser.add_argument( "--cache_dir", default="", type=str, help= "Where do you want to store the pre-trained models downloaded from s3") parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after tokenization. Sequences longer " "than this will be truncated, sequences shorter will be padded.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--do_test", action='store_true', help="Whether to run eval on the test set.") parser.add_argument("--train_file", default="train.tsv", type=str, help="Training file.") parser.add_argument("--dev_file", default="dev.tsv", type=str, help="Validation file.") parser.add_argument("--test_file", default="test.tsv", type=str, help="Test file.") parser.add_argument("--results_file", default="eval_results.txt", type=str, help="File name to write results.") parser.add_argument( "--evaluate_during_training", action='store_true', help="Rul evaluation during training at each logging step.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.") parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--max_steps", default=-1, type=int, help= "If > 0: set total number of training steps to perform. Override num_train_epochs." ) parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") parser.add_argument('--logging_steps', type=int, default=50, help="Log every X updates steps.") parser.add_argument('--save_steps', type=int, default=50, help="Save checkpoint every X updates steps.") parser.add_argument( "--eval_all_checkpoints", action='store_true', help= "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number" ) parser.add_argument("--no_cuda", action='store_true', help="Avoid using CUDA when available") parser.add_argument('--overwrite_output_dir', action='store_true', help="Overwrite the content of the output directory") parser.add_argument( '--overwrite_cache', action='store_true', help="Overwrite the cached training and evaluation sets") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--tpu', action='store_true', help="Whether to run on the TPU defined in the environment variables") parser.add_argument( '--tpu_ip_address', type=str, default='', help="TPU IP address if none are set in the environment variables") parser.add_argument( '--tpu_name', type=str, default='', help="TPU name if none are set in the environment variables") parser.add_argument( '--xrt_tpu_config', type=str, default='', help="XRT TPU config if none are set in the environment variables") parser.add_argument( '--fp16', action='store_true', help= "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit" ) parser.add_argument( '--fp16_opt_level', type=str, default='O1', help= "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.") parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") args = parser.parse_args() if os.path.exists(args.output_dir) and os.listdir( args.output_dir ) and args.do_train and not args.overwrite_output_dir: raise ValueError( "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome." .format(args.output_dir)) # Setup distant debugging if needed if args.server_ip and args.server_port: # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script import ptvsd print("Waiting for debugger attach") ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.wait_for_attach() # Setup CUDA, GPU & distributed training if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.n_gpu = torch.cuda.device_count() else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl') args.n_gpu = 1 args.device = device if args.tpu: if args.tpu_ip_address: os.environ["TPU_IP_ADDRESS"] = args.tpu_ip_address if args.tpu_name: os.environ["TPU_NAME"] = args.tpu_name if args.xrt_tpu_config: os.environ["XRT_TPU_CONFIG"] = args.xrt_tpu_config assert "TPU_IP_ADDRESS" in os.environ assert "TPU_NAME" in os.environ assert "XRT_TPU_CONFIG" in os.environ import torch_xla import torch_xla.core.xla_model as xm args.device = xm.xla_device() args.xla_model = xm # Setup logging logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) # Set seed set_seed(args) # Prepare GLUE task args.task_name = args.task_name.lower() #print(processors) #print(args.task_name ) if args.task_name not in processors: raise ValueError("Task not found: %s" % (args.task_name)) processor = processors[args.task_name]() processor.set_train_file(args.train_file) processor.set_dev_file(args.dev_file) processor.set_dev_file(args.test_file) args.output_mode = output_modes[args.task_name] binary_label_list = processor.get_binary_labels() multi_label_list = processor.get_multi_labels() num_binary_labels = len(binary_label_list) num_multi_labels = len(multi_label_list) # Load pretrained model and tokenizer if args.local_rank not in [-1, 0]: torch.distributed.barrier( ) # Make sure only the first process in distributed training will download model & vocab args.model_type = args.model_type.lower() config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] config = config_class.from_pretrained( args.config_name if args.config_name else args.model_name_or_path, num_binary_labels=num_binary_labels, num_multi_labels=num_multi_labels, finetuning_task=args.task_name, cache_dir=args.cache_dir if args.cache_dir else None) tokenizer = tokenizer_class.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None) model = model_class.from_pretrained( args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config, cache_dir=args.cache_dir if args.cache_dir else None) if args.local_rank == 0: torch.distributed.barrier( ) # Make sure only the first process in distributed training will download model & vocab model.to(args.device) logger.info("Training/evaluation parameters %s", args) # Training if args.do_train: train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) global_step, tr_loss = train(args, train_dataset, model, tokenizer) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0) and not args.tpu: # Create output directory if needed if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: os.makedirs(args.output_dir) logger.info("Saving model checkpoint to %s", args.output_dir) # Save a trained model, configuration and tokenizer using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` model_to_save = model.module if hasattr( model, 'module') else model # Take care of distributed/parallel training model_to_save.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir) # Good practice: save your training arguments together with the trained model torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) # Load a trained model and vocabulary that you have fine-tuned model = model_class.from_pretrained(args.output_dir) tokenizer = tokenizer_class.from_pretrained(args.output_dir) model.to(args.device) # Evaluation results = {} if (args.do_eval or args.do_test) and args.local_rank in [-1, 0]: tokenizer = tokenizer_class.from_pretrained( args.output_dir, do_lower_case=args.do_lower_case) checkpoints = [args.output_dir] if args.eval_all_checkpoints: checkpoints = list( os.path.dirname(c) for c in sorted( glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) logging.getLogger("transformers.modeling_utils").setLevel( logging.WARN) # Reduce logging logger.info("Evaluate the following checkpoints: %s", checkpoints) for checkpoint in checkpoints: global_step = checkpoint.split( '-')[-1] if len(checkpoints) > 1 else "" prefix = checkpoint.split( '/')[-1] if checkpoint.find('checkpoint') != -1 else "" model = model_class.from_pretrained(checkpoint) model.to(args.device) result = evaluate(args, model, tokenizer, prefix=prefix) result = dict( (k + '_{}'.format(global_step), v) for k, v in result.items()) results.update(result) return results
def train_mnist(flags, training_started=None, dynamic_graph=False, fetch_often=False): torch.manual_seed(1) if flags.fake_data: train_loader = xu.SampleGenerator( data=( torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64), ), sample_count=600000 // flags.batch_size // xm.xrt_world_size(), ) test_loader = xu.SampleGenerator( data=( torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64), ), sample_count=100000 // flags.batch_size // xm.xrt_world_size(), ) else: train_dataset = datasets.MNIST( os.path.join(flags.datadir, str(xm.get_ordinal())), train=True, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ), ) test_dataset = datasets.MNIST( os.path.join(flags.datadir, str(xm.get_ordinal())), train=False, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ), ) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=flags.batch_size, sampler=train_sampler, drop_last=flags.drop_last, shuffle=False if train_sampler else True, num_workers=flags.num_workers, ) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=flags.batch_size, drop_last=flags.drop_last, shuffle=False, num_workers=flags.num_workers, ) # Scale learning rate to num cores lr = flags.lr * xm.xrt_world_size() device = xm.xla_device() model = MNIST().to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(flags.logdir) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) loss_fn = nn.NLLLoss() server = xp.start_server(flags.profiler_port) def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): if dynamic_graph: # testing purpose only: dynamic batch size and graph. index = max(-step, -flags.batch_size + 1) # non-empty data, target = data[:-index, :, :, :], target[:-index] if step >= 15 and training_started: # testing purpose only: set event for synchronization. training_started.set() with xp.StepTrace("train_mnist", step_num=step): with xp.Trace("build_graph"): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) if fetch_often: # testing purpose only: fetch XLA tensors to CPU. loss_i = loss.item() tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer)) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for data, target in loader: with xp.StepTrace("test_mnist"): output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce("test_accuracy", accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, flags.num_epochs + 1): xm.master_print("Epoch {} train begin {}".format(epoch, test_utils.now())) train_loop_fn(train_device_loader) xm.master_print("Epoch {} train end {}".format(epoch, test_utils.now())) accuracy = test_loop_fn(test_device_loader) xm.master_print( "Epoch {} test end {}, Accuracy={:.2f}".format(epoch, test_utils.now(), accuracy) ) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary( writer, epoch, dict_to_write={"Accuracy/test": accuracy}, write_xla_metrics=True ) if flags.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print("Max Accuracy: {:.2f}%".format(max_accuracy)) return max_accuracy
def run_training_epoch(self): # get model model = self.get_model() # Epoch start events with self.profiler.profile('on_epoch_start'): # callbacks self.on_epoch_start() # model hooks if self.is_function_implemented('on_epoch_start'): model.on_epoch_start() # track local dataloader so TPU can wrap each epoch train_dataloader = self.train_dataloader # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: device = xm.xla_device() train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device]) train_dataloader = train_dataloader.per_device_loader(device) # bookkeeping outputs = [] # run epoch for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable( enumerate(_with_is_last(train_dataloader)), "get_train_batch"): # stop epoch if we limited the number of training batches if batch_idx >= self.num_training_batches: break self.batch_idx = batch_idx model.global_step = self.global_step # --------------- # RUN TRAIN STEP # --------------- _outputs = self.run_training_batch(batch, batch_idx) batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs # detach tensors in batch_output before appending to outputs outputs.append(_recursive_detach(batch_output)) # when returning -1 from train_step, we end epoch early early_stop_epoch = batch_result == -1 # update lr self.update_learning_rates(interval='step') # --------------- # RUN VAL STEP # --------------- is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 can_check_val = not self.disable_validation and can_check_epoch should_check_val = is_val_check_batch or early_stop_epoch should_check_val = should_check_val or ( is_last_batch and self.val_check_batch == float('inf')) should_check_val = can_check_val and should_check_val # fast_dev_run always forces val checking after train batch if self.fast_dev_run or should_check_val: self.run_evaluation(test_mode=self.testing) # when logs should be saved should_save_log = ( batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch if should_save_log or self.fast_dev_run: if self.proc_rank == 0 and self.logger is not None: self.logger.save() # when metrics should be logged should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch if should_log_metrics or self.fast_dev_run: # logs user requested information to logger self.log_metrics(batch_step_metrics, grad_norm_dic) # --------------- # CHECKPOINTING, EARLY STOPPING # --------------- # save checkpoint even when no test or val step are defined if self.fast_dev_run or should_check_val: self.call_checkpoint_callback() if self.enable_early_stop: self.early_stop_callback.check_metrics( self.callback_metrics) # progress global step according to grads progress if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: self.global_step += 1 self.total_batch_idx += 1 # max steps reached, end training if self.max_steps is not None and self.max_steps == self.global_step: break # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches if early_stop_epoch or self.fast_dev_run: break # process epoch outputs if isinstance( model, (LightningDistributedDataParallel, LightningDataParallel)): model = model.module if self.is_overriden('training_epoch_end', model=model): epoch_output = model.training_epoch_end(outputs) _processed_outputs = self.process_output(epoch_output) log_epoch_metrics = _processed_outputs[2] callback_epoch_metrics = _processed_outputs[3] self.log_metrics(log_epoch_metrics, {}) self.callback_metrics.update(callback_epoch_metrics) # in case validation step is missing and you are not running fast-dev to duplicate last batch if not self.is_overriden('validation_step') and not ( self.fast_dev_run or should_check_val): self.call_checkpoint_callback() if self.enable_early_stop: self.early_stop_callback.check_metrics(self.callback_metrics) # Epoch end events with self.profiler.profile('on_epoch_end'): # callbacks self.on_epoch_end() # model hooks if self.is_function_implemented('on_epoch_end'): model.on_epoch_end()
def main(): parser = argparse.ArgumentParser() # Required params parser.add_argument("--input_dir", default=None, type=str, required=True, help="Comma seperated list of data dir paths (Use prepare-*.py scripts for preprocessing.)") parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.") parser.add_argument("--model", default=None, type=str, required=True, help="Path to pretrained model or shortcut name (bert-base-multilingual-cased)") # Optional arguments parser.add_argument("--model_file", default='model', type=str, help="Filename of task model") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument('--force', action='store_true', help="Overwrite the content of the output directory") parser.add_argument("--tasks", default='all', type=str, help="Comma seperated list of tasks. Names are subdirs of data dirs. Default: all") parser.add_argument("--cache_dir", default='cache', type=str, help="Path to cache directory") # TPU training parser.add_argument("--use_tpu", action='store_true', help="Whether to use a TPU. (Make sure that the environement variables TPU_NAME,\ TPU_IP_ADDRESS and XRT_TPU_CONFIG are set)") args = parser.parse_args() if not args.do_train and not args.do_eval: print('Specify --do_train and/or --do_eval') exit(-1) if args.do_train and not args.force and os.path.exists(args.output_dir): print('Output path already exists') exit(-1) tasks = list_tasks(args.input_dir, args.output_dir, args.tasks) if len(tasks) == 0: print('No (whitelisted) tasks found') exit(-1) print('Starting benchmark tasks!') print('\n▶ Arguments:') for key in vars(args): print(' ➤ {:15}: {}'.format(key, getattr(args, key))) print('\n▶ Device:') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') xla_model = None if args.use_tpu: import torch_xla.core.xla_model as xm device = xm.xla_device() xla_model = xm print(' ➤ Using device: {}'.format(device.type)) print('\n▶ Scheduling to run {} tasks:'.format(len(tasks))) for task_in, task_out in tasks: print(' ➤ {} [Output: {}]'.format(task_in, task_out)) print('\n' + '#' * 80) for i, (task_in, task_out) in enumerate(tasks, 1): print('\n▶ Task {}/{} [{}]:'.format(i, len(tasks), task_in)) run_task(args.model, task_in, task_out, device, args.model_file, args.do_train, args.do_eval, cache_dir=args.cache_dir, xla_model=xla_model) print(' ➤ Finished!')
def setUpClass(cls): # Sets the primary test device to the xla_device (CPU or TPU) cls.primary_device = str(xm.xla_device()) torch_xla._XLAC._xla_set_use_full_mat_mul_precision( use_full_mat_mul_precision=True)
def main(args): blue = lambda x: '\033[94m' + x + '\033[0m' seeding(args.seed) if args.hfta: B = consolidate_hyperparams_and_determine_B( args, ['lr', 'beta1', 'beta2', 'weight_decay', 'gamma', 'step_size'], ) else: B = 0 (args.lr, args.beta1, args.beta2, args.weight_decay, args.gamma, args.step_size) = (args.lr[0], args.beta1[0], args.beta2[0], args.weight_decay[0], args.gamma[0], args.step_size[0]) if args.device == 'cuda': assert torch.cuda.is_available() torch.backends.cudnn.benchmark = True print('Enable cuDNN heuristics!') device = (xm.xla_device() if args.device == 'xla' else torch.device(args.device)) dataset, test_dataset = build_dataset(args) dataloader, testdataloader = build_dataloader(args, dataset, test_dataset) print('len(dataset)={}'.format(len(dataset)), 'len(test_dataset)={}'.format(len(test_dataset))) num_classes = len(dataset.classes) print('classes', num_classes) if args.outf is not None: try: os.makedirs(args.outf) except OSError: pass classifier = PointNetCls( k=num_classes, feature_transform=args.feature_transform, B=B, track_running_stats=(args.device != 'xla'), ) if args.model != '': classifier.load_state_dict(torch.load(args.model)) optimizer = get_hfta_optim_for(optim.Adam, B=B)( classifier.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay, ) scheduler = get_hfta_lr_scheduler_for(optim.lr_scheduler.StepLR, B=B)( optimizer, step_size=args.step_size, gamma=args.gamma, ) scaler = amp.GradScaler(enabled=(args.device == 'cuda' and args.amp)) classifier.to(device) num_batch = len(dataloader) def loss_fn(output, label, batch_size, trans_feat): if B > 0: loss = B * F.nll_loss(output.view(B * batch_size, -1), label) else: loss = F.nll_loss(output, label) if args.feature_transform: loss += feature_transform_regularizer(trans_feat) * 0.001 return loss classifier = classifier.train() epoch_timer = EpochTimer() # Training loop for epoch in range(args.epochs): num_samples_per_epoch = 0 epoch_timer.epoch_start(epoch) for i, data in enumerate(dataloader, 0): if i > args.iters_per_epoch: break if args.warmup_data_loading: continue points, target = data target = target[:, 0] points, target = points.to(device), target.to(device) N = points.size(0) if B > 0: points = points.unsqueeze(0).expand(B, -1, -1, -1).contiguous() target = target.repeat(B) optimizer.zero_grad(set_to_none=True) if args.device == 'cuda': with amp.autocast(enabled=args.amp): pred, trans, trans_feat = classifier(points) loss = loss_fn(pred, target, N, trans_feat) scaler.scale(loss).backward() scaler.step(optimizer) else: pred, trans, trans_feat = classifier(points) loss = loss_fn(pred, target, N, trans_feat) loss.backward() if args.device == 'xla': xm.optimizer_step(optimizer, barrier=True) else: optimizer.step() print('[{}: {}/{}] train loss: {}'.format(epoch, i, num_batch, loss.item())) num_samples_per_epoch += N * max(B, 1) scaler.update() scheduler.step() epoch_timer.epoch_stop(num_samples_per_epoch) print('Epoch {} took {} s!'.format(epoch, epoch_timer.epoch_latency(epoch))) if args.device == 'xla' and not args.eval: print(met.metrics_report()) if args.outf is not None: epoch_timer.to_csv(args.outf) if args.eval: # Run validation loop. print("Running validation loop ...") classifier = classifier.eval() with torch.no_grad(): total_correct = torch.zeros(max(B, 1), device=device) total_testset = 0 for data in testdataloader: if args.warmup_data_loading: continue points, target = data target = target[:, 0] points, target = points.to(device), target.to(device) N = points.size(0) if B > 0: points = points.unsqueeze(0).expand(B, -1, -1, -1).contiguous() target = target.repeat(B) pred, _, _ = classifier(points) pred_choice = pred.argmax(-1) correct = pred_choice.eq( target.view(B, N) if B > 0 else target).sum(-1) total_correct.add_(correct) total_testset += N final_accuracy = total_correct / total_testset final_accuracy = final_accuracy.cpu().tolist() if args.outf is not None: pd.DataFrame({ 'acc': final_accuracy }).to_csv(os.path.join(args.outf, 'eval.csv')) # Return test_accuracy return final_accuracy
def get_device( self, ): return xm.xla_device()
def get_hw_device( self, ): return xm._xla_real_device(xm.xla_device())
def measure_tpu(warmups, steps, h_emb, h_indices, h_offsets, args): import torch_xla import torch_xla.core.xla_model as xm import os tsize = int(os.environ.get("MODEL_PARTITION_SIZE", 3000000)) def syncTPU(tensor): torch_xla._XLAC._xla_sync_multi([tensor], devices=[], wait=True, sync_xla_data=True) alldev = xm.get_xla_supported_devices() allrealdev = xm.xla_real_devices(alldev) print("Found {0} devices: {1}".format(len(allrealdev), allrealdev)) dev = xm.xla_device() if (args.features > tsize): if args.usexlabag: tsplit = torch.split(h_emb.embtable.weight, tsize, dim=0) else: tsplit = torch.split(h_emb.weight, tsize, dim=0) tsplit = list(tsplit) for i, chunk in enumerate(tsplit): tsplit[i] = chunk.to(dev) t = nn.Parameter(torch.ones(10, 10)) if args.usexlabag: h_emb.embtable.weight = t t_emb = h_emb.to(dev) tsplit = torch.cat(tsplit) t_emb.embtable.weight = nn.Parameter(tsplit) print("Xla EMB weight shape: ", t_emb.embtable.weight.shape, " on device: ", str(dev)) else: h_emb.weight = t t_emb = h_emb.to(dev) tsplit = torch.cat(tsplit) t_emb.weight = nn.Parameter(tsplit) print("EMB weight shape: ", t_emb.weight.shape, " on device: ", str(dev)) else: t_emb = h_emb.to(dev) t_indices = h_indices.to(dev) t_offsets = h_offsets.to(dev) emb_times = 0.0 start1 = time.perf_counter() for i in range(warmups + steps): start = time.perf_counter() results = t_emb(t_indices, t_offsets) syncTPU(results) end = time.perf_counter() print("Time: {0:.6f} ".format(end - start)) if (i >= warmups): emb_times += end - start end1 = time.perf_counter() return end1 - start1, emb_times, results
def process_dataloader(self, dataloader): device = xm.xla_device(self.trainer.tpu_id) dataloader = xla_pl.ParallelLoader(dataloader, [device]) dataloader = dataloader.per_device_loader(device) return dataloader
def test_xla_device_is_a_tpu(): """Check that the XLA device is a TPU""" device = xm.xla_device() device_type = xm.xla_device_hw(device) return device_type == "TPU"
def get_tpu_device(): return xm.xla_device()
def model_to_device(self) -> None: self.device = xm.xla_device() self.model = self.wrapped_model.to(self.device)
def build_dataloader_and_sampler( dataset_instance: torch.utils.data.Dataset, datamodule_config: DictConfig ) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]: """Builds and returns a dataloader along with its sample Args: dataset_instance (torch.utils.data.Dataset): Instance of dataset for which dataloader has to be created datamodule_config (omegaconf.DictConfig): Datamodule configuration; required for infering params for dataloader Returns: Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]: Tuple of Dataloader and Sampler instance """ from mmf.common.batch_collator import BatchCollator training_config = get_global_config("training") # Support params coming in from dataloader params other_args = { "num_workers": datamodule_config.get("num_workers", training_config.get("num_workers", 4)), "pin_memory": datamodule_config.get("pin_memory", training_config.get("pin_memory", False)), "shuffle": datamodule_config.get("shuffle", None), "batch_size": datamodule_config.get("batch_size", None), } # IterableDataset returns batches directly, so no need to add Sampler # or batch size as user is expected to control those. This is a fine # assumption for now to not support single item based IterableDataset # as it will add unnecessary complexity and config parameters # to the codebase if not isinstance(dataset_instance, torch.utils.data.IterableDataset): other_args = _add_extra_args_for_dataloader(dataset_instance, other_args) else: other_args.pop("shuffle") loader = torch.utils.data.DataLoader( dataset=dataset_instance, collate_fn=BatchCollator(dataset_instance.dataset_name, dataset_instance.dataset_type), drop_last=False, # see also MultiDatasetLoader.__len__ **other_args, ) if is_xla(): device = xm.xla_device() loader = xla_pl.MpDeviceLoader(loader, device) if other_args["num_workers"] >= 0: # Suppress leaking semaphore warning os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" loader.dataset_type = dataset_instance.dataset_type return loader, other_args.get("sampler", None)
def to_tpu(self) -> None: """Moves the model to the TPU.""" self.model.to(xm.xla_device())
def root_device(self) -> torch.device: return xm.xla_device()
def run_training_epoch(self): # get model model = self.get_model() # Epoch start events with self.profiler.profile('on_epoch_start'): # callbacks self.on_epoch_start() # model hooks if self.is_function_implemented('on_epoch_start'): model.on_epoch_start() # track local dataloader so TPU can wrap each epoch train_dataloader = self.train_dataloader # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: device = xm.xla_device(self.tpu_id) train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device]) train_dataloader = train_dataloader.per_device_loader(device) # bookkeeping outputs = [] # run epoch for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable( enumerate(_with_is_last(train_dataloader)), "get_train_batch"): # stop epoch if we limited the number of training batches if batch_idx >= self.num_training_batches: break self.batch_idx = batch_idx model.global_step = self.global_step # --------------- # RUN TRAIN STEP # --------------- _outputs = self.run_training_batch(batch, batch_idx) batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs # only track outputs when user implements training_epoch_end # otherwise we will build up unnecessary memory if self.is_overridden('training_epoch_end', model=self.get_model()): outputs.append(batch_output) # when returning -1 from train_step, we end epoch early early_stop_epoch = batch_result == -1 # TODO: consolidate all actions that need to take place only after # self.accumulate_grad_batches steps (optimizer step, lr update, global step increment) if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: # update lr self.update_learning_rates(interval='step') # --------------- # RUN VAL STEP # --------------- is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 can_check_epoch = self.current_epoch % self.check_val_every_n_epoch == 0 can_check_val = not self.disable_validation and can_check_epoch should_check_val = is_val_check_batch or early_stop_epoch should_check_val = should_check_val or ( is_last_batch and self.val_check_batch == float('inf')) should_check_val = can_check_val and should_check_val # --------------- # CHECKPOINTING, EARLY STOPPING # --------------- # fast_dev_run always forces val checking after train batch if self.fast_dev_run or should_check_val: self.run_evaluation(test_mode=self.testing) self.call_checkpoint_callback() # when logs should be saved should_save_log = ( batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch if should_save_log or self.fast_dev_run: if self.is_global_zero and self.logger is not None: self.logger.save() # when metrics should be logged should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch if should_log_metrics or self.fast_dev_run: # logs user requested information to logger self.log_metrics(batch_step_metrics, grad_norm_dic) # progress global step according to grads progress if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: self.global_step += 1 self.total_batch_idx += 1 # max steps reached, end training if self.max_steps is not None and self.max_steps == self.global_step: break # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches if early_stop_epoch or self.fast_dev_run: break if self.use_horovod: hvd.join(hvd.local_rank() if self.on_gpu else -1) # process epoch outputs model = self.get_model() if self.is_overridden('training_epoch_end', model=model): epoch_output = model.training_epoch_end(outputs) _processed_outputs = self.process_output(epoch_output) log_epoch_metrics = _processed_outputs[2] callback_epoch_metrics = _processed_outputs[3] self.log_metrics(log_epoch_metrics, {}) self.callback_metrics.update(callback_epoch_metrics) self.add_progress_bar_metrics(_processed_outputs[1]) # when no val loop is present or fast-dev-run still need to call checkpoints if not self.is_overridden('validation_step') and not ( self.fast_dev_run or should_check_val): self.call_checkpoint_callback() # Epoch end events with self.profiler.profile('on_epoch_end'): # callbacks self.on_epoch_end() # model hooks if self.is_function_implemented('on_epoch_end'): model.on_epoch_end()
"""#### Install PyTorch/XLA""" if use_tpu: VERSION = "20200220" #@param ["20200220","nightly", "xrt==1.15.0"] !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py !python pytorch-xla-env-setup.py --version $VERSION import os import torch if use_tpu: # imports the torch_xla package for TPU support import torch_xla import torch_xla.core.xla_model as xm dev = xm.xla_device() print(dev) import torchvision import argparse from torch.utils.tensorboard import SummaryWriter apex = False try: from apex import amp apex = True except ImportError: print( "Install the apex package from https://www.github.com/nvidia/apex to use fp16 for training" )
def __init__(self, fp16: bool = None, cpu: bool = False, _from_accelerator: bool = False): self.__dict__ = self._shared_state if not getattr(self, "initialized", False): self.backend = None if not _from_accelerator: raise ValueError( "Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` " "before using any functionality from the `accelerate` library." ) elif is_tpu_available() and not cpu: self.distributed_type = DistributedType.TPU self.num_processes = xm.xrt_world_size() self.process_index = xm.get_ordinal() self.local_process_index = xm.get_local_ordinal() self.device = xm.xla_device() self.use_fp16 = False elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu: self.distributed_type = DistributedType.MULTI_GPU if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl") self.backend = "nccl" self.num_processes = torch.distributed.get_world_size() self.process_index = torch.distributed.get_rank() self.local_process_index = int(os.environ.get( "LOCAL_RANK", -1)) self.device = torch.device("cuda", self.local_process_index) torch.cuda.set_device(self.device) self.use_fp16 = parse_flag_from_env( "USE_FP16", False) if fp16 is None else fp16 elif get_int_from_env([ "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE" ], 1) > 1: self.distributed_type = DistributedType.MULTI_CPU if is_ccl_available() and get_int_from_env( ["CCL_WORKER_COUNT"], 0) > 0: backend = "ccl" elif torch.distributed.is_mpi_available(): backend = "mpi" else: backend = "gloo" # Try to get launch configuration from environment variables set by MPI launcher - works for Intel MPI, OpenMPI and MVAPICH rank = get_int_from_env([ "RANK", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK" ], 0) size = get_int_from_env([ "WORLD_SIZE", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE" ], 1) local_rank = get_int_from_env([ "LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK" ], 0) local_size = get_int_from_env([ "MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE" ], 1) self.local_process_index = local_rank os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(size) os.environ["LOCAL_RANK"] = str(local_rank) if not os.environ.get("MASTER_PORT", None): os.environ["MASTER_PORT"] = "29500" if not os.environ.get("MASTER_ADDR", None): if local_size != size and backend != "mpi": raise ValueError( "Looks like distributed multinode run but MASTER_ADDR env not set, " "please try exporting rank 0's hostname as MASTER_ADDR" ) if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend, rank=rank, world_size=size) self.backend = backend self.num_processes = torch.distributed.get_world_size() self.process_index = torch.distributed.get_rank() self.local_process_index = local_rank self.device = torch.device("cpu") self.use_fp16 = False else: self.distributed_type = DistributedType.NO self.num_processes = 1 self.process_index = self.local_process_index = 0 self.device = torch.device( "cuda" if torch.cuda.is_available() and not cpu else "cpu") self.use_fp16 = parse_flag_from_env( "USE_FP16", False) if fp16 is None else fp16 self.initialized = True
def evaluate(self, model, dataloaders, max_batches, test_mode: bool = False): """Run evaluation code. :param model: PT model :param dataloaders: list of PT dataloaders :param max_batches: Scalar :param test_mode :return: """ # enable eval mode model.zero_grad() model.eval() # copy properties for forward overrides self.copy_trainer_model_properties(model) # disable gradients to save memory torch.set_grad_enabled(False) # bookkeeping outputs = [] # run validation for dataloader_idx, dataloader in enumerate(dataloaders): dl_outputs = [] # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: device = xm.xla_device() dataloader = xla_pl.ParallelLoader(dataloader, [device]) dataloader = dataloader.per_device_loader(device) for batch_idx, batch in enumerate(dataloader): if batch is None: # pragma: no cover continue # stop short when on fast_dev_run (sets max_batch=1) if batch_idx >= max_batches: break # ----------------- # RUN EVALUATION STEP # ----------------- output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) # on dp / ddp2 might still want to do something with the batch parts if test_mode: if self.is_overriden('test_step_end'): model_ref = self.get_model() with self.profiler.profile('test_step_end'): output = model_ref.test_step_end(output) else: if self.is_overriden('validation_step_end'): model_ref = self.get_model() with self.profiler.profile('validation_step_end'): output = model_ref.validation_step_end(output) # track outputs for collation dl_outputs.append(output) # batch done if batch_idx % self.progress_bar_refresh_rate == 0: if test_mode: self.test_progress_bar.update( self.progress_bar_refresh_rate) else: self.val_progress_bar.update( self.progress_bar_refresh_rate) self.main_progress_bar.update( self.progress_bar_refresh_rate) outputs.append(dl_outputs) eval_results = {} # with a single dataloader don't pass an array if len(dataloaders) == 1: outputs = outputs[0] # give model a chance to do something with the outputs (and method defined) model = self.get_model() if test_mode and self.is_overriden('test_epoch_end'): eval_results = model.test_epoch_end(outputs) elif self.is_overriden('validation_epoch_end'): eval_results = model.validation_epoch_end(outputs) # TODO: remove in v 1.0.0 if test_mode and self.is_overriden('test_end'): eval_results = model.test_end(outputs) m = 'test_end was deprecated in 0.7.0 and will be removed 1.0.0. ' \ 'Use test_epoch_end instead.' warnings.warn(m, DeprecationWarning) elif self.is_overriden('validation_end'): eval_results = model.validation_end(outputs) m = 'validation_end was deprecated in 0.7.0 and will be removed 1.0.0. ' \ 'Use validation_epoch_end instead.' warnings.warn(m, DeprecationWarning) # enable train mode again model.train() # enable gradients to save memory torch.set_grad_enabled(True) return eval_results
def _setup_replication(): # At this point xla_model.py APIs are allowed as the setup is already # completed. if xm.xrt_world_size() > 1: device = xm.xla_device() xm.set_replication(device, [device])
def train_model(tpu=False): train_ids = load_pickle_file(cfg.train_ids_224_pkl) train_class = load_pickle_file(cfg.train_class_224_pkl) train_images = load_pickle_file(cfg.train_image_224_pkl) val_ids = load_pickle_file(cfg.val_ids_224_pkl) val_class = load_pickle_file(cfg.val_class_224_pkl) val_images = load_pickle_file(cfg.val_image_224_pkl) if tpu == True: device = xm.xla_device() else: device = 'cuda' model = Model() model = model.to(device) # writer = SummaryWriter('runs/gpu_experiment_1') # writer.add_graph(model) # train_dataset = ClassificationDataset(id=train_ids, classes = train_class, images = train_images) # val_dataset = ClassificationDataset(id=val_ids, classes=val_class, images = val_images, is_valid=True) train_loader = ClassificationDataLoader(id=train_ids, classes=train_class, images=train_images).fetch( batch_size=cfg.train_bs, drop_last=True, num_workers=0, shuffle=True, tpu=tpu) valid_loader = ClassificationDataLoader(id=val_ids, classes=val_class, images=val_images).fetch( batch_size=cfg.val_bs, drop_last=False, num_workers=0, shuffle=True, tpu=tpu) if tpu: xm.master_print(f"Training for {len(train_loader)} steps per epoch") # Scale learning rate to num core learning_rate = 0.0001 * xm.xrt_world_size() optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, threshold=0.001, mode="min") eng = Engine(model, optimizer, device=device, use_tpu=tpu, tpu_print=10) for epoch in range(cfg.epochs): train_loss = eng.train(train_loader) valid_loss, final_preds = eng.evaluate(valid_loader) # neptune.log_metric(f"train_loss", train_loss) # neptune.log_metric(f"valid_loss", valid_loss) xm.master_print(f"Epoch = {epoch}, LOSS = {valid_loss}") scheduler.step(valid_loss) gc.collect()
def train(index, flags): torch.manual_seed(flags['seed']) device = xm.xla_device() if not xm.is_master_ordinal(): xm.rendezvous('download_only_once') iteration = 0 learning_rate = flags['hparams'].learning_rate train_dataset = TextMelLoader(flags['hparams'].training_files, flags['hparams']) val_dataset = TextMelLoader(flags['hparams'].validation_files, flags['hparams'], speaker_ids=train_dataset.speaker_ids) collate_fn = TextMelCollate(flags['hparams'].n_frames_per_step) if xm.is_master_ordinal(): xm.rendezvous('download_only_once') train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler( val_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=flags['batch_size'], sampler=train_sampler, num_workers=flags['num_workers'], collate_fn=collate_fn, drop_last=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=flags['batch_size'], sampler=val_sampler, shuffle=False, num_workers=flags['num_workers'], collate_fn=collate_fn, drop_last=True) model = load_model(flags['hparams']).to(device).train() criterion = Tacotron2Loss() optimizer = torch.optim.Adam(model.parameters(), lr=flags['hparams'].learning_rate, weight_decay=flags['hparams'].weight_decay) for epoch in range(flags['hparams'].epochs): train_start = time.time() para_train_loader = pl.ParallelLoader( train_loader, [device]).per_device_loader(device) # (text, mel, speaker_id, f0) for batch_num, batch in enumerate(para_train_loader): model.zero_grad() x, y = model.parse_batch(batch) y_pred = model(x) loss = criterion(y_pred, y) if flags['num_workers'] > 1: reduced_loss = reduce_tensor(loss.data, flags['num_workers']).item() else: reduced_loss = loss.item() optimizer.zero_grad() loss.backward() xm.optimizer_step(optimizer) elapsed_train_time = time.time() - train_start print("Batch Process", index, "finished training. Train time was:", elapsed_train_time) if (iteration % flags['hparams'].iters_per_checkpoint == 0): # validate(model, criterion, val_dataset, iteration, # flags['hparams'].batch_size, flags['n_gpus'], collate_fn, logger, # flags['hparams'].distributed_run, flags['rank']) model.eval() eval_start = time.time() with torch.no_grad(): para_train_loader = pl.ParallelLoader( val_loader, [device]).per_device_loader(device) for i, batch in enumerate(para_train_loader): val_loss = 0.0 x, y = model.parse_batch(batch) y_pred = model(x) loss = criterion(y_pred, y) if flags['num_workers'] > 1: reduced_val_loss = reduce_tensor( loss.data, flags['num_workers']).item() else: reduced_val_loss = loss.item() val_loss += reduced_val_loss val_loss = val_loss / (i + 1) elapsed_eval_time = time.time() - eval_start print("Process", index, "finished evaluation. Evaluation time was:", elapsed_eval_time) print("Validation loss {}: {:9f} ".format(iteration, reduced_val_loss)) iteration += 1
def get_tpu_device(args): import torch_xla.core.xla_model as xm return xm.xla_device()
def pre_dispatch(self) -> None: if isinstance(self.device, int): self.device = xm.xla_device(self.device) self.tpu_local_core_rank = xm.get_local_ordinal() self.tpu_global_core_rank = xm.get_ordinal()
def test_get_xla_tensor(self): x = _gen_tensor(14, 24, 8, device=xm.xla_device()) t = x.data.cpu() sx = x.select(1, 12) tx = t.select(1, 12) self.assertEqual(tx, sx.data.cpu())
def main(): parser = argparse.ArgumentParser() parser.add_argument('--train_path', default='./data/wikitext-2/wiki.train.tokens', type=str, required=False) parser.add_argument('--val_path', default='./data/wikitext-2/wiki.valid.tokens', type=str, required=False) parser.add_argument('--save_dir', default=None, type=str, required=False) parser.add_argument('--use_control_codes', default=False, action="store_true", required=False) parser.add_argument('--control_codes', nargs='+', default=['<|endoftext|>']) parser.add_argument('--seq_len', default=256, type=int, required=False) parser.add_argument('--n_tokens', default=-1, type=int, required=False) parser.add_argument('--n_batches', default=-1, type=int, required=False) parser.add_argument('--min_seq_len', default=False, action='store_true') # Uses fast tokenization parser.add_argument('--fast', default=False, action="store_true", required=False) # Efficient for large datasets parser.add_argument('--efficient', default=False, action="store_true", required=False) parser.add_argument('--detokenizer', default=False, action="store_true", required=False) # if from scratch parser.add_argument('--from_scratch', default=False, action="store_true") parser.add_argument('--config', default='gpt2', type=str) # if from a pretrained model parser.add_argument('--checkpoint', default='distilgpt2', type=str) parser.add_argument('--tokenizer', default='gpt2', type=str) parser.add_argument('--from_tf', default=False, action="store_true") parser.add_argument('--optimizer', default='AdamW', type=str) parser.add_argument('--lr', default=5e-5, type=float) parser.add_argument('--lr_schedule', default=True, type=bool) parser.add_argument('--batch_size', default=4, type=int) parser.add_argument('--grad_steps', default=1, type=int) parser.add_argument('--epochs', default=1, type=int) # check # add multiple tpu cores parser.add_argument('--accelerator', default='GPU', type=str) parser.add_argument('--fp16', default=False, action="store_true") parser.add_argument('--apex_mode', default='O1', type=str) parser.add_argument('--logging_steps', default=10, type=int) parser.add_argument('--hist_steps', default=100, type=int) parser.add_argument('--save_steps', default=100, type=int) parser.add_argument('--do_sample', default=False, action="store_true") parser.add_argument('--prompt', default=False, type=str) parser.add_argument('--n_samples', default=1, type=int) parser.add_argument('--max_length', default=256, type=int) parser.add_argument('--temperature', default=None, type=any) parser.add_argument('--top_k', default=None, type=any) parser.add_argument('--top_p', default=None, type=any) parser.add_argument('--repetition_penalty', default=None, type=any) parser.add_argument('--use_sliding_windows', default=False, action="store_true") parser.add_argument('--n_sliding_windows', default=5, type=int) parser.add_argument('--sliding_window_size', default=128, type=int) parser.add_argument('--eval_only', default=False, action="store_true") parser.add_argument('--sample_only', default=False, action="store_true") parser.add_argument('--debug', default=False, action="store_true") parser.add_argument('--tags', nargs='+') parser.add_argument('--seed', default=42, type=int) args = parser.parse_args() torch.manual_seed(args.seed) if args.debug: import ptvsd ptvsd.enable_attach(address=('localhost', 5678), redirect_output=True) ptvsd.wait_for_attach() breakpoint() if args.accelerator == 'TPU': import torch_xla.core.xla_model as xm args.device = xm.xla_device() else: args.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") if args.sample_only: run_sample(args) elif args.eval_only: run_eval(args) else: train(args)