Exemplo n.º 1
0
def load_checkpoint(sess,
                    checkpoint_dir,
                    filename=None,
                    blacklist=(),
                    prefix=None):
    """
    if `filename` is None, we load last checkpoint, otherwise
      we ignore `checkpoint_dir` and load the given checkpoint file.
    """
    if filename is None:
        # load last checkpoint
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt is not None:
            filename = ckpt.model_checkpoint_path
    else:
        checkpoint_dir = os.path.dirname(filename)

    vars_ = []
    var_names = []
    for var in tf.global_variables():
        if prefix is None or var.name.startswith(prefix):
            name = var.name if prefix is None else var.name[len(prefix) + 1:]
            vars_.append(var)
            var_names.append(name)

    var_file = os.path.join(checkpoint_dir, 'vars.pkl')
    if os.path.exists(var_file):
        with open(var_file, 'rb') as f:
            old_names = pickle.load(f)
    else:
        old_names = list(var_names)

    name_mapping = {}
    for name in old_names:
        name_ = name
        for key, value in variable_mapping:
            name_ = re.sub(key, value, name_)
        name_mapping[name] = name_

    var_names_ = []
    for name in var_names:
        for key, value in reverse_mapping:
            name = re.sub(key, value, name)
        var_names_.append(name)
    vars_ = dict(zip(var_names_, vars_))

    variables = {
        old_name[:-2]: vars_[new_name]
        for old_name, new_name in name_mapping.items()
        if new_name in vars_ and not any(prefix in new_name
                                         for prefix in blacklist)
    }

    if filename is not None:
        utils.log('reading model parameters from {}'.format(filename))
        tf.train.Saver(variables).restore(sess, filename)

        utils.debug('retrieved parameters ({})'.format(len(variables)))
        for var in sorted(variables.values(), key=lambda var: var.name):
            utils.debug('  {} {}'.format(var.name, var.get_shape()))
Exemplo n.º 2
0
    def read_data(self, max_train_size, max_dev_size):
        utils.debug('reading training data')
        train_set = utils.read_dataset(self.filenames.train,
                                       self.extensions,
                                       self.vocabs,
                                       max_size=max_train_size,
                                       binary_input=self.binary_input,
                                       character_level=self.character_level)
        self.batch_iterator = utils.read_ahead_batch_iterator(train_set,
                                                              self.batch_size,
                                                              read_ahead=10)

        utils.debug('reading development data')
        dev_sets = [
            utils.read_dataset(dev,
                               self.extensions,
                               self.vocabs,
                               max_size=max_dev_size,
                               binary_input=self.binary_input,
                               character_level=self.character_level)
            for dev in self.filenames.dev
        ]
        # subset of the dev set whose perplexity is periodically evaluated
        self.dev_batches = [
            utils.get_batches(dev_set, batch_size=self.batch_size, batches=-1)
            for dev_set in dev_sets
        ]
Exemplo n.º 3
0
def load_checkpoint(sess, checkpoint_dir, filename, variables):
    if filename is not None:
        ckpt_file = checkpoint_dir + "/" + filename
        utils.log('reading model parameters from {}'.format(ckpt_file))
        tf.train.Saver(variables).restore(sess, ckpt_file)

        utils.debug('retrieved parameters ({})'.format(len(variables)))
        for var in sorted(variables, key=lambda var: var.name):
            utils.debug('  {} {}'.format(var.name, var.get_shape()))
Exemplo n.º 4
0
    def initialize(self,
                   checkpoints=None,
                   reset=False,
                   reset_learning_rate=False,
                   max_to_keep=1,
                   keep_every_n_hours=0,
                   sess=None,
                   **kwargs):
        """
        :param checkpoints: list of checkpoints to load (instead of latest checkpoint)
        :param reset: don't load latest checkpoint, reset learning rate and global step
        :param reset_learning_rate: reset the learning rate to its initial value
        :param max_to_keep: keep this many latest checkpoints at all times
        :param keep_every_n_hours: and keep checkpoints every n hours
        """
        sess = sess or tf.get_default_session()

        if keep_every_n_hours <= 0 or keep_every_n_hours is None:
            keep_every_n_hours = float('inf')

        self.saver = tf.train.Saver(
            max_to_keep=max_to_keep,
            keep_checkpoint_every_n_hours=keep_every_n_hours,
            sharded=False)

        sess.run(tf.global_variables_initializer())
        blacklist = ['dropout_keep_prob']

        if reset_learning_rate or reset:
            blacklist.append('learning_rate')
        if reset:
            blacklist.append('global_step')

        if checkpoints and len(self.models) > 1:
            assert len(self.models) == len(checkpoints)
            for i, checkpoint in enumerate(checkpoints, 1):
                load_checkpoint(sess,
                                None,
                                checkpoint,
                                blacklist=blacklist,
                                prefix='model_{}'.format(i))
        elif checkpoints:  # load partial checkpoints
            for checkpoint in checkpoints:  # checkpoint files to load
                load_checkpoint(sess, None, checkpoint, blacklist=blacklist)
        elif not reset:
            load_checkpoint(sess, self.checkpoint_dir, blacklist=blacklist)

        utils.debug('global step: {}'.format(self.global_step.eval()))
        utils.debug('baseline step: {}'.format(self.baseline_step.eval()))
Exemplo n.º 5
0
    def align(self, sess, output=None, align_encoder_id=0, **kwargs):
        if self.binary and any(self.binary):
            raise NotImplementedError

        if len(self.filenames.test) != len(self.extensions):
            raise Exception('wrong number of input files')

        for line_id, lines in enumerate(utils.read_lines(self.filenames.test)):
            token_ids = [
                sentence if vocab is None else utils.sentence_to_token_ids(
                    sentence,
                    vocab.vocab,
                    character_level=self.character_level.get(ext)) for ext,
                vocab, sentence in zip(self.extensions, self.vocabs, lines)
            ]

            _, weights = self.seq2seq_model.step(sess,
                                                 data=[token_ids],
                                                 forward_only=True,
                                                 align=True,
                                                 update_model=False)

            trg_vocab = self.trg_vocab[0]  # FIXME
            trg_token_ids = token_ids[len(self.src_ext)]
            trg_tokens = [
                trg_vocab.reverse[i]
                if i < len(trg_vocab.reverse) else utils._UNK
                for i in trg_token_ids
            ]

            weights = weights.squeeze()
            max_len = weights.shape[1]

            utils.debug(weights)

            trg_tokens.append(utils._EOS)
            src_tokens = lines[align_encoder_id].split()[:max_len -
                                                         1] + [utils._EOS]

            output_file = '{}.{}.svg'.format(output, line_id +
                                             1) if output is not None else None

            utils.heatmap(src_tokens,
                          trg_tokens,
                          weights,
                          output_file=output_file)
Exemplo n.º 6
0
    def read_data(self, max_train_size, max_dev_size, read_ahead=10, batch_mode='standard', shuffle=True,
                  crash_test=False, **kwargs):
        utils.debug('reading training data')
        self.batch_iterator, self.train_size = utils.get_batch_iterator(
            self.filenames.train, self.extensions, self.vocabs, self.batch_size,
            max_size=max_train_size, character_level=self.character_level, max_seq_len=self.max_len,
            read_ahead=read_ahead, mode=batch_mode, shuffle=shuffle, binary=self.binary, crash_test=crash_test
        )

        utils.debug('reading development data')

        dev_sets = [
            utils.read_dataset(dev, self.extensions, self.vocabs, max_size=max_dev_size,
                               character_level=self.character_level, binary=self.binary)[0]
            for dev in self.filenames.dev
        ]
        # subset of the dev set whose loss is periodically evaluated
        self.dev_batches = [utils.get_batches(dev_set, batch_size=self.batch_size) for dev_set in dev_sets]
