def train_loop(FLAGS, model, trainer, training_data_iter, eval_iterators,
               logger, vocabulary):
    # Accumulate useful statistics.
    A = Accumulator(maxlen=FLAGS.deque_length)

    # Train.
    logger.Log("Training.")

    # New Training Loop
    progress_bar = SimpleProgressBar(msg="Training",
                                     bar_length=60,
                                     enabled=FLAGS.show_progress_bar)
    progress_bar.step(i=0, total=FLAGS.statistics_interval_steps)

    log_entry = pb.SpinnEntry()
    for _ in range(trainer.step, FLAGS.training_steps):
        if (trainer.step -
                trainer.best_dev_step) > FLAGS.early_stopping_steps_to_wait:
            logger.Log('No improvement after ' +
                       str(FLAGS.early_stopping_steps_to_wait) +
                       ' steps. Stopping training.')
            break

        model.train()
        log_entry.Clear()
        log_entry.step = trainer.step
        should_log = False

        start = time.time()

        batch = get_batch(next(training_data_iter))
        X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch

        total_tokens = sum([(nt + 1) / 2
                            for nt in num_transitions_batch.reshape(-1)])

        # Reset cached gradients.
        trainer.optimizer_zero_grad()

        if FLAGS.model_type in ["ChoiPyramid", "Maillard", "CatalanPyramid"]:
            pyramid_temperature_multiplier = FLAGS.pyramid_temperature_decay_per_10k_steps**(
                trainer.step / 10000.0)
            if FLAGS.pyramid_temperature_cycle_length > 0.0:
                min_temp = 1e-5
                pyramid_temperature_multiplier *= (math.cos(
                    (trainer.step) / FLAGS.pyramid_temperature_cycle_length) +
                                                   1 + min_temp) / 2
        else:
            pyramid_temperature_multiplier = None

        # Run model.
        output = model(
            X_batch,
            transitions_batch,
            y_batch,
            use_internal_parser=FLAGS.use_internal_parser,
            validate_transitions=FLAGS.validate_transitions,
            pyramid_temperature_multiplier=pyramid_temperature_multiplier,
            example_lengths=num_transitions_batch)

        # Calculate class accuracy.
        target = torch.from_numpy(y_batch).long()

        # get the index of the max log-probability
        pred = output.data.max(1, keepdim=False)[1].cpu()

        class_acc = pred.eq(target).sum() / float(target.size(0))

        # Calculate class loss.
        xent_loss = nn.CrossEntropyLoss()(output,
                                          to_gpu(
                                              Variable(target,
                                                       volatile=False)))

        # Optionally calculate transition loss.
        transition_loss = model.transition_loss if hasattr(
            model, 'transition_loss') else None

        # Accumulate Total Loss Variable
        total_loss = 0.0
        total_loss += xent_loss
        if transition_loss is not None and model.optimize_transition_loss:
            total_loss += transition_loss
        aux_loss = auxiliary_loss(model)
        total_loss += aux_loss

        # Backward pass.
        total_loss.backward()

        # Hard Gradient Clipping
        nn.utils.clip_grad_norm([
            param for name, param in model.named_parameters()
            if name not in ["embed.embed.weight"]
        ], FLAGS.clipping_max_value)

        # Gradient descent step.
        trainer.optimizer_step()

        end = time.time()

        total_time = end - start

        train_accumulate(model, A, batch)
        A.add('class_acc', class_acc)
        A.add('total_tokens', total_tokens)
        A.add('total_time', total_time)

        if trainer.step % FLAGS.statistics_interval_steps == 0:
            A.add('xent_cost', xent_loss.data[0])
            stats(model, trainer, A, log_entry)
            should_log = True
            progress_bar.finish()

        if trainer.step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0:
            should_log = True
            model.train()
            model(
                X_batch,
                transitions_batch,
                y_batch,
                use_internal_parser=FLAGS.use_internal_parser,
                validate_transitions=FLAGS.validate_transitions,
                pyramid_temperature_multiplier=pyramid_temperature_multiplier,
                example_lengths=num_transitions_batch)
            tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example(
            )

            model.eval()
            model(
                X_batch,
                transitions_batch,
                y_batch,
                use_internal_parser=FLAGS.use_internal_parser,
                validate_transitions=FLAGS.validate_transitions,
                pyramid_temperature_multiplier=pyramid_temperature_multiplier,
                example_lengths=num_transitions_batch)
            ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example(
            )

            if model.use_sentence_pair and len(transitions_batch.shape) == 3:
                transitions_batch = np.concatenate(
                    [transitions_batch[:, :, 0], transitions_batch[:, :, 1]],
                    axis=0)

            # This could be done prior to running the batch for a tiny speed
            # boost.
            t_idxs = list(range(FLAGS.num_samples))
            random.shuffle(t_idxs)
            t_idxs = sorted(t_idxs[:FLAGS.num_samples])
            for t_idx in t_idxs:
                log = log_entry.rl_sampling.add()
                gold = transitions_batch[t_idx]
                pred_tr = tr_transitions_per_example[t_idx]
                pred_ev = ev_transitions_per_example[t_idx]
                strength_tr = sparks([1] + tr_strength[t_idx].tolist(),
                                     dec_str)
                strength_ev = sparks([1] + ev_strength[t_idx].tolist(),
                                     dec_str)
                _, crossing = evalb.crossing(gold, pred_ev)
                log.t_idx = t_idx
                log.crossing = crossing
                log.gold_lb = "".join(map(str, gold))
                log.pred_tr = "".join(map(str, pred_tr))
                log.pred_ev = "".join(map(str, pred_ev))
                log.strg_tr = strength_tr[1:]
                log.strg_ev = strength_ev[1:]

        if trainer.step > 0 and trainer.step % FLAGS.eval_interval_steps == 0:
            should_log = True
            for index, eval_set in enumerate(eval_iterators):
                acc, _ = evaluate(
                    FLAGS,
                    model,
                    eval_set,
                    log_entry,
                    logger,
                    trainer,
                    show_sample=(trainer.step %
                                 FLAGS.sample_interval_steps == 0),
                    vocabulary=vocabulary,
                    eval_index=index)
                if index == 0:
                    trainer.new_dev_accuracy(acc)
            progress_bar.reset()

        if trainer.step > FLAGS.ckpt_step and trainer.step % FLAGS.ckpt_interval_steps == 0:
            should_log = True
            trainer.checkpoint()

        if should_log:
            logger.LogEntry(log_entry)

        progress_bar.step(i=(trainer.step % FLAGS.statistics_interval_steps) +
                          1,
                          total=FLAGS.statistics_interval_steps)
