コード例 #1
0
ファイル: trainer.py プロジェクト: artiom-zayats/docqa_squad
def start_training(data: TrainingData,
                   model: Model,
                   train_params: TrainParams,
                   evaluators: List[Evaluator],
                   out: ModelDir,
                   notes: str = None,
                   initialize_from=None,
                   dry_run=False):
    """ Train a model from scratch """
    if initialize_from is None:
        print("Initializing model at: " + out.dir)
        model.init(data.get_train_corpus(), data.get_resource_loader())
    # Else we assume the model has already completed its first phase of initialization

    if not dry_run:
        init(out, model, False)

    _train(model, data, None, initialize_from, True, train_params, evaluators,
           out, notes, dry_run)
コード例 #2
0
ファイル: trainer.py プロジェクト: artiom-zayats/docqa_squad
def _train_async(model: Model,
                 data: TrainingData,
                 checkpoint: Union[str, None],
                 parameter_checkpoint: Union[str, None],
                 save_start: bool,
                 train_params: TrainParams,
                 evaluators: List[Evaluator],
                 out: ModelDir,
                 notes=None,
                 dry_run=False,
                 start_eval=False):
    """ Train while encoding batches on a seperate thread and storing them in a tensorflow Queue, can
    be much faster then using the feed_dict approach """

    train = data.get_train()

    eval_datasets = data.get_eval()
    loader = data.get_resource_loader()
    #pdb.set_trace()
    print("Training on %d batches" % len(train))
    print("Evaluation datasets: " +
          " ".join("%s (%d)" % (name, len(data))
                   for name, data in eval_datasets.items()))

    # spec the model for the given datasets
    #pdb.set_trace()
    model.set_inputs([train] + list(eval_datasets.values()), loader)
    placeholders = model.get_placeholders()

    train_queue = tf.FIFOQueue(train_params.async_encoding,
                               [x.dtype for x in placeholders],
                               name="train_queue")
    evaluator_runner = AysncEvaluatorRunner(evaluators, model,
                                            train_params.async_encoding)
    #evaluator_runner = EvaluatorRunner(evaluators, model)
    train_enqeue = train_queue.enqueue(placeholders)
    train_close = train_queue.close(True)

    is_train = tf.placeholder(tf.bool, ())
    #pdb.set_trace()
    input_tensors = tf.cond(is_train, lambda: train_queue.dequeue(),
                            lambda: evaluator_runner.eval_queue.dequeue())

    #new input tensor with no eval
    #input_tensors = train_queue.dequeue()

    # tensorfow can't infer the shape for an unsized queue, so set it manually
    for input_tensor, pl in zip(input_tensors, placeholders):
        input_tensor.set_shape(pl.shape)

    print("Init model...")
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    #pdb.set_trace()
    with sess.as_default():
        pred = model.get_predictions_for(dict(zip(placeholders,
                                                  input_tensors)))
    evaluator_runner.set_input(pred)  #az need to fix
    #pdb.set_trace()
    if parameter_checkpoint is not None:
        print("Restoring parameters from %s" % parameter_checkpoint)
        saver = tf.train.Saver()
        saver.restore(sess, checkpoint)
        saver = None

    print("Setting up model prediction / tf...")
    all_vars = tf.global_variables()

    loss, summary_tensor, train_opt, global_step, weight_ema = _build_train_ops(
        train_params)

    #pdb.set_trace()
    # Pre-compute tensors we need at evaluations time
    eval_tensors = []
    for ev in evaluators:
        eval_tensors.append(ev.tensors_needed(pred))
    #az rmeoved

    if train_params.best_weights is not None:
        lst = all_vars
        if weight_ema is not None:
            for x in lst:
                v = weight_ema.average(x)
                if v is not None:
                    lst.append(v)
        best_weight_saver = tf.train.Saver(var_list=lst, max_to_keep=1)
        cur_best = None
    else:
        best_weight_saver = None
        cur_best = None

    saver = tf.train.Saver(max_to_keep=train_params.max_checkpoints_to_keep)
    summary_writer = tf.summary.FileWriter(out.log_dir)

    # Load or initialize the model parameters
    if checkpoint is not None:
        print("Restoring from checkpoint...")
        saver.restore(sess, checkpoint)
        print("Loaded checkpoint: " + str(sess.run(global_step)))
    else:
        print("Initializing parameters...")
        sess.run(tf.global_variables_initializer())

    # Make sure no bugs occur that add to the graph in the train loop, that can cause (eventuall) OOMs
    tf.get_default_graph().finalize()

    if dry_run:
        return

    on_step = sess.run(global_step)

    if save_start:
        summary_writer.add_graph(sess.graph, global_step=on_step)
        save_train_start(out.dir, data, sess.run(global_step), evaluators,
                         train_params, notes)

    def enqueue_train():
        try:
            # feed data from the dataset iterator -> encoder -> queue
            for epoch in range(train_params.num_epochs):
                for batch in train.get_epoch():
                    feed_dict = model.encode(batch, True)
                    sess.run(train_enqeue, feed_dict)
        except tf.errors.CancelledError:
            # The queue_close operator has been called, exit gracefully
            return
        except Exception as e:
            # Crashes the main thread with a queue exception
            sess.run(train_close)
            raise e

    train_enqueue_thread = Thread(target=enqueue_train)
    train_enqueue_thread.daemon = True  # Ensure we exit the program on an excpetion

    print("Start training!")

    batch_time = 0
    epoch_best = 0

    dev_acc = []
    train_acc = []

    train_dict = {is_train: True}
    eval_dict = {is_train: False}
    #pdb.set_trace()
    try:
        train_enqueue_thread.start()

        for epoch in range(train_params.num_epochs):
            for batch_ix in range(len(train)):
                t0 = time.perf_counter()
                on_step = sess.run(global_step) + 1
                get_summary = on_step % train_params.log_period == 0

                if get_summary:
                    summary, _, batch_loss = sess.run(
                        [summary_tensor, train_opt, loss],
                        feed_dict=train_dict)
                    print('batch_loss is: ' + str(batch_loss))
                else:
                    summary = None
                    _, batch_loss = sess.run([train_opt, loss],
                                             feed_dict=train_dict)

                    #print(batch_loss)
                    #pdb.set_trace()
                    #with sess.as_default():
                    #    temp= model.get_predictions_for(dict(zip(placeholders, input_tensors)))
                    #pdb.set_trace()

                if np.isnan(batch_loss):
                    raise RuntimeError("NaN loss!")

                batch_time += time.perf_counter() - t0
                if summary is not None:
                    print("on epoch=%d batch=%d step=%d, time=%.3f" %
                          (epoch, batch_ix + 1, on_step, batch_time))

                    summary_writer.add_summary(
                        tf.Summary(value=[
                            tf.Summary.Value(tag="time",
                                             simple_value=batch_time)
                        ]), on_step)
                    summary_writer.add_summary(summary, on_step)
                    batch_time = 0

                # occasional saving
                if on_step % train_params.save_period == 0:
                    print("Checkpointing")
                    saver.save(sess,
                               join(out.save_dir,
                                    "checkpoint-" + str(on_step)),
                               global_step=global_step)

                # Occasional evaluation
                #pdb.set_trace()
                if (on_step % train_params.eval_period == 0) or start_eval:

                    print("Running evaluation...")
                    start_eval = False
                    t0 = time.perf_counter()
                    for name, data in eval_datasets.items():
                        n_samples = train_params.eval_samples.get(name)
                        #pdb.set_trace()
                        evaluation = evaluator_runner.run_evaluators(
                            sess, data, name, n_samples, eval_dict)
                        #pdb.set_trace()
                        for s in evaluation.to_summaries(name + "-"):
                            summary_writer.add_summary(s, on_step)
                        #pdb.set_trace()
                        # Maybe save as the best weights
                        if train_params.best_weights is not None and name == train_params.best_weights[
                                0]:
                            #pdb.set_trace()
                            val = evaluation.scalars[
                                train_params.best_weights[1]]
                            dev_acc.append(val)

                            if cur_best is None or val > cur_best:
                                epoch_best = epoch
                                send_email(
                                    'epoch: ' + str(epoch_best) + 'acc: ' +
                                    str(val), 'New Best')
                                print(
                                    "Save weights with current best weights (%s vs %.5f)"
                                    % ("None" if cur_best is None else
                                       ("%.5f" % cur_best), val))
                                best_weight_saver.save(sess,
                                                       join(
                                                           out.best_weight_dir,
                                                           "best"),
                                                       global_step=global_step)
                                cur_best = val
                                if (cur_best > 0.37):
                                    email_text = 'Best accuracy for dev data: ' + (
                                        '%.3f' %
                                        cur_best) + ' <br> On epoch n: ' + str(
                                            epoch_best) + ' out of: ' + str(
                                                train_params.num_epochs
                                            ) + ' <br> Folder: ' + str(
                                                out.save_dir)
                                    email_title = 'Good News EveryOne!'
                                    send_email(email_text, email_title)

                            print('Current accuracy for dev data: ' +
                                  ('%.3f' % val))
                            print('Best accuracy for dev data: ' +
                                  ('%.3f' % cur_best) + 'on epoch n:' +
                                  str(epoch_best))
                        else:
                            val_train = evaluation.scalars[
                                train_params.best_weights[1]]
                            train_acc.append(val_train)
                            print('Current accuracy for train data: ' +
                                  ('%.3f' % val_train))

                    print("Evaluation took: %.3f seconds" %
                          (time.perf_counter() - t0))

    finally:
        sess.run(
            train_close)  # terminates the enqueue thread with an exception

    train_enqueue_thread.join()

    email_text = 'Finished ' + str(
        train_params.num_epochs) + ' Best accuracy for dev data: ' + (
            '%.3f' % cur_best) + ' <br> On epoch n: ' + str(
                epoch_best) + ' <br> Acc for train data last: ' + (
                    '%.3f' % val_train) + ' <br> Folder: ' + str(out.save_dir)
    email_title = 'Test Finished'
    image_path = create_train_dev_plot(dev_acc, train_acc, out.save_dir)
    send_email(email_text, email_title, image_path)

    saver.save(sess,
               relpath(join(out.save_dir, "checkpoint-" + str(on_step))),
               global_step=global_step)
    sess.close()