Exemplo n.º 7
0
    def read_data(self,
                  max_train_size,
                  max_dev_size,
                  read_ahead=10,
                  batch_mode='standard',
                  shuffle=True,
                  **kwargs):
        utils.debug('reading training data')
        train_set = utils.read_dataset(self.filenames.train,
                                       self.extensions,
                                       self.vocabs,
                                       max_size=max_train_size,
                                       binary_input=self.binary_input,
                                       character_level=self.character_level,
                                       max_seq_len=self.max_input_len)
        self.train_size = len(train_set)
        self.batch_iterator = utils.read_ahead_batch_iterator(
            train_set,
            self.batch_size,
            read_ahead=read_ahead,
            mode=batch_mode,
            shuffle=shuffle)

        utils.debug('reading development data')
        dev_sets = [
            utils.read_dataset(dev,
                               self.extensions,
                               self.vocabs,
                               max_size=max_dev_size,
                               binary_input=self.binary_input,
                               character_level=self.character_level)
            for dev in self.filenames.dev
        ]
        # subset of the dev set whose perplexity is periodically evaluated
        self.dev_batches = [
            utils.get_batches(dev_set, batch_size=self.batch_size)
            for dev_set in dev_sets
        ]
Exemplo n.º 8
0
def load_checkpoint(sess, checkpoint_dir, filename=None, blacklist=()):
    """ `checkpoint_dir` should be unique to this model
    if `filename` is None, we load last checkpoint, otherwise
      we ignore `checkpoint_dir` and load the given checkpoint file.
    """
    if filename is None:
        # load last checkpoint
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt is not None:
            filename = ckpt.model_checkpoint_path
    else:
        checkpoint_dir = os.path.dirname(filename)

    var_file = os.path.join(checkpoint_dir, 'vars.pkl')

    if os.path.exists(var_file):
        with open(var_file, 'rb') as f:
            var_names = pickle.load(f)
            variables = [
                var for var in tf.global_variables() if var.name in var_names
            ]
    else:
        variables = tf.global_variables()

    # remove variables from blacklist
    variables = [
        var for var in variables
        if not any(prefix in var.name for prefix in blacklist)
    ]

    if filename is not None:
        utils.log('reading model parameters from {}'.format(filename))
        tf.train.Saver(variables).restore(sess, filename)

        utils.debug('retrieved parameters ({})'.format(len(variables)))
        for var in variables:
            utils.debug('  {} {}'.format(var.name, var.get_shape()))
Exemplo n.º 9
0
    def __init__(self,
                 name,
                 encoders,
                 decoder,
                 checkpoint_dir,
                 learning_rate,
                 learning_rate_decay_factor,
                 batch_size,
                 keep_best=1,
                 load_embeddings=None,
                 max_input_len=None,
                 **kwargs):
        super(TranslationModel, self).__init__(name, checkpoint_dir, keep_best,
                                               **kwargs)

        self.batch_size = batch_size
        self.src_ext = [
            encoder.get('ext') or encoder.name for encoder in encoders
        ]
        self.trg_ext = decoder.get('ext') or decoder.name
        self.extensions = self.src_ext + [self.trg_ext]
        self.max_input_len = max_input_len

        encoders_and_decoder = encoders + [decoder]
        self.binary_input = [
            encoder_or_decoder.binary
            for encoder_or_decoder in encoders_and_decoder
        ]
        self.character_level = [
            encoder_or_decoder.character_level
            for encoder_or_decoder in encoders_and_decoder
        ]

        self.learning_rate = tf.Variable(learning_rate,
                                         trainable=False,
                                         name='learning_rate',
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * learning_rate_decay_factor)

        with tf.device('/cpu:0'):
            self.global_step = tf.Variable(0,
                                           trainable=False,
                                           name='global_step')

        self.filenames = utils.get_filenames(extensions=self.extensions,
                                             **kwargs)
        # TODO: check that filenames exist
        utils.debug('reading vocabularies')
        self._read_vocab()

        for encoder_or_decoder, vocab in zip(encoders + [decoder],
                                             self.vocabs):
            if encoder_or_decoder.vocab_size <= 0 and vocab is not None:
                encoder_or_decoder.vocab_size = len(vocab.reverse)

        # this adds an `embedding' attribute to each encoder and decoder
        utils.read_embeddings(self.filenames.embeddings, encoders + [decoder],
                              load_embeddings, self.vocabs)

        # main model
        utils.debug('creating model {}'.format(name))
        self.seq2seq_model = Seq2SeqModel(encoders,
                                          decoder,
                                          self.learning_rate,
                                          self.global_step,
                                          max_input_len=max_input_len,
                                          **kwargs)

        self.batch_iterator = None
        self.dev_batches = None
        self.train_size = None
        self.use_sgd = False