Esempio n. 2
0
def run(only_forward=False):
    logger = afs_safe_logger.Logger(os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log")

    # Select data format.
    data_manager = get_data_manager(FLAGS.data_type)

    logger.Log("Flag Values:\n" + json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True))

    # Load the data.
    raw_training_data, vocabulary = data_manager.load_data(
        FLAGS.training_data_path, FLAGS.lowercase)

    # Load the eval data.
    raw_eval_sets = []
    if FLAGS.eval_data_path:
        for eval_filename in FLAGS.eval_data_path.split(":"):
            raw_eval_data, _ = data_manager.load_data(eval_filename, FLAGS.lowercase)
            raw_eval_sets.append((eval_filename, raw_eval_data))

    # Prepare the vocabulary.
    if not vocabulary:
        logger.Log("In open vocabulary mode. Using loaded embeddings without fine-tuning.")
        train_embeddings = False
        vocabulary = util.BuildVocabulary(
            raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA)
    else:
        logger.Log("In fixed vocabulary mode. Training embeddings.")
        train_embeddings = True

    # Load pretrained embeddings.
    if FLAGS.embedding_data_path:
        logger.Log("Loading vocabulary with " + str(len(vocabulary))
                   + " words from " + FLAGS.embedding_data_path)
        initial_embeddings = util.LoadEmbeddingsFromText(
            vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path)
    else:
        initial_embeddings = None

    # Trim dataset, convert token sequences to integer sequences, crop, and
    # pad.
    logger.Log("Preprocessing training data.")
    training_data = util.PreprocessDataset(
        raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger,
        sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
        for_rnn=sequential_only())
    training_data_iter = util.MakeTrainingIterator(
        training_data, FLAGS.batch_size, FLAGS.smart_batching, FLAGS.use_peano,
        sentence_pair_data=data_manager.SENTENCE_PAIR_DATA)

    # Preprocess eval sets.
    eval_iterators = []
    for filename, raw_eval_set in raw_eval_sets:
        logger.Log("Preprocessing eval data: " + filename)
        eval_data = util.PreprocessDataset(
            raw_eval_set, vocabulary,
            FLAGS.eval_seq_length if FLAGS.eval_seq_length is not None else FLAGS.seq_length,
            data_manager, eval_mode=True, logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
            for_rnn=sequential_only())
        eval_it = util.MakeEvalIterator(eval_data,
            FLAGS.batch_size, FLAGS.eval_data_limit, bucket_eval=FLAGS.bucket_eval,
            shuffle=FLAGS.shuffle_eval, rseed=FLAGS.shuffle_eval_seed)
        eval_iterators.append((filename, eval_it))

    # Build model.
    vocab_size = len(vocabulary)
    num_classes = len(data_manager.LABEL_MAP)

    model, optimizer, trainer = init_model(FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager)

    # Build trainer.
    trainer = ModelTrainer(model, optimizer)

    standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name)
    best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True)

    # Load checkpoint if available.
    if FLAGS.load_best and os.path.isfile(best_checkpoint_path):
        logger.Log("Found best checkpoint, restoring.")
        step, best_dev_error = trainer.load(best_checkpoint_path)
        logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error))
    elif os.path.isfile(standard_checkpoint_path):
        logger.Log("Found checkpoint, restoring.")
        step, best_dev_error = trainer.load(standard_checkpoint_path)
        logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error))
    else:
        assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint."
        step = 0
        best_dev_error = 1.0

    # GPU support.
    the_gpu.gpu = FLAGS.gpu
    if FLAGS.gpu >= 0:
        model.cuda()
    else:
        model.cpu()
    recursively_set_device(optimizer.state_dict(), the_gpu.gpu)

    # Debug
    def set_debug(self):
        self.debug = FLAGS.debug
    model.apply(set_debug)

    # Accumulate useful statistics.
    A = Accumulator(maxlen=FLAGS.deque_length)

    # Do an evaluation-only run.
    if only_forward:
        for index, eval_set in enumerate(eval_iterators):
            acc = evaluate(model, eval_set, logger, step, vocabulary)
    else:
        # Build log format strings.
        model.train()
        X_batch, transitions_batch, y_batch, num_transitions_batch = get_batch(training_data_iter.next())[:4]
        model(X_batch, transitions_batch, y_batch,
                use_internal_parser=FLAGS.use_internal_parser,
                validate_transitions=FLAGS.validate_transitions
                )

        train_str = train_format(model)
        logger.Log("Train-Format: {}".format(train_str))
        train_extra_str = train_extra_format(model)
        logger.Log("Train-Extra-Format: {}".format(train_extra_str))

         # Train
        logger.Log("Training.")

        # New Training Loop
        progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar)
        progress_bar.step(i=0, total=FLAGS.statistics_interval_steps)

        for step in range(step, FLAGS.training_steps):
            model.train()

            start = time.time()

            batch = get_batch(training_data_iter.next())
            X_batch, transitions_batch, y_batch, num_transitions_batch = batch[:4]

            total_tokens = sum([(nt+1)/2 for nt in num_transitions_batch.reshape(-1)])

            # Reset cached gradients.
            optimizer.zero_grad()

            if FLAGS.model_type == "RLSPINN":
                model.spinn.epsilon = FLAGS.rl_epsilon * math.exp(-step/FLAGS.rl_epsilon_decay)

            # Run model.
            output = model(X_batch, transitions_batch, y_batch,
                use_internal_parser=FLAGS.use_internal_parser,
                validate_transitions=FLAGS.validate_transitions
                )

            # Normalize output.
            logits = F.log_softmax(output)

            # Calculate class accuracy.
            target = torch.from_numpy(y_batch).long()
            pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability
            class_acc = pred.eq(target).sum() / float(target.size(0))

            # Calculate class loss.
            xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False)))

            # Optionally calculate transition loss.
            transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None

            # Extract L2 Cost
            l2_loss = l2_cost(model, FLAGS.l2_lambda) if FLAGS.use_l2_cost else None

            # Accumulate Total Loss Variable
            total_loss = 0.0
            total_loss += xent_loss
            if l2_loss is not None:
                total_loss += l2_loss
            if transition_loss is not None and model.optimize_transition_loss:
                total_loss += transition_loss
            total_loss += auxiliary_loss(model)

            # Backward pass.
            total_loss.backward()

            # Hard Gradient Clipping
            clip = FLAGS.clipping_max_value
            for p in model.parameters():
                if p.requires_grad:
                    p.grad.data.clamp_(min=-clip, max=clip)

            # Learning Rate Decay
            if FLAGS.actively_decay_learning_rate:
                optimizer.lr = FLAGS.learning_rate * (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0))

            # Gradient descent step.
            optimizer.step()

            end = time.time()

            total_time = end - start

            train_accumulate(model, data_manager, A, batch)
            A.add('class_acc', class_acc)
            A.add('total_tokens', total_tokens)
            A.add('total_time', total_time)

            if step % FLAGS.statistics_interval_steps == 0:
                progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)
                progress_bar.finish()

                A.add('xent_cost', xent_loss.data[0])
                A.add('l2_cost', l2_loss.data[0])
                stats_args = train_stats(model, optimizer, A, step)

                logger.Log(train_str.format(**stats_args))
                logger.Log(train_extra_str.format(**stats_args))

            if step > 0 and step % FLAGS.eval_interval_steps == 0:
                for index, eval_set in enumerate(eval_iterators):
                    acc = evaluate(model, eval_set, logger, step)
                    if FLAGS.ckpt_on_best_dev_error and index == 0 and (1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step:
                        best_dev_error = 1 - acc
                        logger.Log("Checkpointing with new best dev accuracy of %f" % acc)
                        trainer.save(best_checkpoint_path, step, best_dev_error)
                progress_bar.reset()

            if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0:
                logger.Log("Checkpointing.")
                trainer.save(standard_checkpoint_path, step, best_dev_error)

            progress_bar.step(i=step % FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)
Esempio n. 3
0
def evaluate(model, eval_set, logger, step, vocabulary=None):
    filename, dataset = eval_set

    reporter = EvalReporter()

    # Evaluate
    class_correct = 0
    class_total = 0
    total_batches = len(dataset)
    progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar)
    progress_bar.step(0, total=total_batches)
    total_tokens = 0
    start = time.time()

    model.eval()

    transition_preds = []
    transition_targets = []

    for i, batch in enumerate(dataset):
        eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch = get_batch(batch)[:4]

        # Run model.
        output = model(eval_X_batch, eval_transitions_batch, eval_y_batch,
            use_internal_parser=FLAGS.use_internal_parser,
            validate_transitions=FLAGS.validate_transitions)

        # Normalize output.
        logits = F.log_softmax(output)

        # Calculate class accuracy.
        target = torch.from_numpy(eval_y_batch).long()
        pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability
        class_correct += pred.eq(target).sum()
        class_total += target.size(0)

        # Optionally calculate transition loss/acc.
        transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None

        # Update Aggregate Accuracies
        total_tokens += sum([(nt+1)/2 for nt in eval_num_transitions_batch.reshape(-1)])

        # Accumulate stats for transition accuracy.
        if transition_loss is not None:
            transition_preds.append([m["t_preds"] for m in model.spinn.memories])
            transition_targets.append([m["t_given"] for m in model.spinn.memories])

        # Print Progress
        progress_bar.step(i+1, total=total_batches)
    progress_bar.finish()

    end = time.time()
    total_time = end - start

    # Get time per token.
    time_metric = time_per_token([total_tokens], [total_time])

    # Get class accuracy.
    eval_class_acc = class_correct / float(class_total)

    # Get transition accuracy if applicable.
    if len(transition_preds) > 0:
        all_preds = np.array(flatten(transition_preds))
        all_truth = np.array(flatten(transition_targets))
        eval_trans_acc = (all_preds == all_truth).sum() / float(all_truth.shape[0])
    else:
        eval_trans_acc = 0.0

    stats_str = "Step: %i Eval acc: %f %f %s Time: %5f" % (step, eval_class_acc, eval_trans_acc, filename, time_metric)

    # Extra Component.
    stats_str += "\nEval Extra:"

    logger.Log(stats_str)

    return eval_class_acc
