示例#1
0
 def build_callbacks(self, save_dir, logger, **kwargs):
     metrics = kwargs.get('metrics', 'accuracy')
     if isinstance(metrics, (list, tuple)):
         metrics = metrics[-1]
     monitor = f'val_{metrics}'
     checkpoint = tf.keras.callbacks.ModelCheckpoint(
         os.path.join(save_dir, 'model.h5'),
         # verbose=1,
         monitor=monitor,
         save_best_only=True,
         mode='max',
         save_weights_only=True)
     logger.debug(f'Monitor {checkpoint.monitor} for checkpoint')
     tensorboard_callback = tf.keras.callbacks.TensorBoard(
         log_dir=io_util.makedirs(io_util.path_join(save_dir, 'logs')))
     csv_logger = FineCSVLogger(os.path.join(save_dir, 'train.log'),
                                separator=' | ',
                                append=True)
     callbacks = [checkpoint, tensorboard_callback, csv_logger]
     lr_decay_per_epoch = self.config.get('lr_decay_per_epoch', None)
     if lr_decay_per_epoch:
         learning_rate = self.model.optimizer.get_config().get(
             'learning_rate', None)
         if not learning_rate:
             logger.warning(
                 'Learning rate decay not supported for optimizer={}'.
                 format(repr(self.model.optimizer)))
         else:
             logger.debug(
                 f'Created LearningRateScheduler with lr_decay_per_epoch={lr_decay_per_epoch}'
             )
             callbacks.append(
                 tf.keras.callbacks.LearningRateScheduler(
                     lambda epoch: learning_rate /
                     (1 + lr_decay_per_epoch * epoch)))
     anneal_factor = self.config.get('anneal_factor', None)
     if anneal_factor:
         callbacks.append(
             tf.keras.callbacks.ReduceLROnPlateau(factor=anneal_factor,
                                                  patience=self.config.get(
                                                      'anneal_patience',
                                                      10)))
     early_stopping_patience = self.config.get('early_stopping_patience',
                                               None)
     if early_stopping_patience:
         callbacks.append(
             tf.keras.callbacks.EarlyStopping(
                 monitor=monitor,
                 mode='max',
                 verbose=1,
                 patience=early_stopping_patience))
     return callbacks
示例#2
0
 def __init__(self, filepath: str, padding=PAD, name=None, **kwargs):
     import fasttext
     self.padding = padding.encode('utf-8')
     self.filepath = filepath
     filepath = get_resource(filepath)
     assert os.path.isfile(filepath), f'Resolved path {filepath} is not a file'
     logger.debug('Loading fasttext model from [{}].'.format(filepath))
     # fasttext print a blank line here
     with stdout_redirected(to=os.devnull, stdout=sys.stderr):
         self.model = fasttext.load_model(filepath)
     kwargs.pop('input_dim', None)
     kwargs.pop('output_dim', None)
     kwargs.pop('mask_zero', None)
     if not name:
         name = os.path.splitext(os.path.basename(filepath))[0]
     super().__init__(input_dim=len(self.model.words), output_dim=self.model['king'].size,
                      mask_zero=padding is not None, trainable=False, dtype=tf.string, name=name, **kwargs)
     embed_fn = np.frompyfunc(self.embed, 1, 1)
     # vf = np.vectorize(self.embed, otypes=[np.ndarray])
     self._embed_np = embed_fn
