def _test_case(self, paths, extra_flags): with tempfile.TemporaryDirectory() as data_dir: [ write_empty_file(os.path.join(data_dir, f"{p}.bin")) for p in paths + ["train"] ] cfg = make_lm_config(data_dir, extra_flags=extra_flags) raise_if_valid_subsets_unintentionally_ignored(cfg)
def test_masked_dummy_task(self): cfg = make_lm_config(task="dummy_masked_lm") raise_if_valid_subsets_unintentionally_ignored(cfg)
def main(cfg: FairseqConfig) -> None: if isinstance(cfg, argparse.Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) add_defaults(cfg) if (distributed_utils.is_master(cfg.distributed_training) and "job_logging_cfg" in cfg): # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg)) assert ( cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" metrics.reset() if cfg.common.log_file is not None: handler = logging.FileHandler(filename=cfg.common.log_file) logger.addHandler(handler) np.random.seed(cfg.common.seed) utils.set_torch_seed(cfg.common.seed) if distributed_utils.is_master(cfg.distributed_training): checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir) # Print args logger.info(cfg) if cfg.checkpoint.write_checkpoints_asynchronously: try: import iopath # noqa: F401 except ImportError: logging.exception( "Asynchronous checkpoint writing is specified but iopath is " "not installed: `pip install iopath`") return # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(cfg.task) assert cfg.criterion, "Please specify criterion to train a model" # Build model and criterion if cfg.distributed_training.ddp_backend == "fully_sharded": with fsdp_enable_wrap(cfg.distributed_training): model = fsdp_wrap(task.build_model(cfg.model)) else: model = task.build_model(cfg.model) criterion = task.build_criterion(cfg.criterion) logger.info(model) logger.info("task: {}".format(task.__class__.__name__)) logger.info("model: {}".format(model.__class__.__name__)) logger.info("criterion: {}".format(criterion.__class__.__name__)) logger.info("num. shared model params: {:,} (num. trained: {:,})".format( sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False)), sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False) and p.requires_grad), )) logger.info("num. expert model params: {} (num. trained: {})".format( sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)), sum(p.numel() for p in model.parameters() if getattr(p, "expert", False) and p.requires_grad), )) # Load valid dataset (we load training data below, based on the latest checkpoint) # We load the valid dataset AFTER building the model data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg) if cfg.dataset.combine_valid_subsets: task.load_dataset("valid", combine=True, epoch=1) else: for valid_sub_split in cfg.dataset.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) # (optionally) Configure quantization if cfg.common.quantization_config_path is not None: quantizer = quantization_utils.Quantizer( config_path=cfg.common.quantization_config_path, max_epoch=cfg.optimization.max_epoch, max_update=cfg.optimization.max_update, ) else: quantizer = None # Build trainer if cfg.common.model_parallel_size == 1: trainer = Trainer(cfg, task, model, criterion, quantizer) else: trainer = MegatronTrainer(cfg, task, model, criterion) logger.info("training on {} devices (GPUs/TPUs)".format( cfg.distributed_training.distributed_world_size)) logger.info( "max tokens per device = {} and max sentences per device = {}".format( cfg.dataset.max_tokens, cfg.dataset.batch_size, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint( cfg.checkpoint, trainer, # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) if cfg.common.tpu: import torch_xla.core.xla_model as xm xm.rendezvous("load_checkpoint") # wait for all workers max_epoch = cfg.optimization.max_epoch or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() while epoch_itr.next_epoch_idx <= max_epoch: if lr <= cfg.optimization.stop_min_lr: logger.info( f"stopping training because current learning rate ({lr}) is smaller " "than or equal to minimum learning rate " f"(--stop-min-lr={cfg.optimization.stop_min_lr})") break # train for one epoch valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: break # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch load_dataset=task.has_sharded_data("train"), # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum)) # ioPath implementation to wait for all asynchronous file writes to complete. if cfg.checkpoint.write_checkpoints_asynchronously: logger.info( "ioPath PathManager waiting for all asynchronous checkpoint " "writes to finish.") PathManager.async_close() logger.info("ioPath PathManager finished waiting.")