Esempio n. 4
0
def evaluate(FLAGS,
             model,
             data_manager,
             eval_set,
             log_entry,
             step,
             vocabulary=None,
             show_sample=False,
             eval_index=0):
    filename, dataset = eval_set

    A = Accumulator()
    index = len(log_entry.evaluation)
    eval_log = log_entry.evaluation.add()
    reporter = EvalReporter()
    tree_strs = None

    # Evaluate
    total_batches = len(dataset)
    progress_bar = SimpleProgressBar(msg="Run Eval",
                                     bar_length=60,
                                     enabled=FLAGS.show_progress_bar)
    progress_bar.step(0, total=total_batches)
    total_tokens = 0
    start = time.time()

    model.eval()
    for i, dataset_batch in enumerate(dataset):
        batch = get_batch(dataset_batch)
        eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch

        # Run model.
        """
        output = model(eval_X_batch, eval_transitions_batch, eval_y_batch,
                       use_internal_parser=FLAGS.use_internal_parser,
                       validate_transitions=FLAGS.validate_transitions)
        """
        output = model(eval_X_batch,
                       eval_transitions_batch,
                       eval_y_batch,
                       use_internal_parser=FLAGS.use_internal_parser,
                       validate_transitions=FLAGS.validate_transitions,
                       store_parse_masks=show_sample,
                       example_lengths=eval_num_transitions_batch)

        can_sample = (FLAGS.model_type == "SPINN"
                      and FLAGS.use_internal_parser)
        if show_sample and can_sample:
            tmp_samples = model.get_samples(
                eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report)
            tree_strs = prettyprint_trees(tmp_samples)
        if not FLAGS.write_eval_report:
            show_sample = False  # Only show one sample, regardless of the number of batches.

        # Normalize output.
        logits = F.log_softmax(output)

        # Calculate class accuracy.
        target = torch.from_numpy(eval_y_batch).long()

        # get the index of the max log-probability
        pred = logits.data.max(1, keepdim=False)[1].cpu()

        eval_accumulate(model, data_manager, A, batch)
        A.add('class_correct', pred.eq(target).sum())
        A.add('class_total', target.size(0))

        # Optionally calculate transition loss/acc.
        model.transition_loss if hasattr(model, 'transition_loss') else None

        # Update Aggregate Accuracies
        total_tokens += sum([(nt + 1) / 2
                             for nt in eval_num_transitions_batch.reshape(-1)])
        """
        if FLAGS.write_eval_report:
            reporter_args = [pred, target, eval_ids, output.data.cpu().numpy()]
            if hasattr(model, 'transition_loss'):
                transitions_per_example, _ = model.spinn.get_transitions_per_example(
                    style="preds" if FLAGS.eval_report_use_preds else "given")
                if model.use_sentence_pair:
                    batch_size = pred.size(0)
                    sent1_transitions = transitions_per_example[:batch_size]
                    sent2_transitions = transitions_per_example[batch_size:]
                    reporter_args.append(sent1_transitions)
                    reporter_args.append(sent2_transitions)
                else:
                    reporter_args.append(transitions_per_example)
            reporter.save_batch(*reporter_args)
        """

        if FLAGS.write_eval_report:
            transitions_per_example, _ = model.spinn.get_transitions_per_example(
                style="preds" if FLAGS.eval_report_use_preds else "given") if (
                    FLAGS.model_type == "SPINN"
                    and FLAGS.use_internal_parser) else (None, None)

            if model.use_sentence_pair:
                batch_size = pred.size(0)
                sent1_transitions = transitions_per_example[:
                                                            batch_size] if transitions_per_example is not None else None
                sent2_transitions = transitions_per_example[
                    batch_size:] if transitions_per_example is not None else None

                sent1_trees = tree_strs[:
                                        batch_size] if tree_strs is not None else None
                sent2_trees = tree_strs[
                    batch_size:] if tree_strs is not None else None
            else:
                sent1_transitions = transitions_per_example if transitions_per_example is not None else None
                sent2_transitions = None

                sent1_trees = tree_strs if tree_strs is not None else None
                sent2_trees = None

            reporter.save_batch(pred, target, eval_ids,
                                output.data.cpu().numpy(), sent1_transitions,
                                sent2_transitions, sent1_trees, sent2_trees)

        # Print Progress
        progress_bar.step(i + 1, total=total_batches)
    progress_bar.finish()

    end = time.time()
    total_time = end - start

    A.add('total_tokens', total_tokens)
    A.add('total_time', total_time)

    eval_stats(model, A, eval_log)
    eval_log.filename = filename
    """
    if FLAGS.write_eval_report:
        eval_report_path = os.path.join(FLAGS.log_path, FLAGS.experiment_name + ".report")
        reporter.write_report(eval_report_path)
    """

    if FLAGS.write_eval_report:
        eval_report_path = os.path.join(
            FLAGS.log_path,
            FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report")
        reporter.write_report(eval_report_path)

    eval_class_acc = eval_log.eval_class_accuracy
    eval_trans_acc = eval_log.eval_transition_accuracy

    return eval_class_acc, eval_trans_acc
Esempio n. 5
0
def train_loop(FLAGS, data_manager, model, optimizer, trainer,
               training_data_iter, eval_iterators, logger, true_step,
               best_dev_error, perturbation_id, ev_step, header, root_id,
               vocabulary):
    perturbation_name = FLAGS.experiment_name + "_p" + str(perturbation_id)
    root_name = FLAGS.experiment_name + "_p" + str(root_id)
    # Accumulate useful statistics.
    A = Accumulator(maxlen=FLAGS.deque_length)

    header.start_step = true_step
    header.start_time = int(time.time())
    #header.model_label = perturbation_name

    # Checkpoint paths.
    standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path,
                                                   perturbation_name)
    best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path,
                                               perturbation_name,
                                               best=True)

    # Build log format strings.
    model.train()
    X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = get_batch(
        training_data_iter.next())

    model(X_batch,
          transitions_batch,
          y_batch,
          use_internal_parser=FLAGS.use_internal_parser,
          validate_transitions=FLAGS.validate_transitions)

    # Train.
    logger.Log("Training perturbation %s" % perturbation_id)

    # New Training Loop
    progress_bar = SimpleProgressBar(msg="Training",
                                     bar_length=60,
                                     enabled=FLAGS.show_progress_bar)
    progress_bar.step(i=0, total=FLAGS.statistics_interval_steps)

    log_entry = pb.SpinnEntry()
    for step in range(FLAGS.es_episode_length):
        true_step += 1
        model.train()
        log_entry.Clear()
        log_entry.step = true_step
        log_entry.model_label = perturbation_name
        log_entry.root_label = root_name
        should_log = False

        start = time.time()

        batch = get_batch(training_data_iter.next())
        X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch

        total_tokens = sum([(nt + 1) / 2
                            for nt in num_transitions_batch.reshape(-1)])

        # Reset cached gradients.
        optimizer.zero_grad()

        # Run model.
        output = model(X_batch,
                       transitions_batch,
                       y_batch,
                       use_internal_parser=FLAGS.use_internal_parser,
                       validate_transitions=FLAGS.validate_transitions)

        # Normalize output.
        logits = F.log_softmax(output)

        # Calculate class accuracy.
        target = torch.from_numpy(y_batch).long()

        # get the index of the max log-probability
        pred = logits.data.max(1, keepdim=False)[1].cpu()

        class_acc = pred.eq(target).sum() / float(target.size(0))

        # Calculate class loss.
        xent_loss = nn.NLLLoss()(logits,
                                 to_gpu(Variable(target, volatile=False)))

        # Optionally calculate transition loss.
        transition_loss = model.transition_loss if hasattr(
            model, 'transition_loss') else None

        # Extract L2 Cost
        l2_loss = get_l2_loss(model,
                              FLAGS.l2_lambda) if FLAGS.use_l2_loss else None

        # Accumulate Total Loss Variable
        total_loss = 0.0
        total_loss += xent_loss
        if l2_loss is not None:
            total_loss += l2_loss
        if transition_loss is not None and model.optimize_transition_loss:
            total_loss += transition_loss
        aux_loss = auxiliary_loss(model)
        total_loss += aux_loss
        # Backward pass.
        total_loss.backward()

        # Hard Gradient Clipping
        clip = FLAGS.clipping_max_value
        for p in model.parameters():
            if p.requires_grad:
                p.grad.data.clamp_(min=-clip, max=clip)

        # Learning Rate Decay
        if FLAGS.actively_decay_learning_rate:
            optimizer.lr = FLAGS.learning_rate * \
                (FLAGS.learning_rate_decay_per_10k_steps ** (true_step / 10000.0))

        # Gradient descent step.
        optimizer.step()

        end = time.time()

        total_time = end - start

        train_accumulate(model, data_manager, A, batch)
        A.add('class_acc', class_acc)
        A.add('total_tokens', total_tokens)
        A.add('total_time', total_time)

        if true_step % FLAGS.statistics_interval_steps == 0:
            progress_bar.step(i=FLAGS.statistics_interval_steps,
                              total=FLAGS.statistics_interval_steps)
            progress_bar.finish()

            A.add('xent_cost', xent_loss.data[0])
            A.add('l2_cost', l2_loss.data[0])
            stats(model, optimizer, A, true_step, log_entry)
            should_log = True

        if true_step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0:
            should_log = True
            model.train()
            model(X_batch,
                  transitions_batch,
                  y_batch,
                  use_internal_parser=FLAGS.use_internal_parser,
                  validate_transitions=FLAGS.validate_transitions)
            tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example(
            )

            model.eval()
            model(X_batch,
                  transitions_batch,
                  y_batch,
                  use_internal_parser=FLAGS.use_internal_parser,
                  validate_transitions=FLAGS.validate_transitions)
            ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example(
            )

            if model.use_sentence_pair and len(transitions_batch.shape) == 3:
                transitions_batch = np.concatenate(
                    [transitions_batch[:, :, 0], transitions_batch[:, :, 1]],
                    axis=0)

            # This could be done prior to running the batch for a tiny speed boost.
            t_idxs = range(FLAGS.num_samples)
            random.shuffle(t_idxs)
            t_idxs = sorted(t_idxs[:FLAGS.num_samples])
            for t_idx in t_idxs:
                log = log_entry.rl_sampling.add()
                gold = transitions_batch[t_idx]
                pred_tr = tr_transitions_per_example[t_idx]
                pred_ev = ev_transitions_per_example[t_idx]
                strength_tr = sparks([1] + tr_strength[t_idx].tolist(),
                                     dec_str)
                strength_ev = sparks([1] + ev_strength[t_idx].tolist(),
                                     dec_str)
                _, crossing = evalb.crossing(gold, pred_ev)

                log.t_idx = t_idx
                log.crossing = crossing
                log.gold_lb = "".join(map(str, gold))
                log.pred_tr = "".join(map(str, pred_tr))
                log.pred_ev = "".join(map(str, pred_ev))
                log.strg_tr = strength_tr[1:].encode('utf-8')
                log.strg_ev = strength_ev[1:].encode('utf-8')

        if true_step > 0 and true_step % FLAGS.eval_interval_steps == 0:
            should_log = True
            for index, eval_set in enumerate(eval_iterators):
                acc, tacc = evaluate(
                    FLAGS,
                    model,
                    data_manager,
                    eval_set,
                    log_entry,
                    true_step,
                    vocabulary,
                    show_sample=(true_step % FLAGS.sample_interval_steps == 0),
                    eval_index=index)
                if FLAGS.ckpt_on_best_dev_error and index == 0 and \
                    (1 - acc) < 0.99 * best_dev_error and \
                    true_step > FLAGS.ckpt_step:
                    best_dev_error = 1 - acc
                    logger.Log(
                        "Checkpointing with new best dev accuracy of %f" % acc
                    )  # TODO: This mixes information across dev sets. Fix.
                    trainer.save(best_checkpoint_path, true_step,
                                 best_dev_error, ev_step)
            progress_bar.reset()

        if true_step > FLAGS.ckpt_step and true_step % FLAGS.ckpt_interval_steps == 0:
            should_log = True
            logger.Log("Checkpointing.")
            trainer.save(standard_checkpoint_path, true_step, best_dev_error,
                         ev_step)

        if should_log:
            logger.LogEntry(log_entry)

        progress_bar.step(i=(true_step % FLAGS.statistics_interval_steps) + 1,
                          total=FLAGS.statistics_interval_steps)

    if os.path.exists(best_checkpoint_path):
        return ev_step, true_step, perturbation_id, best_dev_error
    else:
        return ev_step, true_step, perturbation_id, (1 - acc)
