예제 #1
0
def test_load_embeddings_from_ascii():
    vocabulary = {"strange_and_exotic_word": 0, "the": 1, ".": 2}
    loaded_matrix = util.LoadEmbeddingsFromASCII(vocabulary, 5,
                                                 TEST_EMBEDDING_MATRIX)
    expected = np.asarray(
        [[0, 0, 0, 0, 0], [0.418, 0.24968, -0.41242, 0.1217, 0.34527],
         [0.15164, 0.30177, -0.16763, 0.17684, 0.31719]],
        dtype=np.float32)

    np.testing.assert_array_equal(loaded_matrix, expected)
예제 #2
0
def run(only_forward=False):
    logger = afs_safe_logger.Logger(
        os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log")

    if FLAGS.data_type == "bl":
        data_manager = load_boolean_data
    elif FLAGS.data_type == "sst":
        data_manager = load_sst_data
    elif FLAGS.data_type == "snli":
        data_manager = load_snli_data
    else:
        logger.Log("Bad data type.")
        return

    if FLAGS.model_type != "Model0":
        raise NotImplementedError("Only basic model 0 (SPINN-PI, SPINN-PI-NT) "
                                  "is supported in the thin-stack "
                                  "implementation.")

    pp = pprint.PrettyPrinter(indent=4)
    logger.Log("Flag values:\n" + pp.pformat(FLAGS.FlagValuesDict()))

    # Load the data.
    raw_training_data, vocabulary = data_manager.load_data(
        FLAGS.training_data_path)

    # Load the eval data.
    raw_eval_sets = []
    if FLAGS.eval_data_path:
        for eval_filename in FLAGS.eval_data_path.split(":"):
            eval_data, _ = data_manager.load_data(eval_filename)
            raw_eval_sets.append((eval_filename, eval_data))

    # Prepare the vocabulary.
    if not vocabulary:
        logger.Log(
            "In open vocabulary mode. Using loaded embeddings without fine-tuning."
        )
        train_embeddings = False
        vocabulary = util.BuildVocabulary(
            raw_training_data,
            raw_eval_sets,
            FLAGS.embedding_data_path,
            logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA)
    else:
        logger.Log("In fixed vocabulary mode. Training embeddings.")
        train_embeddings = True

    # Load pretrained embeddings.
    if FLAGS.embedding_data_path:
        logger.Log("Loading vocabulary with " + str(len(vocabulary)) +
                   " words from " + FLAGS.embedding_data_path)
        initial_embeddings = util.LoadEmbeddingsFromASCII(
            vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path)
    else:
        initial_embeddings = None

    # Trim dataset, convert token sequences to integer sequences, crop, and
    # pad.
    logger.Log("Preprocessing training data.")
    training_data = util.PreprocessDataset(
        raw_training_data,
        vocabulary,
        FLAGS.seq_length,
        data_manager,
        eval_mode=False,
        logger=logger,
        sentence_pair_data=data_manager.SENTENCE_PAIR_DATA)
    training_data_iter = util.MakeTrainingIterator(training_data,
                                                   FLAGS.batch_size)

    eval_iterators = []
    for filename, raw_eval_set in raw_eval_sets:
        logger.Log("Preprocessing eval data: " + filename)
        e_X, e_transitions, e_y, e_num_transitions = util.PreprocessDataset(
            raw_eval_set,
            vocabulary,
            FLAGS.seq_length,
            data_manager,
            eval_mode=True,
            logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA)
        eval_iterators.append(
            (filename,
             util.MakeEvalIterator(
                 (e_X, e_transitions, e_y, e_num_transitions),
                 FLAGS.batch_size)))

    # Set up the placeholders.
    y = T.vector("y", dtype="int32")
    lr = T.scalar("lr")
    training_mode = T.scalar(
        "training_mode")  # 1: Training with dropout, 0: Eval
    ground_truth_transitions_visible = T.scalar(
        "ground_truth_transitions_visible", dtype="int32")

    logger.Log("Building model.")
    vs = util.VariableStore(default_initializer=util.UniformInitializer(
        FLAGS.init_range),
                            logger=logger)

    if FLAGS.model_type == "RNN":
        model_cls = spinn.plain_rnn.RNN
    else:
        model_cls = getattr(recurrences, FLAGS.model_type)

    # Generator of mask for scheduled sampling
    numpy_random = np.random.RandomState(1234)
    ss_mask_gen = T.shared_randomstreams.RandomStreams(
        numpy_random.randint(999999))

    # Training step number
    ss_prob = T.scalar("ss_prob")

    if data_manager.SENTENCE_PAIR_DATA:
        X = T.itensor3("X")
        transitions = T.itensor3("transitions")
        num_transitions = T.imatrix("num_transitions")

        premise_model, hypothesis_model, logits, zero_fn = build_sentence_pair_model(
            model_cls,
            len(vocabulary),
            FLAGS.seq_length,
            X,
            transitions,
            len(data_manager.LABEL_MAP),
            training_mode,
            ground_truth_transitions_visible,
            vs,
            initial_embeddings=initial_embeddings,
            project_embeddings=(not train_embeddings),
            ss_mask_gen=ss_mask_gen,
            ss_prob=ss_prob)
        premise_stack_top = premise_model.sentence_embeddings
        hypothesis_stack_top = hypothesis_model.sentence_embeddings
        predicted_premise_transitions = premise_model.transitions_pred
        predicted_hypothesis_transitions = hypothesis_model.transitions_pred
    else:
        X = T.matrix("X", dtype="int32")
        transitions = T.imatrix("transitions")
        num_transitions = T.vector("num_transitions", dtype="int32")

        model, logits, zero_fn = build_sentence_model(
            model_cls,
            len(vocabulary),
            FLAGS.seq_length,
            X,
            transitions,
            len(data_manager.LABEL_MAP),
            training_mode,
            ground_truth_transitions_visible,
            vs,
            initial_embeddings=initial_embeddings,
            project_embeddings=(not train_embeddings),
            ss_mask_gen=ss_mask_gen,
            ss_prob=ss_prob)
        stack_top = model.sentence_embeddings
        predicted_transitions = model.transitions_pred

    xent_cost, acc = build_cost(logits, y)

    # Set up L2 regularization.
    l2_cost = 0.0
    for var in vs.trainable_vars:
        l2_cost += FLAGS.l2_lambda * T.sum(T.sqr(vs.vars[var]))

    # Compute cross-entropy cost on action predictions.
    if (not data_manager.SENTENCE_PAIR_DATA
        ) and predicted_transitions is not None:
        transition_cost, action_acc = build_transition_cost(
            predicted_transitions, transitions, num_transitions)
    elif data_manager.SENTENCE_PAIR_DATA and predicted_hypothesis_transitions is not None:
        p_transition_cost, p_action_acc = build_transition_cost(
            predicted_premise_transitions, transitions[:, :, 0],
            num_transitions[:, 0])
        h_transition_cost, h_action_acc = build_transition_cost(
            predicted_hypothesis_transitions, transitions[:, :, 1],
            num_transitions[:, 1])
        transition_cost = p_transition_cost + h_transition_cost
        action_acc = (p_action_acc + h_action_acc
                      ) / 2.0  # TODO(SB): Average over transitions, not words.
    else:
        transition_cost = T.constant(0.0)
        action_acc = T.constant(0.0)
    transition_cost = transition_cost * FLAGS.transition_cost_scale

    total_cost = xent_cost + l2_cost + transition_cost

    if ".ckpt" in FLAGS.ckpt_path:
        checkpoint_path = FLAGS.ckpt_path
    else:
        checkpoint_path = os.path.join(FLAGS.ckpt_path,
                                       FLAGS.experiment_name + ".ckpt")
    if os.path.isfile(checkpoint_path):
        logger.Log("Found checkpoint, restoring.")
        step, best_dev_error = vs.load_checkpoint(
            checkpoint_path,
            num_extra_vars=2,
            skip_saved_unsavables=FLAGS.skip_saved_unsavables)
    else:
        assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint."
        step = 0
        best_dev_error = 1.0

    # Do an evaluation-only run.
    if only_forward:
        if FLAGS.eval_output_paths:
            eval_output_paths = FLAGS.eval_output_paths.strip().split(":")
            assert len(eval_output_paths) == len(
                eval_iterators), "Invalid no. of output paths."
        else:
            eval_output_paths = [
                FLAGS.experiment_name + "-" + os.path.split(eval_set[0])[1] +
                "-parse" for eval_set in eval_iterators
            ]

        # Load model from checkpoint.
        logger.Log("Checkpointed model was trained for %d steps." % (step, ))

        # Generate function for forward pass.
        logger.Log("Building forward pass.")
        if data_manager.SENTENCE_PAIR_DATA:
            eval_fn = theano.function([
                X, transitions, y, num_transitions, training_mode,
                ground_truth_transitions_visible, ss_prob
            ], [
                acc, action_acc, logits, predicted_hypothesis_transitions,
                predicted_premise_transitions
            ],
                                      on_unused_input='warn',
                                      allow_input_downcast=True)
        else:
            eval_fn = theano.function([
                X, transitions, y, num_transitions, training_mode,
                ground_truth_transitions_visible, ss_prob
            ], [acc, action_acc, logits, predicted_transitions],
                                      on_unused_input='warn',
                                      allow_input_downcast=True)

        # Generate the inverse vocabulary lookup table.
        ind_to_word = {v: k for k, v in vocabulary.iteritems()}

        # Do a forward pass and write the output to disk.
        for eval_set, eval_out_path in zip(eval_iterators, eval_output_paths):
            logger.Log("Writing eval output for %s." % (eval_set[0], ))
            evaluate_expanded(eval_fn, eval_set, eval_out_path, logger, step,
                              data_manager.SENTENCE_PAIR_DATA, ind_to_word,
                              zero_fn)
    else:
        # Train
        extra_cost_inputs = [
            y, training_mode, ground_truth_transitions_visible
        ]
        if data_manager.SENTENCE_PAIR_DATA:
            # The two models use slices of the original data.
            # Pass the original data as a non-sequence input as well.
            extra_cost_inputs += [X, transitions]

            premise_error_signal = T.grad(total_cost, premise_stack_top)
            premise_model.make_backprop_scan(
                premise_error_signal,
                extra_cost_inputs=extra_cost_inputs,
                compute_embedding_gradients=False)

            extra_cost_inputs += [premise_model.stack
                                  ] + premise_model.aux_stacks
            hypothesis_error_signal = T.grad(total_cost, hypothesis_stack_top)
            hypothesis_model.make_backprop_scan(
                hypothesis_error_signal,
                extra_cost_inputs=extra_cost_inputs,
                compute_embedding_gradients=False)

            gradients = premise_model.gradients
            hypothesis_gradients = hypothesis_model.gradients
            for key in hypothesis_gradients:
                if key in gradients:
                    gradients[key] += hypothesis_gradients[key]
                else:
                    gradients[key] = hypothesis_gradients[key]

            new_values = util.merge_updates(
                premise_model.scan_updates + premise_model.bscan_updates,
                hypothesis_model.scan_updates +
                hypothesis_model.bscan_updates).items()
            other_params = set(vs.trainable_vars.keys()) - premise_model._vars
            other_params -= hypothesis_model._vars
        else:
            error_signal = T.grad(total_cost, stack_top)
            model.make_backprop_scan(
                error_signal,
                extra_cost_inputs=extra_cost_inputs,
                compute_embedding_gradients=train_embeddings)
            if train_embeddings:
                model.gradients[model.embeddings] = model.embedding_gradients
            gradients = model.gradients

            new_values = model.scan_updates.items(
            ) + model.bscan_updates.items()
            other_params = set(vs.trainable_vars.keys()) - model._vars

        # Remove null stack parameter gradients.
        null_gradients = set()
        for key, val in gradients.iteritems():
            if val is None:
                null_gradients.add(key)
        if null_gradients:
            logger.Log(
                "The following parameters have null (disconnected) cost "
                "gradients and will not be trained: %s" %
                ", ".join(str(k) for k in null_gradients), logger.WARNING)
        for key in null_gradients:
            del gradients[key]

        # Calculate gradients for items before/after stack fprop.
        other_params = [vs.vars[param] for param in other_params]
        other_grads = T.grad(total_cost, wrt=other_params)
        gradients.update(zip(other_params, other_grads))

        new_values += util.RMSprop(total_cost,
                                   gradients.keys(),
                                   lr,
                                   grads=gradients.values())
        new_values += [(key, vs.nongradient_updates[key])
                       for key in vs.nongradient_updates]

        # Create training and eval functions.
        # Unused variable warnings are supressed so that num_transitions can be passed in when training Model 0,
        # which ignores it. This yields more readable code that is very slightly slower.
        logger.Log("Building update function.")
        update_fn = theano.function([
            X, transitions, y, num_transitions, lr, training_mode,
            ground_truth_transitions_visible, ss_prob
        ], [total_cost, xent_cost, transition_cost, action_acc, l2_cost, acc],
                                    updates=new_values,
                                    on_unused_input='warn',
                                    allow_input_downcast=True)
        logger.Log("Building eval function.")
        eval_fn = theano.function([
            X, transitions, y, num_transitions, training_mode,
            ground_truth_transitions_visible, ss_prob
        ], [acc, action_acc],
                                  on_unused_input='warn',
                                  allow_input_downcast=True)
        logger.Log("Training.")

        # Main training loop.
        for step in range(step, FLAGS.training_steps):
            if step % FLAGS.eval_interval_steps == 0:
                for index, eval_set in enumerate(eval_iterators):
                    acc = evaluate(eval_fn, eval_set, logger, step, zero_fn)
                    if FLAGS.ckpt_on_best_dev_error and index == 0 and (
                            1 - acc) < 0.99 * best_dev_error and step > 1000:
                        best_dev_error = 1 - acc
                        logger.Log(
                            "Checkpointing with new best dev accuracy of %f" %
                            acc)
                        vs.save_checkpoint(checkpoint_path + "_best",
                                           extra_vars=[step, best_dev_error])

            X_batch, transitions_batch, y_batch, num_transitions_batch = training_data_iter.next(
            )
            # HACK: Drop training batches which aren't well-sized. (Will only
            # trigger for the final batch in a dataset.)
            if X_batch.shape[0] != FLAGS.batch_size:
                continue

            learning_rate = FLAGS.learning_rate * (
                FLAGS.learning_rate_decay_per_10k_steps**(step / 10000.0))
            ret = update_fn(
                X_batch, transitions_batch, y_batch, num_transitions_batch,
                learning_rate, 1.0, 1.0,
                np.exp(step * np.log(FLAGS.scheduled_sampling_exponent_base)))
            total_cost_val, xent_cost_val, transition_cost_val, action_acc_val, l2_cost_val, acc_val = ret

            if step % FLAGS.statistics_interval_steps == 0:
                logger.Log("Step: %i\tAcc: %f\t%f\tCost: %5f %5f %5f %5f" %
                           (step, acc_val, action_acc_val, total_cost_val,
                            xent_cost_val, transition_cost_val, l2_cost_val))

            if step % FLAGS.ckpt_interval_steps == 0 and step > 0:
                vs.save_checkpoint(checkpoint_path,
                                   extra_vars=[step, best_dev_error])

            # Zero out all auxiliary variables.
            zero_fn()
예제 #3
0
def run(only_forward=False):
    logger = afs_safe_logger.Logger(os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log")

    if FLAGS.data_type == "bl":
        data_manager = load_boolean_data
    elif FLAGS.data_type == "sst":
        data_manager = load_sst_data
    elif FLAGS.data_type == "snli":
        data_manager = load_snli_data
    else:
        logger.Log("Bad data type.")
        return

    pp = pprint.PrettyPrinter(indent=4)
    logger.Log("Flag values:\n" + pp.pformat(FLAGS.FlagValuesDict()))

    # Load the data.
    raw_training_data, vocabulary = data_manager.load_data(
        FLAGS.training_data_path)

    # Load the eval data.
    raw_eval_sets = []
    if FLAGS.eval_data_path:
        for eval_filename in FLAGS.eval_data_path.split(":"):
            eval_data, _ = data_manager.load_data(eval_filename)
            raw_eval_sets.append((eval_filename, eval_data))

    # Prepare the vocabulary.
    if not vocabulary:
        logger.Log("In open vocabulary mode. Using loaded embeddings without fine-tuning.")
        train_embeddings = False
        vocabulary = util.BuildVocabulary(
            raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA)
    else:
        logger.Log("In fixed vocabulary mode. Training embeddings.")
        train_embeddings = True

    # Load pretrained embeddings.
    if FLAGS.embedding_data_path:
        logger.Log("Loading vocabulary with " + str(len(vocabulary))
                   + " words from " + FLAGS.embedding_data_path)
        initial_embeddings = util.LoadEmbeddingsFromASCII(
            vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path)
    else:
        initial_embeddings = None

    # Trim dataset, convert token sequences to integer sequences, crop, and
    # pad.
    logger.Log("Preprocessing training data.")
    training_data = util.PreprocessDataset(
        raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger,
        sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
        for_rnn=FLAGS.model_type == "RNN" or FLAGS.model_type == "CBOW")
    training_data_iter = util.MakeTrainingIterator(
        training_data, FLAGS.batch_size)

    eval_iterators = []
    for filename, raw_eval_set in raw_eval_sets:
        logger.Log("Preprocessing eval data: " + filename)
        e_X, e_transitions, e_y, e_num_transitions = util.PreprocessDataset(
            raw_eval_set, vocabulary, FLAGS.seq_length, data_manager, eval_mode=True, logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
            for_rnn=FLAGS.model_type == "RNN" or FLAGS.model_type == "CBOW")
        eval_iterators.append((filename,
            util.MakeEvalIterator((e_X, e_transitions, e_y, e_num_transitions), FLAGS.batch_size)))

    # Set up the placeholders.

    y = T.vector("y", dtype="int32")
    lr = T.scalar("lr")
    training_mode = T.scalar("training_mode")  # 1: Training with dropout, 0: Eval
    ground_truth_transitions_visible = T.scalar("ground_truth_transitions_visible", dtype="int32")

    logger.Log("Building model.")
    vs = util.VariableStore(
        default_initializer=util.UniformInitializer(FLAGS.init_range), logger=logger)

    if FLAGS.model_type == "CBOW":
        model_cls = spinn.cbow.CBOW
    elif FLAGS.model_type == "RNN":
        model_cls = spinn.plain_rnn.RNN
    else:
        model_cls = getattr(spinn.fat_stack, FLAGS.model_type)

    # Generator of mask for scheduled sampling
    numpy_random = np.random.RandomState(1234)
    ss_mask_gen = T.shared_randomstreams.RandomStreams(numpy_random.randint(999999))

    # Training step number
    ss_prob = T.scalar("ss_prob")

    if data_manager.SENTENCE_PAIR_DATA:
        X = T.itensor3("X")
        transitions = T.itensor3("transitions")
        num_transitions = T.imatrix("num_transitions")

        predicted_premise_transitions, predicted_hypothesis_transitions, logits = build_sentence_pair_model(
            model_cls, len(vocabulary), FLAGS.seq_length,
            X, transitions, len(data_manager.LABEL_MAP), training_mode, ground_truth_transitions_visible, vs,
            initial_embeddings=initial_embeddings, project_embeddings=(not train_embeddings),
            ss_mask_gen=ss_mask_gen,
            ss_prob=ss_prob)
    else:
        X = T.matrix("X", dtype="int32")
        transitions = T.imatrix("transitions")
        num_transitions = T.vector("num_transitions", dtype="int32")

        predicted_transitions, logits = build_sentence_model(
            model_cls, len(vocabulary), FLAGS.seq_length,
            X, transitions, len(data_manager.LABEL_MAP), training_mode, ground_truth_transitions_visible, vs,
            initial_embeddings=initial_embeddings, project_embeddings=(not train_embeddings),
            ss_mask_gen=ss_mask_gen,
            ss_prob=ss_prob)

    xent_cost, acc = build_cost(logits, y)

    # Set up L2 regularization.
    l2_cost = 0.0
    for var in vs.trainable_vars:
        l2_cost += FLAGS.l2_lambda * T.sum(T.sqr(vs.vars[var]))

    # Compute cross-entropy cost on action predictions.
    if (not data_manager.SENTENCE_PAIR_DATA) and FLAGS.model_type not in ["Model0", "RNN", "CBOW"]:
        transition_cost, action_acc = build_transition_cost(predicted_transitions, transitions, num_transitions)
    elif data_manager.SENTENCE_PAIR_DATA and FLAGS.model_type not in ["Model0", "RNN", "CBOW"]:
        p_transition_cost, p_action_acc = build_transition_cost(predicted_premise_transitions, transitions[:, :, 0],
            num_transitions[:, 0])
        h_transition_cost, h_action_acc = build_transition_cost(predicted_hypothesis_transitions, transitions[:, :, 1],
            num_transitions[:, 1])
        transition_cost = p_transition_cost + h_transition_cost
        action_acc = (p_action_acc + h_action_acc) / 2.0  # TODO(SB): Average over transitions, not words.
    else:
        transition_cost = T.constant(0.0)
        action_acc = T.constant(0.0)
    transition_cost = transition_cost * FLAGS.transition_cost_scale

    total_cost = xent_cost + l2_cost + transition_cost

    if ".ckpt" in FLAGS.ckpt_path:
        checkpoint_path = FLAGS.ckpt_path
    else:
        checkpoint_path = os.path.join(FLAGS.ckpt_path, FLAGS.experiment_name + ".ckpt")
    if os.path.isfile(checkpoint_path):
        logger.Log("Found checkpoint, restoring.")
        step, best_dev_error = vs.load_checkpoint(checkpoint_path, num_extra_vars=2,
                                                  skip_saved_unsavables=FLAGS.skip_saved_unsavables)
    else:
        assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint."
        step = 0
        best_dev_error = 1.0

    # Do an evaluation-only run.
    if only_forward:
        if FLAGS.eval_output_paths:
            eval_output_paths = FLAGS.eval_output_paths.strip().split(":")
            assert len(eval_output_paths) == len(eval_iterators), "Invalid no. of output paths."
        else:
            eval_output_paths = [FLAGS.experiment_name + "-" + os.path.split(eval_set[0])[1] + "-parse"
                                  for eval_set in eval_iterators]

        # Load model from checkpoint.
        logger.Log("Checkpointed model was trained for %d steps." % (step,))

        # Generate function for forward pass.
        logger.Log("Building forward pass.")
        if data_manager.SENTENCE_PAIR_DATA:
            eval_fn = theano.function(
                [X, transitions, y, num_transitions, training_mode, ground_truth_transitions_visible, ss_prob],
                [acc, action_acc, logits, predicted_hypothesis_transitions, predicted_premise_transitions],
                on_unused_input='ignore',
                allow_input_downcast=True)
        else:
            eval_fn = theano.function(
                [X, transitions, y, num_transitions, training_mode, ground_truth_transitions_visible, ss_prob],
                [acc, action_acc, logits, predicted_transitions],
                on_unused_input='ignore',
                allow_input_downcast=True)

        # Generate the inverse vocabulary lookup table.
        ind_to_word = {v : k for k, v in vocabulary.iteritems()}

        # Do a forward pass and write the output to disk.
        for eval_set, eval_out_path in zip(eval_iterators, eval_output_paths):
            logger.Log("Writing eval output for %s." % (eval_set[0],))
            evaluate_expanded(eval_fn, eval_set, eval_out_path, logger, step,
                              data_manager.SENTENCE_PAIR_DATA, ind_to_word, FLAGS.model_type not in ["Model0", "RNN", "CBOW"])
    else:
         # Train

        new_values = util.RMSprop(total_cost, vs.trainable_vars.values(), lr)
        new_values += [(key, vs.nongradient_updates[key]) for key in vs.nongradient_updates]
        # Training open-vocabulary embeddings is a questionable idea right now. Disabled:
        # new_values.append(
        #     util.embedding_SGD(total_cost, embedding_params, embedding_lr))

        # Create training and eval functions.
        # Unused variable warnings are supressed so that num_transitions can be passed in when training Model 0,
        # which ignores it. This yields more readable code that is very slightly slower.
        logger.Log("Building update function.")
        update_fn = theano.function(
            [X, transitions, y, num_transitions, lr, training_mode, ground_truth_transitions_visible, ss_prob],
            [total_cost, xent_cost, transition_cost, action_acc, l2_cost, acc],
            updates=new_values,
            on_unused_input='ignore',
            allow_input_downcast=True)
        logger.Log("Building eval function.")
        eval_fn = theano.function([X, transitions, y, num_transitions, training_mode, ground_truth_transitions_visible, ss_prob], [acc, action_acc],
            on_unused_input='ignore',
            allow_input_downcast=True)
        logger.Log("Training.")

        # Main training loop.
        for step in range(step, FLAGS.training_steps):
            if step % FLAGS.eval_interval_steps == 0:
                for index, eval_set in enumerate(eval_iterators):
                    acc = evaluate(eval_fn, eval_set, logger, step)
                    if FLAGS.ckpt_on_best_dev_error and index == 0 and (1 - acc) < 0.99 * best_dev_error and step > 1000:
                        best_dev_error = 1 - acc
                        logger.Log("Checkpointing with new best dev accuracy of %f" % acc)
                        vs.save_checkpoint(checkpoint_path + "_best", extra_vars=[step, best_dev_error])

            X_batch, transitions_batch, y_batch, num_transitions_batch = training_data_iter.next()
            learning_rate = FLAGS.learning_rate * (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0))
            ret = update_fn(X_batch, transitions_batch, y_batch, num_transitions_batch,
                            learning_rate, 1.0, 1.0, np.exp(step*np.log(FLAGS.scheduled_sampling_exponent_base)))
            total_cost_val, xent_cost_val, transition_cost_val, action_acc_val, l2_cost_val, acc_val = ret

            if step % FLAGS.statistics_interval_steps == 0:
                logger.Log(
                    "Step: %i\tAcc: %f\t%f\tCost: %5f %5f %5f %5f"
                    % (step, acc_val, action_acc_val, total_cost_val, xent_cost_val, transition_cost_val,
                       l2_cost_val))

            if step % FLAGS.ckpt_interval_steps == 0 and step > 0:
                vs.save_checkpoint(checkpoint_path, extra_vars=[step, best_dev_error])
예제 #4
0
def run(only_forward=False):
    logger = afs_safe_logger.Logger(
        os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log")

    if FLAGS.data_type == "bl":
        data_manager = load_boolean_data
    elif FLAGS.data_type == "sst":
        data_manager = load_sst_data
    elif FLAGS.data_type == "snli":
        data_manager = load_snli_data
    else:
        logger.Log("Bad data type.")
        return

    pp = pprint.PrettyPrinter(indent=4)
    logger.Log("Flag values:\n" + pp.pformat(FLAGS.FlagValuesDict()))

    # Load the data.
    raw_training_data, vocabulary = data_manager.load_data(
        FLAGS.training_data_path)

    # Load the eval data.
    raw_eval_sets = []
    if FLAGS.eval_data_path:
        for eval_filename in FLAGS.eval_data_path.split(":"):
            eval_data, _ = data_manager.load_data(eval_filename)
            raw_eval_sets.append((eval_filename, eval_data))

    # Prepare the vocabulary.
    if not vocabulary:
        logger.Log(
            "In open vocabulary mode. Using loaded embeddings without fine-tuning."
        )
        train_embeddings = False
        vocabulary = util.BuildVocabulary(
            raw_training_data,
            raw_eval_sets,
            FLAGS.embedding_data_path,
            logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA)
    else:
        logger.Log("In fixed vocabulary mode. Training embeddings.")
        train_embeddings = True

    # Load pretrained embeddings.
    if FLAGS.embedding_data_path:
        logger.Log("Loading vocabulary with " + str(len(vocabulary)) +
                   " words from " + FLAGS.embedding_data_path)
        initial_embeddings = util.LoadEmbeddingsFromASCII(
            vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path)
    else:
        initial_embeddings = None

    # Trim dataset, convert token sequences to integer sequences, crop, and
    # pad.
    logger.Log("Preprocessing training data.")
    training_data = util.PreprocessDataset(
        raw_training_data,
        vocabulary,
        FLAGS.seq_length,
        data_manager,
        eval_mode=False,
        logger=logger,
        sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
        for_rnn=FLAGS.model_type == "RNN" or FLAGS.model_type == "CBOW")
    training_data_iter = util.MakeTrainingIterator(training_data,
                                                   FLAGS.batch_size,
                                                   FLAGS.smart_batching,
                                                   FLAGS.use_peano)

    eval_iterators = []
    for filename, raw_eval_set in raw_eval_sets:
        logger.Log("Preprocessing eval data: " + filename)
        e_X, e_transitions, e_y, e_num_transitions = util.PreprocessDataset(
            raw_eval_set,
            vocabulary,
            FLAGS.seq_length,
            data_manager,
            eval_mode=True,
            logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
            for_rnn=FLAGS.model_type == "RNN" or FLAGS.model_type == "CBOW")
        eval_iterators.append(
            (filename,
             util.MakeEvalIterator(
                 (e_X, e_transitions, e_y, e_num_transitions),
                 FLAGS.batch_size, FLAGS.eval_data_limit)))

    # Set up the placeholders.

    logger.Log("Building model.")

    if FLAGS.model_type == "CBOW":
        model_module = spinn.cbow
    elif FLAGS.model_type == "RNN":
        model_module = spinn.plain_rnn_chainer
    elif FLAGS.model_type == "NTI":
        model_module = spinn.nti
    elif FLAGS.model_type == "SPINN":
        model_module = spinn.fat_stack
    else:
        raise Exception("Requested unimplemented model type %s" %
                        FLAGS.model_type)

    if data_manager.SENTENCE_PAIR_DATA:
        if hasattr(model_module, 'SentencePairTrainer') and hasattr(
                model_module, 'SentencePairModel'):
            trainer_cls = model_module.SentencePairTrainer
            model_cls = model_module.SentencePairModel
        else:
            raise Exception("Unimplemented for model type %s" %
                            FLAGS.model_type)

        num_classes = len(data_manager.LABEL_MAP)
        classifier_trainer = build_sentence_pair_model(
            model_cls, trainer_cls, len(vocabulary), FLAGS.model_dim,
            FLAGS.word_embedding_dim, FLAGS.seq_length, num_classes,
            initial_embeddings, FLAGS.embedding_keep_rate, FLAGS.gpu)
    else:
        if hasattr(model_module, 'SentenceTrainer') and hasattr(
                model_module, 'SentenceModel'):
            trainer_cls = model_module.SentenceTrainer
            model_cls = model_module.SentenceModel
        else:
            raise Exception("Unimplemented for model type %s" %
                            FLAGS.model_type)

        num_classes = len(data_manager.LABEL_MAP)
        classifier_trainer = build_sentence_pair_model(
            model_cls, trainer_cls, len(vocabulary), FLAGS.model_dim,
            FLAGS.word_embedding_dim, FLAGS.seq_length, num_classes,
            initial_embeddings, FLAGS.embedding_keep_rate, FLAGS.gpu)

    if ".ckpt" in FLAGS.ckpt_path:
        checkpoint_path = FLAGS.ckpt_path
    else:
        checkpoint_path = os.path.join(FLAGS.ckpt_path,
                                       FLAGS.experiment_name + ".ckpt")
    if os.path.isfile(checkpoint_path):
        # TODO: Check that resuming works fine with tf summaries.
        logger.Log("Found checkpoint, restoring.")
        step, best_dev_error = classifier_trainer.load(checkpoint_path)
        logger.Log("Resuming at step: {} with best dev accuracy: {}".format(
            step, 1. - best_dev_error))
    else:
        assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint."
        step = 0
        best_dev_error = 1.0

    if FLAGS.write_summaries:
        from spinn.tf_logger import TFLogger
        train_summary_logger = TFLogger(summary_dir=os.path.join(
            FLAGS.summary_dir, FLAGS.experiment_name, 'train'))
        dev_summary_logger = TFLogger(summary_dir=os.path.join(
            FLAGS.summary_dir, FLAGS.experiment_name, 'dev'))

    # Do an evaluation-only run.
    if only_forward:
        raise Exception("Not implemented for chainer.")
    else:
        # Train
        logger.Log("Training.")

        classifier_trainer.init_optimizer(
            clip=FLAGS.clipping_max_value,
            decay=FLAGS.l2_lambda,
            lr=FLAGS.learning_rate,
        )

        # New Training Loop
        progress_bar = SimpleProgressBar(msg="Training",
                                         bar_length=60,
                                         enabled=FLAGS.show_progress_bar)
        avg_class_acc = 0
        avg_trans_acc = 0
        for step in range(step, FLAGS.training_steps):
            X_batch, transitions_batch, y_batch, num_transitions_batch = training_data_iter.next(
            )
            learning_rate = FLAGS.learning_rate * (
                FLAGS.learning_rate_decay_per_10k_steps**(step / 10000.0))

            # Reset cached gradients.
            classifier_trainer.optimizer.zero_grads()

            # Calculate loss and update parameters.
            ret = classifier_trainer.forward(
                {
                    "sentences": X_batch,
                    "transitions": transitions_batch,
                },
                y_batch,
                train=True,
                predict=False)
            y, loss, class_acc, transition_acc = ret

            # Boilerplate for calculating loss.
            xent_cost_val = loss.data
            transition_cost_val = 0
            avg_trans_acc += transition_acc
            avg_class_acc += class_acc

            if FLAGS.show_intermediate_stats and step % 5 == 0 and step % FLAGS.statistics_interval_steps > 0:
                print("Accuracies so far : ",
                      avg_class_acc / (step % FLAGS.statistics_interval_steps),
                      avg_trans_acc / (step % FLAGS.statistics_interval_steps))

            total_cost_val = xent_cost_val + transition_cost_val
            loss.backward()

            if FLAGS.gradient_check:

                def get_loss():
                    _, check_loss, _, _ = classifier_trainer.forward(
                        {
                            "sentences": X_batch,
                            "transitions": transitions_batch,
                        },
                        y_batch,
                        train=True,
                        predict=False)
                    return check_loss

                gradient_check(classifier_trainer.model, get_loss)

            try:
                classifier_trainer.update()
            except:
                import ipdb
                ipdb.set_trace()
                pass

            # Accumulate accuracy for current interval.
            action_acc_val = 0.0
            acc_val = float(classifier_trainer.model.accuracy.data)

            if FLAGS.write_summaries:
                train_summary_logger.log(step=step,
                                         loss=total_cost_val,
                                         accuracy=acc_val)

            progress_bar.step(
                i=max(0, step - 1) % FLAGS.statistics_interval_steps + 1,
                total=FLAGS.statistics_interval_steps)

            if step % FLAGS.statistics_interval_steps == 0:
                progress_bar.finish()
                avg_class_acc /= FLAGS.statistics_interval_steps
                avg_trans_acc /= FLAGS.statistics_interval_steps
                logger.Log(
                    "Step: %i\tAcc: %f\t%f\tCost: %5f %5f %5f %s" %
                    (step, avg_class_acc, avg_trans_acc, total_cost_val,
                     xent_cost_val, transition_cost_val, "l2-not-exposed"))
                avg_trans_acc = 0
                avg_class_acc = 0

            if step > 0 and step % FLAGS.eval_interval_steps == 0:
                for index, eval_set in enumerate(eval_iterators):
                    acc = evaluate(classifier_trainer, eval_set, logger, step)
                    if FLAGS.ckpt_on_best_dev_error and index == 0 and (
                            1 - acc) < 0.99 * best_dev_error and step > 1000:
                        best_dev_error = 1 - acc
                        logger.Log(
                            "Checkpointing with new best dev accuracy of %f" %
                            acc)
                        classifier_trainer.save(checkpoint_path, step,
                                                best_dev_error)
                    if FLAGS.write_summaries:
                        dev_summary_logger.log(step=step,
                                               loss=0.0,
                                               accuracy=acc)
                progress_bar.reset()

            if FLAGS.profile and step >= FLAGS.profile_steps:
                break