def main(argv, base_tag=None, checkpoint=None, override_namespace=None): args = parse_args(argv) if override_namespace: args.__dict__.update(override_namespace.__dict__) set_package_verbosity(args.debug) with logging_block("Set global random seed"): set_global_random_seed(args.random_seed) with logging_block("Preprocess data"): data_collection, meta_data = data_factory.preprocess(args, return_meta=True) data_collection.summary() meta_data.summary() data_generator = DataGenerator(data_collection.train, batch_size=args.batch_size) with logging_block("Prepare Generator"): generator = generator_factory.create(args, meta_data) with logging_block("Prepare Generator Trainer"): trainer = trainer_factory.create(args, meta_data, generator) trainer.summary() with logging_block("Prepare Callback"): if base_tag is None: base_tag = f"{args.dataset}@{time.strftime('%Y%m%d-%H%M%S')}" data_generator.callback = CallbackFactory( trainer=trainer, generator=generator, data_collection=data_collection, meta_data=meta_data, tags=args.tags + [base_tag], ).create_by_args(args) with tf.Session(config=get_tf_config_proto(args.jit)) as sess: if checkpoint: print(f"Restore from checkpoint: {checkpoint}") tf.train.Saver().restore(sess, save_path=checkpoint) data_generator.skip_epochs(int(checkpoint.split('-')[-1])) else: tf.global_variables_initializer().run() for batch_data in data_generator.iter_batch_until( n_epochs=args.epochs, logs={'arg_string': " ".join(argv)}, # for callback.on_train_begin() ): trainer.fit_batch(batch_data)
def create_generator_updater(self, real_samples, objective): with logging_block("Generator Optimization:"): optimizer = generator_factory.create_optimizer(self.args) objectives = [ objective, *generator_factory.create_regularizers(self.args) ] with logging_block("Objective:"): for obj in objectives: print(obj) updater = GeneratorUpdater(self.generator, optimizer=optimizer) for obj in objectives: updater.add_loss(obj) updater.build_graph(real_samples) return updater
def preprocess(self, corpus_config: CorpusConfig, return_meta: bool = False): with logging_block("Prepare text tokenizer..."): tokenizer = self._create_tokenizer(corpus_config) with logging_block("Preprocess text corpus..."): data_collection = self._process_data(tokenizer, corpus_config) if return_meta: meta_data = MetaData( tokenizer=tokenizer, corpus_config=corpus_config, cache_dir=self.get_cache_dir(corpus_config), ) return data_collection, meta_data else: return data_collection
def create_objective_updater(self, real_samples, discriminator, loss): with logging_block("Discriminator Optimization:"): optimizer = discriminator_factory.create_optimizer(self.args) objectives = [ loss, *discriminator_factory.create_regularizers(self.args) ] with logging_block("Objective:"): for obj in objectives: print(obj) updater = DiscriminatorUpdater(discriminator, optimizer=optimizer) for obj in objectives: updater.add_loss(obj) updater.build_graph( real_samples=real_samples, fake_samples=self.generator.generate(real_samples.batch_size, real_samples.maxlen), ) return updater
def load_pretrained_embeddings(self) -> np.ndarray: @cache_center.to_npz(self.cache_dir / 'word_vecs.npz') def load_embeddings(): word_vec_config = self.language_config.load_pretrained_embeddings_msg( ) return word_vec_config.get_matrix_of_tokens(self.tokenizer.tokens) with logging_block("Load pretrained embeddings:"): embeddings = load_embeddings() print(f"Dimensions: {embeddings.shape[1]}.") return embeddings
def _process_data(self, tokenizer, corpus_config): data_collection = DataCollection() for key, path in corpus_config.path.items(): @cache_center.to_npz( self.get_cache_dir(corpus_config) / f'{key}_data.npz') def _process_text_file(filepath) -> np.ndarray: print(f"Load corpus data from {format_path(filepath)}") return tokenizer.texts_to_array(with_iter(tqdm_open(filepath))) with logging_block(f"{key} data:", bullet=False): ids = _process_text_file(path) text = list(map(tokenizer.ids_to_text, ids)) text_dataset = TextDataset(ids=ids, text=text) setattr(data_collection, key, text_dataset) return data_collection
def _create( optimizer_id: str, learning_rate: float, clip_value: float = None, clip_norm: float = None, clip_global_norm: float = None, weight_decay_rate: float = None, use_lookahead: bool = False, **kwargs, ): optimizer = optimizer_cls_table[optimizer_id](learning_rate, **kwargs) with logging_block( f"Optimizer: {format_object(optimizer, learning_rate=learning_rate, **kwargs)}", ): if clip_value: print(f"clip_value: {clip_value}") optimizer = GradientClipping(optimizer, clip_value, clip_by='value') elif clip_norm: print(f"clip_norm: {clip_norm}") optimizer = GradientClipping(optimizer, clip_norm, clip_by='norm') elif clip_global_norm: print(f"clip_global_norm: {clip_global_norm}") optimizer = GradientClipping(optimizer, clip_global_norm, clip_by='global_norm') if weight_decay_rate: print(f"weight_decay_rate: {weight_decay_rate}") optimizer = WeightDecay(optimizer, decay_rate=weight_decay_rate * learning_rate) if use_lookahead: print("use_lookahead: True") optimizer = LookAhead(optimizer) return optimizer
def create_evaluator(self, bleu_n_gram: int = None, fed_sample_size: int = None): proxy = GeneratorProxy(self.text_generator) evaluator = DispatchCallback() evaluator.on_batch_end.attach( proxy.evaluate_ids( partial(mean_length, eos_idx=self.meta_data.eos_idx), sample_size=64, target_channel='samples', ), period=10, ) if bleu_n_gram is not None: for key, dataset in self.data_collection.items(): with logging_block(f"Building '{key}' data BLEU table..."): calculater = BLEUCalculator( dataset.ids, max_gram=bleu_n_gram, eos_idx=self.meta_data.eos_idx, smoothing=SmoothingFunction.fuzz_smoothing, cache_dir=self.meta_data.cache_dir / f"{key}_BLEU", verbose=True, ) evaluator.on_batch_end.attach( proxy.evaluate_ids(calculater.mean_bleu, sample_size=64, target_channel=key), period=10, ) def selfbleu(word_ids): print("Evaluating generated data SelfBLEU...") print() return BLEUCalculator.selfbleu( word_ids, max_gram=bleu_n_gram, eos_idx=self.meta_data.eos_idx, smoothing=SmoothingFunction.fuzz_smoothing, ) evaluator.on_epoch_end.attach( proxy.evaluate_ids( selfbleu, target_channel='samples', sample_size=min(10000, 2 * len(self.data_collection.train)), ), ) if fed_sample_size is not None: for key, dataset in self.data_collection.items(): print(f"Building '{key}' data FED sentence encoder...") calculator = FEDCalculator( hub_url= "https://tfhub.dev/google/universal-sentence-encoder-large/3", references=random_sample(dataset.texts, size=fed_sample_size), ) def fed(texts): print("Evaluating FED Score ...") print() return { "FED": calculator.calculate_fed_score(candidates=texts) } evaluator.on_epoch_end.attach( proxy.evaluate_text( fed, sample_size=fed_sample_size, target_channel=key, ), ) evaluator.on_batch_end.attach( proxy.log_samples_command(sample_size=3), period=100, ) return evaluator
def summary(self): with logging_block("Special tokens config:"): for key, (token, idx) in self._attrs.items(): print(f"{key} token: '{token}', index: {idx}.")
def summary(self): with logging_block(self.scope): trainable_params = count_params(self.trainable_variables) non_trainable_params = count_params(self.non_trainable_variables) print(f"Trainable params: {trainable_params:>12,}") print(f"Non-trainable params: {non_trainable_params:>12,}")
def summary(self): with logging_block("Data summary:"): for key, array in self.items(): print(f"{key} data contains {len(array)} sentences.")
def summary(self): with logging_block('Model Summary'): for updater in self.updaters: updater.module.summary()
def summary(self): with logging_block(f"{self.__class__.__name__} summary:"): print(f"Maxlen: {self.maxlen}.") print(f"Vocabulary size: {self.vocab_size}.") self.special_token_config.summary()