Пример #1
0
def load_data_and_embeddings(FLAGS, data_manager, logger, training_data_path,
                             eval_data_path):
    def choose_train(x):
        return True

    # NOTE
    # if FLAGS.train_genre is not None:
    #     def choose_train(x): return x.get('genre') == FLAGS.train_genre

    def choose_eval(x):
        return True

    # NOTE
    # if FLAGS.eval_genre is not None:
    #     def choose_eval(x): return x.get('genre') == FLAGS.eval_genre

    if not FLAGS.expanded_eval_only_mode:
        if FLAGS.data_type == "nli":
            # Load the data.
            raw_training_data = data_manager.load_data(
                training_data_path,
                FLAGS.lowercase,
                choose_train,
                mode=FLAGS.transition_mode,
                tree_joint=FLAGS.tree_joint,
                distance_type=FLAGS.distance_type)  # only for nli data now
        else:
            # Load the data.
            raw_training_data = data_manager.load_data(
                training_data_path,
                FLAGS.lowercase,
                mode=FLAGS.transition_mode)
    else:
        raw_training_data = None

    if FLAGS.data_type == "nli":
        # Load the eval data.
        raw_eval_sets = []
        for path in eval_data_path.split(':'):
            raw_eval_data = data_manager.load_data(
                path,
                FLAGS.lowercase,
                choose_eval,
                mode=FLAGS.transition_mode,
                tree_joint=FLAGS.tree_joint,
                distance_type=FLAGS.distance_type)  # only for nli data now
            raw_eval_sets.append((path, raw_eval_data))
            # print raw_eval_data[1].keys()
            #exit(1)
    else:
        # Load the eval data.
        raw_eval_sets = []
        for path in eval_data_path.split(':'):
            raw_eval_data = data_manager.load_data(path,
                                                   FLAGS.lowercase,
                                                   mode=FLAGS.transition_mode)
            raw_eval_sets.append((path, raw_eval_data))

    # Prepare the vocabulary.
    if not data_manager.FIXED_VOCABULARY:
        logger.Log(
            "In open vocabulary mode. Using loaded embeddings without fine-tuning."
        )
        vocabulary = util.BuildVocabulary(
            raw_training_data,
            raw_eval_sets,
            FLAGS.embedding_data_path,
            logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA)
    else:
        vocabulary = data_manager.FIXED_VOCABULARY
        logger.Log("In fixed vocabulary mode. Training embeddings.")

    # Load pretrained embeddings.
    if FLAGS.embedding_data_path:
        logger.Log("Loading vocabulary with " + str(len(vocabulary)) +
                   " words from " + FLAGS.embedding_data_path)
        initial_embeddings = util.LoadEmbeddingsFromText(
            vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path)
    else:
        initial_embeddings = None

    # Trim dataset, convert token sequences to integer sequences, crop, and
    # pad.
    logger.Log("Preprocessing training data.")
    training_data = util.PreprocessDataset(
        raw_training_data,
        vocabulary,
        FLAGS.seq_length,
        data_manager,
        eval_mode=False,
        logger=logger,
        sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
        simple=sequential_only(),
        allow_cropping=FLAGS.allow_cropping,
        pad_from_left=pad_from_left(),
        tree_joint=FLAGS.tree_joint) if raw_training_data is not None else None
    training_data_iter = util.MakeTrainingIterator(
        training_data,
        FLAGS.batch_size,
        FLAGS.smart_batching,
        FLAGS.use_peano,
        sentence_pair_data=data_manager.SENTENCE_PAIR_DATA
        # ,train_seed=FLAGS.train_seed
    ) if raw_training_data is not None else None

    # Preprocess eval sets.
    eval_iterators = []
    for filename, raw_eval_set in raw_eval_sets:
        logger.Log("Preprocessing eval data: " + filename)
        eval_data = util.PreprocessDataset(
            raw_eval_set,
            vocabulary,
            FLAGS.eval_seq_length
            if FLAGS.eval_seq_length is not None else FLAGS.seq_length,
            data_manager,
            eval_mode=True,
            logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
            simple=sequential_only(),
            allow_cropping=FLAGS.allow_eval_cropping,
            pad_from_left=pad_from_left(),
            tree_joint=FLAGS.tree_joint)

        eval_it = util.MakeEvalIterator(
            eval_data,
            FLAGS.batch_size * FLAGS.sample_num,  # keep the eval running speed
            FLAGS.eval_data_limit,
            bucket_eval=FLAGS.bucket_eval,
            shuffle=FLAGS.shuffle_eval,
            rseed=FLAGS.shuffle_eval_seed)
        eval_iterators.append((filename, eval_it))
    return vocabulary, initial_embeddings, training_data_iter, eval_iterators
Пример #2
0
def load_data_and_embeddings(FLAGS, data_manager, logger, training_data_path,
                             eval_data_path):
    def choose_train(x):
        return True

    if FLAGS.train_genre is not None:

        def choose_train(x):
            return x.get('genre') == FLAGS.train_genre

    def choose_eval(x):
        return True

    if FLAGS.eval_genre is not None:

        def choose_eval(x):
            return x.get('genre') == FLAGS.eval_genre

    if not FLAGS.expanded_eval_only_mode:
        raw_training_data = data_manager.load_data(training_data_path,
                                                   FLAGS.lowercase,
                                                   eval_mode=False)
    else:
        raw_training_data = None

    raw_eval_sets = []
    for path in eval_data_path.split(':'):
        raw_eval_data = data_manager.load_data(path,
                                               FLAGS.lowercase,
                                               choose_eval,
                                               eval_mode=True)
        raw_eval_sets.append((path, raw_eval_data))

    # Prepare the vocabulary.
    if not data_manager.FIXED_VOCABULARY:
        vocabulary = util.BuildVocabulary(
            raw_training_data,
            raw_eval_sets,
            FLAGS.embedding_data_path,
            logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA)
    else:
        vocabulary = data_manager.FIXED_VOCABULARY
        logger.Log(
            "In fixed vocabulary mode. Training embeddings from scratch.")

    # Load pretrained embeddings.
    if FLAGS.embedding_data_path:
        logger.Log("Loading vocabulary with " + str(len(vocabulary)) +
                   " words from " + FLAGS.embedding_data_path)
        initial_embeddings = util.LoadEmbeddingsFromText(
            vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path)
    else:
        initial_embeddings = None

    # Trim dataset, convert token sequences to integer sequences, crop, and
    # pad.
    logger.Log("Preprocessing training data.")
    if raw_training_data is not None:
        training_data = util.PreprocessDataset(
            raw_training_data,
            vocabulary,
            FLAGS.seq_length,
            data_manager,
            eval_mode=False,
            logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
            simple=sequential_only(),
            allow_cropping=FLAGS.allow_cropping,
            pad_from_left=pad_from_left(
            )) if raw_training_data is not None else None
        training_data_iter = util.MakeTrainingIterator(
            training_data,
            FLAGS.batch_size,
            FLAGS.smart_batching,
            FLAGS.use_peano,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA
        ) if raw_training_data is not None else None
        training_data_length = len(training_data[0])
    else:
        training_data_iter = None
        training_data_length = 0

    # Preprocess eval sets.
    eval_iterators = []
    for filename, raw_eval_set in raw_eval_sets:
        logger.Log("Preprocessing eval data: " + filename)
        eval_data = util.PreprocessDataset(
            raw_eval_set,
            vocabulary,
            FLAGS.eval_seq_length
            if FLAGS.eval_seq_length is not None else FLAGS.seq_length,
            data_manager,
            eval_mode=True,
            logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
            simple=sequential_only(),
            allow_cropping=FLAGS.allow_eval_cropping,
            pad_from_left=pad_from_left())
        eval_it = util.MakeEvalIterator(eval_data,
                                        FLAGS.batch_size,
                                        FLAGS.eval_data_limit,
                                        bucket_eval=FLAGS.bucket_eval,
                                        shuffle=FLAGS.shuffle_eval,
                                        rseed=FLAGS.shuffle_eval_seed)
        eval_iterators.append((filename, eval_it))

    return vocabulary, initial_embeddings, training_data_iter, eval_iterators, training_data_length