Esempio n. 6
0
def evaluate(FLAGS, model, eval_set, log_entry,
             logger, trainer, vocabulary=None, show_sample=False, eval_index=0):
    filename, dataset = eval_set

    A = Accumulator()
    eval_log = log_entry.evaluation.add()
    reporter = EvalReporter()
    tree_strs = None

    # Evaluate
    total_batches = len(dataset)
    progress_bar = SimpleProgressBar(
        msg="Run Eval",
        bar_length=60,
        enabled=FLAGS.show_progress_bar)
    progress_bar.step(0, total=total_batches)
    total_tokens = 0
    cpt = 0
    cpt_max = 0
    start = time.time()

    model.eval()
    for i, dataset_batch in enumerate(dataset):
        batch = get_batch(dataset_batch)
        eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch

        # Run model.
        np.set_printoptions(threshold=np.inf)
        output = model(eval_X_batch, eval_transitions_batch, eval_y_batch,
                       use_internal_parser=FLAGS.use_internal_parser,
                       validate_transitions=FLAGS.validate_transitions,
                       store_parse_masks=show_sample,
                       example_lengths=eval_num_transitions_batch)

        can_sample = (FLAGS.model_type ==
                      "RLSPINN" and FLAGS.use_internal_parser)
        if show_sample and can_sample:
            tmp_samples = model.get_samples(
                eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report)
            tree_strs = prettyprint_trees(tmp_samples)
        if not FLAGS.write_eval_report:
            # Only show one sample, regardless of the number of batches.
            show_sample = False

        # Calculate class accuracy.
        target = torch.from_numpy(eval_y_batch).long()

        # get the index of the max log-probability
        pred = output.data.max(1, keepdim=False)[1].cpu()

        eval_accumulate(model, A, batch)
        A.add('class_correct', pred.eq(target).sum())
        A.add('class_total', target.size(0))

        # Update Aggregate Accuracies
        total_tokens += sum(
            [(nt + 1) / 2 for nt in eval_num_transitions_batch.reshape(-1)])

        if FLAGS.write_eval_report:
            transitions_per_example, _ = model.spinn.get_transitions_per_example(
                style="preds" if FLAGS.eval_report_use_preds else "given") if (
                FLAGS.model_type == "RLSPINN" and FLAGS.use_internal_parser) else (
                None, None)

            if model.use_sentence_pair:
                batch_size = pred.size(0)
                sent1_transitions = transitions_per_example[:
                                                            batch_size] if transitions_per_example is not None else None
                sent2_transitions = transitions_per_example[batch_size:
                                                            ] if transitions_per_example is not None else None

                sent1_trees = tree_strs[:batch_size] if tree_strs is not None else None
                sent2_trees = tree_strs[batch_size:
                                        ] if tree_strs is not None else None
            else:
                sent1_transitions = transitions_per_example if transitions_per_example is not None else None
                sent2_transitions = None

                sent1_trees = tree_strs if tree_strs is not None else None
                sent2_trees = None

            #cp_metric = True # Flagify this
            if FLAGS.cp_metric:
                cp, cp_max = reporter.save_batch(
                    pred,
                    target,
                    eval_ids,
                    output.data.cpu().numpy(),
                    sent1_transitions,
                    sent2_transitions,
                    sent1_trees,
                    sent2_trees, cp_metric=FLAGS.cp_metric)
                cpt += cp
                cpt_max += cp_max
            else:
                reporter.save_batch(
                pred,
                target,
                eval_ids,
                output.data.cpu().numpy(),
                sent1_transitions,
                sent2_transitions,
                sent1_trees,
                sent2_trees)

        # Print Progress
        progress_bar.step(i + 1, total=total_batches)
    progress_bar.finish()
    
    cp_metric_value = cpt / cpt_max

    if tree_strs is not None:
        logger.Log('Sample: ' + tree_strs[0])

    end = time.time()
    total_time = end - start

    A.add('total_tokens', total_tokens)
    A.add('total_time', total_time)

    logger.Log("Eval cp_acc: " + str(cp_metric_value))

    eval_stats(model, A, eval_log)
    eval_log.filename = filename

    if FLAGS.write_eval_report:
        eval_report_path = os.path.join(
            FLAGS.log_path,
            FLAGS.experiment_name +
            ".eval_set_" +
            str(eval_index) +
            ".report")
        reporter.write_report(eval_report_path)

    eval_class_acc = eval_log.eval_class_accuracy
    eval_trans_acc = eval_log.eval_transition_accuracy

    return eval_class_acc, eval_trans_acc
