コード例 #1
0
    def train(self, x_train, y_train, x_valid=None, y_valid=None):

        # Prepare training and validation data(steps, generator)
        train_steps, train_batches = batch_iter(
            x_train,
            y_train,
            self.training_config.batch_size,
            preprocessor=self.preprocessor)
        valid_steps, valid_batches = batch_iter(
            x_valid,
            y_valid,
            self.training_config.batch_size,
            preprocessor=self.preprocessor)

        self.model.compile(
            loss=self.model.crf.loss,
            optimizer=Adam(lr=self.training_config.learning_rate),
        )

        # Prepare callbacks
        callbacks = get_callbacks(
            log_dir=self.checkpoint_path,
            tensorboard=self.tensorboard,
            eary_stopping=self.training_config.early_stopping,
            valid=(valid_steps, valid_batches, self.preprocessor))

        # Train the model
        self.model.fit_generator(generator=train_batches,
                                 steps_per_epoch=train_steps,
                                 epochs=self.training_config.max_epoch,
                                 callbacks=callbacks)
コード例 #2
0
ファイル: trainer.py プロジェクト: zaczou/anago
    def train(self, x_train, y_train, x_valid=None, y_valid=None):

        # Prepare training and validation data(steps, generator)
        train_steps, train_batches = batch_iter(
            list(zip(x_train, y_train)),
            self.training_config.batch_size,
            preprocessor=self.preprocessor)
        valid_steps, valid_batches = batch_iter(
            list(zip(x_valid, y_valid)),
            self.training_config.batch_size,
            preprocessor=self.preprocessor)

        # Build the model
        model = SeqLabeling(self.model_config, self.embeddings,
                            len(self.preprocessor.vocab_tag))
        model.compile(
            loss=model.crf.loss,
            optimizer=Adam(lr=self.training_config.learning_rate),
        )

        # Prepare callbacks for training
        callbacks = get_callbacks(
            log_dir=self.checkpoint_path,
            tensorboard=self.tensorboard,
            eary_stopping=self.training_config.early_stopping,
            valid=(valid_steps, valid_batches, self.preprocessor))

        # Train the model
        model.fit_generator(generator=train_batches,
                            steps_per_epoch=train_steps,
                            epochs=self.training_config.max_epoch,
                            callbacks=callbacks)

        # Save the model
        model.save(os.path.join(self.save_path, 'model_weights.h5'))
コード例 #3
0
    def eval(self, x_test, y_test):

        # Prepare test data(steps, generator)
        train_steps, train_batches = batch_iter(
            x_test,
            y_test,
            batch_size=20,  # Todo: if batch_size=1, eval does not work.
            shuffle=False,
            preprocessor=self.preprocessor)

        # Build the evaluator and evaluate the model
        f1score = F1score(train_steps, train_batches, self.preprocessor)
        f1score.model = self.model
        f1score.on_epoch_end(epoch=-1)  # epoch takes any integer.
コード例 #4
0
    def eval(self, x_test, y_test):

        # Prepare test data(steps, generator)
        train_steps, train_batches = batch_iter(list(zip(x_test, y_test)),
                                                self.config.batch_size,
                                                preprocessor=self.preprocessor)

        # Build the model
        model = SeqLabeling(self.config,
                            ntags=len(self.preprocessor.vocab_tag))
        model.load(filepath=os.path.join(self.save_path, self.weights))

        # Build the evaluator and evaluate the model
        f1score = F1score(train_steps, train_batches, self.preprocessor)
        f1score.model = model
        f1score.on_epoch_end(epoch=-1)  # epoch takes any integer.
コード例 #5
0
ファイル: reader_test.py プロジェクト: pilehvar/phenebank
 def test_batch_iter(self):
     sents, labels = load_data_and_labels(self.filename)
     batch_size = 32
     p = prepare_preprocessor(sents, labels)
     steps, batches = batch_iter(list(zip(sents, labels)), batch_size, preprocessor=p)
     self.assertEqual(len([_ for _ in batches]), steps)  # Todo: infinite loop