def train(self, x_train, y_train, x_valid=None, y_valid=None): # TBD if valid is None, segment train to get one x_all = np.concatenate((x_train, x_valid), axis=0) y_all = np.concatenate((y_train, y_valid), axis=0) self.p = prepare_preprocessor(x_all, y_all, self.model_config) self.model_config.char_vocab_size = len(self.p.vocab_char) self.model_config.case_vocab_size = len(self.p.vocab_case) """ if self.embeddings.use_ELMo: # dump token context independent data for the train set, done once for the training x_train_local = x_train if not self.training_config.early_stop: # in case we want to train with the validation set too, we dump also # the ELMo embeddings for the token of the valid set x_train_local = np.concatenate((x_train, x_valid), axis=0) self.embeddings.dump_ELMo_token_embeddings(x_train_local) """ self.model = get_model(self.model_config, self.p, len(self.p.vocab_tag)) trainer = Trainer(self.model, self.models, self.embeddings, self.model_config, self.training_config, checkpoint_path=self.log_dir, preprocessor=self.p ) trainer.train(x_train, y_train, x_valid, y_valid) if self.embeddings.use_ELMo: self.embeddings.clean_ELMo_cache()
def train_nfold(self, x_train, y_train, x_valid=None, y_valid=None, fold_number=10): if x_valid is not None and y_valid is not None: x_all = np.concatenate((x_train, x_valid), axis=0) y_all = np.concatenate((y_train, y_valid), axis=0) self.p = prepare_preprocessor(x_all, y_all, self.model_config) else: self.p = prepare_preprocessor(x_train, y_train, self.model_config) self.model_config.char_vocab_size = len(self.p.vocab_char) self.model_config.case_vocab_size = len(self.p.vocab_case) self.p.return_lengths = True #self.model = get_model(self.model_config, self.p, len(self.p.vocab_tag)) self.models = [] for k in range(0, fold_number): model = get_model(self.model_config, self.p, len(self.p.vocab_tag)) self.models.append(model) trainer = Trainer(self.model, self.models, self.embeddings, self.model_config, self.training_config, checkpoint_path=self.log_dir, preprocessor=self.p ) trainer.train_nfold(x_train, y_train, x_valid, y_valid) if self.embeddings.use_ELMo: self.embeddings.clean_ELMo_cache()
def train(self, x_train, y_train, f_train=None, x_valid=None, y_valid=None, f_valid=None, callbacks=None): # TBD if valid is None, segment train to get one if early_stop is True # we concatenate all the training+validation data to create the model vocabulary if not x_valid is None: x_all = np.concatenate((x_train, x_valid), axis=0) else: x_all = x_train if not y_valid is None: y_all = np.concatenate((y_train, y_valid), axis=0) else: y_all = y_train features_all = concatenate_or_none((f_train, f_valid), axis=0) self.p = prepare_preprocessor(x_all, y_all, features=features_all, model_config=self.model_config) self.model_config.char_vocab_size = len(self.p.vocab_char) self.model_config.case_vocab_size = len(self.p.vocab_case) self.model = get_model(self.model_config, self.p, len(self.p.vocab_tag), load_pretrained_weights=True) print_parameters(self.model_config, self.training_config) self.model.print_summary() # uncomment to plot graph #plot_model(self.model, # to_file='data/models/textClassification/'+self.model_config.model_name+'_'+self.model_config.architecture+'.png') trainer = Trainer( self.model, self.models, self.embeddings, self.model_config, self.training_config, checkpoint_path=self.log_dir, preprocessor=self.p, transformer_preprocessor=self.model.transformer_preprocessor) trainer.train(x_train, y_train, x_valid, y_valid, features_train=f_train, features_valid=f_valid, callbacks=callbacks) if self.embeddings and self.embeddings.use_ELMo: self.embeddings.clean_ELMo_cache()
def train_nfold(self, x_train, y_train, x_valid=None, y_valid=None, f_train: np.array = None, f_valid: np.array = None, fold_number=10, callbacks=None): x_all = np.concatenate((x_train, x_valid), axis=0) if x_valid is not None else x_train y_all = np.concatenate((y_train, y_valid), axis=0) if y_valid is not None else y_train features_all = concatenate_or_none((f_train, f_valid), axis=0) self.p = prepare_preprocessor(x_all, y_all, features=features_all, model_config=self.model_config) self.model_config.char_vocab_size = len(self.p.vocab_char) self.model_config.case_vocab_size = len(self.p.vocab_case) self.p.return_lengths = True if 'bert' in self.model_config.model_type.lower(): self.model = get_model(self.model_config, self.p, len(self.p.vocab_tag)) self.models = [] for k in range(0, fold_number): model = get_model(self.model_config, self.p, len(self.p.vocab_tag)) self.models.append(model) trainer = Trainer(self.model, self.models, self.embeddings, self.model_config, self.training_config, checkpoint_path=self.log_dir, preprocessor=self.p ) trainer.train_nfold(x_train, y_train, x_valid, y_valid, f_train=f_train, f_valid=f_valid, callbacks=callbacks) if self.embeddings.use_ELMo: self.embeddings.clean_ELMo_cache() if self.embeddings.use_BERT: self.embeddings.clean_BERT_cache() if 'bert' in self.model_config.model_type.lower(): self.save()
def train(self, x_train, y_train, f_train: np.array = None, x_valid=None, y_valid=None, f_valid: np.array = None, callbacks=None): # TBD if valid is None, segment train to get one x_all = np.concatenate((x_train, x_valid), axis=0) if x_valid is not None else x_train y_all = np.concatenate((y_train, y_valid), axis=0) if y_valid is not None else y_train features_all = concatenate_or_none((f_train, f_valid), axis=0) self.p = prepare_preprocessor(x_all, y_all, features=features_all, model_config=self.model_config) self.model_config.char_vocab_size = len(self.p.vocab_char) self.model_config.case_vocab_size = len(self.p.vocab_case) self.model = get_model(self.model_config, self.p, len(self.p.vocab_tag)) if self.p.return_features is not False: print('x_train.shape: ', x_train.shape) print('features_train.shape: ', f_train.shape) sample_transformed_features = self.p.transform_features(f_train) self.model_config.max_feature_size = np.asarray(sample_transformed_features).shape[-1] print('max_feature_size: ', self.model_config.max_feature_size) trainer = Trainer(self.model, self.models, self.embeddings, self.model_config, self.training_config, checkpoint_path=self.log_dir, preprocessor=self.p ) trainer.train(x_train, y_train, x_valid, y_valid, features_train=f_train, features_valid=f_valid, callbacks=callbacks) if self.embeddings.use_ELMo: self.embeddings.clean_ELMo_cache() if self.embeddings.use_BERT: self.embeddings.clean_BERT_cache()
def train_nfold(self, x_train, y_train, x_valid=None, y_valid=None, f_train=None, f_valid=None, callbacks=None): x_all = np.concatenate( (x_train, x_valid), axis=0) if x_valid is not None else x_train y_all = np.concatenate( (y_train, y_valid), axis=0) if y_valid is not None else y_train features_all = concatenate_or_none((f_train, f_valid), axis=0) self.p = prepare_preprocessor(x_all, y_all, features=features_all, model_config=self.model_config) self.model_config.char_vocab_size = len(self.p.vocab_char) self.model_config.case_vocab_size = len(self.p.vocab_case) self.models = [] trainer = Trainer(self.model, self.models, self.embeddings, self.model_config, self.training_config, checkpoint_path=self.log_dir, preprocessor=self.p) trainer.train_nfold(x_train, y_train, x_valid, y_valid, f_train=f_train, f_valid=f_valid, callbacks=callbacks) if self.embeddings and self.embeddings.use_ELMo: self.embeddings.clean_ELMo_cache()
def train(self, x_train, y_train, x_valid=None, y_valid=None): # TBD if valid is None, segment train to get one x_all = np.concatenate((x_train, x_valid), axis=0) y_all = np.concatenate((y_train, y_valid), axis=0) self.p = prepare_preprocessor(x_all, y_all, self.model_config) self.model_config.char_vocab_size = len(self.p.vocab_char) self.model_config.case_vocab_size = len(self.p.vocab_case) self.model = get_model(self.model_config, self.p, len(self.p.vocab_tag)) trainer = Trainer(self.model, self.models, self.embeddings, self.model_config, self.training_config, checkpoint_path=self.log_dir, preprocessor=self.p) trainer.train(x_train, y_train, x_valid, y_valid) if self.embeddings.use_ELMo: self.embeddings.clean_ELMo_cache() if self.embeddings.use_BERT: self.embeddings.clean_BERT_cache()