Esempio n. 7
0
def evaluate(FLAGS,
             model,
             eval_set,
             log_entry,
             logger,
             trainer,
             vocabulary=None,
             show_sample=False,
             eval_index=0,
             target_vocabulary=None):
    filename, dataset = eval_set

    A = Accumulator()
    len(log_entry.evaluation)
    eval_log = log_entry.evaluation.add()
    reporter = EvalReporter()
    tree_strs = None

    # Evaluate
    total_batches = len(dataset)
    progress_bar = SimpleProgressBar(msg="Run Eval",
                                     bar_length=60,
                                     enabled=FLAGS.show_progress_bar)
    progress_bar.step(0, total=total_batches)
    total_tokens = 0
    start = time.time()

    if FLAGS.model_type in ["ChoiPyramid", "Maillard", "CatalanPyramid"]:
        pyramid_temperature_multiplier = FLAGS.pyramid_temperature_decay_per_10k_steps**(
            trainer.step / 10000.0)
        if FLAGS.pyramid_temperature_cycle_length > 0.0:
            min_temp = 1e-5
            pyramid_temperature_multiplier *= (math.cos(
                (trainer.step) / FLAGS.pyramid_temperature_cycle_length) + 1 +
                                               min_temp) / 2
    else:
        pyramid_temperature_multiplier = None

    model.eval()

    ref_file_name = FLAGS.log_path + "/ref_file"
    pred_file_name = FLAGS.log_path + "/pred_file"
    reference_file = open(ref_file_name, "w")
    predict_file = open(pred_file_name, "w")
    full_ref = []
    full_pred = []
    for i, dataset_batch in enumerate(dataset):
        batch = get_batch(dataset_batch)
        eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch

        # Run model.
        output = model(
            eval_X_batch,
            eval_transitions_batch,
            eval_y_batch,
            use_internal_parser=FLAGS.use_internal_parser,
            validate_transitions=FLAGS.validate_transitions,
            pyramid_temperature_multiplier=pyramid_temperature_multiplier,
            example_lengths=eval_num_transitions_batch)

        can_sample = FLAGS.model_type in [
            "ChoiPyramid", "Maillard", "CatalanPyramid"
        ] or (FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser)
        if show_sample and can_sample:
            tmp_samples = model.encoder.get_samples(
                eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report)
            tree_strs = prettyprint_trees(tmp_samples)
        if not FLAGS.write_eval_report:
            # Only show one sample, regardless of the number of batches.
            show_sample = False

        # Get reference translation
        ref_out = [" ".join(map(str, k[:-1])) + " ." for k in eval_y_batch]
        full_ref += ref_out

        # Get predicted translation
        predicted = [[] for i in range(len(eval_y_batch))]
        done = []
        for x in output:
            index = -1
            for x_0 in x:
                index += 1
                val = int(x_0)
                if val == 1:
                    if index in done:
                        continue
                    done.append(index)
                elif index not in done:
                    predicted[index].append(val)
        pred_out = [" ".join(map(str, k)) + " ." for k in predicted]
        full_pred += pred_out

        eval_accumulate(model, A, batch)

        # Optionally calculate transition loss/acc.
        model.encoder.transition_loss if hasattr(model.encoder,
                                                 'transition_loss') else None

        # Update Aggregate Accuracies
        total_tokens += sum([(nt + 1) / \
                            2 for nt in eval_num_transitions_batch.reshape(-1)])

        if FLAGS.write_eval_report:
            transitions_per_example, _ = model.encoder.spinn.get_transitions_per_example(
                style="preds" if FLAGS.eval_report_use_preds else "given") if (
                    FLAGS.model_type == "SPINN"
                    and FLAGS.use_internal_parser) else (None, None)

            sent1_transitions = transitions_per_example if transitions_per_example is not None else None
            sent2_transitions = None

            sent1_trees = tree_strs if tree_strs is not None else None
            sent2_trees = None
            reporter.save_batch(full_pred,
                                full_ref,
                                eval_ids, [None],
                                sent1_transitions,
                                sent2_transitions,
                                sent1_trees,
                                sent2_trees,
                                mt=True)

        # Print Progress
        progress_bar.step(i + 1, total=total_batches)
    progress_bar.finish()

    if tree_strs is not None:
        logger.Log('Sample: ' + tree_strs[0])

    reference_file.write("\n".join(full_ref))
    reference_file.close()
    predict_file.write("\n".join(full_pred))
    predict_file.close()
    bleu_score = os.popen("perl spinn/util/multi-bleu.perl " + ref_file_name +
                          " < " + pred_file_name).read()
    try:
        bleu_score = float(bleu_score)
    except:
        bleu_score = 0.0

    end = time.time()
    total_time = end - start
    A.add('class_correct', bleu_score)
    A.add('class_total', 1)
    A.add('total_tokens', total_tokens)
    A.add('total_time', total_time)
    eval_stats(model, A, eval_log)
    eval_log.filename = filename

    if FLAGS.write_eval_report:
        eval_report_path = os.path.join(
            FLAGS.log_path,
            FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report")
        reporter.write_report(eval_report_path)

    eval_class_acc = eval_log.eval_class_accuracy
    eval_trans_acc = eval_log.eval_transition_accuracy

    return eval_class_acc, eval_trans_acc
Esempio n. 8
0
def evaluate(FLAGS,
             model,
             eval_set,
             log_entry,
             logger,
             trainer,
             vocabulary=None,
             show_sample=False,
             eval_index=0):
    filename, dataset = eval_set

    A = Accumulator()
    len(log_entry.evaluation)
    eval_log = log_entry.evaluation.add()
    reporter = EvalReporter()
    tree_strs = None

    # Evaluate
    total_batches = len(dataset)
    progress_bar = SimpleProgressBar(msg="Run Eval",
                                     bar_length=60,
                                     enabled=FLAGS.show_progress_bar)
    progress_bar.step(0, total=total_batches)
    total_tokens = 0
    start = time.time()

    if FLAGS.model_type in ["ChoiPyramid"]:
        pyramid_temperature_multiplier = FLAGS.pyramid_temperature_decay_per_10k_steps**(
            trainer.step / 10000.0)
        if FLAGS.pyramid_temperature_cycle_length > 0.0:
            min_temp = 1e-5
            pyramid_temperature_multiplier *= (math.cos(
                (trainer.step) / FLAGS.pyramid_temperature_cycle_length) + 1 +
                                               min_temp) / 2
    else:
        pyramid_temperature_multiplier = None

    model.eval()
    for i, dataset_batch in enumerate(dataset):
        batch = get_batch(dataset_batch)
        eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch

        # Run model.
        output = model(
            eval_X_batch,
            eval_transitions_batch,
            eval_y_batch,
            use_internal_parser=FLAGS.use_internal_parser,
            validate_transitions=FLAGS.validate_transitions,
            pyramid_temperature_multiplier=pyramid_temperature_multiplier,
            store_parse_masks=show_sample,
            example_lengths=eval_num_transitions_batch)

        # TODO: Restore support in Pyramid if using.
        can_sample = FLAGS.model_type in [
            "ChoiPyramid"
        ] or (FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser)
        if show_sample and can_sample:
            tmp_samples = model.get_samples(
                eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report)
            tree_strs = prettyprint_trees(tmp_samples)
        if not FLAGS.write_eval_report:
            # Only show one sample, regardless of the number of batches.
            show_sample = False

        # Calculate class accuracy.
        target = torch.from_numpy(eval_y_batch).long()

        # get the index of the max log-probability
        pred = output.data.max(1, keepdim=False)[1].cpu()

        eval_accumulate(model, A, batch)
        A.add('class_correct', pred.eq(target).sum())
        A.add('class_total', target.size(0))

        # Optionally calculate transition loss/acc.
        model.transition_loss if hasattr(model, 'transition_loss') else None

        # Update Aggregate Accuracies
        total_tokens += sum([(nt + 1) / \
                            2 for nt in eval_num_transitions_batch.reshape(-1)])

        if FLAGS.write_eval_report:
            transitions_per_example, _ = model.spinn.get_transitions_per_example(
                style="preds" if FLAGS.eval_report_use_preds else "given") if (
                    FLAGS.model_type == "SPINN"
                    and FLAGS.use_internal_parser) else (None, None)

            if model.use_sentence_pair:
                batch_size = pred.size(0)
                sent1_transitions = transitions_per_example[:
                                                            batch_size] if transitions_per_example is not None else None
                sent2_transitions = transitions_per_example[
                    batch_size:] if transitions_per_example is not None else None

                sent1_trees = tree_strs[:
                                        batch_size] if tree_strs is not None else None
                sent2_trees = tree_strs[
                    batch_size:] if tree_strs is not None else None
            else:
                sent1_transitions = transitions_per_example if transitions_per_example is not None else None
                sent2_transitions = None

                sent1_trees = tree_strs if tree_strs is not None else None
                sent2_trees = None

            reporter.save_batch(pred, target, eval_ids,
                                output.data.cpu().numpy(), sent1_transitions,
                                sent2_transitions, sent1_trees, sent2_trees)

        # Print Progress
        progress_bar.step(i + 1, total=total_batches)
    progress_bar.finish()
    if tree_strs is not None:
        logger.Log('Sample: ' + tree_strs[0])

    end = time.time()
    total_time = end - start

    A.add('total_tokens', total_tokens)
    A.add('total_time', total_time)

    eval_stats(model, A, eval_log)
    eval_log.filename = filename

    if FLAGS.write_eval_report:
        eval_report_path = os.path.join(
            FLAGS.log_path,
            FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report")
        reporter.write_report(eval_report_path)

    eval_class_acc = eval_log.eval_class_accuracy
    eval_trans_acc = eval_log.eval_transition_accuracy

    return eval_class_acc, eval_trans_acc