Пример #1
0
def test(model: Model,
         evaluators,
         datasets: Dict[str, Dataset],
         loader,
         checkpoint,
         ema=True,
         aysnc_encoding=None,
         sample=None) -> Dict[str, Evaluation]:
    print("Setting up model")
    model.set_inputs(list(datasets.values()), loader)

    if aysnc_encoding:
        evaluator_runner = AysncEvaluatorRunner(evaluators, model,
                                                aysnc_encoding)
        inputs = evaluator_runner.dequeue_op
    else:
        evaluator_runner = EvaluatorRunner(evaluators, model)
        inputs = model.get_placeholders()
    input_dict = {p: x for p, x in zip(model.get_placeholders(), inputs)}
    #pdb.set_trace()
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    with sess.as_default():
        pred = model.get_predictions_for(input_dict)  #vz

    #pdb.set_trace()
    evaluator_runner.set_input(pred)

    print("Restoring variables")
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint)

    if ema:
        # FIXME This is a bit stupid, since we are loading variables twice, but I found it
        # a bit fiddly to load the variables directly....
        ema = tf.train.ExponentialMovingAverage(0)
        reader = tf.train.NewCheckpointReader(checkpoint)
        expected_ema_names = {
            ema.average_name(x): x
            for x in tf.trainable_variables()
            if reader.has_tensor(ema.average_name(x))
        }
        if len(expected_ema_names) > 0:
            print("Restoring EMA variables")
            saver = tf.train.Saver(expected_ema_names)
            saver.restore(sess, checkpoint)

    tf.get_default_graph().finalize()

    print("Begin evaluation")

    dataset_outputs = {}
    for name, dataset in datasets.items():
        dataset_outputs[name] = evaluator_runner.run_evaluators(
            sess, dataset, name, sample, {})
    return dataset_outputs
Пример #2
0
def _train(model: Model,
           data: TrainingData,
           checkpoint: Union[str, None],
           parameter_checkpoint: Union[str, None],
           save_start: bool,
           train_params: TrainParams,
           evaluators: List[Evaluator],
           out: ModelDir,
           notes=None,
           dry_run=False,
           start_eval=False):
    if train_params.async_encoding:
        _train_async(model, data, checkpoint, parameter_checkpoint, save_start,
                     train_params, evaluators, out, notes, dry_run, start_eval)
        return

    if train_params.best_weights is not None:
        raise NotImplementedError

    # spec the model for the current voc/input/batching
    train = data.get_train()
    eval_datasets = data.get_eval()
    loader = data.get_resource_loader()
    evaluator_runner = EvaluatorRunner(evaluators, model)

    print("Training on %d batches" % len(train))
    print("Evaluation datasets: " +
          " ".join("%s (%d)" % (name, len(data))
                   for name, data in eval_datasets.items()))

    print("Init model...")
    model.set_inputs([train] + list(eval_datasets.values()), loader)

    print("Setting up model prediction / tf...")

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))

    with sess.as_default():
        pred = model.get_prediction()
    evaluator_runner.set_input(pred)

    if parameter_checkpoint is not None:
        print("Restoring parameters from %s" % parameter_checkpoint)
        saver = tf.train.Saver(tf.trainable_variables())
        saver.restore(sess, parameter_checkpoint)
        saver = None

    loss, summary_tensor, train_opt, global_step, _ = _build_train_ops(
        train_params)

    # Pre-compute tensors we need at evaluations time
    eval_tensors = []
    for ev in evaluators:
        eval_tensors.append(ev.tensors_needed(pred))

    saver = tf.train.Saver(max_to_keep=train_params.max_checkpoints_to_keep)
    summary_writer = tf.summary.FileWriter(out.log_dir)

    # Load or initialize the model parameters
    if checkpoint is not None:
        print("Restoring training from checkpoint...")
        saver.restore(sess, checkpoint)
        print("Loaded checkpoint: " + str(sess.run(global_step)))
        return
    else:
        if parameter_checkpoint is not None:
            print("Initializing training variables...")
            vars = [
                x for x in tf.global_variables()
                if x not in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            ]
            sess.run(tf.variables_initializer(vars))
        else:
            print("Initializing parameters...")
            sess.run(tf.global_variables_initializer())

    # Make sure no bugs occur that add to the graph in the train loop, that can cause (eventuall) OOMs
    tf.get_default_graph().finalize()

    print("Start training!")

    on_step = sess.run(global_step)
    if save_start:
        summary_writer.add_graph(sess.graph, global_step=on_step)
        save_train_start(out.dir, data, on_step, evaluators, train_params,
                         notes)

    if train_params.eval_at_zero:
        print("Running evaluation...")
        start_eval = False
        for name, data in eval_datasets.items():
            n_samples = train_params.eval_samples.get(name)
            evaluation = evaluator_runner.run_evaluators(
                sess, data, name, n_samples)
            for s in evaluation.to_summaries(name + "-"):
                summary_writer.add_summary(s, on_step)

    batch_time = 0
    for epoch in range(train_params.num_epochs):
        for batch_ix, batch in enumerate(train.get_epoch()):
            t0 = time.perf_counter()
            on_step = sess.run(
                global_step
            ) + 1  # +1 because all calculations are done after step

            get_summary = on_step % train_params.log_period == 0
            encoded = model.encode(batch, True)

            if get_summary:
                summary, _, batch_loss = sess.run(
                    [summary_tensor, train_opt, loss], feed_dict=encoded)
            else:
                summary = None
                _, batch_loss = sess.run([train_opt, loss], feed_dict=encoded)

            if np.isnan(batch_loss):
                raise RuntimeError("NaN loss!")

            batch_time += time.perf_counter() - t0
            if get_summary:
                print("on epoch=%d batch=%d step=%d time=%.3f" %
                      (epoch, batch_ix + 1, on_step, batch_time))
                summary_writer.add_summary(
                    tf.Summary(value=[
                        tf.Summary.Value(tag="time", simple_value=batch_time)
                    ]), on_step)
                summary_writer.add_summary(summary, on_step)
                batch_time = 0

            # occasional saving
            if on_step % train_params.save_period == 0:
                print("Checkpointing")
                saver.save(sess,
                           join(out.save_dir, "checkpoint-" + str(on_step)),
                           global_step=global_step)

            # Occasional evaluation
            if (on_step % train_params.eval_period == 0) or start_eval:
                print("Running evaluation...")
                start_eval = False
                t0 = time.perf_counter()
                for name, data in eval_datasets.items():
                    n_samples = train_params.eval_samples.get(name)
                    evaluation = evaluator_runner.run_evaluators(
                        sess, data, name, n_samples)
                    for s in evaluation.to_summaries(name + "-"):
                        summary_writer.add_summary(s, on_step)

                print("Evaluation took: %.3f seconds" %
                      (time.perf_counter() - t0))

    saver.save(sess,
               relpath(join(out.save_dir, "checkpoint-" + str(on_step))),
               global_step=global_step)
    sess.close()
