Beispiel #1
0
    def get_lr(self) -> List[float]:
        if self.max_iters != -1:
            if comm.is_main_process():
                logger = logging.getLogger(__name__)

            schedule_fct = getattr(self, self.warmup_method)
            progress = self.last_epoch / self.max_iters
            lr_cur = [base_lr * schedule_fct(progress, self.warmup) for base_lr in self.base_lrs]
            # warning for exceeding t_total (only active with warmup_linear
            if self.warmup_method == 'warmup_linear' and progress > 1. and not self.warned_for_t_total:
                if comm.is_main_process():
                    logger.info(
                        "Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
                        "Please set 't_total' of {} correctly.".format(self.warmup_method, lr_cur,
                                                                       self.__class__.__name__))
                self.warned_for_t_total = True
            # end warning
        else:
            lr_cur = [base_lr for base_lr in self.base_lrs]

        # Different definitions of half-cosine with warmup are possible. For
        # simplicity we multiply the standard half-cosine schedule by the warmup
        # factor. An alternative is to start the period of the cosine at warmup_iters
        # instead of at 0. In the case that warmup_iters << max_iters the two are
        # very close to each other.
        return lr_cur
Beispiel #2
0
def default_setup(args, cfg):
    output_dir = cfg.work_dir
    set_imix_work_dir(output_dir)
    if output_dir and dist_info.is_main_process():
        PathManager.mkdirs(output_dir)

    rank = dist_info.get_rank()
    logger = setup_logger(output_dir, distributed_rank=rank, name='imix')
    logger.info('Current environment information : \n{}'.format(
        collect_env_info()))
    logger.info('Command line args: \n {}'.format(args))
    if hasattr(args, 'config_file') and args.config_file != '':
        logger.info('{} file content:\n{}'.format(
            args.config_file,
            PathManager.open(args.config_file, 'r').read()))

    logger.info('full config file content: ')
    pprint.pprint({k: v for k, v in cfg.items()})

    if dist_info.is_main_process() and output_dir:
        cfg_path = os.path.join(output_dir, 'config.json')
        with open(cfg_path, 'w') as f:
            f.write(
                json.dumps({k: v
                            for k, v in cfg.items()},
                           indent=4,
                           separators=(',', ':')))
            logger.info('full config file saved to {}'.format(cfg_path))

    seed = getattr(cfg, 'seed', None)
    seed_all_rng(seed=None if seed is None else seed + rank)

    if not (hasattr(cfg, 'eval_only') and getattr(cfg, 'eval_only', False)):
        torch.backends.cudnn.benchmark = getattr(cfg, 'CUDNN_BENCHMARK', False)
Beispiel #3
0
    def __init__(self, vqa_reader, vqa_info_cpler, limit_nums=None):
        if comm.is_main_process():
            logger = logging.getLogger(__name__)
            logger.info('start loading vqadata')

        self.reader = VQAReader(vqa_reader, vqa_reader.datasets)
        self.infocpler = VQAInfoCpler(vqa_info_cpler)
        self._limit_sample_nums = limit_nums
        self._split = vqa_reader.datasets
        if comm.is_main_process():
            logger.info('load vqadata {} successfully'.format(vqa_reader.datasets))
Beispiel #4
0
    def __init__(self,
                 reader_cls,
                 reader,
                 info_cpler_cls,
                 info_cpler,
                 limit_nums=None):
        if comm.is_main_process():
            cls_name = self.__class__.__name__
            logger = logging.getLogger(__name__)
            logger.info('start loading' + cls_name)

        self.reader = reader_cls(reader)
        self.infocpler = info_cpler_cls(info_cpler)
        self._limit_sample_nums = limit_nums
        self.splits = reader.datasets
        if comm.is_main_process():
            logger.info('load {} {} successfully'.format(
                cls_name, reader.datasets))
Beispiel #5
0
 def __int__(self,
             model: Module,
             save_dir: str = '',
             *,
             is_save_disk=None,
             is_record_ck: bool = True,
             **other_train_info: object):
     is_master_process = comm.is_main_process()
     if is_save_disk is None:
         is_save_disk = is_master_process
     super().__init__(model=model,
                      save_dir=save_dir,
                      is_save_disk=is_save_disk,
                      is_record_ck=is_record_ck,
                      **other_train_info)