Exemplo n.º 10
0
def main(args=None):
    args = parser.parse_args(args)

    # read config file and default config
    with open('config/default.yaml') as f:
        default_config = utils.AttrDict(yaml.safe_load(f))

    with open(args.config) as f:
        config = utils.AttrDict(yaml.safe_load(f))
        
        if args.learning_rate is not None:
            args.reset_learning_rate = True
        
        # command-line parameters have higher precedence than config file
        for k, v in vars(args).items():
            if v is not None:
                config[k] = v

        # set default values for parameters that are not defined
        for k, v in default_config.items():
            config.setdefault(k, v)

    if config.score_function:
        config.score_functions = evaluation.name_mapping[config.score_function]

    if args.crash_test:
        config.max_train_size = 0

    if not config.debug:
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # disable TensorFlow's debugging logs
    decoding_mode = any(arg is not None for arg in (args.decode, args.eval, args.align))

    # enforce parameter constraints
    assert config.steps_per_eval % config.steps_per_checkpoint == 0, (
        'steps-per-eval should be a multiple of steps-per-checkpoint')
    assert decoding_mode or args.train or args.save or args.save_embedding, (
        'you need to specify at least one action (decode, eval, align, or train)')
    assert not (args.average and args.ensemble)

    if args.train and args.purge:
        utils.log('deleting previous model')
        shutil.rmtree(config.model_dir, ignore_errors=True)

    os.makedirs(config.model_dir, exist_ok=True)

    # copy config file to model directory
    config_path = os.path.join(config.model_dir, 'config.yaml')
    if args.train and not os.path.exists(config_path):
        with open(args.config) as config_file, open(config_path, 'w') as dest_file:
            content = config_file.read()
            content = re.sub(r'model_dir:.*?\n', 'model_dir: {}\n'.format(config.model_dir), content,
                             flags=re.MULTILINE)
            dest_file.write(content)

    # also copy default config
    config_path = os.path.join(config.model_dir, 'default.yaml')
    if args.train and not os.path.exists(config_path):
        shutil.copy('config/default.yaml', config_path)

    # copy source code to model directory
    tar_path =  os.path.join(config.model_dir, 'code.tar.gz')
    if args.train and not os.path.exists(tar_path):
        with tarfile.open(tar_path, "w:gz") as tar:
            for filename in os.listdir('translate'):
                if filename.endswith('.py'):
                    tar.add(os.path.join('translate', filename), arcname=filename)

    logging_level = logging.DEBUG if args.verbose else logging.INFO
    # always log to stdout in decoding and eval modes (to avoid overwriting precious train logs)
    log_path = os.path.join(config.model_dir, config.log_file)
    logger = utils.create_logger(log_path if args.train else None)
    logger.setLevel(logging_level)

    utils.log('label: {}'.format(config.label))
    utils.log('description:\n  {}'.format('\n  '.join(config.description.strip().split('\n'))))

    utils.log(' '.join(sys.argv))  # print command line
    try:  # print git hash
        commit_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode().strip()
        utils.log('commit hash {}'.format(commit_hash))
    except:
        pass

    utils.log('tensorflow version: {}'.format(tf.__version__))

    # log parameters
    utils.debug('program arguments')
    for k, v in sorted(config.items(), key=itemgetter(0)):
        utils.debug('  {:<20} {}'.format(k, pformat(v)))

    if isinstance(config.dev_prefix, str):
        config.dev_prefix = [config.dev_prefix]

    if config.tasks is not None:
        config.tasks = [utils.AttrDict(task) for task in config.tasks]
        tasks = config.tasks
    else:
        tasks = [config]

    for task in tasks:
        for parameter, value in config.items():
            task.setdefault(parameter, value)

        task.encoders = [utils.AttrDict(encoder) for encoder in task.encoders]
        task.decoders = [utils.AttrDict(decoder) for decoder in task.decoders]

        for encoder_or_decoder in task.encoders + task.decoders:
            for parameter, value in task.items():
                encoder_or_decoder.setdefault(parameter, value)

        if args.max_len:
            args.max_input_len = args.max_len
        if args.max_output_len:   # override decoder's max len
            task.decoders[0].max_len = args.max_output_len
        if args.max_input_len:    # override encoder's max len
            task.encoders[0].max_len = args.max_input_len

    config.checkpoint_dir = os.path.join(config.model_dir, 'checkpoints')

    # setting random seeds
    if config.seed is None:
        config.seed = random.randrange(sys.maxsize)
    if config.tf_seed is None:
        config.tf_seed = random.randrange(sys.maxsize)
    utils.log('python random seed: {}'.format(config.seed))
    utils.log('tf random seed:     {}'.format(config.tf_seed))
    random.seed(config.seed)
    tf.set_random_seed(config.tf_seed)

    device = None
    if config.no_gpu:
        device = '/cpu:0'
        device_id = None
    elif config.gpu_id is not None:
        device = '/gpu:{}'.format(config.gpu_id)
        device_id = config.gpu_id
    else:
        device_id = 0

    # hide other GPUs so that TensorFlow won't use memory on them
    os.environ['CUDA_VISIBLE_DEVICES'] = '' if device_id is None else str(device_id)

    utils.log('creating model')
    utils.log('using device: {}'.format(device))

    with tf.device(device):
        if config.weight_scale:
            if config.initializer == 'uniform':
                initializer = tf.random_uniform_initializer(minval=-config.weight_scale, maxval=config.weight_scale)
            else:
                initializer = tf.random_normal_initializer(stddev=config.weight_scale)
        else:
            initializer = None

        tf.get_variable_scope().set_initializer(initializer)

        # exempt from creating gradient ops
        config.decode_only = decoding_mode

        if config.tasks is not None:
            model = MultiTaskModel(**config)
        else:
            model = TranslationModel(**config)

    # count parameters
    # not counting parameters created by training algorithm (e.g. Adam)
    variables = [var for var in tf.global_variables() if not var.name.startswith('gradients')]
    utils.log('model parameters ({})'.format(len(variables)))
    parameter_count = 0
    for var in sorted(variables, key=lambda var: var.name):
        utils.log('  {} {}'.format(var.name, var.get_shape()))
        v = 1
        for d in var.get_shape():
            v *= d.value
        parameter_count += v
    utils.log('number of parameters: {:.2f}M'.format(parameter_count / 1e6))

    tf_config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)
    tf_config.gpu_options.allow_growth = config.allow_growth
    tf_config.gpu_options.per_process_gpu_memory_fraction = config.mem_fraction

    def average_checkpoints(main_sess, sessions):
        for var in tf.global_variables():
            avg_value = sum(sess.run(var) for sess in sessions) / len(sessions)
            main_sess.run(var.assign(avg_value))

    with tf.Session(config=tf_config) as sess:
        best_checkpoint = os.path.join(config.checkpoint_dir, 'best')

        params = {'variable_mapping': config.variable_mapping, 'reverse_mapping': config.reverse_mapping,
                  'rnn_lm_model_dir': None, 'rnn_mt_model_dir': None,
                  'rnn_lm_cell_name': None, 'origin_model_ckpt': None}
        if config.ensemble and len(config.checkpoints) > 1:
            model.initialize(config.checkpoints, **params)
        elif config.average and len(config.checkpoints) > 1:
            model.initialize(reset=True)
            sessions = [tf.Session(config=tf_config) for _ in config.checkpoints]
            for sess_, checkpoint in zip(sessions, config.checkpoints):
                model.initialize(sess=sess_, checkpoints=[checkpoint], **params)
            average_checkpoints(sess, sessions)
        elif (not config.checkpoints and decoding_mode and
             (os.path.isfile(best_checkpoint + '.index') or os.path.isfile(best_checkpoint + '.index'))):
            # in decoding and evaluation mode, unless specified otherwise (by `checkpoints`),
            # try to load the best checkpoint
            model.initialize([best_checkpoint], **params)
        else:
            # loads last checkpoint, unless `reset` is true
            model.initialize(**config)

        if config.output is not None:
            dirname = os.path.dirname(config.output)
            if dirname:
                os.makedirs(dirname, exist_ok=True)

        try:
            if args.save:
                model.save()
            elif args.save_embedding:
                if config.embedding_output_dir is None:
                    output_dir = "."
                else:
                    output_dir = config.embedding_output_dir
                model.save_embedding(output_dir)
            elif args.decode is not None:
                if config.align is not None:
                    config.align = True
                model.decode(**config)
            elif args.eval is not None:
                model.evaluate(on_dev=False, **config)
            elif args.align is not None:
                model.align(**config)
            elif args.train:
                model.train(**config)
        except KeyboardInterrupt:
            sys.exit()