Пример #3
0
def test(model,
         evaluators,
         datasets: Dict,
         loader,
         checkpoint,
         ema=True,
         aysnc_encoding=None,
         sample=None,
         elmo_char_cnn=True) -> Dict[str, Evaluation]:
    print("Setting up model")
    model.set_inputs(list(datasets.values()), loader)

    if aysnc_encoding:
        evaluator_runner = AysncEvaluatorRunner(evaluators, model,
                                                aysnc_encoding)
        inputs = evaluator_runner.dequeue_op
    else:
        evaluator_runner = EvaluatorRunner(evaluators, model)
        inputs = model.get_placeholders()
    input_dict = {p: x for p, x in zip(model.get_placeholders(), inputs)}

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    with sess.as_default():
        pred = model.get_predictions_for(input_dict)
    evaluator_runner.set_input(pred)

    print("Restoring variables")
    if elmo_char_cnn:
        all_vars = tf.global_variables() + tf.get_collection(
            tf.GraphKeys.SAVEABLE_OBJECTS)
        lm_var_names = {x.name for x in all_vars if x.name.startswith("bilm")}
        vars_to_restore = [x for x in all_vars if x.name not in lm_var_names]
        saver = tf.train.Saver(vars_to_restore)
        sess.run(
            tf.variables_initializer(
                [x for x in all_vars if x.name in lm_var_names]))
        saver.restore(sess, checkpoint)
    else:
        saver = tf.train.Saver()
        saver.restore(sess, checkpoint)

    if ema:
        # FIXME This is a bit stupid, since we are loading variables twice, but I found it
        # a bit fiddly to load the variables directly....
        ema = tf.train.ExponentialMovingAverage(0)
        reader = tf.train.NewCheckpointReader(checkpoint)
        expected_ema_names = {
            ema.average_name(x): x
            for x in tf.trainable_variables()
            if reader.has_tensor(ema.average_name(x))
        }
        if len(expected_ema_names) > 0:
            print("Restoring EMA variables")
            saver = tf.train.Saver(expected_ema_names)
            saver.restore(sess, checkpoint)

    tf.get_default_graph().finalize()

    print("Begin evaluation")

    dataset_outputs = {}
    for name, dataset in datasets.items():
        dataset_outputs[name] = evaluator_runner.run_evaluators(
            sess, dataset, name, sample, {})
    return dataset_outputs