Ejemplo n.º 1
0
    def _setup_amp_backend(self, amp_type: str):
        if self.trainer.precision != 16:
            # no AMP requested, so we can leave now
            return

        amp_type = amp_type.lower()
        assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}'
        if amp_type == 'native':
            if not NATIVE_AMP_AVAILABLE:
                rank_zero_warn('You have asked for native AMP but your PyTorch version does not support it.'
                               ' Consider upgrading with `pip install torch>=1.6`.'
                               ' We will attempt to use NVIDIA Apex for this session.')
                amp_type = 'apex'
            else:
                self.trainer.amp_backend = AMPType.NATIVE
                log.info('Using native 16bit precision.')
                self.backend = NativeAMPPlugin(self.trainer)

        if amp_type == 'apex':
            if not APEX_AVAILABLE:
                rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.'
                               ' Install apex first using this guide: https://github.com/NVIDIA/apex#linux')
            else:
                log.info('Using APEX 16bit precision.')
                self.trainer.amp_backend = AMPType.APEX
                self.backend = ApexPlugin(self.trainer)
                log.warn("LightningOptimizer doesn't support Apex")

        if not self.trainer.amp_backend:
            raise ModuleNotFoundError(
                f'You have asked for AMP support {amp_type}, but there is no support on your side yet.'
                f' Consider installing torch >= 1.6 or NVIDIA Apex.'
            )
Ejemplo n.º 2
0
def dump_json(filename: str, data: Dict, complain: bool = False) -> None:
    r"""
    Save json to file.
    """
    if os.path.isfile(filename):
        if complain:
            raise ValueError(f"File {filename} does already exist!")
        else:
            logger.warn(f"Overwriting {filename} file!")

    with open(filename, 'w') as outfile:
        json.dump(data, outfile)
Ejemplo n.º 3
0
 def effective_block_size(self) -> int:
     if self.cfg.block_size is None:
         block_size = self.tokenizer.model_max_length
         if block_size > 1024:
             log.warn(
                 f"The tokenizer picked seems to have a very large `model_max_length` "
                 f"({self.tokenizer.model_max_length}). "
                 "Picking 1024 instead. You can change that default value by passing dataset.cfg.block_size=x."
             )
         block_size = 1024
     else:
         if self.cfg.block_size > self.tokenizer.model_max_length:
             log.warn(
                 f"The block_size passed ({self.cfg.block_size}) is larger than the maximum length for the model"
                 f"({self.tokenizer.model_max_length}). Using block_size={self.tokenizer.model_max_length}."
             )
         block_size = min(self.cfg.block_size,
                          self.tokenizer.model_max_length)
     return block_size