def main(args=None, model=None) -> GenerativeQAModule: Path(args.output_dir).mkdir(exist_ok=True) # named_actors = [] # args.actor_handles = named_actors # assert args.actor_handles == named_actors if model is None: model: GenerativeQAModule = GenerativeQAModule(args) dataset = Path(args.data_dir).name data_module = Seq2SeqDataModule(model.tokenizer, args) if (args.logger_name == "default" or args.fast_dev_run or str(args.output_dir).startswith("/tmp") or str(args.output_dir).startswith("/var")): training_logger = True # don't pollute wandb logs unnecessarily elif args.logger_name == "wandb": from pytorch_lightning.loggers import WandbLogger project = os.environ.get("WANDB_PROJECT", dataset) training_logger = WandbLogger(name=model.output_dir.name, project=project) elif args.logger_name == "wandb_shared": from pytorch_lightning.loggers import WandbLogger training_logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}") elif args.logger_name == "tb-logs": from pytorch_lightning.loggers import TensorBoardLogger training_logger = TensorBoardLogger('tb_logs', name='my_model') es_callback = (get_early_stopping_callback(model.val_metric, args.early_stopping_patience) if args.early_stopping_patience >= 0 else False) trainer: pl.Trainer = generic_train( model, args, data_module, logging_callback=Seq2SeqLoggingCallback(), checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric), early_stopping_callback=es_callback, logger=training_logger, accelerator=CustomAccel() if args.gpus > 1 else None, profiler=pl.profiler.AdvancedProfiler() if args.profile else None, ) pickle_save(model.hparams, model.output_dir / "hparams.pkl") if not args.do_predict: return model # test() without a model tests using the best checkpoint automatically trainer.test() return model
def main(args=None, model=None) -> GenerativeQAModule: parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd()) parser = GenerativeQAModule.add_retriever_specific_args(parser) args = args or parser.parse_args() Path(args.output_dir).mkdir(exist_ok=True) if model is None: model: GenerativeQAModule = GenerativeQAModule(args) dataset = Path(args.data_dir).name if (args.logger_name == "default" or args.fast_dev_run or str(args.output_dir).startswith("/tmp") or str(args.output_dir).startswith("/var")): logger = True # don't pollute wandb logs unnecessarily elif args.logger_name == "wandb": from pytorch_lightning.loggers import WandbLogger project = os.environ.get("WANDB_PROJECT", dataset) logger = WandbLogger(name=model.output_dir.name, project=project) elif args.logger_name == "wandb_shared": from pytorch_lightning.loggers import WandbLogger logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}") es_callback = (get_early_stopping_callback(model.val_metric, args.early_stopping_patience) if args.early_stopping_patience >= 0 else False) trainer: pl.Trainer = generic_train( model, args, logging_callback=Seq2SeqLoggingCallback(), checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric), early_stopping_callback=es_callback, logger=logger, accelerator=CustomAccel() if args.gpus > 1 else None, ) pickle_save(model.hparams, model.output_dir / "hparams.pkl") if not args.do_predict: return model # test() without a model tests using the best checkpoint automatically trainer.test() return model
def main(args=None, model=None) -> GenerativeQAModule: parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd()) parser = GenerativeQAModule.add_retriever_specific_args(parser) args = args or parser.parse_args() Path(args.output_dir).mkdir(exist_ok=True) Path(args.output_dir + "/dpr_ctx_checkpoint").mkdir( exist_ok=True) # save dpr_context encoder seprately for the future use print(args.shard_dir) if os.path.exists( args.shard_dir ): # we do not need previous kb shards used in dataset re-conding and re-indexing shutil.rmtree(args.shard_dir) Path(args.shard_dir).mkdir(exist_ok=True) if os.path.exists( args.cache_dir ): # we do not need previous cache files used in dataset re-conding and re-indexing shutil.rmtree(args.cache_dir) Path(args.cache_dir).mkdir(exist_ok=True) named_actors = [] if args.distributed_retriever == "ray" and args.gpus > 1: if not is_ray_available(): raise RuntimeError("Please install Ray to use the Ray " "distributed retriever.") # Connect to an existing Ray cluster. try: ray.init(address=args.ray_address) except (ConnectionError, ValueError): logger.warning( "Connection to Ray cluster failed. Make sure a Ray" "cluster is running by either using Ray's cluster " "launcher (`ray up`) or by manually starting Ray on " "each node via `ray start --head` for the head node " "and `ray start --address='<ip address>:6379'` for " "additional nodes. See " "https://docs.ray.io/en/master/cluster/index.html " "for more info.") raise # Create Ray actors only for rank 0. if ("LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == 0) and ("NODE_RANK" not in os.environ or os.environ["NODE_RANK"] == 0): remote_cls = ray.remote(RayRetriever) named_actors = [ remote_cls.options( name="retrieval_worker_{}".format(i)).remote() for i in range(args.num_retrieval_workers) ] else: logger.info( "Getting named actors for NODE_RANK {}, LOCAL_RANK {}".format( os.environ["NODE_RANK"], os.environ["LOCAL_RANK"])) named_actors = [ ray.get_actor("retrieval_worker_{}".format(i)) for i in range(args.num_retrieval_workers) ] args.actor_handles = named_actors assert args.actor_handles == named_actors if model is None: model: GenerativeQAModule = GenerativeQAModule(args) dataset = Path(args.data_dir).name if (args.logger_name == "default" or args.fast_dev_run or str(args.output_dir).startswith("/tmp") or str(args.output_dir).startswith("/var")): training_logger = True # don't pollute wandb logs unnecessarily elif args.logger_name == "wandb": from pytorch_lightning.loggers import WandbLogger project = os.environ.get("WANDB_PROJECT", dataset) training_logger = WandbLogger(name=model.output_dir.name, project=project) elif args.logger_name == "wandb_shared": from pytorch_lightning.loggers import WandbLogger training_logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}") es_callback = (get_early_stopping_callback(model.val_metric, args.early_stopping_patience) if args.early_stopping_patience >= 0 else False) trainer: pl.Trainer = generic_train( model, args, logging_callback=Seq2SeqLoggingCallback(), checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric), early_stopping_callback=es_callback, logger=training_logger, profiler=pl.profiler.AdvancedProfiler() if args.profile else None, ) pickle_save(model.hparams, model.output_dir / "hparams.pkl") if not args.do_predict: return model # test() without a model tests using the best checkpoint automatically trainer.test() return model