def convert_to_coverage_model():
    """Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint"""
    tf.logging.info("converting non-coverage model to coverage model..")

    # initialize an entire coverage model from scratch
    sess = tf.Session(config=utils.get_config())
    print("initializing everything...")
    sess.run(tf.global_variables_initializer())

    # load all non-coverage weights from checkpoint
    saver = tf.train.Saver([
        v for v in tf.global_variables()
        if "coverage" not in v.name and "Adagrad" not in v.name
    ])
    print("restoring non-coverage variables...")
    curr_ckpt = utils.load_ckpt(saver, sess)
    print("restored.")

    # save this model and quit
    new_fname = curr_ckpt + '_cov_init'
    print("saving model to %s..." % (new_fname))
    new_saver = tf.train.Saver(
    )  # this one will save all variables that now exist
    new_saver.save(sess, new_fname)
    print("saved.")
    exit()
def setup_training(model, batcher):
    """Does setup before starting training (run_training)"""
    train_dir = os.path.join(PARAMS.log_root, "train")
    if not os.path.exists(train_dir): os.makedirs(train_dir)

    model.build_graph()  # build the graph
    if PARAMS.convert_to_coverage_model:
        assert PARAMS.coverage, "To convert your non-coverage model to a coverage model, run with convert_to_coverage_model=True and coverage=True"
        convert_to_coverage_model()
    if PARAMS.restore_best_model:
        restore_best_model()
    saver = tf.train.Saver(max_to_keep=3)  # keep 3 checkpoints at a time

    sv = tf.train.Supervisor(
        logdir=train_dir,
        is_chief=True,
        saver=saver,
        summary_op=None,
        save_summaries_secs=60,  # save summaries for tensorboard every 60 secs
        save_model_secs=60,  # checkpoint every 60 secs
        global_step=model.global_step)
    summary_writer = sv.summary_writer
    tf.logging.info("Preparing or waiting for session...")
    sess_context_manager = sv.prepare_or_wait_for_session(
        config=utils.get_config())
    tf.logging.info("Created session.")
    try:
        run_training(
            model, batcher, sess_context_manager, sv,
            summary_writer)  # this is an infinite loop until interrupted
    except KeyboardInterrupt:
        tf.logging.info(
            "Caught keyboard interrupt on worker. Stopping supervisor...")
        sv.stop()
def restore_best_model():
    """Load bestmodel file from eval directory, add variables for adagrad, and save to train directory"""
    tf.logging.info("Restoring bestmodel for training...")

    # Initialize all vars in the model
    sess = tf.Session(config=utils.get_config())
    print("Initializing all variables...")
    sess.run(tf.global_variables_initializer())

    # Restore the best model from eval dir
    saver = tf.train.Saver(
        [v for v in tf.all_variables() if "Adagrad" not in v.name])
    print("Restoring all non-adagrad variables from best model in eval dir...")
    curr_ckpt = utils.load_ckpt(saver, sess, "eval")
    print("Restored %s." % curr_ckpt)

    # Save this model to train dir and quit
    new_model_name = curr_ckpt.split("/")[-1].replace("bestmodel", "model")
    new_fname = os.path.join(PARAMS.log_root, "train", new_model_name)
    print("Saving model to %s..." % (new_fname))
    new_saver = tf.train.Saver(
    )  # this saver saves all variables that now exist, including Adagrad variables
    new_saver.save(sess, new_fname)
    print("Saved.")
    exit()