Exemplo n.º 11
0
    def initialize(self,
                   checkpoints=None,
                   reset=False,
                   reset_learning_rate=False,
                   max_to_keep=1,
                   keep_every_n_hours=0,
                   sess=None,
                   whitelist=None,
                   blacklist=None,
                   **kwargs):
        """
        :param checkpoints: list of checkpoints to load (instead of latest checkpoint)
        :param reset: don't load latest checkpoint, reset learning rate and global step
        :param reset_learning_rate: reset the learning rate to its initial value
        :param max_to_keep: keep this many latest checkpoints at all times
        :param keep_every_n_hours: and keep checkpoints every n hours
        """
        sess = sess or tf.get_default_session()

        if keep_every_n_hours <= 0 or keep_every_n_hours is None:
            keep_every_n_hours = float('inf')

        self.saver = tf.train.Saver(
            max_to_keep=max_to_keep,
            keep_checkpoint_every_n_hours=keep_every_n_hours,
            sharded=False)

        sess.run(tf.global_variables_initializer())

        # load pre-trained embeddings
        for encoder_or_decoder, vocab in zip(self.encoders + self.decoders,
                                             self.vocabs):
            if encoder_or_decoder.embedding_file:
                utils.log('loading embeddings from: {}'.format(
                    encoder_or_decoder.embedding_file))
                embeddings = {}
                with open(encoder_or_decoder.embedding_file) as embedding_file:
                    for line in embedding_file:
                        word, vector = line.split(' ', 1)
                        if word in vocab.vocab:
                            embeddings[word] = np.array(
                                list(map(float, vector.split())))
                # standardize (mean of 0, std of 0.01)
                mean = sum(embeddings.values()) / len(embeddings)
                std = np.sqrt(
                    sum((value - mean)**2
                        for value in embeddings.values())) / (len(embeddings) -
                                                              1)
                for key in embeddings:
                    embeddings[key] = 0.01 * (embeddings[key] - mean) / std

                # change TensorFlow variable's value
                with tf.variable_scope(tf.get_variable_scope(), reuse=True):
                    embedding_var = tf.get_variable('embedding_' +
                                                    encoder_or_decoder.name)
                    embedding_value = embedding_var.eval()
                    for word, i in vocab.vocab.items():
                        if word in embeddings:
                            embedding_value[i] = embeddings[word]
                    sess.run(embedding_var.assign(embedding_value))

        if whitelist:
            with open(whitelist) as f:
                whitelist = list(line.strip() for line in f)
        if blacklist:
            with open(blacklist) as f:
                blacklist = list(line.strip() for line in f)
        else:
            blacklist = []

        blacklist.append('dropout_keep_prob')

        if reset_learning_rate or reset:
            blacklist.append('learning_rate')
        if reset:
            blacklist.append('global_step')

        params = {
            k: kwargs.get(k)
            for k in ('variable_mapping', 'reverse_mapping')
        }

        if checkpoints and len(self.models) > 1:
            assert len(self.models) == len(checkpoints)
            for i, checkpoint in enumerate(checkpoints, 1):
                load_checkpoint(sess,
                                None,
                                checkpoint,
                                blacklist=blacklist,
                                whitelist=whitelist,
                                prefix='model_{}'.format(i),
                                **params)
        elif checkpoints:  # load partial checkpoints
            for checkpoint in checkpoints:  # checkpoint files to load
                load_checkpoint(sess,
                                None,
                                checkpoint,
                                blacklist=blacklist,
                                whitelist=whitelist,
                                **params)
        elif not reset:
            load_checkpoint(sess,
                            self.checkpoint_dir,
                            blacklist=blacklist,
                            whitelist=whitelist,
                            **params)

        utils.debug('global step: {}'.format(self.global_step.eval()))
        utils.debug('baseline step: {}'.format(self.baseline_step.eval()))
Exemplo n.º 12
0
    def train_step(self,
                   steps_per_checkpoint,
                   model_dir,
                   steps_per_eval=None,
                   max_steps=0,
                   max_epochs=0,
                   eval_burn_in=0,
                   decay_if_no_progress=None,
                   decay_after_n_epoch=None,
                   decay_every_n_epoch=None,
                   sgd_after_n_epoch=None,
                   sgd_learning_rate=None,
                   min_learning_rate=None,
                   loss_function='xent',
                   use_baseline=True,
                   **kwargs):
        if min_learning_rate is not None and self.learning_rate.eval(
        ) < min_learning_rate:
            utils.debug('learning rate is too small: stopping')
            raise utils.FinishedTrainingException
        if 0 < max_steps <= self.global_step.eval(
        ) or 0 < max_epochs <= self.epoch.eval():
            raise utils.FinishedTrainingException

        start_time = time.time()

        if loss_function == 'reinforce':
            step_function = self.seq2seq_model.reinforce_step
        else:
            step_function = self.seq2seq_model.step

        res = step_function(next(self.batch_iterator),
                            update_model=True,
                            use_sgd=self.training.use_sgd,
                            update_baseline=True)

        self.training.loss += res.loss
        self.training.baseline_loss += getattr(res, 'baseline_loss', 0)

        self.training.time += time.time() - start_time
        self.training.steps += 1

        global_step = self.global_step.eval()
        epoch = self.epoch.eval()

        if decay_after_n_epoch is not None and self.batch_size * global_step >= decay_after_n_epoch * self.train_size:
            if decay_every_n_epoch is not None and (
                    self.batch_size * (global_step - self.training.last_decay)
                    >= decay_every_n_epoch * self.train_size):
                self.learning_rate_decay_op.eval()
                utils.debug('  decaying learning rate to: {:.3g}'.format(
                    self.learning_rate.eval()))
                self.training.last_decay = global_step

        if sgd_after_n_epoch is not None and epoch >= sgd_after_n_epoch:
            if not self.training.use_sgd:
                utils.debug('epoch {}, starting to use SGD'.format(epoch + 1))
                self.training.use_sgd = True
                if sgd_learning_rate is not None:
                    self.learning_rate.assign(sgd_learning_rate).eval()
                self.training.last_decay = global_step  # reset learning rate decay

        if steps_per_checkpoint and global_step % steps_per_checkpoint == 0:
            loss = self.training.loss / self.training.steps
            baseline_loss = self.training.baseline_loss / self.training.steps
            step_time = self.training.time / self.training.steps

            summary = 'step {} epoch {} learning rate {:.3g} step-time {:.3f} loss {:.3f}'.format(
                global_step, epoch + 1, self.learning_rate.eval(), step_time,
                loss)

            if self.name is not None:
                summary = '{} {}'.format(self.name, summary)
            if use_baseline and loss_function == 'reinforce':
                summary = '{} baseline-loss {:.4f}'.format(
                    summary, baseline_loss)

            utils.log(summary)

            if decay_if_no_progress and len(
                    self.training.losses) >= decay_if_no_progress:
                if loss >= max(self.training.losses[:decay_if_no_progress]):
                    self.learning_rate_decay_op.eval()

            self.training.losses.append(loss)
            self.training.loss, self.training.time, self.training.steps, self.training.baseline_loss = 0, 0, 0, 0

        if steps_per_eval and global_step % steps_per_eval == 0 and 0 <= eval_burn_in <= global_step:

            eval_dir = 'eval' if self.name is None else 'eval_{}'.format(
                self.name)
            eval_output = os.path.join(model_dir, eval_dir)

            os.makedirs(eval_output, exist_ok=True)

            # if there are several dev files, we define several output files
            output = [
                os.path.join(eval_output,
                             '{}.{}.out'.format(prefix, global_step))
                for prefix in self.dev_prefix
            ]

            kwargs_ = dict(kwargs)
            kwargs_['output'] = output
            score, *_ = self.evaluate(on_dev=True, **kwargs_)
            self.training.scores.append((global_step, score))

        if steps_per_eval and global_step % steps_per_eval == 0:
            raise utils.EvalException
        elif steps_per_checkpoint and global_step % steps_per_checkpoint == 0:
            raise utils.CheckpointException
