Example #1
0
 def __init__(self, training_root_path, ensemble_weights_path, clear_text,
              batch_size, disambiguate, beam_size, output_all_features):
     print("Initialize Predicter")
     self.training_root_path: str = training_root_path
     self.ensemble_weights_path: List[str] = ensemble_weights_path
     self.clear_text: bool = clear_text
     self.batch_size: int = batch_size
     self.disambiguate: bool = disambiguate
     self.translate: bool = False
     self.beam_size: int = beam_size
     self.output_all_features: bool = output_all_features
     self.data_config: DataConfig = None
     self.config_file_path = self.training_root_path + "/config.json"
     self.data_config = DataConfig()
     self.data_config.load_from_file(self.config_file_path)
     self.config = ModelConfig(self.data_config)
     self.config.load_from_file(self.config_file_path)
     if self.clear_text:
         self.config.data_config.input_clear_text = [
             True for _ in range(self.config.data_config.input_features)
         ]
     if self.data_config.output_features <= 0:
         self.disambiguate = False
     if self.data_config.output_translations <= 0:
         self.translate = False
     assert (self.disambiguate or self.translate)
     self.ensemble = self.create_ensemble(self.config,
                                          self.ensemble_weights_path)
     print("Predicter initialized")
Example #2
0
    def predict(self):
        config_file_path = self.training_root_path + "/config.json"
        self.data_config = DataConfig()
        self.data_config.load_from_file(config_file_path)
        config = ModelConfig(self.data_config)
        config.load_from_file(config_file_path)
        if self.clear_text:
            config.data_config.input_clear_text = [
                True for _ in range(config.data_config.input_features)
            ]
        if self.data_config.output_features <= 0:
            self.disambiguate = False
        if self.data_config.output_translations <= 0:
            self.translate = False

        assert (self.disambiguate or self.translate)

        ensemble = self.create_ensemble(config, self.ensemble_weights_path)

        i = 0
        batch_x = None
        batch_z = None
        for line in sys.stdin:
            #self.file.write(line)
            if i == 0:
                sample_x = read_sample_x_from_string(
                    line,
                    feature_count=config.data_config.input_features,
                    clear_text=config.data_config.input_clear_text)
                self.preprocess_sample_x(ensemble, sample_x)
                if batch_x is None:
                    batch_x = [[] for _ in range(len(sample_x))]
                for j in range(len(sample_x)):
                    batch_x[j].append(sample_x[j])
                if self.disambiguate and not self.output_all_features:
                    i = 1
                else:
                    if len(batch_x[0]) >= self.batch_size:
                        self.predict_and_output(
                            ensemble, batch_x, batch_z,
                            self.data_config.input_clear_text)
                        batch_x = None
            elif i == 1:
                sample_z = read_sample_z_from_string(
                    line, feature_count=config.data_config.output_features)
                if batch_z is None:
                    batch_z = [[] for _ in range(len(sample_z))]
                for j in range(len(sample_z)):
                    batch_z[j].append(sample_z[j])
                i = 0
                if len(batch_z[0]) >= self.batch_size:
                    self.predict_and_output(ensemble, batch_x, batch_z,
                                            self.data_config.input_clear_text)
                    batch_x = None
                    batch_z = None

        if batch_x is not None:
            self.predict_and_output(ensemble, batch_x, batch_z,
                                    self.data_config.input_clear_text)
