Esempio n. 1
0
    def __init__(self, **kwargs) :
        dataset_folder = Path(kwargs["dataset_folder"]).resolve()
        check_valid_path(dataset_folder)
        result_folder = kwargs["result_folder"]

        self.initial_epoch = 1
        self.test_mode = kwargs["test"]
        self.epochs = kwargs["epochs"]
        self.hidden_size = kwargs["hidden_size"]
        self.num_heads = kwargs["heads"]
        self.use_label_smoothing = kwargs["label_smoothing"]
        
        self.ckpt_path = kwargs["ckpt_path"]
        self.ckpt_epoch = kwargs["ckpt_epoch"]
        
        # model에 필요한 폴더 및 파일 생성
        self.log_folder, self.ckpt_folder, self.image_folder = create_folder(result_folder)
        if not self.test_mode :
            self.training_result_file = self.log_folder / "training_result.txt"
        self.test_result_file = None
        
        # kwargs 값 저장
        msg = ""
        for k, v in list(kwargs.items()) :
            msg += "{} = {}\n".format(k, v)
        msg += "new model checkpoint path = {}\n".format(self.ckpt_folder)
        with (self.log_folder / "model_settings.txt").open("w", encoding = "utf-8") as fp :
            fp.write(msg)
        
        # 필요한 data를 불러옴
        self.src_word2id, self.src_id2word, self.src_vocab_size = load_word_dic(dataset_folder / "src_word2id.pkl")
        self.tar_word2id, self.tar_id2word, self.tar_vocab_size = load_word_dic(dataset_folder / "tar_word2id.pkl")
        
        if not self.test_mode :
            # encoder data : <END> tag 추가
            # decoder data : 1) input = <START> tag만 추가 2)output = <END> tag만 추가
            train_src, num_train_src = get_dataset(self.src_word2id, dataset_folder / "train_src.txt", False, True, True)
            train_tar, num_train_tar = get_dataset(self.tar_word2id, dataset_folder / "train_tar.txt", True, True, True)
            if num_train_src != num_train_tar :
                raise Exception("한글 데이터셋({})과 영어 데이터셋({})의 크기가 다릅니다.".format(
                    num_train_src, num_train_tar))

            self.num_train = num_train_src
            self.train_dataset = tf.data.Dataset.from_generator(lambda: zip(train_src, train_tar), (tf.int32, tf.int32))
            self.train_dataset = self.train_dataset.cache().shuffle(self.num_train + 1).padded_batch(
                batch_size = kwargs["batch_size"], padded_shapes = (tf.TensorShape([None]), tf.TensorShape([None])), 
                padding_values = (self.src_word2id["<PAD>"], self.tar_word2id["<PAD>"])).prefetch(1)

        test_src_path = dataset_folder / "test.txt"
        if test_src_path.exists() :
            test_src, self.num_test = get_dataset(self.src_word2id, test_src_path, False, True, False)
            self.test_dataset = tf.data.Dataset.from_generator(lambda: test_src, tf.int32)
            self.test_dataset = self.test_dataset.cache().batch(1).prefetch(1)
            self.test_result_file = self.log_folder / "test_result.txt"
        elif self.test_mode :
            raise FileNotFoundError("[ {} ] 경로가 존재하지 않습니다.".format(test_src_path))

        self.transformer = Transformer(self.src_vocab_size, self.tar_vocab_size, self.src_word2id["<PAD>"],
            kwargs["num_layers"], kwargs["heads"], kwargs["embedding_size"], kwargs["hidden_size"],
            kwargs["dropout_rate"], kwargs["use_conv"])
Esempio n. 2
0
def load_word_dic(wordDicPath):
    check_valid_path(wordDicPath)

    print("word dictionary 불러오는 중....")
    with wordDicPath.open("rb") as fp:
        word2id = pickle.load(fp)
    id2word = {v: k for k, v in list(word2id.items())}
    vocab_size = len(word2id)

    return word2id, id2word, vocab_size
Esempio n. 3
0
def get_dataset(word2id, dataset_path, add_start_tag, add_end_tag,
                error_when_no_token):
    # dataset에는 이미 전처리가 완료된 token이 저장되어 있음
    check_valid_path(dataset_path)

    with dataset_path.open("r", encoding="utf-8") as fp:
        print("[ {} ] dataset 불러오는 중...".format(dataset_path))
        token_lists = [
            word_to_id(line.split(" "), word2id, add_start_tag, add_end_tag,
                       error_when_no_token)
            for line in [line.strip() for line in fp.readlines()]
        ]

    num_data = len(token_lists)

    return token_lists, num_data
