示例#1
0
 def create_ensemble(ensemble_size, model):
     ensemble = []
     for _ in range(ensemble_size):
         copy = Model()
         copy.config = model.config
         copy.create_model()
         ensemble.append(copy)
     return ensemble
示例#2
0
    def predict(self):
        config_file_path = self.training_root_path + "/config.json"
        config = ModelConfig()
        config.load_from_file(config_file_path)

        model = Model()
        model.config = config

        ensemble = self.create_ensemble(len(self.ensemble_weights_path), model)
        self.load_ensemble_weights(ensemble, self.ensemble_weights_path)

        output = None
        i = 0
        for line in sys.stdin:
            if i == 0:
                sample_x = read_sample_x_or_y_from_string(line)
                output = self.predict_ensemble_on_sample(ensemble, sample_x)
                i = 1
            elif i == 1:
                sample_z = read_sample_z_from_string(line)
                sample_y = self.generate_wsd_on_sample(output, sample_z)
                sys.stdout.write(sample_y + "\n")
                sys.stdout.flush()
                i = 0
示例#3
0
    def train(self):
        model_weights_last_path = self.model_path + "/model_weights_last"
        model_weights_loss_path = self.model_path + "/model_weights_loss"
        model_weights_wsd_path = self.model_path + "/model_weights_wsd"
        model_weights_end_of_epoch_path = self.model_path + "/model_weights_end_of_epoch_"
        training_info_path = self.model_path + "/training_info"
        training_losses_path = self.model_path + "/training_losses"
        train_file_path = self.data_path + "/train"
        dev_file_path = self.data_path + "/dev"
        config_file_path = self.data_path + "/config.json"

        print("Loading config and embeddings")
        config = ModelConfig()
        config.load_from_file(config_file_path)

        print("Creating model")
        model = Model()
        model.config = config
        self.recreate_model(model)

        print("Warming up on fake batch")
        batch_x, batch_y = create_fake_batch(
            batch_size=self.batch_size,
            sample_size=self.warmup_sample_size,
            input_features=model.config.input_features,
            input_vocabulary_sizes=model.config.input_vocabulary_sizes,
            output_features=model.config.output_features,
            output_vocabulary_sizes=model.config.output_vocabulary_sizes)
        model.begin_train_on_batch()
        model.train_on_batch(batch_x, batch_y, None)
        model.end_train_on_batch()

        self.recreate_model(model)

        print("Loading training and development data")
        train_samples = read_all_samples_from_file(train_file_path)
        dev_samples = read_all_samples_from_file(dev_file_path)

        current_ensemble = 0
        current_epoch = 0
        current_batch = 0
        current_sample_index = 0
        best_dev_wsd = None
        best_dev_loss = None

        if not self.reset and os.path.isfile(
                training_info_path) and os.path.isfile(
                    model_weights_last_path):
            print("Resuming from previous training")
            current_ensemble, current_epoch, current_batch, current_sample_index, best_dev_wsd, best_dev_loss = load_training_info(
                training_info_path)
            model.load_model_weights(model_weights_last_path)
        elif self.shuffle_train_on_init:
            print("Shuffling training data")
            random.shuffle(train_samples)

        create_directory_if_not_exists(self.model_path)

        self.print_state(current_ensemble, current_epoch, current_batch,
                         [None for _ in range(model.config.output_features)],
                         [None
                          for _ in range(model.config.output_features)], None)

        for current_ensemble in range(current_ensemble, self.ensemble_size):
            sample_accumulate_between_eval = 0
            train_losses = None
            while self.stop_after_epoch == -1 or current_epoch < self.stop_after_epoch:

                reached_eof = False
                model.begin_train_on_batch()
                for _ in range(self.update_every_batch):
                    batch_x, batch_y, batch_z, actual_batch_size, reached_eof = read_batch_from_samples(
                        train_samples, self.batch_size, current_sample_index)
                    if actual_batch_size == 0: break
                    batch_losses = model.train_on_batch(
                        batch_x, batch_y, batch_z)
                    if train_losses is None:
                        train_losses = [0 for _ in batch_losses]
                    for i in range(len(batch_losses)):
                        train_losses[i] += batch_losses[i] * actual_batch_size
                    current_sample_index += actual_batch_size
                    sample_accumulate_between_eval += actual_batch_size
                    current_batch += 1
                    if reached_eof: break
                model.end_train_on_batch()

                if reached_eof:
                    print("Reached eof at batch " + str(current_batch))
                    if self.save_end_of_epoch:
                        model.save_model_weights(
                            model_weights_end_of_epoch_path +
                            str(current_epoch) + "_" + str(current_ensemble))
                    current_batch = 0
                    current_sample_index = 0
                    current_epoch += 1
                    random.shuffle(train_samples)

                if current_batch % self.test_every_batch == 0:
                    dev_wsd, dev_losses = self.test_on_dev(
                        self.batch_size, dev_samples, model)
                    for i in range(len(train_losses)):
                        train_losses[i] /= float(
                            sample_accumulate_between_eval)
                    self.print_state(current_ensemble, current_epoch,
                                     current_batch, train_losses, dev_losses,
                                     dev_wsd)
                    save_training_losses(training_losses_path, train_losses[0],
                                         dev_losses[0], dev_wsd)
                    sample_accumulate_between_eval = 0
                    train_losses = None

                    if best_dev_loss is None or dev_losses[0] < best_dev_loss:
                        if self.save_best_loss:
                            model.save_model_weights(model_weights_loss_path +
                                                     str(current_ensemble))
                        best_dev_loss = dev_losses[0]

                    if best_dev_wsd is None or dev_wsd > best_dev_wsd:
                        model.save_model_weights(model_weights_wsd_path +
                                                 str(current_ensemble))
                        best_dev_wsd = dev_wsd
                        print("New best dev WSD: " + str(best_dev_wsd))

                    model.save_model_weights(model_weights_last_path)
                    save_training_info(training_info_path, current_ensemble,
                                       current_epoch, current_batch,
                                       current_sample_index, best_dev_wsd,
                                       best_dev_loss)

            self.recreate_model(model)
            current_epoch = 0
            best_dev_wsd = None
            best_dev_loss = None