Example #3
0
class Predicter(object):
    def __init__(self, training_root_path, ensemble_weights_path, clear_text,
                 batch_size, disambiguate, beam_size, output_all_features):
        print("Initialize Predicter")
        self.training_root_path: str = training_root_path
        self.ensemble_weights_path: List[str] = ensemble_weights_path
        self.clear_text: bool = clear_text
        self.batch_size: int = batch_size
        self.disambiguate: bool = disambiguate
        self.translate: bool = False
        self.beam_size: int = beam_size
        self.output_all_features: bool = output_all_features
        self.data_config: DataConfig = None
        self.config_file_path = self.training_root_path + "/config.json"
        self.data_config = DataConfig()
        self.data_config.load_from_file(self.config_file_path)
        self.config = ModelConfig(self.data_config)
        self.config.load_from_file(self.config_file_path)
        if self.clear_text:
            self.config.data_config.input_clear_text = [
                True for _ in range(self.config.data_config.input_features)
            ]
        if self.data_config.output_features <= 0:
            self.disambiguate = False
        if self.data_config.output_translations <= 0:
            self.translate = False
        assert (self.disambiguate or self.translate)
        self.ensemble = self.create_ensemble(self.config,
                                             self.ensemble_weights_path)
        print("Predicter initialized")

    def predict(self, lines):
        i = 0
        batch_x = None
        batch_z = None
        out = []
        for line in lines:
            if i == 0:
                sample_x = read_sample_x_from_string(
                    line,
                    feature_count=self.config.data_config.input_features,
                    clear_text=self.config.data_config.input_clear_text)
                self.preprocess_sample_x(self.ensemble, sample_x)
                if batch_x is None:
                    batch_x = [[] for _ in range(len(sample_x))]
                for j in range(len(sample_x)):
                    batch_x[j].append(sample_x[j])
                if self.disambiguate and not self.output_all_features:
                    i = 1
                else:
                    if len(batch_x[0]) >= self.batch_size:
                        out.append(
                            self.predict_and_output(
                                self.ensemble, batch_x, batch_z,
                                self.data_config.input_clear_text))
                        batch_x = None
            elif i == 1:
                sample_z = read_sample_z_from_string(
                    line,
                    feature_count=self.config.data_config.output_features)
                if batch_z is None:
                    batch_z = [[] for _ in range(len(sample_z))]
                for j in range(len(sample_z)):
                    batch_z[j].append(sample_z[j])
                i = 0
                if len(batch_z[0]) >= self.batch_size:
                    out.append(
                        self.predict_and_output(
                            self.ensemble, batch_x, batch_z,
                            self.data_config.input_clear_text))
                    batch_x = None
                    batch_z = None

        if batch_x is not None:
            out.append(
                self.predict_and_output(self.ensemble, batch_x, batch_z,
                                        self.data_config.input_clear_text))
        return out

    def predictFile(self, file_in, file_out):
        i = 0
        c = 0
        batch_x = None
        batch_z = None
        source_file = bz2.BZ2File(file_in, "r")
        sink_file = bz2.BZ2File(file_out, "w")
        for line_b in source_file:
            line = line_b.decode("utf-8").rstrip('\n')
            if (c % 1000 == 0):
                print("Processing line " + str(c))
            c = c + 1
            if (line[0] == '{'):
                sink_file.write(bytes(line, "utf-8"))
                sink_file.write(bytes('\n', "utf-8"))
                continue
            if i == 0:
                sample_x = read_sample_x_from_string(
                    line,
                    feature_count=self.config.data_config.input_features,
                    clear_text=self.config.data_config.input_clear_text)
                self.preprocess_sample_x(self.ensemble, sample_x)
                if batch_x is None:
                    batch_x = [[] for _ in range(len(sample_x))]
                for j in range(len(sample_x)):
                    batch_x[j].append(sample_x[j])
                if self.disambiguate and not self.output_all_features:
                    i = 1
                else:
                    if len(batch_x[0]) >= self.batch_size:
                        sink_file.write(
                            bytes(
                                self.predict_and_output(
                                    self.ensemble, batch_x, batch_z,
                                    self.data_config.input_clear_text),
                                "utf-8"))
                        sink_file.write(bytes('\n', "utf-8"))
                        batch_x = None
            elif i == 1:
                sample_z = read_sample_z_from_string(
                    line,
                    feature_count=self.config.data_config.output_features)
                if batch_z is None:
                    batch_z = [[] for _ in range(len(sample_z))]
                for j in range(len(sample_z)):
                    batch_z[j].append(sample_z[j])
                i = 0
                if len(batch_z[0]) >= self.batch_size:
                    sink_file.write(
                        bytes(
                            self.predict_and_output(
                                self.ensemble, batch_x, batch_z,
                                self.data_config.input_clear_text), "utf-8"))
                    sink_file.write(bytes('\n', "utf-8"))
                    batch_x = None
                    batch_z = None
        if batch_x is not None:
            sink_file.write(
                bytes(
                    self.predict_and_output(self.ensemble, batch_x, batch_z,
                                            self.data_config.input_clear_text),
                    "utf-8"))
            sink_file.write(bytes('\n', "utf-8"))
        source_file.close()
        sink_file.close()

    def create_ensemble(self, config: ModelConfig,
                        ensemble_weights_paths: List[str]):
        ensemble = [Model(config) for _ in range(len(ensemble_weights_paths))]
        for i in range(len(ensemble)):
            ensemble[i].load_model_weights(ensemble_weights_paths[i])
            ensemble[i].set_beam_size(self.beam_size)
        return ensemble

    @staticmethod
    def preprocess_sample_x(ensemble: List[Model], sample_x):
        ensemble[0].preprocess_samples([[sample_x]])

    def predict_and_output(self, ensemble: List[Model], batch_x, batch_z,
                           clear_text):
        pad_batch_x(batch_x, clear_text)
        output_wsd, output_translation = None, None
        # TODO: refact this horror
        if self.disambiguate and not self.translate and self.output_all_features:
            output_all_features = Predicter.predict_ensemble_all_features_on_batch(
                ensemble, batch_x)
            batch_all_features = Predicter.generate_all_features_on_batch(
                output_all_features, batch_x)
            result = ""
            for sample_all_features in batch_all_features:
                # sys.stdout.write(sample_all_features + "\n")
                # result.append(sample_all_features)
                result = result + sample_all_features + "\n"
            # sys.stdout.flush()
            return result
        if self.disambiguate and not self.translate:
            output_wsd = Predicter.predict_ensemble_wsd_on_batch(
                ensemble, batch_x)
        elif self.translate and not self.disambiguate:
            output_translation = Predicter.predict_ensemble_translation_on_batch(
                ensemble, batch_x)
        else:
            output_wsd, output_translation = Predicter.predict_ensemble_wsd_and_translation_on_batch(
                ensemble, batch_x)
        if output_wsd is not None and output_translation is None:
            batch_wsd = Predicter.generate_wsd_on_batch(output_wsd, batch_z)
            result = ""
            for sample_wsd in batch_wsd:
                # sys.stdout.write(sample_wsd + "\n")
                # result.append(sample_wsd)
                result = result + sample_wsd + "\n"
            return result
        elif output_translation is not None and output_wsd is None:
            batch_translation = Predicter.generate_translation_on_batch(
                output_translation, ensemble[0].config.data_config.
                output_translation_vocabularies[0][0])
            result = ""
            for sample_translation in batch_translation:
                # sys.stdout.write(sample_translation + "\n")
                # result.append(sample_translation)
                result = result + sample_translation + "\n"
            return result
        elif output_wsd is not None and output_translation is not None:
            batch_wsd = Predicter.generate_wsd_on_batch(output_wsd, batch_z)
            batch_translation = Predicter.generate_translation_on_batch(
                output_translation, ensemble[0].config.data_config.
                output_translation_vocabularies[0][0])
            assert len(batch_wsd) == len(batch_translation)
            # result = []
            result = ""
            for i in range(len(batch_wsd)):
                # sys.stdout.write(batch_wsd[i] + "\n")
                # sys.stdout.write(batch_translation[i] + "\n")
                # result.append(batch_wsd[i])
                # result.append(batch_translation[i])
                result = result + batch_wsd[i] + "\n"
                result = result + batch_translation[i] + "\n"
            return result
        # sys.stdout.flush()

    @staticmethod
    def predict_ensemble_wsd_on_batch(ensemble: List[Model], batch_x):
        if len(ensemble) == 1:
            return ensemble[0].predict_wsd_on_batch(batch_x)
        ensemble_sample_y = None
        for model in ensemble:
            model_sample_y = model.predict_wsd_on_batch(batch_x)
            model_sample_y = log_softmax(model_sample_y, dim=2)
            if ensemble_sample_y is None:
                ensemble_sample_y = model_sample_y
            else:
                ensemble_sample_y = model_sample_y + ensemble_sample_y
        return ensemble_sample_y

    @staticmethod
    def predict_ensemble_all_features_on_batch(ensemble: List[Model], batch_x):
        if len(ensemble) == 1:
            return ensemble[0].predict_all_features_on_batch(batch_x)
        else:
            # TODO: manage ensemble
            return None

    @staticmethod
    def predict_ensemble_translation_on_batch(ensemble: List[Model], batch_x):
        if len(ensemble) == 1:
            return ensemble[0].predict_translation_on_batch(batch_x)
        else:
            # TODO: manage ensemble
            return None

    @staticmethod
    def predict_ensemble_wsd_and_translation_on_batch(ensemble: List[Model],
                                                      batch_x):
        if len(ensemble) == 1:
            return ensemble[0].predict_wsd_and_translation_on_batch(batch_x)
        else:
            # TODO: manage ensemble
            return None

    @staticmethod
    def generate_wsd_on_batch(output, batch_z):
        batch_wsd = []
        for i in range(len(batch_z[0])):
            batch_wsd.append(
                Predicter.generate_wsd_on_sample(output[i], batch_z[0][i]))
        return batch_wsd

    @staticmethod
    def generate_all_features_on_batch(output, batch_x):
        batch_wsd = []
        for i in range(len(batch_x[0])):
            batch_wsd.append(
                Predicter.generate_all_features_on_sample(output, batch_x, i))
        return batch_wsd

    @staticmethod
    def generate_translation_on_batch(output, vocabulary):
        return unpad_turn_to_text_and_remove_bpe_of_batch_t(output, vocabulary)

    @staticmethod
    def generate_wsd_on_sample(output, sample_z):
        sample_wsd: List[str] = []
        for i in range(len(sample_z)):
            restricted_possibilities = sample_z[i]
            if 0 in restricted_possibilities:
                sample_wsd.append("0")
            elif -1 in restricted_possibilities:
                sample_wsd.append(str(torch_argmax(output[i]).item()))
            else:
                max_proba = None
                max_possibility = None
                for possibility in restricted_possibilities:
                    proba = output[i][possibility]
                    if max_proba is None or proba > max_proba:
                        max_proba = proba
                        max_possibility = possibility
                sample_wsd.append(str(max_possibility))
        return " ".join(sample_wsd)

    @staticmethod
    def generate_all_features_on_sample(output, batch_x, i):
        return " ".join([
            "/".join([
                str(torch_argmax(output[k][i][j]).item())
                for k in range(len(output))
            ]) for j in range(len(batch_x[0][i]))
        ])
