def create_ensemble(ensemble_size, model): ensemble = [] for _ in range(ensemble_size): copy = Model() copy.config = model.config copy.create_model() ensemble.append(copy) return ensemble
def predict(self): config_file_path = self.training_root_path + "/config.json" config = ModelConfig() config.load_from_file(config_file_path) model = Model() model.config = config ensemble = self.create_ensemble(len(self.ensemble_weights_path), model) self.load_ensemble_weights(ensemble, self.ensemble_weights_path) output = None i = 0 for line in sys.stdin: if i == 0: sample_x = read_sample_x_or_y_from_string(line) output = self.predict_ensemble_on_sample(ensemble, sample_x) i = 1 elif i == 1: sample_z = read_sample_z_from_string(line) sample_y = self.generate_wsd_on_sample(output, sample_z) sys.stdout.write(sample_y + "\n") sys.stdout.flush() i = 0
def train(self): model_weights_last_path = self.model_path + "/model_weights_last" model_weights_loss_path = self.model_path + "/model_weights_loss" model_weights_wsd_path = self.model_path + "/model_weights_wsd" model_weights_end_of_epoch_path = self.model_path + "/model_weights_end_of_epoch_" training_info_path = self.model_path + "/training_info" training_losses_path = self.model_path + "/training_losses" train_file_path = self.data_path + "/train" dev_file_path = self.data_path + "/dev" config_file_path = self.data_path + "/config.json" print("Loading config and embeddings") config = ModelConfig() config.load_from_file(config_file_path) print("Creating model") model = Model() model.config = config self.recreate_model(model) print("Warming up on fake batch") batch_x, batch_y = create_fake_batch( batch_size=self.batch_size, sample_size=self.warmup_sample_size, input_features=model.config.input_features, input_vocabulary_sizes=model.config.input_vocabulary_sizes, output_features=model.config.output_features, output_vocabulary_sizes=model.config.output_vocabulary_sizes) model.begin_train_on_batch() model.train_on_batch(batch_x, batch_y, None) model.end_train_on_batch() self.recreate_model(model) print("Loading training and development data") train_samples = read_all_samples_from_file(train_file_path) dev_samples = read_all_samples_from_file(dev_file_path) current_ensemble = 0 current_epoch = 0 current_batch = 0 current_sample_index = 0 best_dev_wsd = None best_dev_loss = None if not self.reset and os.path.isfile( training_info_path) and os.path.isfile( model_weights_last_path): print("Resuming from previous training") current_ensemble, current_epoch, current_batch, current_sample_index, best_dev_wsd, best_dev_loss = load_training_info( training_info_path) model.load_model_weights(model_weights_last_path) elif self.shuffle_train_on_init: print("Shuffling training data") random.shuffle(train_samples) create_directory_if_not_exists(self.model_path) self.print_state(current_ensemble, current_epoch, current_batch, [None for _ in range(model.config.output_features)], [None for _ in range(model.config.output_features)], None) for current_ensemble in range(current_ensemble, self.ensemble_size): sample_accumulate_between_eval = 0 train_losses = None while self.stop_after_epoch == -1 or current_epoch < self.stop_after_epoch: reached_eof = False model.begin_train_on_batch() for _ in range(self.update_every_batch): batch_x, batch_y, batch_z, actual_batch_size, reached_eof = read_batch_from_samples( train_samples, self.batch_size, current_sample_index) if actual_batch_size == 0: break batch_losses = model.train_on_batch( batch_x, batch_y, batch_z) if train_losses is None: train_losses = [0 for _ in batch_losses] for i in range(len(batch_losses)): train_losses[i] += batch_losses[i] * actual_batch_size current_sample_index += actual_batch_size sample_accumulate_between_eval += actual_batch_size current_batch += 1 if reached_eof: break model.end_train_on_batch() if reached_eof: print("Reached eof at batch " + str(current_batch)) if self.save_end_of_epoch: model.save_model_weights( model_weights_end_of_epoch_path + str(current_epoch) + "_" + str(current_ensemble)) current_batch = 0 current_sample_index = 0 current_epoch += 1 random.shuffle(train_samples) if current_batch % self.test_every_batch == 0: dev_wsd, dev_losses = self.test_on_dev( self.batch_size, dev_samples, model) for i in range(len(train_losses)): train_losses[i] /= float( sample_accumulate_between_eval) self.print_state(current_ensemble, current_epoch, current_batch, train_losses, dev_losses, dev_wsd) save_training_losses(training_losses_path, train_losses[0], dev_losses[0], dev_wsd) sample_accumulate_between_eval = 0 train_losses = None if best_dev_loss is None or dev_losses[0] < best_dev_loss: if self.save_best_loss: model.save_model_weights(model_weights_loss_path + str(current_ensemble)) best_dev_loss = dev_losses[0] if best_dev_wsd is None or dev_wsd > best_dev_wsd: model.save_model_weights(model_weights_wsd_path + str(current_ensemble)) best_dev_wsd = dev_wsd print("New best dev WSD: " + str(best_dev_wsd)) model.save_model_weights(model_weights_last_path) save_training_info(training_info_path, current_ensemble, current_epoch, current_batch, current_sample_index, best_dev_wsd, best_dev_loss) self.recreate_model(model) current_epoch = 0 best_dev_wsd = None best_dev_loss = None