Exemplo n.º 13
0
    def __init__(self,
                 encoders,
                 decoders,
                 checkpoint_dir,
                 learning_rate,
                 learning_rate_decay_factor,
                 batch_size,
                 keep_best=1,
                 dev_prefix=None,
                 name=None,
                 ref_ext=None,
                 pred_edits=False,
                 dual_output=False,
                 binary=None,
                 truncate_lines=True,
                 ensemble=False,
                 checkpoints=None,
                 beam_size=1,
                 len_normalization=1,
                 lexicon=None,
                 debug=False,
                 **kwargs):

        self.batch_size = batch_size
        self.character_level = {}
        self.binary = []
        self.debug = debug

        for encoder_or_decoder in encoders + decoders:
            encoder_or_decoder.ext = encoder_or_decoder.ext or encoder_or_decoder.name
            self.character_level[
                encoder_or_decoder.ext] = encoder_or_decoder.character_level
            self.binary.append(encoder_or_decoder.get('binary', False))

        self.encoders, self.decoders = encoders, decoders

        self.char_output = decoders[0].character_level

        self.src_ext = [encoder.ext for encoder in encoders]
        self.trg_ext = [decoder.ext for decoder in decoders]

        self.extensions = self.src_ext + self.trg_ext

        self.ref_ext = ref_ext
        if self.ref_ext is not None:
            self.binary.append(False)

        self.pred_edits = pred_edits
        self.dual_output = dual_output

        self.dev_prefix = dev_prefix
        self.name = name

        self.max_input_len = [encoder.max_len for encoder in encoders]
        self.max_output_len = [decoder.max_len for decoder in decoders]
        self.beam_size = beam_size

        if truncate_lines:
            self.max_len = None  # we let seq2seq.get_batch handle long lines (by truncating them)
        else:  # the line reader will drop lines that are too long
            self.max_len = dict(
                zip(self.extensions, self.max_input_len + self.max_output_len))

        self.learning_rate = tf.Variable(learning_rate,
                                         trainable=False,
                                         name='learning_rate',
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * learning_rate_decay_factor)

        with tf.device('/cpu:0'):
            self.global_step = tf.Variable(0,
                                           trainable=False,
                                           name='global_step')
            self.baseline_step = tf.Variable(0,
                                             trainable=False,
                                             name='baseline_step')

        self.filenames = utils.get_filenames(extensions=self.extensions,
                                             dev_prefix=dev_prefix,
                                             name=name,
                                             ref_ext=ref_ext,
                                             binary=self.binary,
                                             **kwargs)
        utils.debug('reading vocabularies')
        self.vocabs = None
        self.src_vocab, self.trg_vocab = None, None
        self.read_vocab()

        for encoder_or_decoder, vocab in zip(encoders + decoders, self.vocabs):
            if vocab:
                if encoder_or_decoder.vocab_size:  # reduce vocab size
                    vocab.reverse[:] = vocab.reverse[:encoder_or_decoder.
                                                     vocab_size]
                    for token, token_id in list(vocab.vocab.items()):
                        if token_id >= encoder_or_decoder.vocab_size:
                            del vocab.vocab[token]
                else:
                    encoder_or_decoder.vocab_size = len(vocab.reverse)

        utils.debug('creating model')

        self.models = []
        if ensemble and checkpoints is not None:
            for i, _ in enumerate(checkpoints, 1):
                with tf.variable_scope('model_{}'.format(i)):
                    model = Seq2SeqModel(encoders,
                                         decoders,
                                         self.learning_rate,
                                         self.global_step,
                                         name=name,
                                         pred_edits=pred_edits,
                                         dual_output=dual_output,
                                         baseline_step=self.baseline_step,
                                         **kwargs)
                    self.models.append(model)
            self.seq2seq_model = self.models[0]
        else:
            self.seq2seq_model = Seq2SeqModel(encoders,
                                              decoders,
                                              self.learning_rate,
                                              self.global_step,
                                              name=name,
                                              pred_edits=pred_edits,
                                              dual_output=dual_output,
                                              baseline_step=self.baseline_step,
                                              **kwargs)
            self.models.append(self.seq2seq_model)

        self.seq2seq_model.create_beam_op(self.models, len_normalization)

        self.batch_iterator = None
        self.dev_batches = None
        self.train_size = None
        self.saver = None
        self.keep_best = keep_best
        self.checkpoint_dir = checkpoint_dir
        self.epoch = None

        self.training = utils.AttrDict()  # used to keep track of training

        if lexicon:
            with open(lexicon) as lexicon_file:
                self.lexicon = dict(line.split() for line in lexicon_file)
        else:
            self.lexicon = None