Пример #3
0
def run(only_forward=False):
    logger = afs_safe_logger.Logger(
        os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log")

    if FLAGS.data_type == "bl":
        data_manager = load_boolean_data
    elif FLAGS.data_type == "sst":
        data_manager = load_sst_data
    elif FLAGS.data_type == "snli":
        data_manager = load_snli_data
    else:
        logger.Log("Bad data type.")
        return

    pp = pprint.PrettyPrinter(indent=4)
    logger.Log("Flag values:\n" + pp.pformat(FLAGS.FlagValuesDict()))

    # Load the data.
    raw_training_data, vocabulary = data_manager.load_data(
        FLAGS.training_data_path)

    # Load the eval data.
    raw_eval_sets = []
    if FLAGS.eval_data_path:
        for eval_filename in FLAGS.eval_data_path.split(":"):
            eval_data, _ = data_manager.load_data(eval_filename)
            raw_eval_sets.append((eval_filename, eval_data))

    # Prepare the vocabulary.
    if not vocabulary:
        logger.Log(
            "In open vocabulary mode. Using loaded embeddings without fine-tuning."
        )
        train_embeddings = False
        vocabulary = util.BuildVocabulary(
            raw_training_data,
            raw_eval_sets,
            FLAGS.embedding_data_path,
            logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA)
    else:
        logger.Log("In fixed vocabulary mode. Training embeddings.")
        train_embeddings = True

    # Load pretrained embeddings.
    if FLAGS.embedding_data_path:
        logger.Log("Loading vocabulary with " + str(len(vocabulary)) +
                   " words from " + FLAGS.embedding_data_path)
        initial_embeddings = util.LoadEmbeddingsFromASCII(
            vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path)
    else:
        initial_embeddings = None

    # Trim dataset, convert token sequences to integer sequences, crop, and
    # pad.
    logger.Log("Preprocessing training data.")
    training_data = util.PreprocessDataset(
        raw_training_data,
        vocabulary,
        FLAGS.seq_length,
        data_manager,
        eval_mode=False,
        logger=logger,
        sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
        for_rnn=FLAGS.model_type == "RNN" or FLAGS.model_type == "CBOW")
    training_data_iter = util.MakeTrainingIterator(training_data,
                                                   FLAGS.batch_size)

    eval_iterators = []
    for filename, raw_eval_set in raw_eval_sets:
        logger.Log("Preprocessing eval data: " + filename)
        e_X, e_transitions, e_y, e_num_transitions = util.PreprocessDataset(
            raw_eval_set,
            vocabulary,
            FLAGS.seq_length,
            data_manager,
            eval_mode=True,
            logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
            for_rnn=FLAGS.model_type == "RNN" or FLAGS.model_type == "CBOW")
        eval_iterators.append(
            (filename,
             util.MakeEvalIterator(
                 (e_X, e_transitions, e_y, e_num_transitions),
                 FLAGS.batch_size)))

    # Set up the placeholders.

    y = T.vector("y", dtype="int32")
    lr = T.scalar("lr")
    training_mode = T.scalar(
        "training_mode")  # 1: Training with dropout, 0: Eval
    ground_truth_transitions_visible = T.scalar(
        "ground_truth_transitions_visible", dtype="int32")

    logger.Log("Building model.")
    vs = util.VariableStore(default_initializer=util.UniformInitializer(
        FLAGS.init_range),
                            logger=logger)

    if FLAGS.model_type == "CBOW":
        model_cls = spinn.cbow.CBOW
    elif FLAGS.model_type == "RNN":
        model_cls = spinn.plain_rnn.RNN
    else:
        model_cls = getattr(spinn.fat_stack, FLAGS.model_type)

    # Generator of mask for scheduled sampling
    numpy_random = np.random.RandomState(1234)
    ss_mask_gen = T.shared_randomstreams.RandomStreams(
        numpy_random.randint(999999))

    # Training step number
    ss_prob = T.scalar("ss_prob")

    if data_manager.SENTENCE_PAIR_DATA:
        X = T.itensor3("X")
        transitions = T.itensor3("transitions")
        num_transitions = T.imatrix("num_transitions")

        predicted_premise_transitions, predicted_hypothesis_transitions, logits = build_sentence_pair_model(
            model_cls,
            len(vocabulary),
            FLAGS.seq_length,
            X,
            transitions,
            len(data_manager.LABEL_MAP),
            training_mode,
            ground_truth_transitions_visible,
            vs,
            initial_embeddings=initial_embeddings,
            project_embeddings=(not train_embeddings),
            ss_mask_gen=ss_mask_gen,
            ss_prob=ss_prob)
    else:
        X = T.matrix("X", dtype="int32")
        transitions = T.imatrix("transitions")
        num_transitions = T.vector("num_transitions", dtype="int32")

        predicted_transitions, logits = build_sentence_model(
            model_cls,
            len(vocabulary),
            FLAGS.seq_length,
            X,
            transitions,
            len(data_manager.LABEL_MAP),
            training_mode,
            ground_truth_transitions_visible,
            vs,
            initial_embeddings=initial_embeddings,
            project_embeddings=(not train_embeddings),
            ss_mask_gen=ss_mask_gen,
            ss_prob=ss_prob)

    xent_cost, acc = build_cost(logits, y)

    # Set up L2 regularization.
    l2_cost = 0.0
    for var in vs.trainable_vars:
        l2_cost += FLAGS.l2_lambda * T.sum(T.sqr(vs.vars[var]))

    # Compute cross-entropy cost on action predictions.
    if (not data_manager.SENTENCE_PAIR_DATA) and FLAGS.model_type not in [
            "Model0", "RNN", "CBOW"
    ]:
        transition_cost, action_acc = build_transition_cost(
            predicted_transitions, transitions, num_transitions)
    elif data_manager.SENTENCE_PAIR_DATA and FLAGS.model_type not in [
            "Model0", "RNN", "CBOW"
    ]:
        p_transition_cost, p_action_acc = build_transition_cost(
            predicted_premise_transitions, transitions[:, :, 0],
            num_transitions[:, 0])
        h_transition_cost, h_action_acc = build_transition_cost(
            predicted_hypothesis_transitions, transitions[:, :, 1],
            num_transitions[:, 1])
        transition_cost = p_transition_cost + h_transition_cost
        action_acc = (p_action_acc + h_action_acc
                      ) / 2.0  # TODO(SB): Average over transitions, not words.
    else:
        transition_cost = T.constant(0.0)
        action_acc = T.constant(0.0)
    transition_cost = transition_cost * FLAGS.transition_cost_scale

    total_cost = xent_cost + l2_cost + transition_cost

    if ".ckpt" in FLAGS.ckpt_path:
        checkpoint_path = FLAGS.ckpt_path
    else:
        checkpoint_path = os.path.join(FLAGS.ckpt_path,
                                       FLAGS.experiment_name + ".ckpt")
    if os.path.isfile(checkpoint_path):
        logger.Log("Found checkpoint, restoring.")
        step, best_dev_error = vs.load_checkpoint(
            checkpoint_path,
            num_extra_vars=2,
            skip_saved_unsavables=FLAGS.skip_saved_unsavables)
    else:
        assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint."
        step = 0
        best_dev_error = 1.0

    # Do an evaluation-only run.
    if only_forward:
        if FLAGS.eval_output_paths:
            eval_output_paths = FLAGS.eval_output_paths.strip().split(":")
            assert len(eval_output_paths) == len(
                eval_iterators), "Invalid no. of output paths."
        else:
            eval_output_paths = [
                FLAGS.experiment_name + "-" + os.path.split(eval_set[0])[1] +
                "-parse" for eval_set in eval_iterators
            ]

        # Load model from checkpoint.
        logger.Log("Checkpointed model was trained for %d steps." % (step, ))

        # Generate function for forward pass.
        logger.Log("Building forward pass.")
        if data_manager.SENTENCE_PAIR_DATA:
            eval_fn = theano.function([
                X, transitions, y, num_transitions, training_mode,
                ground_truth_transitions_visible, ss_prob
            ], [
                acc, action_acc, logits, predicted_hypothesis_transitions,
                predicted_premise_transitions
            ],
                                      on_unused_input='ignore',
                                      allow_input_downcast=True)
        else:
            eval_fn = theano.function([
                X, transitions, y, num_transitions, training_mode,
                ground_truth_transitions_visible, ss_prob
            ], [acc, action_acc, logits, predicted_transitions],
                                      on_unused_input='ignore',
                                      allow_input_downcast=True)

        # Generate the inverse vocabulary lookup table.
        ind_to_word = {v: k for k, v in vocabulary.iteritems()}

        # Do a forward pass and write the output to disk.
        for eval_set, eval_out_path in zip(eval_iterators, eval_output_paths):
            logger.Log("Writing eval output for %s." % (eval_set[0], ))
            evaluate_expanded(eval_fn, eval_set, eval_out_path, logger, step,
                              data_manager.SENTENCE_PAIR_DATA, ind_to_word,
                              FLAGS.model_type not in ["Model0", "RNN"])
    else:
        # Train

        new_values = util.RMSprop(total_cost, vs.trainable_vars.values(), lr)
        new_values += [(key, vs.nongradient_updates[key])
                       for key in vs.nongradient_updates]
        # Training open-vocabulary embeddings is a questionable idea right now. Disabled:
        # new_values.append(
        #     util.embedding_SGD(total_cost, embedding_params, embedding_lr))

        # Create training and eval functions.
        # Unused variable warnings are supressed so that num_transitions can be passed in when training Model 0,
        # which ignores it. This yields more readable code that is very slightly slower.
        logger.Log("Building update function.")
        update_fn = theano.function([
            X, transitions, y, num_transitions, lr, training_mode,
            ground_truth_transitions_visible, ss_prob
        ], [total_cost, xent_cost, transition_cost, action_acc, l2_cost, acc],
                                    updates=new_values,
                                    on_unused_input='ignore',
                                    allow_input_downcast=True)
        logger.Log("Building eval function.")
        eval_fn = theano.function([
            X, transitions, y, num_transitions, training_mode,
            ground_truth_transitions_visible, ss_prob
        ], [acc, action_acc],
                                  on_unused_input='ignore',
                                  allow_input_downcast=True)
        logger.Log("Training.")

        # Main training loop.
        for step in range(step, FLAGS.training_steps):
            if step % FLAGS.eval_interval_steps == 0:
                for index, eval_set in enumerate(eval_iterators):
                    acc = evaluate(eval_fn, eval_set, logger, step)
                    if FLAGS.ckpt_on_best_dev_error and index == 0 and (
                            1 - acc) < 0.99 * best_dev_error and step > 1000:
                        best_dev_error = 1 - acc
                        logger.Log(
                            "Checkpointing with new best dev accuracy of %f" %
                            acc)
                        vs.save_checkpoint(checkpoint_path + "_best",
                                           extra_vars=[step, best_dev_error])

            X_batch, transitions_batch, y_batch, num_transitions_batch = training_data_iter.next(
            )
            learning_rate = FLAGS.learning_rate * (
                FLAGS.learning_rate_decay_per_10k_steps**(step / 10000.0))
            ret = update_fn(
                X_batch, transitions_batch, y_batch, num_transitions_batch,
                learning_rate, 1.0, 1.0,
                np.exp(step * np.log(FLAGS.scheduled_sampling_exponent_base)))
            total_cost_val, xent_cost_val, transition_cost_val, action_acc_val, l2_cost_val, acc_val = ret

            if step % FLAGS.statistics_interval_steps == 0:
                logger.Log("Step: %i\tAcc: %f\t%f\tCost: %5f %5f %5f %5f" %
                           (step, acc_val, action_acc_val, total_cost_val,
                            xent_cost_val, transition_cost_val, l2_cost_val))

            if step % FLAGS.ckpt_interval_steps == 0 and step > 0:
                vs.save_checkpoint(checkpoint_path,
                                   extra_vars=[step, best_dev_error])
Пример #4
0
def load_data_and_embeddings(FLAGS, data_manager, logger, training_data_path,
                             eval_data_path):

    choose_train = lambda x: True
    if FLAGS.train_genre is not None:
        choose_train = lambda x: x.get('genre') == FLAGS.train_genre

    choose_eval = lambda x: True
    if FLAGS.eval_genre is not None:
        choose_eval = lambda x: x.get('genre') == FLAGS.eval_genre

    if FLAGS.data_type == "snli":
        # Load the data.
        raw_training_data, vocabulary = data_manager.load_data(
            training_data_path, FLAGS.lowercase, choose_train)
    else:
        # Load the data.
        raw_training_data, vocabulary = data_manager.load_data(
            training_data_path, FLAGS.lowercase)

    if FLAGS.data_type == "snli":
        # Load the eval data.
        raw_eval_sets = []
        raw_eval_data, _ = data_manager.load_data(eval_data_path,
                                                  FLAGS.lowercase, choose_eval)
        raw_eval_sets.append((eval_data_path, raw_eval_data))
    else:
        # Load the eval data.
        raw_eval_sets = []
        raw_eval_data, _ = data_manager.load_data(eval_data_path,
                                                  FLAGS.lowercase)
        raw_eval_sets.append((eval_data_path, raw_eval_data))

    # Prepare the vocabulary.
    if not vocabulary:
        logger.Log(
            "In open vocabulary mode. Using loaded embeddings without fine-tuning."
        )
        vocabulary = util.BuildVocabulary(
            raw_training_data,
            raw_eval_sets,
            FLAGS.embedding_data_path,
            logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA)
    else:
        logger.Log("In fixed vocabulary mode. Training embeddings.")

    # Load pretrained embeddings.
    if FLAGS.embedding_data_path:
        logger.Log("Loading vocabulary with " + str(len(vocabulary)) +
                   " words from " + FLAGS.embedding_data_path)
        initial_embeddings = util.LoadEmbeddingsFromText(
            vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path)
    else:
        initial_embeddings = None

    # Trim dataset, convert token sequences to integer sequences, crop, and
    # pad.
    logger.Log("Preprocessing training data.")
    training_data = util.PreprocessDataset(
        raw_training_data,
        vocabulary,
        FLAGS.seq_length,
        data_manager,
        eval_mode=False,
        logger=logger,
        sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
        for_rnn=sequential_only())
    training_data_iter = util.MakeTrainingIterator(
        training_data,
        FLAGS.batch_size,
        FLAGS.smart_batching,
        FLAGS.use_peano,
        sentence_pair_data=data_manager.SENTENCE_PAIR_DATA)

    # Preprocess eval sets.
    eval_iterators = []
    for filename, raw_eval_set in raw_eval_sets:
        logger.Log("Preprocessing eval data: " + filename)
        eval_data = util.PreprocessDataset(
            raw_eval_set,
            vocabulary,
            FLAGS.eval_seq_length
            if FLAGS.eval_seq_length is not None else FLAGS.seq_length,
            data_manager,
            eval_mode=True,
            logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
            for_rnn=sequential_only())
        eval_it = util.MakeEvalIterator(eval_data,
                                        FLAGS.batch_size,
                                        FLAGS.eval_data_limit,
                                        bucket_eval=FLAGS.bucket_eval,
                                        shuffle=FLAGS.shuffle_eval,
                                        rseed=FLAGS.shuffle_eval_seed)
        eval_iterators.append((filename, eval_it))

    return vocabulary, initial_embeddings, training_data_iter, eval_iterators
Пример #5
0
def run(only_forward=False):
    logger = afs_safe_logger.Logger(
        os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log")

    if FLAGS.data_type == "bl":
        data_manager = load_boolean_data
    elif FLAGS.data_type == "sst":
        data_manager = load_sst_data
    elif FLAGS.data_type == "snli":
        data_manager = load_snli_data
    else:
        logger.Log("Bad data type.")
        return

    pp = pprint.PrettyPrinter(indent=4)
    logger.Log("Flag values:\n" + pp.pformat(FLAGS.FlagValuesDict()))

    # Load the data.
    raw_training_data, vocabulary = data_manager.load_data(
        FLAGS.training_data_path)

    # Load the eval data.
    raw_eval_sets = []
    if FLAGS.eval_data_path:
        for eval_filename in FLAGS.eval_data_path.split(":"):
            eval_data, _ = data_manager.load_data(eval_filename)
            raw_eval_sets.append((eval_filename, eval_data))

    # Prepare the vocabulary.
    if not vocabulary:
        logger.Log(
            "In open vocabulary mode. Using loaded embeddings without fine-tuning."
        )
        train_embeddings = False
        vocabulary = util.BuildVocabulary(
            raw_training_data,
            raw_eval_sets,
            FLAGS.embedding_data_path,
            logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA)
    else:
        logger.Log("In fixed vocabulary mode. Training embeddings.")
        train_embeddings = True

    # Load pretrained embeddings.
    if FLAGS.embedding_data_path:
        logger.Log("Loading vocabulary with " + str(len(vocabulary)) +
                   " words from " + FLAGS.embedding_data_path)
        initial_embeddings = util.LoadEmbeddingsFromASCII(
            vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path)
    else:
        initial_embeddings = None

    # Trim dataset, convert token sequences to integer sequences, crop, and
    # pad.
    logger.Log("Preprocessing training data.")
    training_data = util.PreprocessDataset(
        raw_training_data,
        vocabulary,
        FLAGS.seq_length,
        data_manager,
        eval_mode=False,
        logger=logger,
        sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
        for_rnn=FLAGS.model_type == "RNN" or FLAGS.model_type == "CBOW")
    training_data_iter = util.MakeTrainingIterator(training_data,
                                                   FLAGS.batch_size,
                                                   FLAGS.smart_batching,
                                                   FLAGS.use_peano)

    eval_iterators = []
    for filename, raw_eval_set in raw_eval_sets:
        logger.Log("Preprocessing eval data: " + filename)
        e_X, e_transitions, e_y, e_num_transitions = util.PreprocessDataset(
            raw_eval_set,
            vocabulary,
            FLAGS.seq_length,
            data_manager,
            eval_mode=True,
            logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
            for_rnn=FLAGS.model_type == "RNN" or FLAGS.model_type == "CBOW")
        eval_iterators.append(
            (filename,
             util.MakeEvalIterator(
                 (e_X, e_transitions, e_y, e_num_transitions),
                 FLAGS.batch_size, FLAGS.eval_data_limit)))

    # Set up the placeholders.

    logger.Log("Building model.")

    if FLAGS.model_type == "CBOW":
        model_module = spinn.cbow
    elif FLAGS.model_type == "RNN":
        model_module = spinn.plain_rnn_chainer
    elif FLAGS.model_type == "NTI":
        model_module = spinn.nti
    elif FLAGS.model_type == "SPINN":
        model_module = spinn.fat_stack
    else:
        raise Exception("Requested unimplemented model type %s" %
                        FLAGS.model_type)

    if data_manager.SENTENCE_PAIR_DATA:
        if hasattr(model_module, 'SentencePairTrainer') and hasattr(
                model_module, 'SentencePairModel'):
            trainer_cls = model_module.SentencePairTrainer
            model_cls = model_module.SentencePairModel
        else:
            raise Exception("Unimplemented for model type %s" %
                            FLAGS.model_type)

        num_classes = len(data_manager.LABEL_MAP)
        classifier_trainer = build_sentence_pair_model(
            model_cls, trainer_cls, len(vocabulary), FLAGS.model_dim,
            FLAGS.word_embedding_dim, FLAGS.seq_length, num_classes,
            initial_embeddings, FLAGS.embedding_keep_rate, FLAGS.gpu)
    else:
        if hasattr(model_module, 'SentenceTrainer') and hasattr(
                model_module, 'SentenceModel'):
            trainer_cls = model_module.SentenceTrainer
            model_cls = model_module.SentenceModel
        else:
            raise Exception("Unimplemented for model type %s" %
                            FLAGS.model_type)

        num_classes = len(data_manager.LABEL_MAP)
        classifier_trainer = build_sentence_pair_model(
            model_cls, trainer_cls, len(vocabulary), FLAGS.model_dim,
            FLAGS.word_embedding_dim, FLAGS.seq_length, num_classes,
            initial_embeddings, FLAGS.embedding_keep_rate, FLAGS.gpu)

    if ".ckpt" in FLAGS.ckpt_path:
        checkpoint_path = FLAGS.ckpt_path
    else:
        checkpoint_path = os.path.join(FLAGS.ckpt_path,
                                       FLAGS.experiment_name + ".ckpt")
    if os.path.isfile(checkpoint_path):
        # TODO: Check that resuming works fine with tf summaries.
        logger.Log("Found checkpoint, restoring.")
        step, best_dev_error = classifier_trainer.load(checkpoint_path)
        logger.Log("Resuming at step: {} with best dev accuracy: {}".format(
            step, 1. - best_dev_error))
    else:
        assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint."
        step = 0
        best_dev_error = 1.0

    if FLAGS.write_summaries:
        from spinn.tf_logger import TFLogger
        train_summary_logger = TFLogger(summary_dir=os.path.join(
            FLAGS.summary_dir, FLAGS.experiment_name, 'train'))
        dev_summary_logger = TFLogger(summary_dir=os.path.join(
            FLAGS.summary_dir, FLAGS.experiment_name, 'dev'))

    # Do an evaluation-only run.
    if only_forward:
        raise Exception("Not implemented for chainer.")
    else:
        # Train
        logger.Log("Training.")

        classifier_trainer.init_optimizer(
            clip=FLAGS.clipping_max_value,
            decay=FLAGS.l2_lambda,
            lr=FLAGS.learning_rate,
        )

        # New Training Loop
        progress_bar = SimpleProgressBar(msg="Training",
                                         bar_length=60,
                                         enabled=FLAGS.show_progress_bar)
        avg_class_acc = 0
        avg_trans_acc = 0
        for step in range(step, FLAGS.training_steps):
            X_batch, transitions_batch, y_batch, num_transitions_batch = training_data_iter.next(
            )
            learning_rate = FLAGS.learning_rate * (
                FLAGS.learning_rate_decay_per_10k_steps**(step / 10000.0))

            # Reset cached gradients.
            classifier_trainer.optimizer.zero_grads()

            # Calculate loss and update parameters.
            ret = classifier_trainer.forward(
                {
                    "sentences": X_batch,
                    "transitions": transitions_batch,
                },
                y_batch,
                train=True,
                predict=False)
            y, loss, class_acc, transition_acc = ret

            # Boilerplate for calculating loss.
            xent_cost_val = loss.data
            transition_cost_val = 0
            avg_trans_acc += transition_acc
            avg_class_acc += class_acc

            if FLAGS.show_intermediate_stats and step % 5 == 0 and step % FLAGS.statistics_interval_steps > 0:
                print("Accuracies so far : ",
                      avg_class_acc / (step % FLAGS.statistics_interval_steps),
                      avg_trans_acc / (step % FLAGS.statistics_interval_steps))

            total_cost_val = xent_cost_val + transition_cost_val
            loss.backward()

            if FLAGS.gradient_check:

                def get_loss():
                    _, check_loss, _, _ = classifier_trainer.forward(
                        {
                            "sentences": X_batch,
                            "transitions": transitions_batch,
                        },
                        y_batch,
                        train=True,
                        predict=False)
                    return check_loss

                gradient_check(classifier_trainer.model, get_loss)

            try:
                classifier_trainer.update()
            except:
                import ipdb
                ipdb.set_trace()
                pass

            # Accumulate accuracy for current interval.
            action_acc_val = 0.0
            acc_val = float(classifier_trainer.model.accuracy.data)

            if FLAGS.write_summaries:
                train_summary_logger.log(step=step,
                                         loss=total_cost_val,
                                         accuracy=acc_val)

            progress_bar.step(
                i=max(0, step - 1) % FLAGS.statistics_interval_steps + 1,
                total=FLAGS.statistics_interval_steps)

            if step % FLAGS.statistics_interval_steps == 0:
                progress_bar.finish()
                avg_class_acc /= FLAGS.statistics_interval_steps
                avg_trans_acc /= FLAGS.statistics_interval_steps
                logger.Log(
                    "Step: %i\tAcc: %f\t%f\tCost: %5f %5f %5f %s" %
                    (step, avg_class_acc, avg_trans_acc, total_cost_val,
                     xent_cost_val, transition_cost_val, "l2-not-exposed"))
                avg_trans_acc = 0
                avg_class_acc = 0

            if step > 0 and step % FLAGS.eval_interval_steps == 0:
                for index, eval_set in enumerate(eval_iterators):
                    acc = evaluate(classifier_trainer, eval_set, logger, step)
                    if FLAGS.ckpt_on_best_dev_error and index == 0 and (
                            1 - acc) < 0.99 * best_dev_error and step > 1000:
                        best_dev_error = 1 - acc
                        logger.Log(
                            "Checkpointing with new best dev accuracy of %f" %
                            acc)
                        classifier_trainer.save(checkpoint_path, step,
                                                best_dev_error)
                    if FLAGS.write_summaries:
                        dev_summary_logger.log(step=step,
                                               loss=0.0,
                                               accuracy=acc)
                progress_bar.reset()

            if FLAGS.profile and step >= FLAGS.profile_steps:
                break
Пример #6
0
def run(only_forward=False):
    logger = afs_safe_logger.Logger(os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log")


    # Select data format.
    if FLAGS.data_type == "bl":
        data_manager = load_boolean_data
    elif FLAGS.data_type == "sst":
        data_manager = load_sst_data
    elif FLAGS.data_type == "snli":
        data_manager = load_snli_data
    elif FLAGS.data_type == "arithmetic":
        data_manager = load_simple_data
    else:
        logger.Log("Bad data type.")
        return

    pp = pprint.PrettyPrinter(indent=4)
    logger.Log("Flag values:\n" + pp.pformat(FLAGS.FlagValuesDict()))

    # Make Metrics Logger.
    metrics_path = "{}/{}".format(FLAGS.metrics_path, FLAGS.experiment_name)
    if not os.path.exists(metrics_path):
        os.makedirs(metrics_path)
    metrics_logger = MetricsLogger(metrics_path)
    M = Accumulator(maxlen=FLAGS.deque_length)

    # Load the data.
    raw_training_data, vocabulary = data_manager.load_data(
        FLAGS.training_data_path, FLAGS.lowercase)

    # Load the eval data.
    raw_eval_sets = []
    if FLAGS.eval_data_path:
        for eval_filename in FLAGS.eval_data_path.split(":"):
            raw_eval_data, _ = data_manager.load_data(eval_filename, FLAGS.lowercase)
            raw_eval_sets.append((eval_filename, raw_eval_data))

    # Prepare the vocabulary.
    if not vocabulary:
        logger.Log("In open vocabulary mode. Using loaded embeddings without fine-tuning.")
        train_embeddings = False
        vocabulary = util.BuildVocabulary(
            raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA)
    else:
        logger.Log("In fixed vocabulary mode. Training embeddings.")
        train_embeddings = True

    # Load pretrained embeddings.
    if FLAGS.embedding_data_path:
        logger.Log("Loading vocabulary with " + str(len(vocabulary))
                   + " words from " + FLAGS.embedding_data_path)
        initial_embeddings = util.LoadEmbeddingsFromText(
            vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path)
    else:
        initial_embeddings = None

    # Trim dataset, convert token sequences to integer sequences, crop, and
    # pad.
    logger.Log("Preprocessing training data.")
    training_data = util.PreprocessDataset(
        raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger,
        sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
        for_rnn=sequential_only(),
        use_left_padding=FLAGS.use_left_padding)
    training_data_iter = util.MakeTrainingIterator(
        training_data, FLAGS.batch_size, FLAGS.smart_batching, FLAGS.use_peano,
        sentence_pair_data=data_manager.SENTENCE_PAIR_DATA)

    # Preprocess eval sets.
    eval_iterators = []
    for filename, raw_eval_set in raw_eval_sets:
        logger.Log("Preprocessing eval data: " + filename)
        eval_data = util.PreprocessDataset(
            raw_eval_set, vocabulary,
            FLAGS.eval_seq_length if FLAGS.eval_seq_length is not None else FLAGS.seq_length,
            data_manager, eval_mode=True, logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
            for_rnn=sequential_only(),
            use_left_padding=FLAGS.use_left_padding)
        eval_it = util.MakeEvalIterator(eval_data,
            FLAGS.batch_size, FLAGS.eval_data_limit, bucket_eval=FLAGS.bucket_eval,
            shuffle=FLAGS.shuffle_eval, rseed=FLAGS.shuffle_eval_seed)
        eval_iterators.append((filename, eval_it))

    # Choose model.
    model_specific_params = {}
    logger.Log("Building model.")
    if FLAGS.model_type == "CBOW":
        model_module = spinn.cbow
    elif FLAGS.model_type == "RNN":
        model_module = spinn.plain_rnn
    elif FLAGS.model_type == "SPINN":
        model_module = spinn.fat_stack
    elif FLAGS.model_type == "RLSPINN":
        model_module = spinn.rl_spinn
    elif FLAGS.model_type == "RAESPINN":
        model_module = spinn.rae_spinn
    elif FLAGS.model_type == "GENSPINN":
        model_module = spinn.gen_spinn
    elif FLAGS.model_type == "ATTSPINN":
        model_module = spinn.att_spinn
        model_specific_params['using_diff_in_mlstm'] = FLAGS.using_diff_in_mlstm
        model_specific_params['using_prod_in_mlstm'] = FLAGS.using_prod_in_mlstm
        model_specific_params['using_null_in_attention'] = FLAGS.using_null_in_attention
        attlogger = logging.getLogger('spinn.attention')
        attlogger.setLevel(logging.INFO)
        fh = logging.FileHandler(os.path.join(FLAGS.log_path, '{}.att.log'.format(FLAGS.experiment_name)))
        fh.setLevel(logging.INFO)
        fh.setFormatter(logging.Formatter('%(asctime)s %(levelname)s: %(message)s'))
        attlogger.addHandler(fh)
    else:
        raise Exception("Requested unimplemented model type %s" % FLAGS.model_type)

    # Build model.
    vocab_size = len(vocabulary)
    num_classes = len(data_manager.LABEL_MAP)

    if data_manager.SENTENCE_PAIR_DATA:
        trainer_cls = model_module.SentencePairTrainer
        model_cls = model_module.SentencePairModel
        use_sentence_pair = True
    else:
        trainer_cls = model_module.SentenceTrainer
        model_cls = model_module.SentenceModel
        num_classes = len(data_manager.LABEL_MAP)
        use_sentence_pair = False

    model = model_cls(model_dim=FLAGS.model_dim,
         word_embedding_dim=FLAGS.word_embedding_dim,
         vocab_size=vocab_size,
         initial_embeddings=initial_embeddings,
         num_classes=num_classes,
         mlp_dim=FLAGS.mlp_dim,
         embedding_keep_rate=FLAGS.embedding_keep_rate,
         classifier_keep_rate=FLAGS.semantic_classifier_keep_rate,
         tracking_lstm_hidden_dim=FLAGS.tracking_lstm_hidden_dim,
         transition_weight=FLAGS.transition_weight,
         encode_style=FLAGS.encode_style,
         encode_reverse=FLAGS.encode_reverse,
         encode_bidirectional=FLAGS.encode_bidirectional,
         encode_num_layers=FLAGS.encode_num_layers,
         use_sentence_pair=use_sentence_pair,
         use_skips=FLAGS.use_skips,
         lateral_tracking=FLAGS.lateral_tracking,
         use_tracking_in_composition=FLAGS.use_tracking_in_composition,
         use_difference_feature=FLAGS.use_difference_feature,
         use_product_feature=FLAGS.use_product_feature,
         num_mlp_layers=FLAGS.num_mlp_layers,
         mlp_bn=FLAGS.mlp_bn,
         rl_mu=FLAGS.rl_mu,
         rl_baseline=FLAGS.rl_baseline,
         rl_reward=FLAGS.rl_reward,
         rl_weight=FLAGS.rl_weight,
         predict_leaf=FLAGS.predict_leaf,
         gen_h=FLAGS.gen_h,
         model_specific_params=model_specific_params,
        )

    # Build optimizer.
    if FLAGS.optimizer_type == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=FLAGS.learning_rate, betas=(0.9, 0.999), eps=1e-08)
    elif FLAGS.optimizer_type == "RMSprop":
        optimizer = optim.RMSprop(model.parameters(), lr=FLAGS.learning_rate, eps=1e-08)
    else:
        raise NotImplementedError

    # Build trainer.
    classifier_trainer = trainer_cls(model, optimizer)

    standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name)
    best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True)

    # Load checkpoint if available.
    if FLAGS.load_best and os.path.isfile(best_checkpoint_path):
        logger.Log("Found best checkpoint, restoring.")
        step, best_dev_error = classifier_trainer.load(best_checkpoint_path)
        logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error))
    elif os.path.isfile(standard_checkpoint_path):
        logger.Log("Found checkpoint, restoring.")
        step, best_dev_error = classifier_trainer.load(standard_checkpoint_path)
        logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error))
    else:
        assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint."
        step = 0
        best_dev_error = 1.0

    # Print model size.
    logger.Log("Architecture: {}".format(model))
    total_params = sum([reduce(lambda x, y: x * y, w.size(), 1.0) for w in model.parameters()])
    logger.Log("Total params: {}".format(total_params))

    # GPU support.
    the_gpu.gpu = FLAGS.gpu
    if FLAGS.gpu >= 0:
        model.cuda()
    else:
        model.cpu()

    # Debug
    def set_debug(self):
        self.debug = FLAGS.debug
    model.apply(set_debug)

    # Accumulate useful statistics.
    A = Accumulator(maxlen=FLAGS.deque_length)

    # Do an evaluation-only run.
    if only_forward:
        for index, eval_set in enumerate(eval_iterators):
            acc = evaluate(model, eval_set, logger, metrics_logger, step, vocabulary)
    else:
         # Train
        logger.Log("Training.")

        # New Training Loop
        progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar)
        progress_bar.step(i=0, total=FLAGS.statistics_interval_steps)

        for step in range(step, FLAGS.training_steps):
            model.train()

            start = time.time()

            X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = training_data_iter.next()

            if FLAGS.truncate_train_batch:
                X_batch, transitions_batch = truncate(
                    X_batch, transitions_batch, num_transitions_batch)

            total_tokens = num_transitions_batch.ravel().sum()

            # Reset cached gradients.
            optimizer.zero_grad()

            # Run model.
            output = model(X_batch, transitions_batch, y_batch,
                use_internal_parser=FLAGS.use_internal_parser,
                validate_transitions=FLAGS.validate_transitions
                )

            # Normalize output.
            logits = F.log_softmax(output)

            # Calculate class accuracy.
            target = torch.from_numpy(y_batch).long()
            pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability
            class_acc = pred.eq(target).sum() / float(target.size(0))

            A.add('class_acc', class_acc)
            M.add('class_acc', class_acc)

            # Calculate class loss.
            xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False)))

            # Optionally calculate transition loss/accuracy.
            transition_acc = model.transition_acc if hasattr(model, 'transition_acc') else 0.0
            transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None
            rl_loss = model.rl_loss if hasattr(model, 'rl_loss') else None
            policy_loss = model.policy_loss if hasattr(model, 'policy_loss') else None
            rae_loss = model.spinn.rae_loss if hasattr(model.spinn, 'rae_loss') else None
            leaf_loss = model.spinn.leaf_loss if hasattr(model.spinn, 'leaf_loss') else None
            gen_loss = model.spinn.gen_loss if hasattr(model.spinn, 'gen_loss') else None

            # Force Transition Loss Optimization
            if FLAGS.force_transition_loss:
                model.optimize_transition_loss = True

            # Accumulate stats for transition accuracy.
            if transition_loss is not None:
                preds = [m["t_preds"] for m in model.spinn.memories]
                truth = [m["t_given"] for m in model.spinn.memories]
                A.add('preds', preds)
                A.add('truth', truth)

            # Accumulate stats for leaf prediction accuracy.
            if leaf_loss is not None:
                A.add('leaf_acc', model.spinn.leaf_acc)

            # Accumulate stats for word prediction accuracy.
            if gen_loss is not None:
                A.add('gen_acc', model.spinn.gen_acc)

            # Note: Keep track of transition_acc, although this is a naive average.
            # Should be weighted by length of sequences in batch.
            M.add('transition_acc', transition_acc)

            # Extract L2 Cost
            l2_loss = l2_cost(model, FLAGS.l2_lambda) if FLAGS.use_l2_cost else None

            # Boilerplate for calculating loss values.
            xent_cost_val = xent_loss.data[0]
            transition_cost_val = transition_loss.data[0] if transition_loss is not None else 0.0
            l2_cost_val = l2_loss.data[0] if l2_loss is not None else 0.0
            rl_cost_val = rl_loss.data[0] if rl_loss is not None else 0.0
            policy_cost_val = policy_loss.data[0] if policy_loss is not None else 0.0
            rae_cost_val = rae_loss.data[0] if rae_loss is not None else 0.0
            leaf_cost_val = leaf_loss.data[0] if leaf_loss is not None else 0.0
            gen_cost_val = gen_loss.data[0] if gen_loss is not None else 0.0

            # Accumulate Total Loss Data
            total_cost_val = 0.0
            total_cost_val += xent_cost_val
            if transition_loss is not None and model.optimize_transition_loss:
                total_cost_val += transition_cost_val
            total_cost_val += l2_cost_val
            total_cost_val += rl_cost_val
            total_cost_val += policy_cost_val
            total_cost_val += rae_cost_val
            total_cost_val += leaf_cost_val
            total_cost_val += gen_cost_val

            M.add('total_cost', total_cost_val)
            M.add('xent_cost', xent_cost_val)
            M.add('transition_cost', transition_cost_val)
            M.add('l2_cost', l2_cost_val)

            # Logging for RL
            rl_keys = ['rl_loss', 'policy_loss', 'norm_rewards', 'norm_baseline', 'norm_advantage']
            for k in rl_keys:
                if hasattr(model, k):
                    val = getattr(model, k)
                    val = val.data[0] if isinstance(val, Variable) else val
                    M.add(k, val)

            # Accumulate Total Loss Variable
            total_loss = 0.0
            total_loss += xent_loss
            if l2_loss is not None:
                total_loss += l2_loss
            if transition_loss is not None and model.optimize_transition_loss:
                total_loss += transition_loss
            if rl_loss is not None:
                total_loss += rl_loss
            if policy_loss is not None:
                total_loss += policy_loss
            if rae_loss is not None:
                total_loss += rae_loss
            if leaf_loss is not None:
                total_loss += leaf_loss
            if gen_loss is not None:
                total_loss += gen_loss

            # Useful for debugging gradient flow.
            if FLAGS.debug:
                losses = [('total_loss', total_loss), ('xent_loss', xent_loss)]
                if l2_loss is not None:
                    losses.append(('l2_loss', l2_loss))
                if transition_loss is not None and model.optimize_transition_loss:
                    losses.append(('transition_loss', transition_loss))
                if rl_loss is not None:
                    losses.append(('rl_loss', rl_loss))
                if policy_loss is not None:
                    losses.append(('policy_loss', policy_loss))
                debug_gradient(model, losses)
                import ipdb; ipdb.set_trace()

            # Backward pass.
            total_loss.backward()

            # Hard Gradient Clipping
            clip = FLAGS.clipping_max_value
            for p in model.parameters():
                if p.requires_grad:
                    p.grad.data.clamp_(min=-clip, max=clip)

            # Learning Rate Decay
            if FLAGS.actively_decay_learning_rate:
                optimizer.lr = FLAGS.learning_rate * (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0))

            # Gradient descent step.
            optimizer.step()

            end = time.time()

            total_time = end - start

            A.add('total_tokens', total_tokens)
            A.add('total_time', total_time)

            if step % FLAGS.statistics_interval_steps == 0:
                progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)
                progress_bar.finish()
                avg_class_acc = A.get_avg('class_acc')
                if transition_loss is not None:
                    all_preds = np.array(flatten(A.get('preds')))
                    all_truth = np.array(flatten(A.get('truth')))
                    avg_trans_acc = (all_preds == all_truth).sum() / float(all_truth.shape[0])
                else:
                    avg_trans_acc = 0.0
                if leaf_loss is not None:
                    avg_leaf_acc = A.get_avg('leaf_acc')
                else:
                    avg_leaf_acc = 0.0
                if gen_loss is not None:
                    avg_gen_acc = A.get_avg('gen_acc')
                else:
                    avg_gen_acc = 0.0
                time_metric = time_per_token(A.get('total_tokens'), A.get('total_time'))
                stats_args = {
                    "step": step,
                    "class_acc": avg_class_acc,
                    "transition_acc": avg_trans_acc,
                    "total_cost": total_cost_val,
                    "xent_cost": xent_cost_val,
                    "transition_cost": transition_cost_val,
                    "l2_cost": l2_cost_val,
                    "rl_cost": rl_cost_val,
                    "policy_cost": policy_cost_val,
                    "rae_cost": rae_cost_val,
                    "leaf_acc": avg_leaf_acc,
                    "leaf_cost": leaf_cost_val,
                    "gen_acc": avg_gen_acc,
                    "gen_cost": gen_cost_val,
                    "time": time_metric,
                }
                stats_str = "Step: {step}"

                # Accuracy Component.
                stats_str += " Acc: {class_acc:.5f} {transition_acc:.5f}"
                if leaf_loss is not None:
                    stats_str += " leaf{leaf_acc:.5f}"
                if gen_loss is not None:
                    stats_str += " gen{gen_acc:.5f}"

                # Cost Component.
                stats_str += " Cost: {total_cost:.5f} {xent_cost:.5f} {transition_cost:.5f} {l2_cost:.5f}"
                if rl_loss is not None:
                    stats_str += " r{rl_cost:.5f}"
                if policy_loss is not None:
                    stats_str += " p{policy_cost:.5f}"
                if rae_loss is not None:
                    stats_str += " rae{rae_cost:.5f}"
                if leaf_loss is not None:
                    stats_str += " leaf{leaf_cost:.5f}"
                if gen_loss is not None:
                    stats_str += " gen{gen_cost:.5f}"

                # Time Component.
                stats_str += " Time: {time:.5f}"
                logger.Log(stats_str.format(**stats_args))

            if step > 0 and step % FLAGS.eval_interval_steps == 0:
                for index, eval_set in enumerate(eval_iterators):
                    acc = evaluate(model, eval_set, logger, metrics_logger, step)
                    if FLAGS.ckpt_on_best_dev_error and index == 0 and (1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step:
                        best_dev_error = 1 - acc
                        logger.Log("Checkpointing with new best dev accuracy of %f" % acc)
                        classifier_trainer.save(best_checkpoint_path, step, best_dev_error)
                progress_bar.reset()

            if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0:
                logger.Log("Checkpointing.")
                classifier_trainer.save(standard_checkpoint_path, step, best_dev_error)

            if step % FLAGS.metrics_interval_steps == 0:
                m_keys = M.cache.keys()
                for k in m_keys:
                    metrics_logger.Log(k, M.get_avg(k), step)

            progress_bar.step(i=step % FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)
Пример #7
0
def run(only_forward=False):
    logger = afs_safe_logger.Logger(os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log")

    # Select data format.
    data_manager = get_data_manager(FLAGS.data_type)

    logger.Log("Flag Values:\n" + json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True))

    # Load the data.
    raw_training_data, vocabulary = data_manager.load_data(
        FLAGS.training_data_path, FLAGS.lowercase)

    # Load the eval data.
    raw_eval_sets = []
    if FLAGS.eval_data_path:
        for eval_filename in FLAGS.eval_data_path.split(":"):
            raw_eval_data, _ = data_manager.load_data(eval_filename, FLAGS.lowercase)
            raw_eval_sets.append((eval_filename, raw_eval_data))

    # Prepare the vocabulary.
    if not vocabulary:
        logger.Log("In open vocabulary mode. Using loaded embeddings without fine-tuning.")
        train_embeddings = False
        vocabulary = util.BuildVocabulary(
            raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA)
    else:
        logger.Log("In fixed vocabulary mode. Training embeddings.")
        train_embeddings = True

    # Load pretrained embeddings.
    if FLAGS.embedding_data_path:
        logger.Log("Loading vocabulary with " + str(len(vocabulary))
                   + " words from " + FLAGS.embedding_data_path)
        initial_embeddings = util.LoadEmbeddingsFromText(
            vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path)
    else:
        initial_embeddings = None

    # Trim dataset, convert token sequences to integer sequences, crop, and
    # pad.
    logger.Log("Preprocessing training data.")
    training_data = util.PreprocessDataset(
        raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger,
        sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
        for_rnn=sequential_only())
    training_data_iter = util.MakeTrainingIterator(
        training_data, FLAGS.batch_size, FLAGS.smart_batching, FLAGS.use_peano,
        sentence_pair_data=data_manager.SENTENCE_PAIR_DATA)

    # Preprocess eval sets.
    eval_iterators = []
    for filename, raw_eval_set in raw_eval_sets:
        logger.Log("Preprocessing eval data: " + filename)
        eval_data = util.PreprocessDataset(
            raw_eval_set, vocabulary,
            FLAGS.eval_seq_length if FLAGS.eval_seq_length is not None else FLAGS.seq_length,
            data_manager, eval_mode=True, logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
            for_rnn=sequential_only())
        eval_it = util.MakeEvalIterator(eval_data,
            FLAGS.batch_size, FLAGS.eval_data_limit, bucket_eval=FLAGS.bucket_eval,
            shuffle=FLAGS.shuffle_eval, rseed=FLAGS.shuffle_eval_seed)
        eval_iterators.append((filename, eval_it))

    # Build model.
    vocab_size = len(vocabulary)
    num_classes = len(data_manager.LABEL_MAP)

    model, optimizer, trainer = init_model(FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager)

    # Build trainer.
    trainer = ModelTrainer(model, optimizer)

    standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name)
    best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True)

    # Load checkpoint if available.
    if FLAGS.load_best and os.path.isfile(best_checkpoint_path):
        logger.Log("Found best checkpoint, restoring.")
        step, best_dev_error = trainer.load(best_checkpoint_path)
        logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error))
    elif os.path.isfile(standard_checkpoint_path):
        logger.Log("Found checkpoint, restoring.")
        step, best_dev_error = trainer.load(standard_checkpoint_path)
        logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error))
    else:
        assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint."
        step = 0
        best_dev_error = 1.0

    # GPU support.
    the_gpu.gpu = FLAGS.gpu
    if FLAGS.gpu >= 0:
        model.cuda()
    else:
        model.cpu()
    recursively_set_device(optimizer.state_dict(), the_gpu.gpu)

    # Debug
    def set_debug(self):
        self.debug = FLAGS.debug
    model.apply(set_debug)

    # Accumulate useful statistics.
    A = Accumulator(maxlen=FLAGS.deque_length)

    # Do an evaluation-only run.
    if only_forward:
        for index, eval_set in enumerate(eval_iterators):
            acc = evaluate(model, eval_set, logger, step, vocabulary)
    else:
        # Build log format strings.
        model.train()
        X_batch, transitions_batch, y_batch, num_transitions_batch = get_batch(training_data_iter.next())[:4]
        model(X_batch, transitions_batch, y_batch,
                use_internal_parser=FLAGS.use_internal_parser,
                validate_transitions=FLAGS.validate_transitions
                )

        train_str = train_format(model)
        logger.Log("Train-Format: {}".format(train_str))
        train_extra_str = train_extra_format(model)
        logger.Log("Train-Extra-Format: {}".format(train_extra_str))

         # Train
        logger.Log("Training.")

        # New Training Loop
        progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar)
        progress_bar.step(i=0, total=FLAGS.statistics_interval_steps)

        for step in range(step, FLAGS.training_steps):
            model.train()

            start = time.time()

            batch = get_batch(training_data_iter.next())
            X_batch, transitions_batch, y_batch, num_transitions_batch = batch[:4]

            total_tokens = sum([(nt+1)/2 for nt in num_transitions_batch.reshape(-1)])

            # Reset cached gradients.
            optimizer.zero_grad()

            if FLAGS.model_type == "RLSPINN":
                model.spinn.epsilon = FLAGS.rl_epsilon * math.exp(-step/FLAGS.rl_epsilon_decay)

            # Run model.
            output = model(X_batch, transitions_batch, y_batch,
                use_internal_parser=FLAGS.use_internal_parser,
                validate_transitions=FLAGS.validate_transitions
                )

            # Normalize output.
            logits = F.log_softmax(output)

            # Calculate class accuracy.
            target = torch.from_numpy(y_batch).long()
            pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability
            class_acc = pred.eq(target).sum() / float(target.size(0))

            # Calculate class loss.
            xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False)))

            # Optionally calculate transition loss.
            transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None

            # Extract L2 Cost
            l2_loss = l2_cost(model, FLAGS.l2_lambda) if FLAGS.use_l2_cost else None

            # Accumulate Total Loss Variable
            total_loss = 0.0
            total_loss += xent_loss
            if l2_loss is not None:
                total_loss += l2_loss
            if transition_loss is not None and model.optimize_transition_loss:
                total_loss += transition_loss
            total_loss += auxiliary_loss(model)

            # Backward pass.
            total_loss.backward()

            # Hard Gradient Clipping
            clip = FLAGS.clipping_max_value
            for p in model.parameters():
                if p.requires_grad:
                    p.grad.data.clamp_(min=-clip, max=clip)

            # Learning Rate Decay
            if FLAGS.actively_decay_learning_rate:
                optimizer.lr = FLAGS.learning_rate * (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0))

            # Gradient descent step.
            optimizer.step()

            end = time.time()

            total_time = end - start

            train_accumulate(model, data_manager, A, batch)
            A.add('class_acc', class_acc)
            A.add('total_tokens', total_tokens)
            A.add('total_time', total_time)

            if step % FLAGS.statistics_interval_steps == 0:
                progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)
                progress_bar.finish()

                A.add('xent_cost', xent_loss.data[0])
                A.add('l2_cost', l2_loss.data[0])
                stats_args = train_stats(model, optimizer, A, step)

                logger.Log(train_str.format(**stats_args))
                logger.Log(train_extra_str.format(**stats_args))

            if step > 0 and step % FLAGS.eval_interval_steps == 0:
                for index, eval_set in enumerate(eval_iterators):
                    acc = evaluate(model, eval_set, logger, step)
                    if FLAGS.ckpt_on_best_dev_error and index == 0 and (1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step:
                        best_dev_error = 1 - acc
                        logger.Log("Checkpointing with new best dev accuracy of %f" % acc)
                        trainer.save(best_checkpoint_path, step, best_dev_error)
                progress_bar.reset()

            if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0:
                logger.Log("Checkpointing.")
                trainer.save(standard_checkpoint_path, step, best_dev_error)

            progress_bar.step(i=step % FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)