Example #4
0
class Predicter(object):

    def __init__(self):
        self.training_root_path: str = str()
        self.ensemble_weights_path: List[str] = []
        self.clear_text: bool = bool()
        self.batch_size: int = int()
        self.disambiguate: bool = bool()
        self.translate: bool = False
        self.beam_size: int = int()
        self.output_all_features: bool = bool()
        self.write_log: bool = bool()
        self.data_config: DataConfig = None
        self.log = open('output.log', 'w+')

    def predict(self):
        config_file_path = self.training_root_path + "/config.json"
        self.data_config = DataConfig()
        self.data_config.load_from_file(config_file_path)
        config = ModelConfig(self.data_config)
        config.load_from_file(config_file_path)
        if self.clear_text:
            config.data_config.input_clear_text = [True for _ in range(config.data_config.input_features)]
        if self.data_config.output_features <= 0:
            self.disambiguate = False
        if self.data_config.output_translations <= 0:
            self.translate = False

        assert(self.disambiguate or self.translate)

        ensemble = self.create_ensemble(config, self.ensemble_weights_path)

        i = 0
        batch_x = None
        batch_z = None
        for line in sys.stdin:
            if self.write_log:
                self.log.write("input: " + line)
            if i == 0:
                sample_x = read_sample_x_from_string(line, feature_count=config.data_config.input_features, clear_text=config.data_config.input_clear_text)
                self.preprocess_sample_x(ensemble, sample_x)
                if batch_x is None:
                    batch_x = [[] for _ in range(len(sample_x))]
                for j in range(len(sample_x)):
                    batch_x[j].append(sample_x[j])
                if self.disambiguate and not self.output_all_features:
                    i = 1
                else:
                    if len(batch_x[0]) >= self.batch_size:
                        self.predict_and_output(ensemble, batch_x, batch_z, self.data_config.input_clear_text)
                        batch_x = None
            elif i == 1:
                sample_z = read_sample_z_from_string(line, feature_count=config.data_config.output_features)
                if batch_z is None:
                    batch_z = [[] for _ in range(len(sample_z))]
                for j in range(len(sample_z)):
                    batch_z[j].append(sample_z[j])
                i = 0
                if len(batch_z[0]) >= self.batch_size:
                    self.predict_and_output(ensemble, batch_x, batch_z, self.data_config.input_clear_text)
                    batch_x = None
                    batch_z = None

        if batch_x is not None:
            self.predict_and_output(ensemble, batch_x, batch_z, self.data_config.input_clear_text)

    def create_ensemble(self, config: ModelConfig, ensemble_weights_paths: List[str]):
        ensemble = [Model(config) for _ in range(len(ensemble_weights_paths))]
        for i in range(len(ensemble)):
            ensemble[i].load_model_weights(ensemble_weights_paths[i])
            ensemble[i].set_beam_size(self.beam_size)
        return ensemble

    @staticmethod
    def preprocess_sample_x(ensemble: List[Model], sample_x):
        ensemble[0].preprocess_samples([[sample_x]])

    def _write(self, val):
        if self.write_log:
            self.log.write("output: " + val)
        sys.stdout.write(val)

    def predict_and_output(self, ensemble: List[Model], batch_x, batch_z, clear_text):
        pad_batch_x(batch_x, clear_text)
        output_wsd, output_translation = None, None
        # TODO: refact this horror
        if self.disambiguate and not self.translate and self.output_all_features:
            output_all_features = Predicter.predict_ensemble_all_features_on_batch(ensemble, batch_x)
            batch_all_features = Predicter.generate_all_features_on_batch(output_all_features, batch_x)
            for sample_all_features in batch_all_features:
                self._write(sample_all_features + "\n")
            sys.stdout.flush()
            return
        if self.disambiguate and not self.translate:
            output_wsd = Predicter.predict_ensemble_wsd_on_batch(ensemble, batch_x)
        elif self.translate and not self.disambiguate:
            output_translation = Predicter.predict_ensemble_translation_on_batch(ensemble, batch_x)
        else:
            output_wsd, output_translation = Predicter.predict_ensemble_wsd_and_translation_on_batch(ensemble, batch_x)
        if output_wsd is not None and output_translation is None:
            batch_wsd = Predicter.generate_wsd_on_batch(output_wsd, batch_z)
            for sample_wsd in batch_wsd:
                self._write(sample_wsd + "\n")
        elif output_translation is not None and output_wsd is None:
            batch_translation = Predicter.generate_translation_on_batch(output_translation, ensemble[0].config.data_config.output_translation_vocabularies[0][0])
            for sample_translation in batch_translation:
                self._write(sample_translation + "\n")
        elif output_wsd is not None and output_translation is not None:
            batch_wsd = Predicter.generate_wsd_on_batch(output_wsd, batch_z)
            batch_translation = Predicter.generate_translation_on_batch(output_translation, ensemble[0].config.data_config.output_translation_vocabularies[0][0])
            assert len(batch_wsd) == len(batch_translation)
            for wsd, trans in zip(batch_wsd, batch_translation):
                self._write(wsd + "\n" + trans + "\n")
        sys.stdout.flush()

    @staticmethod
    def predict_ensemble_wsd_on_batch(ensemble: List[Model], batch_x):
        if len(ensemble) == 1:
            return ensemble[0].predict_wsd_on_batch(batch_x)
        ensemble_sample_y = None
        for model in ensemble:
            model_sample_y = model.predict_wsd_on_batch(batch_x)
            model_sample_y = log_softmax(model_sample_y, dim=2)
            if ensemble_sample_y is None:
                ensemble_sample_y = model_sample_y
            else:
                ensemble_sample_y = model_sample_y + ensemble_sample_y
        return ensemble_sample_y

    @staticmethod
    def predict_ensemble_all_features_on_batch(ensemble: List[Model], batch_x):
        if len(ensemble) == 1:
            return ensemble[0].predict_all_features_on_batch(batch_x)
        else:
            # TODO: manage ensemble
            return None

    @staticmethod
    def predict_ensemble_translation_on_batch(ensemble: List[Model], batch_x):
        if len(ensemble) == 1:
            return ensemble[0].predict_translation_on_batch(batch_x)
        else:
            # TODO: manage ensemble
            return None

    @staticmethod
    def predict_ensemble_wsd_and_translation_on_batch(ensemble: List[Model], batch_x):
        if len(ensemble) == 1:
            return ensemble[0].predict_wsd_and_translation_on_batch(batch_x)
        else:
            # TODO: manage ensemble
            return None

    @staticmethod
    def generate_wsd_on_batch(output, batch_z):
        batch_wsd = []
        for i in range(len(batch_z[0])):
            batch_wsd.append(Predicter.generate_wsd_on_sample(output[i], batch_z[0][i]))
        return batch_wsd

    @staticmethod
    def generate_all_features_on_batch(output, batch_x):
        batch_wsd = []
        for i in range(len(batch_x[0])):
            batch_wsd.append(Predicter.generate_all_features_on_sample(output, batch_x, i))
        return batch_wsd

    @staticmethod
    def generate_translation_on_batch(output, vocabulary):
        return unpad_turn_to_text_and_remove_bpe_of_batch_t(output, vocabulary)

    @staticmethod
    def generate_wsd_on_sample(output, sample_z):
        sample_wsd: List[str] = []
        for i in range(len(sample_z)):
            restricted_possibilities = sample_z[i]
            if 0 in restricted_possibilities:
                sample_wsd.append("0")
            elif -1 in restricted_possibilities:
                sample_wsd.append(str(torch_argmax(output[i]).item()))
            else:
                max_proba = None
                max_possibility = None
                for possibility in restricted_possibilities:
                    proba = output[i][possibility]
                    if max_proba is None or proba > max_proba:
                        max_proba = proba
                        max_possibility = possibility
                sample_wsd.append(str(max_possibility))
        return " ".join(sample_wsd)

    @staticmethod
    def generate_all_features_on_sample(output, batch_x, i):
        return " ".join(["/".join([str(torch_argmax(output[k][i][j]).item()) for k in range(len(output))]) for j in range(len(batch_x[0][i]))])