Exemplo n.º 14
0
    def train(self,
              sess,
              beam_size,
              steps_per_checkpoint,
              steps_per_eval=None,
              eval_output=None,
              max_steps=0,
              max_epochs=0,
              eval_burn_in=0,
              decay_if_no_progress=5,
              decay_after_n_epoch=None,
              decay_every_n_epoch=None,
              sgd_after_n_epoch=None,
              loss_function='xent',
              baseline_steps=0,
              reinforce_baseline=True,
              reward_function=None,
              use_edits=False,
              **kwargs):
        utils.log('reading training and development data')

        self.global_step = 0
        for model in self.models:
            model.read_data(**kwargs)
            # those parameters are used to track the progress of each task
            model.loss, model.time, model.steps = 0, 0, 0
            model.baseline_loss = 0
            model.previous_losses = []
            global_step = model.global_step.eval(sess)
            model.epoch = model.batch_size * global_step // model.train_size
            model.last_decay = global_step

            for _ in range(global_step):  # read all the data up to this step
                next(model.batch_iterator)

            self.global_step += global_step

        # pre-train baseline
        if loss_function == 'reinforce' and baseline_steps > 0 and reinforce_baseline:
            utils.log('pre-training baseline')
            for model in self.models:
                baseline_loss = 0
                for step in range(1, baseline_steps + 1):
                    baseline_loss += model.baseline_step(
                        sess,
                        reward_function=reward_function,
                        use_edits=use_edits)

                    if step % steps_per_checkpoint == 0:
                        loss = baseline_loss / steps_per_checkpoint
                        baseline_loss = 0
                        utils.log('{} step {} baseline loss {:.4f}'.format(
                            model.name, step, loss))

        utils.log('starting training')
        while True:
            i = np.random.choice(len(self.models), 1, p=self.ratios)[0]
            model = self.models[i]

            start_time = time.time()
            res = model.train_step(sess,
                                   loss_function=loss_function,
                                   reward_function=reward_function,
                                   use_edits=use_edits)
            model.loss += res.loss

            if loss_function == 'reinforce':
                model.baseline_loss += res.baseline_loss

            model.time += time.time() - start_time
            model.steps += 1
            self.global_step += 1
            model_global_step = model.global_step.eval(sess)

            epoch = model.batch_size * model_global_step / model.train_size
            model.epoch = int(epoch) + 1

            if decay_after_n_epoch is not None and epoch >= decay_after_n_epoch:
                if decay_every_n_epoch is not None and (
                        model.batch_size *
                    (model_global_step - model.last_decay) >=
                        decay_every_n_epoch * model.train_size):
                    sess.run(model.learning_rate_decay_op)
                    utils.debug('  decaying learning rate to: {:.4f}'.format(
                        model.learning_rate.eval()))
                    model.last_decay = model_global_step

            if sgd_after_n_epoch is not None and epoch >= sgd_after_n_epoch:
                if not model.use_sgd:
                    utils.debug('  epoch {}, starting to use SGD'.format(
                        model.epoch))
                    model.use_sgd = True

            if steps_per_checkpoint and self.global_step % steps_per_checkpoint == 0:
                for model_ in self.models:
                    if model_.steps == 0:
                        continue

                    loss_ = model_.loss / model_.steps
                    step_time_ = model_.time / model_.steps

                    if loss_function == 'reinforce':
                        baseline_loss_ = ' baseline loss {:.4f}'.format(
                            model_.baseline_loss / model_.steps)
                        model_.baseline_loss = 0
                    else:
                        baseline_loss_ = ''

                    utils.log(
                        '{} step {} epoch {} learning rate {:.4f} step-time {:.4f}{} loss {:.4f}'
                        .format(model_.name,
                                model_.global_step.eval(sess), model.epoch,
                                model_.learning_rate.eval(), step_time_,
                                baseline_loss_, loss_))

                    if decay_if_no_progress and len(
                            model_.previous_losses) >= decay_if_no_progress:
                        if loss_ >= max(
                                model_.previous_losses[:decay_if_no_progress]):
                            sess.run(model_.learning_rate_decay_op)

                    model_.previous_losses.append(loss_)
                    model_.loss, model_.time, model_.steps = 0, 0, 0
                    model_.eval_step(sess)

                self.save(sess)

            if steps_per_eval and self.global_step % steps_per_eval == 0 and 0 <= eval_burn_in <= self.global_step:
                score = 0

                for ratio, model_ in zip(self.ratios, self.models):
                    if eval_output is None:
                        output = None
                    elif len(model_.filenames.dev) > 1:
                        # if there are several dev files, we define several output files
                        # TODO: put dev_prefix into the name of the output file (also in the logging output)
                        output = [
                            '{}.{}.{}.{}'.format(eval_output, i + 1,
                                                 model_.name,
                                                 model_.global_step.eval(sess))
                            for i in range(len(model_.filenames.dev))
                        ]
                    else:
                        output = '{}.{}.{}'.format(
                            eval_output, model_.name,
                            model_.global_step.eval(sess))

                    # kwargs_ = {**kwargs, 'output': output}
                    kwargs_ = dict(kwargs)
                    kwargs_['output'] = output
                    scores_ = model_.evaluate(sess,
                                              beam_size,
                                              on_dev=True,
                                              use_edits=use_edits,
                                              **kwargs_)
                    score_ = scores_[
                        0]  # in case there are several dev files, only the first one counts

                    # if there is a main task, pick best checkpoint according to its score
                    # otherwise use the average score across tasks
                    if self.main_task is None:
                        score += ratio * score_
                    elif model_.name == self.main_task:
                        score = score_

                self.manage_best_checkpoints(self.global_step, score)

            if 0 < max_steps <= self.global_step or 0 < max_epochs <= epoch:
                utils.log('finished training')
                # TODO: save models
                return
Exemplo n.º 15
0
def main(args=None):
    args = parser.parse_args(args)

    # read config file and default config
    with open('config/default.yaml') as f:
        default_config = utils.AttrDict(yaml.safe_load(f))

    with open(args.config) as f:
        config = utils.AttrDict(yaml.safe_load(f))

        if args.learning_rate is not None:
            args.reset_learning_rate = True

        # command-line parameters have higher precedence than config file
        for k, v in vars(args).items():
            if v is not None:
                config[k] = v

        # set default values for parameters that are not defined
        for k, v in default_config.items():
            config.setdefault(k, v)

    # enforce parameter constraints
    assert config.steps_per_eval % config.steps_per_checkpoint == 0, (
        'steps-per-eval should be a multiple of steps-per-checkpoint')
    assert args.decode is not None or args.eval or args.train or args.align, (
        'you need to specify at least one action (decode, eval, align, or train)'
    )
    assert not (args.avg_checkpoints and args.ensemble)

    if args.purge:
        utils.log('deleting previous model')
        shutil.rmtree(config.model_dir, ignore_errors=True)

    os.makedirs(config.model_dir, exist_ok=True)

    # copy config file to model directory
    config_path = os.path.join(config.model_dir, 'config.yaml')
    if not os.path.exists(config_path):
        shutil.copy(args.config, config_path)

    # also copy default config
    config_path = os.path.join(config.model_dir, 'default.yaml')
    if not os.path.exists(config_path):
        shutil.copy('config/default.yaml', config_path)

    # copy source code to model directory
    tar_path = os.path.join(config.model_dir, 'code.tar.gz')
    if not os.path.exists(tar_path):
        with tarfile.open(tar_path, "w:gz") as tar:
            for filename in os.listdir('translate'):
                if filename.endswith('.py'):
                    tar.add(os.path.join('translate', filename),
                            arcname=filename)

    logging_level = logging.DEBUG if args.verbose else logging.INFO
    # always log to stdout in decoding and eval modes (to avoid overwriting precious train logs)
    log_path = os.path.join(config.model_dir, config.log_file)
    logger = utils.create_logger(log_path if args.train else None)
    logger.setLevel(logging_level)

    utils.log('label: {}'.format(config.label))
    utils.log('description:\n  {}'.format('\n  '.join(
        config.description.strip().split('\n'))))

    utils.log(' '.join(sys.argv))  # print command line
    try:  # print git hash
        commit_hash = subprocess.check_output(['git', 'rev-parse',
                                               'HEAD']).decode().strip()
        utils.log('commit hash {}'.format(commit_hash))
    except:
        pass

    utils.log('tensorflow version: {}'.format(tf.__version__))

    # log parameters
    utils.debug('program arguments')
    for k, v in sorted(config.items(), key=itemgetter(0)):
        utils.debug('  {:<20} {}'.format(k, pformat(v)))

    if isinstance(config.dev_prefix, str):
        config.dev_prefix = [config.dev_prefix]

    if config.tasks is not None:
        config.tasks = [utils.AttrDict(task) for task in config.tasks]
        tasks = config.tasks
    else:
        tasks = [config]

    for task in tasks:
        for parameter, value in config.items():
            task.setdefault(parameter, value)

        task.encoders = [utils.AttrDict(encoder) for encoder in task.encoders]
        task.decoders = [utils.AttrDict(decoder) for decoder in task.decoders]

        for encoder_or_decoder in task.encoders + task.decoders:
            for parameter, value in task.items():
                encoder_or_decoder.setdefault(parameter, value)

    device = None
    if config.no_gpu:
        device = '/cpu:0'
    elif config.gpu_id is not None:
        device = '/gpu:{}'.format(config.gpu_id)

    utils.log('creating model')
    utils.log('using device: {}'.format(device))

    with tf.device(device):
        config.checkpoint_dir = os.path.join(config.model_dir, 'checkpoints')

        if config.weight_scale:
            if config.initializer == 'uniform':
                initializer = tf.random_uniform_initializer(
                    minval=-config.weight_scale, maxval=config.weight_scale)
            else:
                initializer = tf.random_normal_initializer(
                    stddev=config.weight_scale)
        else:
            initializer = None

        tf.get_variable_scope().set_initializer(initializer)

        config.decode_only = args.decode is not None or args.eval or args.align  # exempt from creating gradient ops

        if config.tasks is not None:
            model = MultiTaskModel(**config)
        else:
            model = TranslationModel(**config)

    # count parameters
    utils.log('model parameters ({})'.format(len(tf.global_variables())))
    parameter_count = 0
    for var in tf.global_variables():
        utils.log('  {} {}'.format(var.name, var.get_shape()))

        if not var.name.startswith(
                'gradients'
        ):  # not counting parameters created by training algorithm (e.g. Adam)
            v = 1
            for d in var.get_shape():
                v *= d.value
            parameter_count += v
    utils.log('number of parameters: {:.2f}M'.format(parameter_count / 1e6))

    tf_config = tf.ConfigProto(log_device_placement=False,
                               allow_soft_placement=True)
    tf_config.gpu_options.allow_growth = config.allow_growth
    tf_config.gpu_options.per_process_gpu_memory_fraction = config.mem_fraction

    def average_checkpoints(main_sess, sessions):
        for var in tf.global_variables():
            avg_value = sum(sess.run(var) for sess in sessions) / len(sessions)
            main_sess.run(var.assign(avg_value))

    with tf.Session(config=tf_config) as sess:
        best_checkpoint = os.path.join(config.checkpoint_dir, 'best')

        if ((config.ensemble or config.avg_checkpoints)
                and (args.eval or args.decode is not None)
                and len(config.checkpoints) > 1):
            # create one session for each model in the ensemble
            sessions = [tf.Session() for _ in config.checkpoints]
            for sess_, checkpoint in zip(sessions, config.checkpoints):
                model.initialize(sess_, [checkpoint])

            if config.ensemble:
                sess = sessions
            else:
                sess = sessions[0]
                average_checkpoints(sess, sessions)
        elif (not config.checkpoints
              and (args.eval or args.decode is not None or args.align)
              and (os.path.isfile(best_checkpoint + '.index')
                   or os.path.isfile(best_checkpoint + '.index'))):
            # in decoding and evaluation mode, unless specified otherwise (by `checkpoints`),
            # try to load the best checkpoint)
            model.initialize(sess, [best_checkpoint])
        else:
            # loads last checkpoint, unless `reset` is true
            model.initialize(sess, **config)

        if args.decode is not None:
            model.decode(sess, **config)
        elif args.eval:
            model.evaluate(sess, on_dev=False, **config)
        elif args.align:
            model.align(sess, **config)
        elif args.train:
            try:
                model.train(sess=sess, **config)
            except (KeyboardInterrupt, utils.FinishedTrainingException):
                utils.log('exiting...')
                model.save(sess)
                sys.exit()