def run_eval(model, batcher, vocab):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    model.build_graph()  # build the graph
    saver = tf.train.Saver(
        max_to_keep=3)  # we will keep 3 best checkpoints at a time
    sess = tf.Session(config=utils.get_config())
    eval_dir = os.path.join(
        PARAMS.log_root, "eval")  # make a subdir of the root dir for eval data
    bestmodel_save_path = os.path.join(
        eval_dir,
        'bestmodel')  # this is where checkpoints of best models are saved
    summary_writer = tf.summary.FileWriter(eval_dir)
    running_avg_loss = 0  # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
    best_loss = None  # will hold the best loss achieved so far

    while True:
        _ = utils.load_ckpt(saver, sess)  # load a new checkpoint
        batch = batcher.next_batch()  # get the next batch

        # run eval on the batch
        t0 = time.time()
        results = model.run_eval_step(sess, batch)
        t1 = time.time()
        tf.logging.info('seconds for batch: %.2f', t1 - t0)

        # print the loss and coverage loss to screen
        loss = results['loss']
        tf.logging.info('loss: %f', loss)
        if PARAMS.coverage:
            coverage_loss = results['coverage_loss']
            tf.logging.info("coverage_loss: %f", coverage_loss)

        # add summaries
        summaries = results['summaries']
        train_step = results['global_step']
        summary_writer.add_summary(summaries, train_step)

        # calculate running avg loss
        running_avg_loss = calc_running_avg_loss(np.asscalar(loss),
                                                 running_avg_loss,
                                                 summary_writer, train_step)

        # If running_avg_loss is best so far, save this checkpoint (early stopping).
        # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
        if best_loss is None or running_avg_loss < best_loss:
            tf.logging.info(
                'Found new best model with %.3f running_avg_loss. Saving to %s',
                running_avg_loss, bestmodel_save_path)
            saver.save(sess,
                       bestmodel_save_path,
                       global_step=train_step,
                       latest_filename='checkpoint_best')
            best_loss = running_avg_loss

        # flush the summary writer every so often
        if train_step % 100 == 0:
            summary_writer.flush()
    def __init__(self, model, batcher, vocab):
        """Initialize decoder.

        Args:
          model: a Seq2SeqAttentionModel object.
          batcher: a Batcher object.
          vocab: Vocabulary object
        """
        self._model = model
        self._model.build_graph()
        self._batcher = batcher
        self._vocab = vocab
        self._saver = tf.train.Saver(
        )  # we use this to load checkpoints for decoding
        self._sess = tf.Session(config=util.get_config())

        # Load an initial checkpoint to use for decoding
        ckpt_path = util.load_ckpt(self._saver, self._sess)

        if hps.single_pass:
            # Make a descriptive decode directory name
            ckpt_name = "ckpt-" + ckpt_path.split('-')[
                -1]  # this is something of the form "ckpt-123456"
            self._decode_dir = os.path.join(hps.log_root,
                                            get_decode_dir_name(ckpt_name))
            if os.path.exists(self._decode_dir):
                raise Exception(
                    "single_pass decode directory %s should not already exist"
                    % self._decode_dir)

        else:  # Generic decode dir name
            self._decode_dir = os.path.join(hps.log_root, "decode")

        # Make the decode dir if necessary
        if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir)

        if hps.single_pass:
            # Make the dirs to contain output written in the correct format for pyrouge
            self._rouge_ref_dir = os.path.join(self._decode_dir, "reference")
            if not os.path.exists(self._rouge_ref_dir):
                os.mkdir(self._rouge_ref_dir)
            self._rouge_dec_dir = os.path.join(self._decode_dir, "decoded")
            if not os.path.exists(self._rouge_dec_dir):
                os.mkdir(self._rouge_dec_dir)
        if hps.rouge_eval_only:
            if not hps.eval_path:
                raise Exception(
                    "Must specify path to folder containing decoded files for evaluation"
                )
            else:
                self._rouge_dec_dir = hps.eval_path
                if not os.path.exists(self._rouge_dec_dir):
                    raise Exception(
                        "Folder containing decoded files for evaluation does not exist!"
                    )
