Пример #1
0
 def __init__(self, filepath: str, src, dst=None, **kwargs) -> None:
     if not dst:
         dst = src + '_fasttext'
     self.filepath = filepath
     flash(f'Loading fasttext model {filepath} [blink][yellow]...[/yellow][/blink]')
     filepath = get_resource(filepath)
     with stdout_redirected(to=os.devnull, stdout=sys.stderr):
         self._model = fasttext.load_model(filepath)
     flash('')
     output_dim = self._model['king'].size
     super().__init__(output_dim, src, dst)
Пример #2
0
def smatch_eval(pred, gold, use_fast=False) -> Union[SmatchScores, F1_]:
    script = get_resource(_FAST_SMATCH_SCRIPT if use_fast else _SMATCH_SCRIPT)
    home = os.path.dirname(script)
    pred = os.path.realpath(pred)
    gold = os.path.realpath(gold)
    with pushd(home):
        flash('Running evaluation script [blink][yellow]...[/yellow][/blink]')
        cmd = f'bash {script} {pred} {gold}'
        text = run_cmd(cmd)
        flash('')
    return format_fast_scores(text) if use_fast else format_official_scores(
        text)
Пример #3
0
def load_word2vec(path, delimiter=' ', cache=True) -> Tuple[Dict[str, np.ndarray], int]:
    realpath = get_resource(path)
    binpath = replace_ext(realpath, '.pkl')
    if cache:
        try:
            flash('Loading word2vec from cache [blink][yellow]...[/yellow][/blink]')
            word2vec, dim = load_pickle(binpath)
            flash('')
            return word2vec, dim
        except IOError:
            pass

    dim = None
    word2vec = dict()
    f = TimingFileIterator(realpath)
    for idx, line in enumerate(f):
        f.log('Loading word2vec from text file [blink][yellow]...[/yellow][/blink]')
        line = line.rstrip().split(delimiter)
        if len(line) > 2:
            if dim is None:
                dim = len(line)
            else:
                if len(line) != dim:
                    logger.warning('{}#{} length mismatches with {}'.format(path, idx + 1, dim))
                    continue
            word, vec = line[0], line[1:]
            word2vec[word] = np.array(vec, dtype=np.float32)
    dim -= 1
    if cache:
        flash('Caching word2vec [blink][yellow]...[/yellow][/blink]')
        save_pickle((word2vec, dim), binpath)
        flash('')
    return word2vec, dim
Пример #4
0
def load_word2vec_as_vocab_tensor(
        path,
        delimiter=' ',
        cache=True) -> Tuple[Dict[str, int], torch.Tensor]:
    realpath = get_resource(path)
    vocab_path = replace_ext(realpath, '.vocab')
    matrix_path = replace_ext(realpath, '.pt')
    if cache:
        try:
            flash(
                'Loading vocab and matrix from cache [blink][yellow]...[/yellow][/blink]'
            )
            vocab = load_pickle(vocab_path)
            matrix = torch.load(matrix_path, map_location='cpu')
            flash('')
            return vocab, matrix
        except IOError:
            pass

    word2vec, dim = load_word2vec(path, delimiter, cache)
    vocab = dict((k, i) for i, k in enumerate(word2vec.keys()))
    matrix = torch.Tensor(list(word2vec.values()))
    if cache:
        flash('Caching vocab and matrix [blink][yellow]...[/yellow][/blink]')
        save_pickle(vocab, vocab_path)
        torch.save(matrix, matrix_path)
        flash('')
    return vocab, matrix
Пример #5
0
 def load(self, save_dir: str, devices=None, **kwargs):
     save_dir = get_resource(save_dir)
     # flash('Loading config and vocabs [blink][yellow]...[/yellow][/blink]')
     if devices is None and self.model:
         devices = self.devices
     self.load_config(save_dir, **kwargs)
     self.load_vocabs(save_dir)
     flash('Building model [blink][yellow]...[/yellow][/blink]')
     self.model = self.build_model(**merge_dict(self.config,
                                                training=False,
                                                **kwargs,
                                                overwrite=True,
                                                inplace=True))
     flash('')
     self.load_weights(save_dir, **kwargs)
     self.to(devices)
     self.model.eval()
Пример #6
0
 def to(self,
        devices=Union[int, float, List[int],
                      Dict[str, Union[int, torch.device]]],
        logger: logging.Logger = None):
     if devices == -1 or devices == [-1]:
         devices = []
     elif isinstance(devices, (int, float)) or devices is None:
         devices = cuda_devices(devices)
     if devices:
         if logger:
             logger.info(
                 f'Using GPUs: [on_blue][cyan][bold]{devices}[/bold][/cyan][/on_blue]'
             )
         if isinstance(devices, list):
             flash(
                 f'Moving model to GPUs {devices} [blink][yellow]...[/yellow][/blink]'
             )
             self.model = self.model.to(devices[0])
             if len(devices) > 1 and not isdebugging() and not isinstance(
                     self.model, nn.DataParallel):
                 self.model = self.parallelize(devices)
         elif isinstance(devices, dict):
             for name, module in self.model.named_modules():
                 for regex, device in devices.items():
                     try:
                         on_device: torch.device = next(
                             module.parameters()).device
                     except StopIteration:
                         continue
                     if on_device == device:
                         continue
                     if isinstance(device, int):
                         if on_device.index == device:
                             continue
                     if re.match(regex, name):
                         if not name:
                             name = '*'
                         flash(
                             f'Moving module [yellow]{name}[/yellow] to [on_yellow][magenta][bold]{device}'
                             f'[/bold][/magenta][/on_yellow]: [red]{regex}[/red]\n'
                         )
                         module.to(device)
         else:
             raise ValueError(f'Unrecognized devices {devices}')
         flash('')
     else:
         if logger:
             logger.info('Using CPU')