Exemplo n.º 16
0
    def __init__(self,
                 encoders,
                 decoder,
                 learning_rate,
                 global_step,
                 max_gradient_norm,
                 num_samples=512,
                 dropout_rate=0.0,
                 freeze_variables=None,
                 lm_weight=None,
                 max_output_len=50,
                 attention=True,
                 feed_previous=0.0,
                 optimizer='sgd',
                 max_input_len=None,
                 decode_only=False,
                 len_normalization=1.0,
                 **kwargs):
        self.lm_weight = lm_weight
        self.encoders = encoders
        self.decoder = decoder

        self.learning_rate = learning_rate
        self.global_step = global_step

        self.encoder_count = len(encoders)
        self.trg_vocab_size = decoder.vocab_size
        self.trg_cell_size = decoder.cell_size
        self.binary_input = [
            encoder.name for encoder in encoders if encoder.binary
        ]

        self.max_output_len = max_output_len
        self.max_input_len = max_input_len
        self.len_normalization = len_normalization

        # if we use sampled softmax, we need an output projection
        # sampled softmax only makes sense if we sample less than vocabulary size
        if num_samples == 0 or num_samples >= self.trg_vocab_size:
            output_projection = None
            softmax_loss_function = None
        else:
            with tf.device('/cpu:0'):
                with variable_scope.variable_scope('decoder_{}'.format(
                        decoder.name)):
                    w = decoders.get_variable_unsafe(
                        'proj_w', [self.trg_cell_size, self.trg_vocab_size])
                    w_t = tf.transpose(w)
                    b = decoders.get_variable_unsafe('proj_b',
                                                     [self.trg_vocab_size])
                output_projection = (w, b)

            def softmax_loss_function(inputs, labels):
                with tf.device('/cpu:0'):
                    labels = tf.reshape(labels, [-1, 1])
                    return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels,
                                                      num_samples,
                                                      self.trg_vocab_size)

        if dropout_rate > 0:
            self.dropout = tf.Variable(1 - dropout_rate,
                                       trainable=False,
                                       name='dropout_keep_prob')
            self.dropout_off = self.dropout.assign(1.0)
            self.dropout_on = self.dropout.assign(1 - dropout_rate)
        else:
            self.dropout = None

        self.feed_previous = tf.constant(feed_previous, dtype=tf.float32)

        self.encoder_inputs = []
        self.encoder_input_length = []

        self.extensions = [encoder.name
                           for encoder in encoders] + [decoder.name]
        self.encoder_names = [encoder.name for encoder in encoders]
        self.decoder_name = decoder.name
        self.extensions = self.encoder_names + [self.decoder_name]

        for encoder in self.encoders:
            if encoder.binary:
                placeholder = tf.placeholder(
                    tf.float32,
                    shape=[None, None, encoder.embedding_size],
                    name='encoder_{}'.format(encoder.name))
            else:
                placeholder = tf.placeholder(tf.int32,
                                             shape=[None, None],
                                             name='encoder_{}'.format(
                                                 encoder.name))

            self.encoder_inputs.append(placeholder)
            self.encoder_input_length.append(
                tf.placeholder(tf.int64,
                               shape=[None],
                               name='encoder_{}_length'.format(encoder.name)))

        self.decoder_inputs = tf.placeholder(tf.int32,
                                             shape=[None, None],
                                             name='decoder_{}'.format(
                                                 self.decoder.name))
        self.decoder_input = tf.placeholder(tf.int32,
                                            shape=[None],
                                            name='beam_search_{}'.format(
                                                decoder.name))
        self.target_weights = tf.placeholder(tf.float32,
                                             shape=[None, None],
                                             name='weight_{}'.format(
                                                 self.decoder.name))
        self.targets = tf.placeholder(tf.int32,
                                      shape=[None, None],
                                      name='target_{}'.format(
                                          self.decoder.name))

        self.decoder_input_length = tf.placeholder(
            tf.int64,
            shape=[None],
            name='decoder_{}_length'.format(decoder.name))

        parameters = dict(encoders=encoders,
                          decoder=decoder,
                          dropout=self.dropout,
                          output_projection=output_projection)

        self.attention_states, self.encoder_state = decoders.multi_encoder(
            self.encoder_inputs,
            encoder_input_length=self.encoder_input_length,
            **parameters)

        decoder = decoders.attention_decoder if attention else decoders.decoder

        self.outputs, self.attention_weights = decoder(
            attention_states=self.attention_states,
            initial_state=self.encoder_state,
            decoder_inputs=self.decoder_inputs,
            feed_previous=self.feed_previous,
            decoder_input_length=self.decoder_input_length,
            **parameters)

        self.beam_output, self.beam_tensors = decoders.beam_search_decoder(
            decoder_input=self.decoder_input,
            attention_states=self.attention_states,
            initial_state=self.encoder_state,
            **parameters)

        self.loss = decoders.sequence_loss(
            logits=self.outputs,
            targets=self.targets,
            weights=self.target_weights,
            softmax_loss_function=softmax_loss_function)

        if not decode_only:
            # gradients and SGD update operation for training the model
            if freeze_variables is None:
                freeze_variables = []

            # compute gradient only for variables that are not frozen
            frozen_parameters = [
                var.name for var in tf.trainable_variables() if any(
                    re.match(var_, var.name) for var_ in freeze_variables)
            ]
            if frozen_parameters:
                utils.debug('frozen parameters: {}'.format(
                    ', '.join(frozen_parameters)))
            params = [
                var for var in tf.trainable_variables()
                if var.name not in frozen_parameters
            ]

            if optimizer.lower() == 'adadelta':
                opt = tf.train.AdadeltaOptimizer(learning_rate=learning_rate)
            elif optimizer.lower() == 'adagrad':
                opt = tf.train.AdagradOptimizer(learning_rate=learning_rate)
            elif optimizer.lower() == 'adam':
                opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
            else:
                opt = tf.train.GradientDescentOptimizer(
                    learning_rate=learning_rate)

            gradients = tf.gradients(self.loss, params)
            clipped_gradients, self.gradient_norms = tf.clip_by_global_norm(
                gradients, max_gradient_norm)
            self.updates = opt.apply_gradients(zip(clipped_gradients, params),
                                               global_step=self.global_step)

        def tensor_prod(x, w, b):
            shape = tf.shape(x)
            x = tf.reshape(x, tf.pack([tf.mul(shape[0], shape[1]), shape[2]]))
            x = tf.matmul(x, w) + b
            x = tf.reshape(x, tf.pack([shape[0], shape[1], b.get_shape()[0]]))
            return x

        if output_projection is not None:
            w, b = output_projection
            self.outputs = tensor_prod(self.outputs, w, b)
            self.beam_output = tf.nn.xw_plus_b(self.beam_output, w, b)

        self.beam_output = tf.nn.softmax(self.beam_output)
