コード例 #1
0
 def _format_precision_config(self):
     amp_type = self.lightning_module.trainer.accelerator_connector.amp_type
     amp_level = self.lightning_module.trainer.accelerator_connector.amp_level
     precision = self.lightning_module.trainer.accelerator_connector.precision
     if precision == 16:
         if "fp16" not in self.config and amp_type == AMPType.NATIVE:
             # FP16 is a DeepSpeed standalone AMP implementation
             rank_zero_info("Enabling DeepSpeed FP16.")
             self.config["fp16"] = {
                 "enabled": True,
                 "loss_scale": self.loss_scale,
                 "initial_scale_power": self.initial_scale_power,
                 "loss_scale_window": self.loss_scale_window,
                 "hysteresis": self.hysteresis,
                 "min_loss_scale": self.min_loss_scale
             }
         elif "amp" not in self.config and amp_type == AMPType.APEX:
             rank_zero_only("Enabling DeepSpeed APEX Implementation.")
             self.config["amp"] = {
                 "enabled": True,
                 "opt_level": amp_level,
             }
     if "zero_optimization" in self.config and not ("amp" in self.config or
                                                    "fp16" in self.config):
         raise MisconfigurationException(
             "To use DeepSpeed ZeRO Optimization, you must set precision=16."
         )
コード例 #2
0
 def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule':
     obj = super().__new__(cls)
     # track `DataHooks` calls and run `prepare_data` only on rank zero
     obj.prepare_data = cls._track_data_hook_calls(
         obj, rank_zero_only(obj.prepare_data))
     obj.setup = cls._track_data_hook_calls(obj, obj.setup)
     obj.teardown = cls._track_data_hook_calls(obj, obj.teardown)
     return obj
コード例 #3
0
 def _format_precision_config(self):
     if self.amp_type == AMPType.APEX:
         amp_level = self.amp_level
     if self.precision in (16, "mixed"):
         if "fp16" not in self.config and self.amp_type == AMPType.NATIVE:
             # FP16 is a DeepSpeed standalone AMP implementation
             rank_zero_info("Enabling DeepSpeed FP16.")
             self.config["fp16"] = {
                 "enabled": True,
                 "loss_scale": self.loss_scale,
                 "initial_scale_power": self.initial_scale_power,
                 "loss_scale_window": self.loss_scale_window,
                 "hysteresis": self.hysteresis,
                 "min_loss_scale": self.min_loss_scale,
             }
         elif "amp" not in self.config and self.amp_type == AMPType.APEX:
             rank_zero_only("Enabling DeepSpeed APEX Implementation.")
             self.config["amp"] = {"enabled": True, "opt_level": amp_level}
コード例 #4
0
 def _format_precision_config(self):
     amp_type = self.lightning_module.trainer.accelerator_connector.amp_type
     amp_level = self.lightning_module.trainer.accelerator_connector.amp_level
     precision = self.lightning_module.trainer.accelerator_connector.precision
     if precision in (16, 'mixed'):
         if "fp16" not in self.config and amp_type == AMPType.NATIVE:
             # FP16 is a DeepSpeed standalone AMP implementation
             rank_zero_info("Enabling DeepSpeed FP16.")
             self.config["fp16"] = {
                 "enabled": True,
                 "loss_scale": self.loss_scale,
                 "initial_scale_power": self.initial_scale_power,
                 "loss_scale_window": self.loss_scale_window,
                 "hysteresis": self.hysteresis,
                 "min_loss_scale": self.min_loss_scale
             }
         elif "amp" not in self.config and amp_type == AMPType.APEX:
             rank_zero_only("Enabling DeepSpeed APEX Implementation.")
             self.config["amp"] = {
                 "enabled": True,
                 "opt_level": amp_level,
             }
コード例 #5
0
    parser.add_argument("--num_eval_passages", type=int, default=100)
    parser.add_argument("--eval_batch_size", default=4, type=int)
    parser.add_argument("--target_dataset",
                        default="test",
                        choices=["validation", "test"])
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    logging.basicConfig(
        level=logging.WARN,
        format=
        "[%(asctime)s] [%(levelname)s] %(message)s (%(funcName)s@%(filename)s:%(lineno)s)"
    )
    logging.getLogger("lightning").setLevel(logging.ERROR)

    # Reader.prepare_data = lambda self: None
    model = Reader.load_from_checkpoint(args.checkpoint_file)
    model.hparams.num_eval_passages = args.num_eval_passages
    model.hparams.eval_batch_size = args.eval_batch_size
    if args.target_dataset == "validation":
        model.hparams.test_file = model.hparams.validation_file
        model.hparams.nq_gold_test_file = model.hparams.nq_gold_validation_file

    trainer = Trainer.from_argparse_args(args)
    result = trainer.test(model)

    def report_results():
        print("result: %s" % result)

    rank_zero_only(report_results)()