Пример #7
0
def make_gold_conll(ontonotes_path, language):
    ontonotes_path = os.path.abspath(get_resource(ontonotes_path))
    to_conll = get_resource(
        'https://gist.githubusercontent.com/hankcs/46b9137016c769e4b6137104daf43a92/raw/66369de6c24b5ec47696ae307591f0d72c6f3f02/ontonotes_to_conll.sh'
    )
    to_conll = os.path.abspath(to_conll)
    # shutil.rmtree(os.path.join(ontonotes_path, 'conll-2012'), ignore_errors=True)
    with pushd(ontonotes_path):
        try:
            flash(
                f'Converting [blue]{language}[/blue] to CoNLL format, '
                f'this might take half an hour [blink][yellow]...[/yellow][/blink]'
            )
            run_cmd(f'bash {to_conll} {ontonotes_path} {language}')
            flash('')
        except RuntimeError as e:
            flash(
                f'[red]Failed[/red] to convert {language} of {ontonotes_path} to CoNLL. See exceptions for detail'
            )
            raise e
Пример #8
0
 def fit(self,
         trn_data,
         dev_data,
         save_dir,
         batch_size,
         epochs,
         devices=None,
         logger=None,
         seed=None,
         finetune=False,
         eval_trn=True,
         _device_placeholder=False,
         **kwargs):
     # Common initialization steps
     config = self._capture_config(locals())
     if not logger:
         logger = self.build_logger('train', save_dir)
     if not seed:
         self.config.seed = 233 if isdebugging() else int(time.time())
     set_seed(self.config.seed)
     logger.info(self._savable_config.to_json(sort=True))
     if isinstance(devices, list) or devices is None or isinstance(
             devices, float):
         flash('[yellow]Querying CUDA devices [blink]...[/blink][/yellow]')
         devices = -1 if isdebugging() else cuda_devices(devices)
         flash('')
     # flash(f'Available GPUs: {devices}')
     if isinstance(devices, list):
         first_device = (devices[0] if devices else -1)
     elif isinstance(devices, dict):
         first_device = next(iter(devices.values()))
     elif isinstance(devices, int):
         first_device = devices
     else:
         first_device = -1
     if _device_placeholder and first_device >= 0:
         _dummy_placeholder = self._create_dummy_placeholder_on(
             first_device)
     if finetune:
         if isinstance(finetune, str):
             self.load(finetune, devices=devices)
         else:
             self.load(save_dir, devices=devices)
         logger.info(
             f'Finetune model loaded with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'
             f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.'
         )
     self.on_config_ready(**self.config)
     trn = self.build_dataloader(**merge_dict(config,
                                              data=trn_data,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              training=True,
                                              device=first_device,
                                              logger=logger,
                                              vocabs=self.vocabs,
                                              overwrite=True))
     dev = self.build_dataloader(
         **merge_dict(config,
                      data=dev_data,
                      batch_size=batch_size,
                      shuffle=False,
                      training=None,
                      device=first_device,
                      logger=logger,
                      vocabs=self.vocabs,
                      overwrite=True)) if dev_data else None
     if not finetune:
         flash('[yellow]Building model [blink]...[/blink][/yellow]')
         self.model = self.build_model(**merge_dict(config, training=True))
         flash('')
         logger.info(
             f'Model built with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'
             f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.'
         )
         assert self.model, 'build_model is not properly implemented.'
     _description = repr(self.model)
     if len(_description.split('\n')) < 10:
         logger.info(_description)
     self.save_config(save_dir)
     self.save_vocabs(save_dir)
     self.to(devices, logger)
     if _device_placeholder and first_device >= 0:
         del _dummy_placeholder
     criterion = self.build_criterion(**merge_dict(config, trn=trn))
     optimizer = self.build_optimizer(
         **merge_dict(config, trn=trn, criterion=criterion))
     metric = self.build_metric(**self.config)
     if hasattr(trn.dataset, '__len__') and dev and hasattr(
             dev.dataset, '__len__'):
         logger.info(
             f'{len(trn.dataset)}/{len(dev.dataset)} samples in trn/dev set.'
         )
         trn_size = len(trn) // self.config.get('gradient_accumulation', 1)
         ratio_width = len(f'{trn_size}/{trn_size}')
     else:
         ratio_width = None
     return self.execute_training_loop(**merge_dict(config,
                                                    trn=trn,
                                                    dev=dev,
                                                    epochs=epochs,
                                                    criterion=criterion,
                                                    optimizer=optimizer,
                                                    metric=metric,
                                                    logger=logger,
                                                    save_dir=save_dir,
                                                    devices=devices,
                                                    ratio_width=ratio_width,
                                                    trn_data=trn_data,
                                                    dev_data=dev_data,
                                                    eval_trn=eval_trn,
                                                    overwrite=True))