Beispiel #6
0
    def save_checkpoint_ema(self, savePath, epochId):
        # If EMA is used, save averaged model
        if self.use_ema and comm.is_main_process():
            output_ema_state_dict = {}
            for param_name in self.model.state_dict():
                assert param_name in self.ema_state_dict
                if hasattr(self.model, 'module'):
                    output_ema_state_dict[
                        param_name[7:]] = self.ema_state_dict[
                            param_name]  # skip prefix "module."
                else:
                    output_ema_state_dict[param_name] = self.ema_state_dict[
                        param_name]
            output_ema_model_file = os.path.join(
                savePath, 'pytorch_model_' + str(epochId) + '_ema.bin')
            torch.save(output_ema_state_dict, output_ema_model_file)

            logger.info(
                'Saving ema checkpoint to {}'.format(output_ema_model_file))
    def _download_model(self):
        from imix.utils.distributed_info import is_main_process
        from imix.utils.config import get_imix_cache_dir
        _is_master = is_main_process()
        model_file_path = os.path.join(get_imix_cache_dir(), 'wiki.en.bin')

        if not _is_master:
            return model_file_path

        if PathManager.exists(model_file_path):
            self.writer.write(f'Vectors already present at {model_file_path}.',
                              'info')
            return model_file_path

        import requests
        from tqdm import tqdm

        FASTTEXT_WIKI_URL = 'https://dl.fbaipublicfiles.com/pythia/pretrained_models/fasttext/wiki.en.bin'
        PathManager.mkdirs(os.path.dirname(model_file_path))
        response = requests.get(FASTTEXT_WIKI_URL, stream=True)

        with PathManager.open(model_file_path, 'wb') as f:
            pbar = tqdm(
                total=int(response.headers['Content-Length']) / 4096,
                miniters=50,
                disable=not _is_master,
            )

            idx = 0
            for data in response.iter_content(chunk_size=4096):
                if data:
                    if idx % 50 == 0:
                        pbar.update(len(data))
                    f.write(data)
                    idx += 1

            pbar.close()

        self.writer.write(f'fastText bin downloaded at {model_file_path}.',
                          'info')

        return model_file_path
Beispiel #8
0
    def __init__(self, vocab_file, embedding_name, *args, **kwargs):
        """Use this vocab class when you have a custom vocabulary class but you
        want to use pretrained embedding vectos for it. This will only load the
        vectors which intersect with your vocabulary. Use the embedding_name
        specified in torchtext's pretrained aliases:

        ['charngram.100d', 'fasttext.en.300d', 'fasttext.simple.300d',
         'glove.42B.300d', 'glove.840B.300d', 'glove.twitter.27B.25d',
         'glove.twitter.27B.50d', 'glove.twitter.27B.100d',
         'glove.twitter.27B.200d', 'glove.6B.50d', 'glove.6B.100d',
         'glove.6B.200d', 'glove.6B.300d']

        Parameters
        ----------
        vocab_file : str
            Vocabulary file containing list of words with one word per line
            which will be used to collect vectors
        embedding_name : str
            Embedding name picked up from the list of the pretrained aliases
            mentioned above
        """
        super().__init__(vocab_file, *args, **kwargs)

        self.type = 'intersected'

        name = embedding_name.split('.')[0]
        dim = embedding_name.split('.')[2][:-1]
        middle = embedding_name.split('.')[1]

        class_name = 'GloVe'
        params = [middle]

        if name == 'glove':
            params.append(int(dim))

        vector_cache = get_imix_cache_dir()

        # First test loading the vectors in master so that everybody doesn't
        # download it in case it doesn't exist
        if is_main_process():
            vocab.pretrained_aliases[embedding_name](cache=vector_cache)
        synchronize()

        embedding = getattr(vocab, class_name)(*params, cache=vector_cache)

        self.vectors = torch.empty(
            (self.get_size(), len(embedding.vectors[0])), dtype=torch.float)

        self.embedding_dim = len(embedding.vectors[0])

        for i in range(0, 4):
            self.vectors[i] = torch.ones_like(self.vectors[i]) * 0.1 * i

        for i in range(4, self.get_size()):
            word = self.itos[i]
            embedding_index = embedding.stoi.get(word, None)

            if embedding_index is None:
                self.vectors[i] = self.vectors[self.UNK_INDEX]
            else:
                self.vectors[i] = embedding.vectors[embedding_index]