Exemplo n.º 17
0
    def __init__(self,
                 encoders,
                 decoders,
                 checkpoint_dir,
                 learning_rate,
                 learning_rate_decay_factor,
                 batch_size,
                 keep_best=1,
                 dev_prefix=None,
                 score_function='corpus_scores',
                 name=None,
                 ref_ext=None,
                 pred_edits=False,
                 dual_output=False,
                 binary=None,
                 **kwargs):

        self.batch_size = batch_size
        self.character_level = {}
        self.binary = []

        for encoder_or_decoder in encoders + decoders:
            encoder_or_decoder.ext = encoder_or_decoder.ext or encoder_or_decoder.name
            self.character_level[
                encoder_or_decoder.ext] = encoder_or_decoder.character_level
            self.binary.append(encoder_or_decoder.get('binary', False))

        self.char_output = decoders[0].character_level

        self.src_ext = [encoder.ext for encoder in encoders]
        self.trg_ext = [decoder.ext for decoder in decoders]

        self.extensions = self.src_ext + self.trg_ext

        self.ref_ext = ref_ext
        if self.ref_ext is not None:
            self.binary.append(False)

        self.pred_edits = pred_edits
        self.dual_output = dual_output

        self.dev_prefix = dev_prefix
        self.name = name

        self.max_input_len = [encoder.max_len for encoder in encoders]
        self.max_output_len = [decoder.max_len for decoder in decoders]
        self.max_len = dict(
            zip(self.extensions, self.max_input_len + self.max_output_len))

        self.learning_rate = tf.Variable(learning_rate,
                                         trainable=False,
                                         name='learning_rate',
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * learning_rate_decay_factor)

        with tf.device('/cpu:0'):
            self.global_step = tf.Variable(0,
                                           trainable=False,
                                           name='global_step')
            self.baseline_step = tf.Variable(0,
                                             trainable=False,
                                             name='baseline_step')

        self.filenames = utils.get_filenames(extensions=self.extensions,
                                             dev_prefix=dev_prefix,
                                             name=name,
                                             ref_ext=ref_ext,
                                             binary=self.binary,
                                             **kwargs)
        utils.debug('reading vocabularies')
        self.vocabs = None
        self.src_vocab, self.trg_vocab = None, None
        self.read_vocab()

        for encoder_or_decoder, vocab in zip(encoders + decoders, self.vocabs):
            if vocab:
                encoder_or_decoder.vocab_size = len(vocab.reverse)

        utils.debug('creating model')
        self.seq2seq_model = Seq2SeqModel(encoders,
                                          decoders,
                                          self.learning_rate,
                                          self.global_step,
                                          name=name,
                                          pred_edits=pred_edits,
                                          dual_output=dual_output,
                                          baseline_step=self.baseline_step,
                                          **kwargs)

        self.batch_iterator = None
        self.dev_batches = None
        self.train_size = None
        self.saver = None
        self.keep_best = keep_best
        self.checkpoint_dir = checkpoint_dir

        self.training = utils.AttrDict()  # used to keep track of training

        try:
            self.reversed_scores = getattr(
                evaluation, score_function).reversed  # the lower the better
        except AttributeError:
            self.reversed_scores = False  # the higher the better
Exemplo n.º 18
0
def load_checkpoint(sess, checkpoint_dir, filename=None, blacklist=()):
    """
    if `filename` is None, we load last checkpoint, otherwise
      we ignore `checkpoint_dir` and load the given checkpoint file.
    """
    if filename is None:
        # load last checkpoint
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt is not None:
            filename = ckpt.model_checkpoint_path
    else:
        checkpoint_dir = os.path.dirname(filename)

    var_file = os.path.join(checkpoint_dir, 'vars.pkl')

    def get_variable_by_name(name):
        for var in tf.global_variables():
            if var.name == name:
                return var
        return None

    if os.path.exists(var_file):
        with open(var_file, 'rb') as f:
            var_names = pickle.load(f)

        variables = {}

        for var_name in var_names:
            skip = False
            for var in tf.global_variables():
                name = var.name
                for key, value in reverse_mapping:
                    name = re.sub(key, value, name)
                if var_name == name:
                    variables[var_name] = var
                    skip = True
                    break

            if skip:
                continue

            name = var_name
            for key, value in variable_mapping:
                name = re.sub(key, value, name)

            for var in tf.global_variables():
                if var.name == name:
                    variables[var_name] = var
                    break
    else:
        variables = {var.name: var for var in tf.global_variables()}

    # remove variables from blacklist
    # variables = [var for var in variables if not any(prefix in var.name for prefix in blacklist)]
    variables = {
        name[:-2]: var
        for name, var in variables.items()
        if not any(prefix in name for prefix in blacklist)
    }

    if filename is not None:
        utils.log('reading model parameters from {}'.format(filename))
        tf.train.Saver(variables).restore(sess, filename)

        utils.debug('retrieved parameters ({})'.format(len(variables)))
        for var in sorted(variables.values(), key=lambda var: var.name):
            utils.debug('  {} {}'.format(var.name, var.get_shape()))