Example #5
0
    def train(self):
        model_weights_last_path = self.model_path + "/model_weights_last"
        model_weights_before_last_path = self.model_path + "/model_weights_before_last"
        model_weights_loss_path = self.model_path + "/model_weights_loss"
        model_weights_loss_before_path = self.model_path + "/model_weights_loss_before"
        model_weights_wsd_path = self.model_path + "/model_weights_wsd"
        model_weights_wsd_before_path = self.model_path + "/model_weights_wsd_before"
        model_weights_bleu_path = self.model_path + "/model_weights_bleu"
        model_weights_bleu_before_path = self.model_path + "/model_weights_bleu_before"
        model_weights_end_of_epoch_path = self.model_path + "/model_weights_end_of_epoch_"
        training_info_path = self.model_path + "/training_info"
        tensorboard_path = self.model_path + "/tensorboard"
        train_file_path = self.data_path + "/train"
        dev_file_path = self.data_path + "/dev"
        config_file_path = self.data_path + "/config.json"

        print("Loading config and embeddings")
        data_config: DataConfig = DataConfig()
        data_config.load_from_file(config_file_path)
        config: ModelConfig = ModelConfig(data_config)
        config.load_from_file(config_file_path)

        # change config from CLI parameters
        config.input_embeddings_sizes = set_if_not_none(
            self.input_embeddings_size, config.input_embeddings_sizes)
        if self.input_embeddings_tokenize_model is not None:
            config.set_input_embeddings_tokenize_model(
                self.input_embeddings_tokenize_model)
        if self.input_elmo_model is not None:
            config.set_input_elmo_path(self.input_elmo_model)
        if self.input_bert_model is not None:
            config.set_input_bert_model(self.input_bert_model)
        if self.input_auto_model is not None:
            config.set_input_auto_model(self.input_auto_model,
                                        self.input_auto_path)
        if self.input_word_dropout_rate is not None:
            config.input_word_dropout_rate = self.input_word_dropout_rate
            eprint("Warning: input_word_dropout_rate is not implemented")
        if self.input_resize is not None:
            config.set_input_resize(self.input_resize)
        config.input_linear_size = set_if_not_none(self.input_linear_size,
                                                   config.input_linear_size)
        config.input_dropout_rate = set_if_not_none(self.input_dropout_rate,
                                                    config.input_dropout_rate)
        config.input_combination_method = set_if_not_none(
            self.input_combination_method, config.input_combination_method)
        config.encoder_type = set_if_not_none(self.encoder_type,
                                              config.encoder_type)
        config.encoder_lstm_hidden_size = set_if_not_none(
            self.encoder_lstm_hidden_size, config.encoder_lstm_hidden_size)
        config.encoder_lstm_layers = set_if_not_none(
            self.encoder_lstm_layers, config.encoder_lstm_layers)
        config.encoder_lstm_dropout = set_if_not_none(
            self.encoder_lstm_dropout, config.encoder_lstm_dropout)
        config.encoder_transformer_hidden_size = set_if_not_none(
            self.encoder_transformer_hidden_size,
            config.encoder_transformer_hidden_size)
        config.encoder_transformer_layers = set_if_not_none(
            self.encoder_transformer_layers, config.encoder_transformer_layers)
        config.encoder_transformer_heads = set_if_not_none(
            self.encoder_transformer_heads, config.encoder_transformer_heads)
        config.encoder_transformer_dropout = set_if_not_none(
            self.encoder_transformer_dropout,
            config.encoder_transformer_dropout)
        config.encoder_transformer_positional_encoding = set_if_not_none(
            self.encoder_transformer_positional_encoding,
            config.encoder_transformer_positional_encoding)
        config.encoder_transformer_scale_embeddings = set_if_not_none(
            self.encoder_transformer_scale_embeddings,
            config.encoder_transformer_scale_embeddings)
        config.decoder_translation_transformer_hidden_size = set_if_not_none(
            self.decoder_translation_transformer_hidden_size,
            config.decoder_translation_transformer_hidden_size)
        config.decoder_translation_transformer_layers = set_if_not_none(
            self.decoder_translation_transformer_layers,
            config.decoder_translation_transformer_layers)
        config.decoder_translation_transformer_heads = set_if_not_none(
            self.decoder_translation_transformer_heads,
            config.decoder_translation_transformer_heads)
        config.decoder_translation_transformer_dropout = set_if_not_none(
            self.decoder_translation_transformer_dropout,
            config.decoder_translation_transformer_dropout)
        config.decoder_translation_scale_embeddings = set_if_not_none(
            self.decoder_translation_scale_embeddings,
            config.decoder_translation_scale_embeddings)
        config.decoder_translation_share_embeddings = set_if_not_none(
            self.decoder_translation_share_embeddings,
            config.decoder_translation_share_embeddings)
        config.decoder_translation_share_encoder_embeddings = set_if_not_none(
            self.decoder_translation_share_encoder_embeddings,
            config.decoder_translation_share_encoder_embeddings)
        config.decoder_translation_tokenizer_bert = set_if_not_none(
            self.decoder_translation_tokenizer_bert,
            config.decoder_translation_tokenizer_bert)

        print("GPU is available: " + str(torch.cuda.is_available()))

        model: Model = Model(config)
        model.set_adam_parameters(adam_beta1=self.adam_beta1,
                                  adam_beta2=self.adam_beta2,
                                  adam_eps=self.adam_eps)
        model.set_lr_scheduler(lr_scheduler=self.lr_scheduler,
                               fixed_lr=self.lr_scheduler_fixed_lr,
                               warmup=self.lr_scheduler_noam_warmup,
                               model_size=self.lr_scheduler_noam_model_size)
        model.classifier_loss_factor = self.classifier_loss_factor
        model.decoder_loss_factor = self.decoder_loss_factor

        current_ensemble = 0
        current_epoch = 0
        current_batch = 0
        current_batch_total = 0
        current_sample_index = 0
        skipped_batch = 0
        best_dev_loss = None
        best_dev_wsd = None
        best_dev_bleu = None
        random_seed = self.generate_random_seed()

        if not self.reset and os.path.isfile(training_info_path) and (
                os.path.isfile(model_weights_last_path)
                or os.path.isfile(model_weights_before_last_path)):
            print("Resuming from previous training")
            current_ensemble, current_epoch, current_batch, current_batch_total, current_sample_index, best_dev_loss, best_dev_wsd, best_dev_bleu, random_seed = load_training_info(
                training_info_path)
            try:
                model.load_model_weights(model_weights_last_path)
            except RuntimeError as e:
                if os.path.isfile(model_weights_before_last_path):
                    print("Warning - loading before last weights: " + str(e))
                    model.load_model_weights(model_weights_before_last_path)
                else:
                    raise e
        else:
            print("Creating model")
            model.create_model()
            create_directory_if_not_exists(self.model_path)

        print("Random seed is " + str(random_seed))

        print("Config is: ")
        pprint.pprint(config.get_serializable_data())

        print("Number of parameters (total): " +
              model.get_number_of_parameters(filter_requires_grad=False))
        print("Number of parameters (learned): " +
              model.get_number_of_parameters(filter_requires_grad=True))

        print("Warming up on " + str(self.warmup_batch_count) + " batches")
        train_samples = read_samples_from_file(
            train_file_path, data_config.input_clear_text,
            data_config.output_features, data_config.output_translations,
            data_config.output_translation_features,
            data_config.output_translation_clear_text,
            self.batch_size * self.warmup_batch_count)
        model.preprocess_samples(train_samples)
        for i in range(self.warmup_batch_count):
            batch_x, batch_y, batch_z, batch_tt, actual_batch_size, reached_eof = read_batch_from_samples(
                train_samples, self.batch_size, -1, 0,
                data_config.input_features, data_config.output_features,
                data_config.output_translations,
                data_config.output_translation_features,
                data_config.input_clear_text,
                data_config.output_translation_clear_text)
            model.begin_train_on_batch()
            model.train_on_batch(batch_x, batch_y, batch_tt)
            model.end_train_on_batch()

        print("Loading training and development data")
        train_samples = read_samples_from_file(
            train_file_path,
            input_clear_text=data_config.input_clear_text,
            output_features=data_config.output_features,
            output_translations=data_config.output_translations,
            output_translation_features=data_config.
            output_translation_features,
            output_translation_clear_text=data_config.
            output_translation_clear_text)
        dev_samples = read_samples_from_file(
            dev_file_path,
            input_clear_text=data_config.input_clear_text,
            output_features=data_config.output_features,
            output_translations=data_config.output_translations,
            output_translation_features=data_config.
            output_translation_features,
            output_translation_clear_text=data_config.
            output_translation_clear_text)

        print("Preprocessing training and development data")
        model.preprocess_samples(train_samples)
        model.preprocess_samples(dev_samples)

        if self.shuffle_train_on_init:
            print("Shuffling training data")
            random.seed(random_seed)
            random.shuffle(train_samples)

        self.print_state(
            current_ensemble,
            current_epoch, current_batch, current_batch_total,
            len(train_samples), current_sample_index, skipped_batch, [
                None for _ in range(data_config.output_features +
                                    data_config.output_translations *
                                    data_config.output_translation_features)
            ], [
                None for _ in range(data_config.output_features +
                                    data_config.output_translations *
                                    data_config.output_translation_features)
            ], [None for _ in range(data_config.output_features)], None)

        if self.reset:
            shutil.rmtree(tensorboard_path, ignore_errors=True)

        for current_ensemble in range(current_ensemble, self.ensemble_size):
            if tensorboardX is not None:
                tb_writer = tensorboardX.SummaryWriter(tensorboard_path +
                                                       '/ensemble' +
                                                       str(current_ensemble))
            else:
                tb_writer = None
            sample_accumulate_between_eval = 0
            train_losses = None
            while self.stop_after_epoch == -1 or current_epoch < self.stop_after_epoch:

                model.update_learning_rate(step=current_batch_total)

                if skipped_batch == 0:
                    print("training sample " + str(current_sample_index) +
                          "/" + str(len(train_samples)),
                          end="\r")
                else:
                    print("training sample " + str(current_sample_index) +
                          "/" + str(len(train_samples)) + " (skipped " +
                          str(skipped_batch) + " batch)",
                          end="\r")
                sys.stdout.flush()

                reached_eof = False
                model.begin_train_on_batch()
                sub_batch_index = 0
                while sub_batch_index < self.update_every_batch:
                    batch_x, batch_y, batch_z, batch_tt, actual_batch_size, reached_eof = read_batch_from_samples(
                        train_samples, self.batch_size, self.token_per_batch,
                        current_sample_index, data_config.input_features,
                        data_config.output_features,
                        data_config.output_translations,
                        data_config.output_translation_features,
                        data_config.input_clear_text,
                        data_config.output_translation_clear_text)
                    if actual_batch_size == 0:
                        break
                    try:
                        batch_losses = model.train_on_batch(
                            batch_x, batch_y, batch_tt)
                        if train_losses is None:
                            train_losses = [0 for _ in batch_losses]
                        for i in range(len(batch_losses)):
                            train_losses[
                                i] += batch_losses[i] * actual_batch_size
                        sub_batch_index += 1
                    except RuntimeError as e:
                        # print()
                        # print("Warning - skipping batch: " + str(e))
                        # vvv does not work because batch_x[0] may be a tuple (see bert embeddings), we should guarantee that it is a tuple
                        # print('Warning: skipping batch (batch size was ' + str(actual_batch_size) + ', sentence length was ' + str(batch_x[0].size(1)) + ")")
                        skipped_batch += 1
                        torch.cuda.empty_cache()
                        model.begin_train_on_batch()
                    current_sample_index += actual_batch_size
                    sample_accumulate_between_eval += actual_batch_size
                    current_batch += 1
                    current_batch_total += 1
                    if reached_eof:
                        break
                model.end_train_on_batch()

                if reached_eof:
                    if self.save_end_of_epoch:
                        model.save_model_weights(
                            model_weights_end_of_epoch_path +
                            str(current_epoch) + "_" + str(current_ensemble))
                    current_batch = 0
                    current_sample_index = 0
                    current_epoch += 1
                    random_seed = self.generate_random_seed()
                    random.seed(random_seed)
                    random.shuffle(train_samples)

                if current_batch % self.eval_frequency == 0:
                    dev_losses, dev_wsd, dev_bleu = self.test_on_dev(
                        dev_samples, model)
                    for i in range(len(train_losses)):
                        train_losses[i] /= float(
                            sample_accumulate_between_eval)
                    self.print_state(current_ensemble, current_epoch,
                                     current_batch, current_batch_total,
                                     len(train_samples), current_sample_index,
                                     skipped_batch, train_losses, dev_losses,
                                     dev_wsd, dev_bleu)
                    self.write_tensorboard(
                        tb_writer, current_epoch, train_samples,
                        current_sample_index, train_losses, dev_losses,
                        dev_wsd, data_config.output_feature_names, dev_bleu,
                        model.optimizer.scheduler.get_learning_rate(
                            current_batch_total))
                    sample_accumulate_between_eval = 0
                    train_losses = None
                    skipped_batch = 0

                    if best_dev_loss is None or dev_losses[0] < best_dev_loss:
                        if self.save_best_loss:
                            rename_file_if_exists(
                                model_weights_loss_path +
                                str(current_ensemble),
                                model_weights_loss_before_path +
                                str(current_ensemble))
                            model.save_model_weights(model_weights_loss_path +
                                                     str(current_ensemble))
                            remove_file_if_exists(
                                model_weights_loss_before_path +
                                str(current_ensemble))
                            print("New best dev loss: " + str(dev_losses[0]))
                        best_dev_loss = dev_losses[0]

                    if len(dev_wsd) > 0 and (best_dev_wsd is None
                                             or dev_wsd[0] > best_dev_wsd):
                        rename_file_if_exists(
                            model_weights_wsd_path + str(current_ensemble),
                            model_weights_wsd_before_path +
                            str(current_ensemble))
                        model.save_model_weights(model_weights_wsd_path +
                                                 str(current_ensemble))
                        remove_file_if_exists(model_weights_wsd_before_path +
                                              str(current_ensemble))
                        best_dev_wsd = dev_wsd[0]
                        print("New best dev WSD: " + str(best_dev_wsd))

                    if (best_dev_bleu is None or
                            dev_bleu > best_dev_bleu) and dev_bleu is not None:
                        rename_file_if_exists(
                            model_weights_bleu_path + str(current_ensemble),
                            model_weights_bleu_before_path +
                            str(current_ensemble))
                        model.save_model_weights(model_weights_bleu_path +
                                                 str(current_ensemble))
                        remove_file_if_exists(model_weights_bleu_before_path +
                                              str(current_ensemble))
                        best_dev_bleu = dev_bleu
                        print("New best dev BLEU: " + str(best_dev_bleu))

                    rename_file_if_exists(model_weights_last_path,
                                          model_weights_before_last_path)
                    model.save_model_weights(model_weights_last_path)
                    remove_file_if_exists(model_weights_before_last_path)
                    save_training_info(training_info_path, current_ensemble,
                                       current_epoch, current_batch,
                                       current_batch_total,
                                       current_sample_index, best_dev_loss,
                                       best_dev_wsd, best_dev_bleu,
                                       random_seed)

            model.create_model()
            current_epoch = 0
            current_batch_total = 0
            best_dev_loss = None
            best_dev_wsd = None
            best_dev_bleu = None