def setup_and_run_decoding(FLAGS, hps):
    # raise ValueError("Pay attention to dropout is set or not")
    if os.path.exists(FLAGS.decode_output_file):
        raise ValueError("`decode_output_file` exists")

    decode_model_hps = hps
    decode_model_hps = hps._replace(
        mode="decode")._replace(batch_size=FLAGS.beam_size)
    train_dir = os.path.join(FLAGS.log_root, "train")
    model_creators = [_model_factory(name) for name in FLAGS.names]

    print("Loading Decoding Data from %s " % FLAGS.decode_data_dir)
    vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size)
    decode_batchers = MultitaskBatcher(
        data_paths=[FLAGS.decode_data_dir],
        vocabs=[vocab],
        hps=decode_model_hps,
        single_pass=True)

    # only for one model
    decode_models = MultitaskBaseModel(
        names=[FLAGS.names[0]],
        all_hparams=[decode_model_hps._replace(max_dec_steps=1)],
        mixing_ratios=None,
        model_creators=[model_creators[0]],
        logdir=train_dir,
        soft_sharing_coef=FLAGS.soft_sharing_coef,
        # additional args
        vocab=vocab)

    with decode_models.graph.as_default():
        decoder = BeamSearchDecoder(model=decode_models,
                                    batcher=decode_batchers,
                                    vocab=vocab,
                                    ckpt_dir=train_dir,
                                    decode_dir=FLAGS.decode_output_file,
                                    FLAGS=FLAGS)
        decode_sess = tf.Session(graph=decode_models.graph,
                                 config=misc_utils.get_config())
        decoder.build_graph(decode_sess)

        # run decode for calculating scores
        decoder.decode(ckpt_file=FLAGS.decode_ckpt_file)

    scores = evaluate(
        mode="test",
        gen_file=decoder._decode_dir,
        ref_file=FLAGS.eval_target_dir,
        execute_dir=FLAGS.eval_folder_dir,
        source_file=FLAGS.eval_source_dir,
        evaluation_task=FLAGS.names[0])

    print(scores)
    def initialize_or_restore_session(self, ckpt_file=None):
        """Initialize or restore session

        Args:
            ckpt_file: directory to specific checkpoints
        """
        # restore from lastest_checkpoint or specific file
        with self._graph.as_default():
            self._sess = tf.Session(
                graph=self._graph, config=misc_utils.get_config())
            self._sess.run(tf.global_variables_initializer())

            if self._logdir or ckpt_file:
                # restore from lastest_checkpoint or specific file if provided
                misc_utils.load_ckpt(saver=self._saver,
                                     sess=self._sess,
                                     ckpt_dir=self._logdir,
                                     ckpt_file=ckpt_file)
                return
示例#8
0
                                n_classes=len(test_data.y_indices),
                                **config['model'])
            predicted = model.test(extracted_features)

        gold = np.array(list(test_data.get_correct()))
        test_data.store_results(predicted, config['test_output_path'])

        results = evaluate_results(config['test_output_path'],
                                   config['resources'][config['test']]['path'],
                                   print_report=config['print_report'])
        logger.info("Finished testing!")
        return results, train_time.elapsed_time, test_time.elapsed_time


if __name__ == '__main__':
    base_config = get_config('config.yaml')
    model_config = get_config(argv[1])
    config_ = {**base_config, **model_config}
    logger = get_logger(__name__, config=config_['logging'])

    project_config = {x:y for x,y in base_config.items() if x not in {'project_name',
                                                                      'description',
                                                                      'results_db_uri',
                                                                      'logging',
                                                                      'deploy'}}
    project = Project(project_name=config_['project_name'],
                      description=config_['description'], # Project description
                      project_config=project_config,  # Base configuration for project
                      mongodb_uri=config_['results_db_uri'],
                      force_clean_repo=False)  # Crash program if git status doesn't return clean repo