Esempio n. 4
0
File: Train.py Progetto: bothe/NMT
    def __init__(self, **kwargs):
        dataset_folder = Path(kwargs["dataset_folder"]).resolve()
        check_valid_path(dataset_folder)
        result_folder = kwargs["result_folder"]

        self.initial_epoch = 1
        self.test_mode = kwargs["test"]
        self.epochs = kwargs["epochs"]
        self.use_label_smoothing = kwargs["label_smoothing"]

        self.ckpt_path = kwargs["ckpt_path"]
        self.ckpt_epoch = kwargs["ckpt_epoch"]

        # model에 필요한 폴더 및 파일 생성
        self.log_folder, self.ckpt_folder, self.image_folder = create_folder(
            result_folder)
        if not self.test_mode:
            self.training_result_file = self.log_folder / "training_result.txt"
        self.test_result_file = None

        # kwargs 값 저장
        msg = ""
        for k, v in list(kwargs.items()):
            msg += "{} = {}\n".format(k, v)
        msg += "new model checkpoint path = {}\n".format(self.ckpt_folder)
        with (self.log_folder / "model_settings.txt").open(
                "w", encoding="utf-8") as fp:
            fp.write(msg)

        # 필요한 data를 불러옴
        self.src_word2id, self.src_id2word, self.src_vocab_size = load_word_dic(
            dataset_folder / "src_word2id.pkl")
        self.tar_word2id, self.tar_id2word, self.tar_vocab_size = load_word_dic(
            dataset_folder / "tar_word2id.pkl")

        if not self.test_mode:
            train_src, num_train_src = get_dataset(
                self.src_word2id, dataset_folder / "train_src.txt", False,
                True, True)
            train_tar, num_train_tar = get_dataset(
                self.tar_word2id, dataset_folder / "train_tar.txt", True, True,
                True)
            if num_train_src != num_train_tar:
                raise Exception(
                    "source 데이터셋({})과 target 데이터셋({})의 크기가 다릅니다.".format(
                        num_train_src, num_train_tar))

            self.num_train = num_train_src
            self.train_dataset = tf.data.Dataset.from_generator(
                lambda: zip(train_src, train_tar), (tf.int32, tf.int32))
            self.train_dataset = self.train_dataset.cache().shuffle(
                self.num_train + 1).padded_batch(
                    batch_size=kwargs["batch_size"],
                    padded_shapes=(tf.TensorShape([None]),
                                   tf.TensorShape([None])),
                    padding_values=(self.src_word2id["<PAD>"],
                                    self.tar_word2id["<PAD>"])).prefetch(1)

        test_src_path = dataset_folder / "test.txt"
        if test_src_path.exists():
            test_src, self.num_test = get_dataset(self.src_word2id,
                                                  test_src_path, False, True,
                                                  False)

            # self.test_src_max_len = max([len(sentence) for sentence in test_src])
            # padded_test_src = tf.keras.preprocessing.sequence.pad_sequences(
            #    test_src, maxlen = self.test_src_max_len, padding = 'post',
            #    dtype = 'int32', value = self.src_word2id["<PAD>"])

            self.test_dataset = tf.data.Dataset.from_generator(
                lambda: test_src, tf.int32)
            self.test_dataset = self.test_dataset.cache().batch(1).prefetch(1)
            self.test_result_file = self.log_folder / "test_result.txt"

        elif self.test_mode:
            raise FileNotFoundError(
                "[ {} ] 경로가 존재하지 않습니다.".format(test_src_path))

        self.encoder = Encoder(self.src_vocab_size, kwargs["embedding_size"],
                               kwargs["hidden_size"], kwargs["dropout_rate"],
                               kwargs["gru"], kwargs["bi"])
        self.decoder = Decoder(self.tar_vocab_size, kwargs["embedding_size"],
                               kwargs["hidden_size"], kwargs["attention_size"],
                               kwargs["dropout_rate"], kwargs["gru"],
                               kwargs["bi"])

        # 아래 line 6줄은 colab에서 한글 깨짐을 방지하기 위한 부분으로 생략해도 됩니다.
        # %config InlineBackend.figure_format = 'retina'
        # !apt -qq -y install fonts-nanum
        fontpath = '/usr/share/fonts/truetype/nanum/NanumBarunGothic.ttf'
        font = fm.FontProperties(fname=fontpath, size=9)
        plt.rc('font', family='NanumBarunGothic')
        mpl.font_manager._rebuild()