Example #1
0
def main(argv, base_tag=None, checkpoint=None, override_namespace=None):
    args = parse_args(argv)
    if override_namespace:
        args.__dict__.update(override_namespace.__dict__)

    set_package_verbosity(args.debug)

    with logging_block("Set global random seed"):
        set_global_random_seed(args.random_seed)

    with logging_block("Preprocess data"):
        data_collection, meta_data = data_factory.preprocess(args,
                                                             return_meta=True)
        data_collection.summary()
        meta_data.summary()
        data_generator = DataGenerator(data_collection.train,
                                       batch_size=args.batch_size)

    with logging_block("Prepare Generator"):
        generator = generator_factory.create(args, meta_data)

    with logging_block("Prepare Generator Trainer"):
        trainer = trainer_factory.create(args, meta_data, generator)
        trainer.summary()

    with logging_block("Prepare Callback"):
        if base_tag is None:
            base_tag = f"{args.dataset}@{time.strftime('%Y%m%d-%H%M%S')}"
        data_generator.callback = CallbackFactory(
            trainer=trainer,
            generator=generator,
            data_collection=data_collection,
            meta_data=meta_data,
            tags=args.tags + [base_tag],
        ).create_by_args(args)

    with tf.Session(config=get_tf_config_proto(args.jit)) as sess:
        if checkpoint:
            print(f"Restore from checkpoint: {checkpoint}")
            tf.train.Saver().restore(sess, save_path=checkpoint)
            data_generator.skip_epochs(int(checkpoint.split('-')[-1]))
        else:
            tf.global_variables_initializer().run()

        for batch_data in data_generator.iter_batch_until(
                n_epochs=args.epochs,
                logs={'arg_string':
                      " ".join(argv)},  # for callback.on_train_begin()
        ):
            trainer.fit_batch(batch_data)
Example #2
0
    def create_generator_updater(self, real_samples, objective):
        with logging_block("Generator Optimization:"):
            optimizer = generator_factory.create_optimizer(self.args)
            objectives = [
                objective, *generator_factory.create_regularizers(self.args)
            ]
            with logging_block("Objective:"):
                for obj in objectives:
                    print(obj)

        updater = GeneratorUpdater(self.generator, optimizer=optimizer)
        for obj in objectives:
            updater.add_loss(obj)
        updater.build_graph(real_samples)
        return updater
Example #3
0
    def preprocess(self,
                   corpus_config: CorpusConfig,
                   return_meta: bool = False):
        with logging_block("Prepare text tokenizer..."):
            tokenizer = self._create_tokenizer(corpus_config)

        with logging_block("Preprocess text corpus..."):
            data_collection = self._process_data(tokenizer, corpus_config)

        if return_meta:
            meta_data = MetaData(
                tokenizer=tokenizer,
                corpus_config=corpus_config,
                cache_dir=self.get_cache_dir(corpus_config),
            )
            return data_collection, meta_data
        else:
            return data_collection
Example #4
0
    def create_objective_updater(self, real_samples, discriminator, loss):
        with logging_block("Discriminator Optimization:"):
            optimizer = discriminator_factory.create_optimizer(self.args)
            objectives = [
                loss, *discriminator_factory.create_regularizers(self.args)
            ]
            with logging_block("Objective:"):
                for obj in objectives:
                    print(obj)

        updater = DiscriminatorUpdater(discriminator, optimizer=optimizer)
        for obj in objectives:
            updater.add_loss(obj)
        updater.build_graph(
            real_samples=real_samples,
            fake_samples=self.generator.generate(real_samples.batch_size,
                                                 real_samples.maxlen),
        )
        return updater
Example #5
0
    def load_pretrained_embeddings(self) -> np.ndarray:
        @cache_center.to_npz(self.cache_dir / 'word_vecs.npz')
        def load_embeddings():
            word_vec_config = self.language_config.load_pretrained_embeddings_msg(
            )
            return word_vec_config.get_matrix_of_tokens(self.tokenizer.tokens)

        with logging_block("Load pretrained embeddings:"):
            embeddings = load_embeddings()
            print(f"Dimensions: {embeddings.shape[1]}.")

        return embeddings
Example #6
0
    def _process_data(self, tokenizer, corpus_config):
        data_collection = DataCollection()
        for key, path in corpus_config.path.items():

            @cache_center.to_npz(
                self.get_cache_dir(corpus_config) / f'{key}_data.npz')
            def _process_text_file(filepath) -> np.ndarray:
                print(f"Load corpus data from {format_path(filepath)}")
                return tokenizer.texts_to_array(with_iter(tqdm_open(filepath)))

            with logging_block(f"{key} data:", bullet=False):
                ids = _process_text_file(path)
                text = list(map(tokenizer.ids_to_text, ids))
                text_dataset = TextDataset(ids=ids, text=text)
                setattr(data_collection, key, text_dataset)

        return data_collection
