예제 #1
0
파일: base.py 프로젝트: rmotsar/finetune
    def _initialize(self):
        # Initializes the non-serialized bits of the class.
        self._set_random_seed(self.config.seed)

        download_data_if_required()
        self.encoder = TextEncoder()

        # symbolic ops
        self.logits = None  # classification logits
        self.target_loss = None  # cross-entropy loss
        self.lm_losses = None  # language modeling losses
        self.lm_predict_op = None
        self.train = None  # gradient + parameter update
        self.features = None  # hidden representation fed to classifier
        self.summaries = None  # Tensorboard summaries
        self.train_writer = None
        self.valid_writer = None
        self.predict_params = None
        self.train_op = None
        self.predict_op = None
        self.predict_proba_op = None
        self.sess = None
        self.noop = tf.no_op()

        # indicator vars
        self.is_built = False  # has tf graph been constructed?
        self.is_trained = False  # has model been fine-tuned?
예제 #2
0
파일: base.py 프로젝트: sungjunlee/finetune
    def _initialize(self):
        # Initializes the non-serialized bits of the class.
        self._set_random_seed(self.config.seed)

        download_data_if_required()
        self.encoder = TextEncoder()

        # symbolic ops
        self.logits = None  # classification logits
        self.target_loss = None  # cross-entropy loss
        self.lm_losses = None  # language modeling losses
        self.lm_predict_op = None
        self.train = None  # gradient + parameter update
        self.features = None  # hidden representation fed to classifier
        self.summaries = None  # Tensorboard summaries
        self.train_writer = None
        self.valid_writer = None
        self.predict_params = None
        self.train_op = None
        self.predict_op = None
        self.predict_proba_op = None
        self.sess = None
        self.noop = tf.no_op()

        # indicator vars
        self.is_built = False  # has tf graph been constructed?
        self.is_trained = False  # has model been fine-tuned?
        self.require_lm = False

        def process_embeddings(name, value):
            if "/we:0" not in name:
                return value

            vocab_size = self.encoder.vocab_size
            word_embeddings = value[:vocab_size -
                                    len(self.encoder.special_tokens)]
            special_embed = value[len(word_embeddings):vocab_size]
            positional_embed = value[vocab_size:]
            if self.config.interpolate_pos_embed and self.config.max_length != len(
                    positional_embed):
                positional_embed = interpolate_pos_embed(
                    positional_embed, self.config.max_length)
            elif self.config.max_length > len(positional_embed):
                raise ValueError(
                    "Max Length cannot be greater than {} if interploate_pos_embed is turned off"
                    .format(len(positional_embed)))
            else:
                positional_embed = positional_embed[:self.config.max_length]

            embeddings = np.concatenate(
                (word_embeddings, special_embed, positional_embed), axis=0)
            return embeddings

        self.saver = Saver(
            fallback_filename=JL_BASE,
            exclude_matches=None if self.config.save_adam_vars else "adam",
            variable_transforms=[process_embeddings])
예제 #3
0
from abc import ABCMeta, abstractmethod

import tqdm
import numpy as np
import tensorflow as tf
from tensorflow.python.data import Dataset
from sklearn.model_selection import train_test_split


from finetune.errors import FinetuneError
from finetune.config import PAD_TOKEN
from finetune.encoding import TextEncoder, ArrayEncodedOutput, EncodedOutput
from finetune.imbalance import compute_class_weights

ENCODER = TextEncoder()
LOGGER = logging.getLogger('finetune')


class BasePipeline(metaclass=ABCMeta):
    def __init__(self, config):
        self.config = config
        self.label_encoder = None
        self.target_dim = None
        self.pad_idx_ = None
        self.rebuild = False
        self.epoch = 0

    @abstractmethod
    def _target_encoder(self):
        # Overridden by subclass to produce the right target encoding for a given target model.