示例#3
0
    def samples_to_dataset(self, samples: Generator, map_x=None, map_y=None, batch_size=32, shuffle=None, repeat=None,
                           drop_remainder=False,
                           prefetch=1, cache=True) -> tf.data.Dataset:
        output_types, output_shapes, padding_values = self.output_types, self.output_shapes, self.padding_values
        if not all(v for v in [output_shapes, output_shapes,
                               padding_values]):
            # print('Did you forget to call build_config() on your transform?')
            self.build_config()
            output_types, output_shapes, padding_values = self.output_types, self.output_shapes, self.padding_values
        assert all(v for v in [output_shapes, output_shapes,
                               padding_values]), 'Your create_types_shapes_values returns None, which is not allowed'
        # if not callable(samples):
        #     samples = Transform.generator_to_callable(samples)
        dataset = tf.data.Dataset.from_generator(samples, output_types=output_types, output_shapes=output_shapes)
        if cache:
            logger.debug('Dataset cache enabled')
            dataset = dataset.cache(cache if isinstance(cache, str) else '')
        if shuffle:
            if isinstance(shuffle, bool):
                shuffle = 1024
            dataset = dataset.shuffle(shuffle)
        if repeat:
            dataset = dataset.repeat(repeat)
        if batch_size:
            dataset = dataset.padded_batch(batch_size, output_shapes, padding_values, drop_remainder)
        if prefetch:
            dataset = dataset.prefetch(prefetch)
        if map_x is None:
            map_x = self.map_x
        if map_y is None:
            map_y = self.map_y
        if map_x or map_y:
            def mapper(X, Y):
                if map_x:
                    X = self.x_to_idx(X)
                if map_y:
                    Y = self.y_to_idx(Y)
                return X, Y

            dataset = dataset.map(mapper, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        return dataset
示例#4
0
def get_resource(path: str,
                 save_dir=hanlp_home(),
                 extract=True,
                 prefix=HANLP_URL,
                 append_location=True,
                 verbose=HANLP_VERBOSE):
    """Fetch real (local) path for a resource (model, corpus, whatever) to ``save_dir``.

    Args:
      path: A local path (which will returned as is) or a remote URL (which will be downloaded, decompressed then
        returned).
      save_dir: Where to store the resource (Default value = :meth:`hanlp.utils.io_util.hanlp_home`)
      extract: Whether to unzip it if it's a zip file (Default value = True)
      prefix: A prefix when matched with an URL (path), then that URL is considered to be official. For official
        resources, they will not go to a folder called ``thirdparty`` under :const:`~hanlp_common.constants.IDX`.
      append_location:  (Default value = True)
      verbose: Whether to print log messages.

    Returns:
      The real path to the resource.

    """
    path = hanlp.pretrained.ALL.get(path, path)
    anchor: str = None
    compressed = None
    if os.path.isdir(path):
        return path
    elif os.path.isfile(path):
        pass
    elif path.startswith('http:') or path.startswith('https:'):
        url = path
        if '#' in url:
            url, anchor = url.split('#', maxsplit=1)
        realpath = path_from_url(path, save_dir, prefix, append_location)
        realpath, compressed = split_if_compressed(realpath)
        # check if resource is there
        if anchor:
            if anchor.startswith('/'):
                # indicates the folder name has to be polished
                anchor = anchor.lstrip('/')
                parts = anchor.split('/')
                renamed_realpath = str(
                    Path(realpath).parent.joinpath(parts[0]))
                if os.path.isfile(realpath + compressed):
                    os.rename(realpath + compressed,
                              renamed_realpath + compressed)
                realpath = renamed_realpath
                anchor = '/'.join(parts[1:])
            child = path_join(realpath, anchor)
            if os.path.exists(child):
                return child
        elif os.path.isdir(realpath) or (os.path.isfile(realpath) and
                                         (compressed and extract)):
            return realpath
        else:
            if compressed:
                pattern = realpath + '.*'
                files = glob.glob(pattern)
                files = list(
                    filter(lambda x: not x.endswith('.downloading'), files))
                zip_path = realpath + compressed
                if zip_path in files:
                    files.remove(zip_path)
                if files:
                    if len(files) > 1:
                        logger.debug(
                            f'Found multiple files with {pattern}, will use the first one.'
                        )
                    return files[0]
        # realpath is where its path after exaction
        if compressed:
            realpath += compressed
        if not os.path.isfile(realpath):
            path = download(url=path, save_path=realpath, verbose=verbose)
        else:
            path = realpath
    if extract and compressed:
        path = uncompress(path, verbose=verbose)
        if anchor:
            path = path_join(path, anchor)

    return path
示例#5
0
def get_resource(path: str,
                 save_dir=None,
                 extract=True,
                 prefix=HANLP_URL,
                 append_location=True):
    """
    Fetch real path for a resource (model, corpus, whatever)
    :param path: the general path (can be a url or a real path)
    :param extract: whether to unzip it if it's a zip file
    :param save_dir:
    :return: the real path to the resource
    """
    anchor: str = None
    compressed = None
    if os.path.isdir(path):
        return path
    elif os.path.isfile(path):
        pass
    elif path.startswith('http:') or path.startswith('https:'):
        url = path
        if '#' in url:
            url, anchor = url.split('#', maxsplit=1)
        realpath = path_from_url(path, save_dir, prefix, append_location)
        realpath, compressed = split_if_compressed(realpath)
        # check if resource is there
        if anchor:
            if anchor.startswith('/'):
                # indicates the folder name has to be polished
                anchor = anchor.lstrip('/')
                parts = anchor.split('/')
                realpath = str(Path(realpath).parent.joinpath(parts[0]))
                anchor = '/'.join(parts[1:])
            child = path_join(realpath, anchor)
            if os.path.exists(child):
                return child
        elif os.path.isdir(realpath) or (os.path.isfile(realpath) and
                                         (compressed and extract)):
            return realpath
        else:
            pattern = realpath + '*'
            files = glob.glob(pattern)
            zip_path = realpath + compressed
            if extract and zip_path in files:
                files.remove(zip_path)
            if files:
                if len(files) > 1:
                    logger.debug(
                        f'Found multiple files with {pattern}, will use the first one.'
                    )
                return files[0]
        # realpath is where its path after exaction
        if compressed:
            realpath += compressed
        if not os.path.isfile(realpath):
            path = download(url=path, save_path=realpath)
        else:
            path = realpath
    if extract and compressed:
        path = uncompress(path)
        if anchor:
            path = path_join(path, anchor)

    return path