コード例 #3
0
ファイル: trainer.py プロジェクト: artiom-zayats/docqa_squad
def _train(model: Model,
           data: TrainingData,
           checkpoint: Union[str, None],
           parameter_checkpoint: Union[str, None],
           save_start: bool,
           train_params: TrainParams,
           evaluators: List[Evaluator],
           out: ModelDir,
           notes=None,
           dry_run=False,
           start_eval=False):
    if train_params.async_encoding:
        _train_async(model, data, checkpoint, parameter_checkpoint, save_start,
                     train_params, evaluators, out, notes, dry_run, start_eval)
        return

    if train_params.best_weights is not None:
        raise NotImplementedError

    # spec the model for the current voc/input/batching
    train = data.get_train()
    eval_datasets = data.get_eval()
    loader = data.get_resource_loader()
    evaluator_runner = EvaluatorRunner(evaluators, model)

    print("Training on %d batches" % len(train))
    print("Evaluation datasets: " +
          " ".join("%s (%d)" % (name, len(data))
                   for name, data in eval_datasets.items()))

    print("Init model...")
    model.set_inputs([train] + list(eval_datasets.values()), loader)

    print("Setting up model prediction / tf...")

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))

    with sess.as_default():
        pred = model.get_prediction()
    evaluator_runner.set_input(pred)

    if parameter_checkpoint is not None:
        print("Restoring parameters from %s" % parameter_checkpoint)
        saver = tf.train.Saver(tf.trainable_variables())
        saver.restore(sess, parameter_checkpoint)
        saver = None

    loss, summary_tensor, train_opt, global_step, _ = _build_train_ops(
        train_params)

    # Pre-compute tensors we need at evaluations time
    eval_tensors = []
    for ev in evaluators:
        eval_tensors.append(ev.tensors_needed(pred))

    saver = tf.train.Saver(max_to_keep=train_params.max_checkpoints_to_keep)
    summary_writer = tf.summary.FileWriter(out.log_dir)

    # Load or initialize the model parameters
    if checkpoint is not None:
        print("Restoring training from checkpoint...")
        saver.restore(sess, checkpoint)
        print("Loaded checkpoint: " + str(sess.run(global_step)))
        return
    else:
        if parameter_checkpoint is not None:
            print("Initializing training variables...")
            vars = [
                x for x in tf.global_variables()
                if x not in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            ]
            sess.run(tf.variables_initializer(vars))
        else:
            print("Initializing parameters...")
            sess.run(tf.global_variables_initializer())

    # Make sure no bugs occur that add to the graph in the train loop, that can cause (eventuall) OOMs
    tf.get_default_graph().finalize()

    print("Start training!")

    on_step = sess.run(global_step)
    if save_start:
        summary_writer.add_graph(sess.graph, global_step=on_step)
        save_train_start(out.dir, data, on_step, evaluators, train_params,
                         notes)

    if train_params.eval_at_zero:
        print("Running evaluation...")
        start_eval = False
        for name, data in eval_datasets.items():
            n_samples = train_params.eval_samples.get(name)
            evaluation = evaluator_runner.run_evaluators(
                sess, data, name, n_samples)
            for s in evaluation.to_summaries(name + "-"):
                summary_writer.add_summary(s, on_step)

    batch_time = 0
    for epoch in range(train_params.num_epochs):
        for batch_ix, batch in enumerate(train.get_epoch()):
            t0 = time.perf_counter()
            on_step = sess.run(
                global_step
            ) + 1  # +1 because all calculations are done after step

            get_summary = on_step % train_params.log_period == 0
            encoded = model.encode(batch, True)

            if get_summary:
                summary, _, batch_loss = sess.run(
                    [summary_tensor, train_opt, loss], feed_dict=encoded)
            else:
                summary = None
                _, batch_loss = sess.run([train_opt, loss], feed_dict=encoded)

            if np.isnan(batch_loss):
                raise RuntimeError("NaN loss!")

            batch_time += time.perf_counter() - t0
            if get_summary:
                print("on epoch=%d batch=%d step=%d time=%.3f" %
                      (epoch, batch_ix + 1, on_step, batch_time))
                summary_writer.add_summary(
                    tf.Summary(value=[
                        tf.Summary.Value(tag="time", simple_value=batch_time)
                    ]), on_step)
                summary_writer.add_summary(summary, on_step)
                batch_time = 0

            # occasional saving
            if on_step % train_params.save_period == 0:
                print("Checkpointing")
                saver.save(sess,
                           join(out.save_dir, "checkpoint-" + str(on_step)),
                           global_step=global_step)

            # Occasional evaluation
            if (on_step % train_params.eval_period == 0) or start_eval:
                print("Running evaluation...")
                start_eval = False
                t0 = time.perf_counter()
                for name, data in eval_datasets.items():
                    n_samples = train_params.eval_samples.get(name)
                    evaluation = evaluator_runner.run_evaluators(
                        sess, data, name, n_samples)
                    for s in evaluation.to_summaries(name + "-"):
                        summary_writer.add_summary(s, on_step)

                print("Evaluation took: %.3f seconds" %
                      (time.perf_counter() - t0))

    saver.save(sess,
               relpath(join(out.save_dir, "checkpoint-" + str(on_step))),
               global_step=global_step)
    sess.close()