Example #7
0
def _create(
    optimizer_id: str,
    learning_rate: float,
    clip_value: float = None,
    clip_norm: float = None,
    clip_global_norm: float = None,
    weight_decay_rate: float = None,
    use_lookahead: bool = False,
    **kwargs,
):
    optimizer = optimizer_cls_table[optimizer_id](learning_rate, **kwargs)
    with logging_block(
            f"Optimizer: {format_object(optimizer, learning_rate=learning_rate, **kwargs)}",
    ):
        if clip_value:
            print(f"clip_value: {clip_value}")
            optimizer = GradientClipping(optimizer,
                                         clip_value,
                                         clip_by='value')
        elif clip_norm:
            print(f"clip_norm: {clip_norm}")
            optimizer = GradientClipping(optimizer, clip_norm, clip_by='norm')
        elif clip_global_norm:
            print(f"clip_global_norm: {clip_global_norm}")
            optimizer = GradientClipping(optimizer,
                                         clip_global_norm,
                                         clip_by='global_norm')

        if weight_decay_rate:
            print(f"weight_decay_rate: {weight_decay_rate}")
            optimizer = WeightDecay(optimizer,
                                    decay_rate=weight_decay_rate *
                                    learning_rate)

        if use_lookahead:
            print("use_lookahead: True")
            optimizer = LookAhead(optimizer)

    return optimizer
Example #8
0
    def create_evaluator(self,
                         bleu_n_gram: int = None,
                         fed_sample_size: int = None):
        proxy = GeneratorProxy(self.text_generator)
        evaluator = DispatchCallback()
        evaluator.on_batch_end.attach(
            proxy.evaluate_ids(
                partial(mean_length, eos_idx=self.meta_data.eos_idx),
                sample_size=64,
                target_channel='samples',
            ),
            period=10,
        )
        if bleu_n_gram is not None:
            for key, dataset in self.data_collection.items():
                with logging_block(f"Building '{key}' data BLEU table..."):
                    calculater = BLEUCalculator(
                        dataset.ids,
                        max_gram=bleu_n_gram,
                        eos_idx=self.meta_data.eos_idx,
                        smoothing=SmoothingFunction.fuzz_smoothing,
                        cache_dir=self.meta_data.cache_dir / f"{key}_BLEU",
                        verbose=True,
                    )
                evaluator.on_batch_end.attach(
                    proxy.evaluate_ids(calculater.mean_bleu,
                                       sample_size=64,
                                       target_channel=key),
                    period=10,
                )

            def selfbleu(word_ids):
                print("Evaluating generated data SelfBLEU...")
                print()
                return BLEUCalculator.selfbleu(
                    word_ids,
                    max_gram=bleu_n_gram,
                    eos_idx=self.meta_data.eos_idx,
                    smoothing=SmoothingFunction.fuzz_smoothing,
                )

            evaluator.on_epoch_end.attach(
                proxy.evaluate_ids(
                    selfbleu,
                    target_channel='samples',
                    sample_size=min(10000,
                                    2 * len(self.data_collection.train)),
                ), )

        if fed_sample_size is not None:

            for key, dataset in self.data_collection.items():
                print(f"Building '{key}' data FED sentence encoder...")
                calculator = FEDCalculator(
                    hub_url=
                    "https://tfhub.dev/google/universal-sentence-encoder-large/3",
                    references=random_sample(dataset.texts,
                                             size=fed_sample_size),
                )

                def fed(texts):
                    print("Evaluating FED Score ...")
                    print()
                    return {
                        "FED": calculator.calculate_fed_score(candidates=texts)
                    }

                evaluator.on_epoch_end.attach(
                    proxy.evaluate_text(
                        fed,
                        sample_size=fed_sample_size,
                        target_channel=key,
                    ), )

        evaluator.on_batch_end.attach(
            proxy.log_samples_command(sample_size=3),
            period=100,
        )
        return evaluator
Example #9
0
 def summary(self):
     with logging_block("Special tokens config:"):
         for key, (token, idx) in self._attrs.items():
             print(f"{key} token: '{token}', index: {idx}.")
Example #10
0
 def summary(self):
     with logging_block(self.scope):
         trainable_params = count_params(self.trainable_variables)
         non_trainable_params = count_params(self.non_trainable_variables)
         print(f"Trainable     params: {trainable_params:>12,}")
         print(f"Non-trainable params: {non_trainable_params:>12,}")
Example #11
0
 def summary(self):
     with logging_block("Data summary:"):
         for key, array in self.items():
             print(f"{key} data contains {len(array)} sentences.")
Example #12
0
 def summary(self):
     with logging_block('Model Summary'):
         for updater in self.updaters:
             updater.module.summary()
Example #13
0
 def summary(self):
     with logging_block(f"{self.__class__.__name__} summary:"):
         print(f"Maxlen: {self.maxlen}.")
         print(f"Vocabulary size: {self.vocab_size}.")
         self.special_token_config.summary()