Exemplo n.º 1
0
 def __init__(self, args, params):
     super().__init__(args)
     fused_adam_cls = get_fused_adam_class()
     if fused_adam_cls is not None and torch.cuda.is_available():
         print('| using FusedAdam')
         self._optimizer = fused_adam_cls(params, **self.optimizer_config)
     else:
         self._optimizer = Adam(params, **self.optimizer_config)
Exemplo n.º 2
0
 def __init__(self, args, params):
     super().__init__(args)
     fused_adam_cls = get_fused_adam_class()
     if not args.use_old_adam and fused_adam_cls is not None and torch.cuda.is_available(
     ):
         logger.info('using FusedAdam')
         self._optimizer = fused_adam_cls(params, **self.optimizer_config)
     else:
         self._optimizer = Adam(params, **self.optimizer_config)
Exemplo n.º 3
0
 def __init__(self, args, params):
     super().__init__(args)
     fused_adam_cls = get_fused_adam_class()
     use_fused_adam = (not getattr(args, 'use_old_adam', False)
                       and fused_adam_cls is not None
                       and torch.cuda.is_available())
     if use_fused_adam:
         logger.info('using FusedAdam')
         self._optimizer = fused_adam_cls(params, **self.optimizer_config)
     else:
         self._optimizer = RAdam(params, **self.optimizer_config)
Exemplo n.º 4
0
 def __init__(self, cfg: DictConfig, params):
     super().__init__(cfg)
     fused_adam_cls = get_fused_adam_class()
     use_fused_adam = (not getattr(cfg, "use_old_adam", False)
                       and fused_adam_cls is not None
                       and torch.cuda.is_available())
     if getattr(cfg, "tpu", False):
         # on TPUs we use the Adam defined here, since it
         # automatically casts gradients to FP32
         self._optimizer = Adam(params, **self.optimizer_config)
     elif use_fused_adam:
         logger.info("using FusedAdam")
         self._optimizer = fused_adam_cls(params, **self.optimizer_config)
     else:
         self._optimizer = Adam(params, **self.optimizer_config)
Exemplo n.º 5
0
 def __init__(self, cfg: FairseqAdamConfig, params):
     super().__init__(cfg)
     fused_adam_cls = get_fused_adam_class()
     use_fused_adam = (not getattr(cfg, "use_old_adam", False)
                       and fused_adam_cls is not None
                       and torch.cuda.is_available())
     if getattr(cfg, "tpu", False):
         if self.cfg.fp16_adam_stats:
             raise NotImplementedError(
                 "--fp16-adam-stats is only supported on GPU")
         # on TPUs we use the Adam defined here, since it
         # automatically casts gradients to FP32
         self._optimizer = Adam(params, **self.optimizer_config)
     elif use_fused_adam:
         logger.info("using FusedAdam")
         self._optimizer = fused_adam_cls(
             params,
             use_fp16_stats=self.cfg.fp16_adam_stats,
             **self.optimizer_config)
     else:
         if self.cfg.fp16_adam_stats:
             raise NotImplementedError(
                 "--fp16-adam-stats is only supported with FusedAdamV1")
         self._optimizer = Adam(params, **self.optimizer_config)
Exemplo n.º 6
0
    def __init__(self, cfg: FairseqAdamConfig, params):
        super().__init__(cfg)

        fused_adam_cls = get_fused_adam_class()
        use_fused_adam = (not getattr(cfg, "use_old_adam", False)
                          and fused_adam_cls is not None
                          and torch.cuda.is_available())

        if getattr(cfg, "tpu", False):
            # on TPUs we use the Adam defined here, since it
            # automatically casts gradients to FP32
            self._optimizer = Adam(params, **self.optimizer_config)
        elif self.cfg.use_habana and self.cfg.use_fused_adam:
            logger.info("using FusedAdamW")
            try:
                from habana_frameworks.torch.hpex.optimizers import FusedAdamW
            except ImportError:
                raise ImportError("Please install habana_torch.")
            self._optimizer = FusedAdamW(params, **self.optimizer_config)
        elif use_fused_adam:
            logger.info("using FusedAdam")
            self._optimizer = fused_adam_cls(params, **self.optimizer_config)
        else:
            self._optimizer = Adam(params, **self.optimizer_config)