def build_callbacks(self, save_dir, logger, **kwargs): metrics = kwargs.get('metrics', 'accuracy') if isinstance(metrics, (list, tuple)): metrics = metrics[-1] monitor = f'val_{metrics}' checkpoint = tf.keras.callbacks.ModelCheckpoint( os.path.join(save_dir, 'model.h5'), # verbose=1, monitor=monitor, save_best_only=True, mode='max', save_weights_only=True) logger.debug(f'Monitor {checkpoint.monitor} for checkpoint') tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=io_util.makedirs(io_util.path_join(save_dir, 'logs'))) csv_logger = FineCSVLogger(os.path.join(save_dir, 'train.log'), separator=' | ', append=True) callbacks = [checkpoint, tensorboard_callback, csv_logger] lr_decay_per_epoch = self.config.get('lr_decay_per_epoch', None) if lr_decay_per_epoch: learning_rate = self.model.optimizer.get_config().get( 'learning_rate', None) if not learning_rate: logger.warning( 'Learning rate decay not supported for optimizer={}'. format(repr(self.model.optimizer))) else: logger.debug( f'Created LearningRateScheduler with lr_decay_per_epoch={lr_decay_per_epoch}' ) callbacks.append( tf.keras.callbacks.LearningRateScheduler( lambda epoch: learning_rate / (1 + lr_decay_per_epoch * epoch))) anneal_factor = self.config.get('anneal_factor', None) if anneal_factor: callbacks.append( tf.keras.callbacks.ReduceLROnPlateau(factor=anneal_factor, patience=self.config.get( 'anneal_patience', 10))) early_stopping_patience = self.config.get('early_stopping_patience', None) if early_stopping_patience: callbacks.append( tf.keras.callbacks.EarlyStopping( monitor=monitor, mode='max', verbose=1, patience=early_stopping_patience)) return callbacks
def __init__(self, filepath: str, padding=PAD, name=None, **kwargs): import fasttext self.padding = padding.encode('utf-8') self.filepath = filepath filepath = get_resource(filepath) assert os.path.isfile(filepath), f'Resolved path {filepath} is not a file' logger.debug('Loading fasttext model from [{}].'.format(filepath)) # fasttext print a blank line here with stdout_redirected(to=os.devnull, stdout=sys.stderr): self.model = fasttext.load_model(filepath) kwargs.pop('input_dim', None) kwargs.pop('output_dim', None) kwargs.pop('mask_zero', None) if not name: name = os.path.splitext(os.path.basename(filepath))[0] super().__init__(input_dim=len(self.model.words), output_dim=self.model['king'].size, mask_zero=padding is not None, trainable=False, dtype=tf.string, name=name, **kwargs) embed_fn = np.frompyfunc(self.embed, 1, 1) # vf = np.vectorize(self.embed, otypes=[np.ndarray]) self._embed_np = embed_fn
def samples_to_dataset(self, samples: Generator, map_x=None, map_y=None, batch_size=32, shuffle=None, repeat=None, drop_remainder=False, prefetch=1, cache=True) -> tf.data.Dataset: output_types, output_shapes, padding_values = self.output_types, self.output_shapes, self.padding_values if not all(v for v in [output_shapes, output_shapes, padding_values]): # print('Did you forget to call build_config() on your transform?') self.build_config() output_types, output_shapes, padding_values = self.output_types, self.output_shapes, self.padding_values assert all(v for v in [output_shapes, output_shapes, padding_values]), 'Your create_types_shapes_values returns None, which is not allowed' # if not callable(samples): # samples = Transform.generator_to_callable(samples) dataset = tf.data.Dataset.from_generator(samples, output_types=output_types, output_shapes=output_shapes) if cache: logger.debug('Dataset cache enabled') dataset = dataset.cache(cache if isinstance(cache, str) else '') if shuffle: if isinstance(shuffle, bool): shuffle = 1024 dataset = dataset.shuffle(shuffle) if repeat: dataset = dataset.repeat(repeat) if batch_size: dataset = dataset.padded_batch(batch_size, output_shapes, padding_values, drop_remainder) if prefetch: dataset = dataset.prefetch(prefetch) if map_x is None: map_x = self.map_x if map_y is None: map_y = self.map_y if map_x or map_y: def mapper(X, Y): if map_x: X = self.x_to_idx(X) if map_y: Y = self.y_to_idx(Y) return X, Y dataset = dataset.map(mapper, num_parallel_calls=tf.data.experimental.AUTOTUNE) return dataset
def get_resource(path: str, save_dir=hanlp_home(), extract=True, prefix=HANLP_URL, append_location=True, verbose=HANLP_VERBOSE): """Fetch real (local) path for a resource (model, corpus, whatever) to ``save_dir``. Args: path: A local path (which will returned as is) or a remote URL (which will be downloaded, decompressed then returned). save_dir: Where to store the resource (Default value = :meth:`hanlp.utils.io_util.hanlp_home`) extract: Whether to unzip it if it's a zip file (Default value = True) prefix: A prefix when matched with an URL (path), then that URL is considered to be official. For official resources, they will not go to a folder called ``thirdparty`` under :const:`~hanlp_common.constants.IDX`. append_location: (Default value = True) verbose: Whether to print log messages. Returns: The real path to the resource. """ path = hanlp.pretrained.ALL.get(path, path) anchor: str = None compressed = None if os.path.isdir(path): return path elif os.path.isfile(path): pass elif path.startswith('http:') or path.startswith('https:'): url = path if '#' in url: url, anchor = url.split('#', maxsplit=1) realpath = path_from_url(path, save_dir, prefix, append_location) realpath, compressed = split_if_compressed(realpath) # check if resource is there if anchor: if anchor.startswith('/'): # indicates the folder name has to be polished anchor = anchor.lstrip('/') parts = anchor.split('/') renamed_realpath = str( Path(realpath).parent.joinpath(parts[0])) if os.path.isfile(realpath + compressed): os.rename(realpath + compressed, renamed_realpath + compressed) realpath = renamed_realpath anchor = '/'.join(parts[1:]) child = path_join(realpath, anchor) if os.path.exists(child): return child elif os.path.isdir(realpath) or (os.path.isfile(realpath) and (compressed and extract)): return realpath else: if compressed: pattern = realpath + '.*' files = glob.glob(pattern) files = list( filter(lambda x: not x.endswith('.downloading'), files)) zip_path = realpath + compressed if zip_path in files: files.remove(zip_path) if files: if len(files) > 1: logger.debug( f'Found multiple files with {pattern}, will use the first one.' ) return files[0] # realpath is where its path after exaction if compressed: realpath += compressed if not os.path.isfile(realpath): path = download(url=path, save_path=realpath, verbose=verbose) else: path = realpath if extract and compressed: path = uncompress(path, verbose=verbose) if anchor: path = path_join(path, anchor) return path
def get_resource(path: str, save_dir=None, extract=True, prefix=HANLP_URL, append_location=True): """ Fetch real path for a resource (model, corpus, whatever) :param path: the general path (can be a url or a real path) :param extract: whether to unzip it if it's a zip file :param save_dir: :return: the real path to the resource """ anchor: str = None compressed = None if os.path.isdir(path): return path elif os.path.isfile(path): pass elif path.startswith('http:') or path.startswith('https:'): url = path if '#' in url: url, anchor = url.split('#', maxsplit=1) realpath = path_from_url(path, save_dir, prefix, append_location) realpath, compressed = split_if_compressed(realpath) # check if resource is there if anchor: if anchor.startswith('/'): # indicates the folder name has to be polished anchor = anchor.lstrip('/') parts = anchor.split('/') realpath = str(Path(realpath).parent.joinpath(parts[0])) anchor = '/'.join(parts[1:]) child = path_join(realpath, anchor) if os.path.exists(child): return child elif os.path.isdir(realpath) or (os.path.isfile(realpath) and (compressed and extract)): return realpath else: pattern = realpath + '*' files = glob.glob(pattern) zip_path = realpath + compressed if extract and zip_path in files: files.remove(zip_path) if files: if len(files) > 1: logger.debug( f'Found multiple files with {pattern}, will use the first one.' ) return files[0] # realpath is where its path after exaction if compressed: realpath += compressed if not os.path.isfile(realpath): path = download(url=path, save_path=realpath) else: path = realpath if extract and compressed: path = uncompress(path) if anchor: path = path_join(path, anchor) return path