def setup_training(FLAGS, hps):
    """Does setup before starting training (run_training)"""
    
    # Setting up the Multitask Wrapper
    # ----------------------------------------
    if FLAGS.autoMR:
        # for decode, we can still use this one
        # since both are essentially the same
        # except no auto-MR feature
        MultitaskModel = MultitaskAutoMRModel
    else:
        MultitaskModel = MultitaskBaseModel

    # Setting up the models and directories
    # ----------------------------------------
    num_models = len(FLAGS.names)
    # train_dir is a folder, decode_dir is a file
    train_dir = os.path.join(FLAGS.log_root, "train")
    decode_dir = os.path.join(FLAGS.log_root, "decode")
    model_creators = [_model_factory(name) for name in FLAGS.names]
    if not os.path.exists(train_dir):
        os.makedirs(train_dir)

    # Setting up the batchers and data readers
    # ----------------------------------------
    print("Loading Training Data from %s " % FLAGS.train_data_dirs)
    vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size)
    train_batchers = MultitaskBatcher(
        data_paths=FLAGS.train_data_dirs,
        vocabs=[vocab for _ in range(num_models)],
        hps=hps, single_pass=False)
    # not using decode_model_hps which have batch-size = beam-size
    val_batchers = MultitaskBatcher(
        data_paths=[FLAGS.val_data_dir],
        vocabs=[vocab], hps=hps, single_pass=False)

    # Setting up the task selectors
    # ----------------------------------------
    Q_initial = -1
    if FLAGS.reward_scaling_factor > 0.0:
        Q_initial = Q_initial / FLAGS.reward_scaling_factor
        tf.logging.info("Normalization %.2f" % FLAGS.reward_scaling_factor)

    # Build
    # ----------------------------------------
    print("Mixing ratios are %s " % FLAGS.mixing_ratios)
    train_models = MultitaskModel(
        names=FLAGS.names,
        all_hparams=[hps for _ in range(num_models)],
        mixing_ratios=FLAGS.mixing_ratios,
        model_creators=model_creators,
        logdir=train_dir,
        soft_sharing_coef=FLAGS.soft_sharing_coef,
        data_generators=train_batchers,
        val_data_generator=val_batchers,
        vocab=vocab,
        selector_Q_initial=Q_initial,
        alpha=FLAGS.selector_alpha,
        temperature_anneal_rate=None)

    # Note this use a different decoder_batcher
    
    # The model is configured with max_dec_steps=1 because we only ever run
    # one step of the decoder at a time (to do beam search). Note that the
    # batcher is initialized with max_dec_steps equal to e.g. 100 because
    # the batches need to contain the full summaries

    # If in decode mode, set batch_size = beam_size
    # Reason: in decode mode, we decode one example at a time.
    # On each step, we have beam_size-many hypotheses in the beam, so we need
    # to make a batch of these hypotheses.
    decode_model_hps = hps
    decode_model_hps = hps._replace(
        mode="decode")._replace(batch_size=FLAGS.beam_size)

    # we need to constantly re-initialize this generator
    # so save arguments as a namedtuple
    print("Loading Validation Data from %s " % FLAGS.val_data_dir)
    decode_batcher_args = MultitaskBatcherArgs(
        data_paths=[FLAGS.val_data_dir],
        vocabs=[vocab],
        hps=decode_model_hps,
        single_pass=True)
    
    decode_batchers = (
        MultitaskBatcher(** decode_batcher_args._asdict()))

    # only for one model
    decode_models = MultitaskBaseModel(
        names=[FLAGS.names[0]],
        all_hparams=[decode_model_hps._replace(max_dec_steps=1)],
        mixing_ratios=None,
        model_creators=[model_creators[0]],
        logdir=train_dir,
        soft_sharing_coef=FLAGS.soft_sharing_coef,
        vocab=vocab)

    with decode_models.graph.as_default():
        decoder = BeamSearchDecoder(model=decode_models,
                                    batcher=decode_batchers,
                                    vocab=vocab,
                                    ckpt_dir=train_dir,
                                    decode_dir=decode_dir,
                                    FLAGS=FLAGS)
        decode_sess = tf.Session(graph=decode_models.graph,
                                 config=misc_utils.get_config())
        decoder.build_graph(decode_sess)

    try:
        # this is an infinite loop until interrupted
        run_training(FLAGS=FLAGS,
                     models=train_models,
                     decoder=decoder,
                     decode_batcher_args=decode_batcher_args)
    
    except KeyboardInterrupt:
        tf.logging.info("Stopped...")