Example #1
0
def train_loop(FLAGS, model, trainer, training_data_iter, eval_iterators,
               logger):
    # Accumulate useful statistics.
    A = Accumulator(maxlen=FLAGS.deque_length)

    # Train.
    logger.Log("Training.")

    # New Training Loop
    progress_bar = SimpleProgressBar(msg="Training",
                                     bar_length=60,
                                     enabled=FLAGS.show_progress_bar)
    progress_bar.step(i=0, total=FLAGS.statistics_interval_steps)

    log_entry = pb.SpinnEntry()
    for _ in range(trainer.step, FLAGS.training_steps):
        if (trainer.step -
                trainer.best_dev_step) > FLAGS.early_stopping_steps_to_wait:
            logger.Log('No improvement after ' +
                       str(FLAGS.early_stopping_steps_to_wait) +
                       ' steps. Stopping training.')
            break

        model.train()
        log_entry.Clear()
        log_entry.step = trainer.step
        should_log = False

        start = time.time()

        batch = get_batch(next(training_data_iter))
        X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch

        total_tokens = sum([(nt + 1) / 2
                            for nt in num_transitions_batch.reshape(-1)])

        # Reset cached gradients.
        trainer.optimizer_zero_grad()

        temperature = math.sin(
            math.pi / 2 +
            trainer.step / float(FLAGS.rl_confidence_interval) * 2 * math.pi)
        temperature = (temperature + 1) / 2

        # Confidence Penalty for Transition Predictions.
        if FLAGS.rl_confidence_penalty:
            epsilon = FLAGS.rl_epsilon * \
                math.exp(-trainer.step / float(FLAGS.rl_epsilon_decay))
            temp = 1 + \
                (temperature - .5) * FLAGS.rl_confidence_penalty * epsilon
            model.spinn.temperature = max(1e-3, temp)

        # Soft Wake/Sleep based on temperature.
        if FLAGS.rl_wake_sleep:
            model.rl_weight = temperature * FLAGS.rl_weight

        # Run model.
        output = model(X_batch,
                       transitions_batch,
                       y_batch,
                       use_internal_parser=FLAGS.use_internal_parser,
                       validate_transitions=FLAGS.validate_transitions)

        # Calculate class accuracy.
        target = torch.from_numpy(y_batch).long()

        # get the index of the max log-probability
        pred = output.data.max(1, keepdim=False)[1].cpu()

        class_acc = pred.eq(target).sum() / float(target.size(0))

        # Calculate class loss.
        xent_loss = nn.CrossEntropyLoss()(output,
                                          to_gpu(
                                              Variable(target,
                                                       volatile=False)))

        # Optionally calculate transition loss.
        transition_loss = model.transition_loss if hasattr(
            model, 'transition_loss') else None

        # Accumulate Total Loss Variable
        total_loss = 0.0
        total_loss += xent_loss
        if transition_loss is not None and model.optimize_transition_loss:
            total_loss += transition_loss
        aux_loss = auxiliary_loss(model)
        total_loss += aux_loss

        # Backward pass.
        total_loss.backward()

        # Hard Gradient Clipping
        nn.utils.clip_grad_norm([
            param for name, param in model.named_parameters()
            if name not in ["embed.embed.weight"]
        ], FLAGS.clipping_max_value)

        # Gradient descent step.
        trainer.optimizer_step()

        end = time.time()

        total_time = end - start

        train_accumulate(model, A, batch)
        A.add('class_acc', class_acc)
        A.add('total_tokens', total_tokens)
        A.add('total_time', total_time)

        train_rl_accumulate(model, A, batch)

        if trainer.step % FLAGS.statistics_interval_steps == 0:
            progress_bar.step(i=FLAGS.statistics_interval_steps,
                              total=FLAGS.statistics_interval_steps)
            progress_bar.finish()

            A.add('xent_cost', xent_loss.data[0])
            stats(model, trainer, A, log_entry)
            should_log = True

        if trainer.step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0:
            should_log = True
            model.train()
            model(X_batch,
                  transitions_batch,
                  y_batch,
                  use_internal_parser=FLAGS.use_internal_parser,
                  validate_transitions=FLAGS.validate_transitions)
            tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example(
            )

            model.eval()
            model(X_batch,
                  transitions_batch,
                  y_batch,
                  use_internal_parser=FLAGS.use_internal_parser,
                  validate_transitions=FLAGS.validate_transitions)
            ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example(
            )

            if model.use_sentence_pair and len(transitions_batch.shape) == 3:
                transitions_batch = np.concatenate(
                    [transitions_batch[:, :, 0], transitions_batch[:, :, 1]],
                    axis=0)

            # This could be done prior to running the batch for a tiny speed
            # boost.
            t_idxs = list(range(FLAGS.num_samples))
            random.shuffle(t_idxs)
            t_idxs = sorted(t_idxs[:FLAGS.num_samples])
            for t_idx in t_idxs:
                log = log_entry.rl_sampling.add()
                gold = transitions_batch[t_idx]
                pred_tr = tr_transitions_per_example[t_idx]
                pred_ev = ev_transitions_per_example[t_idx]
                strength_tr = sparks([1] + tr_strength[t_idx].tolist(),
                                     dec_str)
                strength_ev = sparks([1] + ev_strength[t_idx].tolist(),
                                     dec_str)
                _, crossing = evalb.crossing(gold, pred)
                log.t_idx = t_idx
                log.crossing = crossing
                log.gold_lb = "".join(map(str, gold))
                log.pred_tr = "".join(map(str, pred_tr))
                log.pred_ev = "".join(map(str, pred_ev))
                log.strg_tr = strength_tr[1:]
                log.strg_ev = strength_ev[1:]

        if trainer.step > 0 and trainer.step % FLAGS.eval_interval_steps == 0:
            should_log = True
            for index, eval_set in enumerate(eval_iterators):
                acc, _ = evaluate(FLAGS,
                                  model,
                                  eval_set,
                                  log_entry,
                                  logger,
                                  trainer,
                                  eval_index=index)
                if index == 0:
                    trainer.new_dev_accuracy(acc)

            progress_bar.reset()

        if trainer.step > FLAGS.ckpt_step and trainer.step % FLAGS.ckpt_interval_steps == 0:
            should_log = True
            trainer.checkpoint()

        if should_log:
            logger.LogEntry(log_entry)

        progress_bar.step(i=(trainer.step % FLAGS.statistics_interval_steps) +
                          1,
                          total=FLAGS.statistics_interval_steps)
Example #2
0
def train_loop(FLAGS, data_manager, model, optimizer, trainer, training_data_iter, eval_iterators, logger, step, best_dev_error):
    # Accumulate useful statistics.
    A = Accumulator(maxlen=FLAGS.deque_length)
    M = MetricsWriter(os.path.join(FLAGS.metrics_path, FLAGS.experiment_name))

    # Checkpoint paths.
    standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name)
    best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True)

    # Build log format strings.
    model.train()
    X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = get_batch(training_data_iter.next())
    model(X_batch, transitions_batch, y_batch,
            use_internal_parser=FLAGS.use_internal_parser,
            validate_transitions=FLAGS.validate_transitions
            )

    logger.Log("")
    logger.Log("# ----- BEGIN: Log Configuration ----- #")

    # Preview train string template.
    train_str = train_format(model)
    logger.Log("Train-Format: {}".format(train_str))
    train_extra_str = train_extra_format(model)
    logger.Log("Train-Extra-Format: {}".format(train_extra_str))

    # Preview eval string template.
    eval_str = eval_format(model)
    logger.Log("Eval-Format: {}".format(eval_str))
    eval_extra_str = eval_extra_format(model)
    logger.Log("Eval-Extra-Format: {}".format(eval_extra_str))

    logger.Log("# ----- END: Log Configuration ----- #")
    logger.Log("")

    # Train.
    logger.Log("Training.")

    # New Training Loop
    progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar)
    progress_bar.step(i=0, total=FLAGS.statistics_interval_steps)

    for step in range(step, FLAGS.training_steps):
        model.train()

        start = time.time()

        batch = get_batch(training_data_iter.next())
        X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch

        total_tokens = sum([(nt+1)/2 for nt in num_transitions_batch.reshape(-1)])

        # Reset cached gradients.
        optimizer.zero_grad()

        # Run model.
        output = model(X_batch, transitions_batch, y_batch,
            use_internal_parser=FLAGS.use_internal_parser,
            validate_transitions=FLAGS.validate_transitions
            )

        # Normalize output.
        logits = F.log_softmax(output)

        # Calculate class accuracy.
        target = torch.from_numpy(y_batch).long()
        pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability
        class_acc = pred.eq(target).sum() / float(target.size(0))

        # Calculate class loss.
        xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False)))

        # Optionally calculate transition loss.
        transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None

        # Extract L2 Cost
        l2_loss = l2_cost(model, FLAGS.l2_lambda) if FLAGS.use_l2_cost else None

        # Accumulate Total Loss Variable
        total_loss = 0.0
        total_loss += xent_loss
        if l2_loss is not None:
            total_loss += l2_loss
        if transition_loss is not None and model.optimize_transition_loss:
            total_loss += transition_loss
        total_loss += auxiliary_loss(model)

        # Backward pass.
        total_loss.backward()

        # Hard Gradient Clipping
        clip = FLAGS.clipping_max_value
        for p in model.parameters():
            if p.requires_grad:
                p.grad.data.clamp_(min=-clip, max=clip)

        # Learning Rate Decay
        if FLAGS.actively_decay_learning_rate:
            optimizer.lr = FLAGS.learning_rate * (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0))

        # Gradient descent step.
        optimizer.step()

        end = time.time()

        total_time = end - start

        train_accumulate(model, data_manager, A, batch)
        A.add('class_acc', class_acc)
        A.add('total_tokens', total_tokens)
        A.add('total_time', total_time)

        if step % FLAGS.statistics_interval_steps == 0:
            progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)
            progress_bar.finish()

            A.add('xent_cost', xent_loss.data[0])
            A.add('l2_cost', l2_loss.data[0])
            stats_args = train_stats(model, optimizer, A, step)

            train_metrics(M, stats_args, step)

            logger.Log(train_str.format(**stats_args))
            logger.Log(train_extra_str.format(**stats_args))

        if step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0:
            model.train()
            model(X_batch, transitions_batch, y_batch,
                use_internal_parser=FLAGS.use_internal_parser,
                validate_transitions=FLAGS.validate_transitions
                )
            tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example()

            model.eval()
            model(X_batch, transitions_batch, y_batch,
                use_internal_parser=FLAGS.use_internal_parser,
                validate_transitions=FLAGS.validate_transitions
                )
            ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example()

            transition_str = "Samples:"
            if model.use_sentence_pair and len(transitions_batch.shape) == 3:
                transitions_batch = np.concatenate([
                    transitions_batch[:,:,0], transitions_batch[:,:,1]], axis=0)

            # This could be done prior to running the batch for a tiny speed boost.
            t_idxs = range(FLAGS.num_samples)
            random.shuffle(t_idxs)
            t_idxs = sorted(t_idxs[:FLAGS.num_samples])
            for t_idx in t_idxs:
                gold = transitions_batch[t_idx]
                pred_tr = tr_transitions_per_example[t_idx]
                pred_ev = ev_transitions_per_example[t_idx]
                stength_tr = sparks([1] + tr_strength[t_idx].tolist())
                stength_ev = sparks([1] + ev_strength[t_idx].tolist())
                _, crossing = evalb.crossing(gold, pred)
                transition_str += "\n{}. crossing={}".format(t_idx, crossing)
                transition_str += "\n     g{}".format("".join(map(str, gold)))
                transition_str += "\n      {}".format(stength_tr[1:].encode('utf-8'))
                transition_str += "\n    pt{}".format("".join(map(str, pred_tr)))
                transition_str += "\n      {}".format(stength_ev[1:].encode('utf-8'))
                transition_str += "\n    pe{}".format("".join(map(str, pred_ev)))
            logger.Log(transition_str)

        if step > 0 and step % FLAGS.eval_interval_steps == 0:
            for index, eval_set in enumerate(eval_iterators):
                acc, tacc = evaluate(FLAGS, model, data_manager, eval_set, index, logger, step)
                if FLAGS.ckpt_on_best_dev_error and index == 0 and (1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step:
                    best_dev_error = 1 - acc
                    logger.Log("Checkpointing with new best dev accuracy of %f" % acc)
                    trainer.save(best_checkpoint_path, step, best_dev_error)
            progress_bar.reset()

        if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0:
            logger.Log("Checkpointing.")
            trainer.save(standard_checkpoint_path, step, best_dev_error)

        progress_bar.step(i=step % FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)
def train_loop(FLAGS, data_manager, model, optimizer, trainer,
               training_data_iter, eval_iterators, logger, step, best_dev_error, vocabulary):
    # Accumulate useful statistics.
    A = Accumulator(maxlen=FLAGS.deque_length)

    # Checkpoint paths.
    standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name)
    best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True)

    # Build log format strings.
    model.train()
    X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = get_batch(
        training_data_iter.next())
    model(X_batch, transitions_batch, y_batch,
          use_internal_parser=FLAGS.use_internal_parser,
          validate_transitions=FLAGS.validate_transitions,
          pyramid_temperature_multiplier=1.0,
          example_lengths=num_transitions_batch
          )

    # Train.
    logger.Log("Training.")

    # New Training Loop
    progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar)
    progress_bar.step(i=0, total=FLAGS.statistics_interval_steps)

    log_entry = pb.SpinnEntry()
    for step in range(step, FLAGS.training_steps):
        model.train()
        log_entry.Clear()
        log_entry.step = step
        should_log = False

        start = time.time()

        batch = get_batch(training_data_iter.next())
        X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch

        total_tokens = sum([(nt + 1) / 2 for nt in num_transitions_batch.reshape(-1)])

        # Reset cached gradients.
        optimizer.zero_grad()

        if FLAGS.model_type in ["Pyramid", "ChoiPyramid"]:
            pyramid_temperature_multiplier = FLAGS.pyramid_temperature_decay_per_10k_steps ** (
                step / 10000.0)
            if FLAGS.pyramid_temperature_cycle_length > 0.0:
                min_temp = 1e-5
                pyramid_temperature_multiplier *= (math.cos((step) /
                                                            FLAGS.pyramid_temperature_cycle_length) + 1 + min_temp) / 2
        else:
            pyramid_temperature_multiplier = None

        # Run model.
        output = model(X_batch, transitions_batch, y_batch,
                       use_internal_parser=FLAGS.use_internal_parser,
                       validate_transitions=FLAGS.validate_transitions,
                       pyramid_temperature_multiplier=pyramid_temperature_multiplier,
                       example_lengths=num_transitions_batch
                       )

        # Normalize output.
        logits = F.log_softmax(output)

        # Calculate class accuracy.
        target = torch.from_numpy(y_batch).long()

        # get the index of the max log-probability
        pred = logits.data.max(1, keepdim=False)[1].cpu()

        class_acc = pred.eq(target).sum() / float(target.size(0))

        # Calculate class loss.
        xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False)))

        # Optionally calculate transition loss.
        transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None

        # Extract L2 Cost
        l2_loss = get_l2_loss(model, FLAGS.l2_lambda) if FLAGS.use_l2_loss else None

        # Accumulate Total Loss Variable
        total_loss = 0.0
        total_loss += xent_loss
        if l2_loss is not None:
            total_loss += l2_loss
        if transition_loss is not None and model.optimize_transition_loss:
            total_loss += transition_loss
        aux_loss = auxiliary_loss(model)
        total_loss += aux_loss
        # Backward pass.
        total_loss.backward()

        # Hard Gradient Clipping
        clip = FLAGS.clipping_max_value
        for p in model.parameters():
            if p.requires_grad:
                p.grad.data.clamp_(min=-clip, max=clip)

        # Learning Rate Decay
        if FLAGS.actively_decay_learning_rate:
            optimizer.lr = FLAGS.learning_rate * \
                (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0))

        # Gradient descent step.
        optimizer.step()

        end = time.time()

        total_time = end - start

        train_accumulate(model, data_manager, A, batch)
        A.add('class_acc', class_acc)
        A.add('total_tokens', total_tokens)
        A.add('total_time', total_time)

        if step % FLAGS.statistics_interval_steps == 0:
            A.add('xent_cost', xent_loss.data[0])
            A.add('l2_cost', l2_loss.data[0])
            stats(model, optimizer, A, step, log_entry)
            should_log = True
            progress_bar.finish()

        if step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0:
            should_log = True
            model.train()
            model(X_batch, transitions_batch, y_batch,
                  use_internal_parser=FLAGS.use_internal_parser,
                  validate_transitions=FLAGS.validate_transitions,
                  pyramid_temperature_multiplier=pyramid_temperature_multiplier,
                  example_lengths=num_transitions_batch
                  )
            tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example()

            model.eval()
            model(X_batch, transitions_batch, y_batch,
                  use_internal_parser=FLAGS.use_internal_parser,
                  validate_transitions=FLAGS.validate_transitions,
                  pyramid_temperature_multiplier=pyramid_temperature_multiplier,
                  example_lengths=num_transitions_batch
                  )
            ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example()

            if model.use_sentence_pair and len(transitions_batch.shape) == 3:
                transitions_batch = np.concatenate([
                    transitions_batch[:, :, 0], transitions_batch[:, :, 1]], axis=0)

            # This could be done prior to running the batch for a tiny speed boost.
            t_idxs = range(FLAGS.num_samples)
            random.shuffle(t_idxs)
            t_idxs = sorted(t_idxs[:FLAGS.num_samples])
            for t_idx in t_idxs:
                log = log_entry.rl_sampling.add()
                gold = transitions_batch[t_idx]
                pred_tr = tr_transitions_per_example[t_idx]
                pred_ev = ev_transitions_per_example[t_idx]
                strength_tr = sparks([1] + tr_strength[t_idx].tolist(), dec_str)
                strength_ev = sparks([1] + ev_strength[t_idx].tolist(), dec_str)
                _, crossing = evalb.crossing(gold, pred_ev)
                log.t_idx = t_idx
                log.crossing = crossing
                log.gold_lb = "".join(map(str, gold))
                log.pred_tr = "".join(map(str, pred_tr))
                log.pred_ev = "".join(map(str, pred_ev))
                log.strg_tr = strength_tr[1:].encode('utf-8')
                log.strg_ev = strength_ev[1:].encode('utf-8')

        if step > 0 and step % FLAGS.eval_interval_steps == 0:
            should_log = True
            for index, eval_set in enumerate(eval_iterators):
                acc, tacc = evaluate(FLAGS, model, data_manager, eval_set, log_entry, logger, step,
                                     show_sample=(
                                         step %
                                         FLAGS.sample_interval_steps == 0), vocabulary=vocabulary, eval_index=index)
                if FLAGS.ckpt_on_best_dev_error and index == 0 and (
                        1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step:
                    best_dev_error = 1 - acc
                    logger.Log("Checkpointing with new best dev accuracy of %f" % acc)  # TODO: This mixes information across dev sets. Fix.
                    trainer.save(best_checkpoint_path, step, best_dev_error)
            progress_bar.reset()

        if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0:
            should_log = True
            logger.Log("Checkpointing.")
            trainer.save(standard_checkpoint_path, step, best_dev_error)

        if should_log:
            logger.LogEntry(log_entry)

        progress_bar.step(i=(step % FLAGS.statistics_interval_steps) + 1,
                          total=FLAGS.statistics_interval_steps)
Example #4
0
def train_loop(FLAGS, data_manager, model, optimizer, trainer,
               training_data_iter, eval_iterators, logger, step, best_dev_error):
    # Accumulate useful statistics.
    A = Accumulator(maxlen=FLAGS.deque_length)

    # Checkpoint paths.
    standard_checkpoint_path = get_checkpoint_path(
        FLAGS.ckpt_path, FLAGS.experiment_name)
    best_checkpoint_path = get_checkpoint_path(
        FLAGS.ckpt_path, FLAGS.experiment_name, best=True)

    # Build log format strings.
    model.train()
    X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = get_batch(
        training_data_iter.next())
    model(X_batch, transitions_batch, y_batch,
          use_internal_parser=FLAGS.use_internal_parser,
          validate_transitions=FLAGS.validate_transitions
          )

    # Train.
    logger.Log("Training.")

    # New Training Loop
    progress_bar = SimpleProgressBar(
        msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar)
    progress_bar.step(i=0, total=FLAGS.statistics_interval_steps)

    log_entry = pb.SpinnEntry()
    for step in range(step, FLAGS.training_steps):
        model.train()
        log_entry.Clear()
        log_entry.step = step
        should_log = False

        start = time.time()

        batch = get_batch(training_data_iter.next())
        X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch

        total_tokens = sum(
            [(nt + 1) / 2 for nt in num_transitions_batch.reshape(-1)])

        # Reset cached gradients.
        optimizer.zero_grad()

        epsilon = FLAGS.rl_epsilon * math.exp(-step / FLAGS.rl_epsilon_decay)

        # Epsilon Greedy w. Decay.
        model.spinn.epsilon = epsilon

        # Confidence Penalty for Transition Predictions.
        temperature = math.sin(math.pi / 2 + step /
                               float(FLAGS.rl_confidence_interval) * 2 * math.pi)
        temperature = (temperature + 1) / 2

        if FLAGS.rl_confidence_penalty:
            temp = 1 + \
                (temperature - .5) * FLAGS.rl_confidence_penalty * epsilon
            model.spinn.temperature = max(1e-3, temp)

        # Soft Wake/Sleep based on temperature.
        if FLAGS.rl_wake_sleep:
            model.rl_weight = temperature * FLAGS.rl_weight

        # Run model.
        output = model(X_batch, transitions_batch, y_batch,
                       use_internal_parser=FLAGS.use_internal_parser,
                       validate_transitions=FLAGS.validate_transitions
                       )

        # Normalize output.
        logits = F.log_softmax(output)

        # Calculate class accuracy.
        target = torch.from_numpy(y_batch).long()
        pred = logits.data.max(1)[
            1].cpu()  # get the index of the max log-probability
        class_acc = pred.eq(target).sum() / float(target.size(0))

        # Calculate class loss.
        xent_loss = nn.NLLLoss()(
            logits, to_gpu(Variable(target, volatile=False)))

        # Optionally calculate transition loss.
        transition_loss = model.transition_loss if hasattr(
            model, 'transition_loss') else None

        # Extract L2 Cost
        l2_loss = get_l2_loss(
            model, FLAGS.l2_lambda) if FLAGS.use_l2_loss else None

        # Accumulate Total Loss Variable
        total_loss = 0.0
        total_loss += xent_loss
        if l2_loss is not None:
            total_loss += l2_loss
        if transition_loss is not None and model.optimize_transition_loss:
            total_loss += transition_loss
        aux_loss = auxiliary_loss(model)
        total_loss += aux_loss

        # Backward pass.
        total_loss.backward()

        # Hard Gradient Clipping
        clip = FLAGS.clipping_max_value
        for p in model.parameters():
            if p.requires_grad:
                p.grad.data.clamp_(min=-clip, max=clip)

        # Learning Rate Decay
        if FLAGS.actively_decay_learning_rate:
            optimizer.lr = FLAGS.learning_rate * \
                (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0))

        # Gradient descent step.
        optimizer.step()

        end = time.time()

        total_time = end - start

        train_accumulate(model, data_manager, A, batch)
        A.add('class_acc', class_acc)
        A.add('total_tokens', total_tokens)
        A.add('total_time', total_time)

        train_rl_accumulate(model, data_manager, A, batch)

        if step % FLAGS.statistics_interval_steps == 0 \
                or step % FLAGS.metrics_interval_steps == 0:
            if step % FLAGS.statistics_interval_steps == 0:
                progress_bar.step(i=FLAGS.statistics_interval_steps,
                                  total=FLAGS.statistics_interval_steps)
                progress_bar.finish()

            A.add('xent_cost', xent_loss.data[0])
            A.add('l2_cost', l2_loss.data[0])
            stats(model, optimizer, A, step, log_entry)

        if step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0:
            should_log = True
            model.train()
            model(X_batch, transitions_batch, y_batch,
                  use_internal_parser=FLAGS.use_internal_parser,
                  validate_transitions=FLAGS.validate_transitions
                  )
            tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example(
            )

            model.eval()
            model(X_batch, transitions_batch, y_batch,
                  use_internal_parser=FLAGS.use_internal_parser,
                  validate_transitions=FLAGS.validate_transitions
                  )
            ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example(
            )

            if model.use_sentence_pair and len(transitions_batch.shape) == 3:
                transitions_batch = np.concatenate([
                    transitions_batch[:, :, 0], transitions_batch[:, :, 1]], axis=0)

            # This could be done prior to running the batch for a tiny speed
            # boost.
            t_idxs = range(FLAGS.num_samples)
            random.shuffle(t_idxs)
            t_idxs = sorted(t_idxs[:FLAGS.num_samples])
            for t_idx in t_idxs:
                log = log_entry.rl_sampling.add()
                gold = transitions_batch[t_idx]
                pred_tr = tr_transitions_per_example[t_idx]
                pred_ev = ev_transitions_per_example[t_idx]
                strength_tr = sparks(
                    [1] + tr_strength[t_idx].tolist(), dec_str)
                strength_ev = sparks(
                    [1] + ev_strength[t_idx].tolist(), dec_str)
                _, crossing = evalb.crossing(gold, pred)

                log.t_idx = t_idx
                log.crossing = crossing
                log.gold_lb = "".join(map(str, gold))
                log.pred_tr = "".join(map(str, pred_tr))
                log.pred_ev = "".join(map(str, pred_ev))
                log.strg_tr = strength_tr[1:].encode('utf-8')
                log.strg_ev = strength_ev[1:].encode('utf-8')

        if step > 0 and step % FLAGS.eval_interval_steps == 0:
            should_log = True
            for index, eval_set in enumerate(eval_iterators):
                acc, tacc = evaluate(
                    FLAGS, model, data_manager, eval_set, log_entry, step)
                if FLAGS.ckpt_on_best_dev_error and index == 0 and (
                        1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step:
                    best_dev_error = 1 - acc
                    logger.Log(
                        "Checkpointing with new best dev accuracy of %f" % acc)
                    trainer.save(best_checkpoint_path, step, best_dev_error)
            progress_bar.reset()

        if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0:
            should_log = True
            logger.Log("Checkpointing.")
            trainer.save(standard_checkpoint_path, step, best_dev_error)

        log_level = afs_safe_logger.ProtoLogger.INFO
        if not should_log and step % FLAGS.metrics_interval_steps == 0:
            # Log to file, but not to stderr.
            should_log = True
            log_level = afs_safe_logger.ProtoLogger.DEBUG

        if should_log:
            logger.LogEntry(log_entry, level=log_level)

        progress_bar.step(i=step % FLAGS.statistics_interval_steps,
                          total=FLAGS.statistics_interval_steps)
Example #5
0
def train_loop(FLAGS, model, trainer, training_data_iter, eval_iterators,
               logger, vocabulary, target_vocabulary):
    # Accumulate useful statistics.
    A = Accumulator(maxlen=FLAGS.deque_length)

    # Train.
    logger.Log("Training.")

    # New Training Loop
    progress_bar = SimpleProgressBar(msg="Training",
                                     bar_length=60,
                                     enabled=FLAGS.show_progress_bar)
    progress_bar.step(i=0, total=FLAGS.statistics_interval_steps)
    rl_only = False
    log_entry = pb.SpinnEntry()
    for _ in range(trainer.step, FLAGS.training_steps):
        if FLAGS.rl_alternate and trainer.step % 1000 == 0 and trainer.step > 0:
            rl_only = not rl_only
            if rl_only:
                logger.Log('Switching training mode: RL only.')
            else:
                logger.Log('Switching training mode: MT only.')
        if (trainer.step -
                trainer.best_dev_step) > FLAGS.early_stopping_steps_to_wait:
            logger.Log('No improvement after ' +
                       str(FLAGS.early_stopping_steps_to_wait) +
                       ' steps. Stopping training.')
            break

        model.train()
        log_entry.Clear()
        log_entry.step = trainer.step
        should_log = False

        start = time.time()

        batch = get_batch(next(training_data_iter))
        X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch

        total_tokens = sum([(nt + 1) / 2
                            for nt in num_transitions_batch.reshape(-1)])

        # Reset cached gradients.
        trainer.optimizer_zero_grad()

        temperature = math.sin(
            math.pi / 2 +
            trainer.step / float(FLAGS.rl_confidence_interval) * 2 * math.pi)
        temperature = (temperature + 1) / 2

        # Confidence Penalty for Transition Predictions.
        if FLAGS.rl_confidence_penalty:
            epsilon = FLAGS.rl_epsilon * \
                math.exp(-trainer.step / float(FLAGS.rl_epsilon_decay))
            temp = 1 + \
                (temperature - .5) * FLAGS.rl_confidence_penalty * epsilon
            model.spinn.temperature = max(1e-3, temp)

        # Soft Wake/Sleep based on temperature.
        if FLAGS.rl_wake_sleep:
            model.rl_weight = temperature * FLAGS.rl_weight

        # Run model.
        output, trg, attention, mask = model(
            X_batch,
            transitions_batch,
            y_batch,
            use_internal_parser=FLAGS.use_internal_parser,
            validate_transitions=FLAGS.validate_transitions,
            example_lengths=num_transitions_batch)

        criterion = nn.NLLLoss()
        batch_size = len(y_batch)
        trg_seq_len = trg.shape[0]
        mt_loss = 0.0
        if rl_only == False:
            num_classes = output.shape[-1]
            mask = to_gpu(mask)

            for i in range(trg_seq_len):
                mt_loss += criterion(
                    output[i, :].index_select(0, mask[i].nonzero().squeeze(1)),
                    trg[i].index_select(0,
                                        mask[i].nonzero().squeeze(1)).view(-1))
        elif FLAGS.rl_alternate:
            model.policy_loss = 0.0
            model.value_loss = 0.0
        # Optionally calculate transition loss.
        mt_loss = mt_loss / trg_seq_len
        model.transition_loss = model.encoder.transition_loss if hasattr(
            model.encoder, 'transition_loss') else None
        transition_loss = model.transition_loss if hasattr(
            model, 'transition_loss') else None
        model.mt_loss = mt_loss

        # Accumulate Total Loss Variable
        total_loss = 0.0
        total_loss += mt_loss
        if transition_loss is not None and model.encoder.optimize_transition_loss:
            model.optimize_transition_loss = model.encoder.optimize_transition_loss
            total_loss += transition_loss
        aux_loss = auxiliary_loss(model)
        total_loss += aux_loss[0]

        # Backward pass.
        total_loss.backward()

        # Hard Gradient Clipping
        nn.utils.clip_grad_norm_([
            param for name, param in model.named_parameters()
            if name not in ["embed.embed.weight"]
        ], FLAGS.clipping_max_value)

        # Gradient descent step.
        trainer.optimizer_step()
        bb = list(model.parameters())[-1].clone()
        end = time.time()

        total_time = end - start

        train_accumulate(model, A, batch)
        A.add('total_tokens', total_tokens)
        A.add('total_time', total_time)
        A.add('mt_loss', float(mt_loss))

        train_rl_accumulate(model, A, batch)

        if trainer.step % FLAGS.statistics_interval_steps == 0:
            stats(model, trainer, A, log_entry)
            should_log = True
            progress_bar.finish()

        if trainer.step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0:
            should_log = True
            model.train()
            model(X_batch,
                  transitions_batch,
                  y_batch,
                  use_internal_parser=FLAGS.use_internal_parser,
                  validate_transitions=FLAGS.validate_transitions,
                  example_lengths=num_transitions_batch)
            tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example(
            )

            model.eval()
            model(X_batch,
                  transitions_batch,
                  y_batch,
                  use_internal_parser=FLAGS.use_internal_parser,
                  validate_transitions=FLAGS.validate_transitions,
                  example_lengths=num_transitions_batch)
            ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example(
            )

            if model.use_sentence_pair and len(transitions_batch.shape) == 3:
                transitions_batch = np.concatenate(
                    [transitions_batch[:, :, 0], transitions_batch[:, :, 1]],
                    axis=0)

            # This could be done prior to running the batch for a tiny speed
            # boost.
            t_idxs = list(range(FLAGS.num_samples))
            random.shuffle(t_idxs)
            t_idxs = sorted(t_idxs[:FLAGS.num_samples])
            for t_idx in t_idxs:
                log = log_entry.rl_sampling.add()
                gold = transitions_batch[t_idx]
                pred_tr = tr_transitions_per_example[t_idx]
                pred_ev = ev_transitions_per_example[t_idx]
                strength_tr = sparks([1] + tr_strength[t_idx].tolist(),
                                     dec_str)
                strength_ev = sparks([1] + ev_strength[t_idx].tolist(),
                                     dec_str)
                _, crossing = evalb.crossing(gold, pred_ev)
                log.t_idx = t_idx
                log.crossing = crossing
                log.gold_lb = "".join(map(str, gold))
                log.pred_tr = "".join(map(str, pred_tr))
                log.pred_ev = "".join(map(str, pred_ev))
                log.strg_tr = strength_tr[1:]
                log.strg_ev = strength_ev[1:]

        if trainer.step > 0 and trainer.step % FLAGS.eval_interval_steps == 0:
            should_log = True
            for index, eval_set in enumerate(eval_iterators):
                acc, _ = evaluate(
                    FLAGS,
                    model,
                    eval_set,
                    log_entry,
                    logger,
                    trainer,
                    show_sample=(trainer.step %
                                 FLAGS.sample_interval_steps == 0),
                    vocabulary=vocabulary,
                    eval_index=index,
                    target_vocabulary=target_vocabulary)
                if index == 0:
                    trainer.new_dev_accuracy(acc)
            progress_bar.reset()

        if trainer.step > FLAGS.ckpt_step and trainer.step % FLAGS.ckpt_interval_steps == 0:
            should_log = True
            trainer.checkpoint()

        if should_log:
            logger.LogEntry(log_entry)

        progress_bar.step(i=(trainer.step % FLAGS.statistics_interval_steps) +
                          1,
                          total=FLAGS.statistics_interval_steps)
Example #6
0
    def run_pyramid(self, x, show_sample=False):
        batch_size, seq_len, model_dim = x.data.size()

        all_state_pairs = []
        all_state_pairs.append(torch.chunk(x, seq_len, 1))

        # Temp fix:
        show_sample = False

        if show_sample:
            print

        for layer in range(seq_len - 1, 0, -1):
            composition_results = []
            selection_logits_list = []

            for position in range(layer):
                left = torch.squeeze(all_state_pairs[-1][position])
                right = torch.squeeze(all_state_pairs[-1][position + 1])
                composition_results.append(self.composition_fn(left, right))

            if self.gated:
                for position in range(layer):
                    selection_logits_list.append(
                        self.selection_fn(composition_results[position]))

                selection_logits = torch.cat(selection_logits_list, 1)

                if show_sample:
                    selection_probs = F.softmax(selection_logits)
                    print sparks(
                        np.transpose(
                            selection_probs[0, :].data.cpu().numpy()).tolist())

                if self.training and self.selection_keep_rate is not None:
                    noise = torch.bernoulli(
                        (to_gpu(torch.ones(1, 1)) * self.selection_keep_rate
                         ).expand_as(selection_logits)) * -1000.
                    selection_logits += Variable(noise)
                selection_probs = F.softmax(selection_logits)

                layer_state_pairs = []
                for position in range(layer):
                    if position < (layer - 1):
                        copy_left = torch.sum(
                            selection_probs[:, position + 1:], 1)
                    else:
                        copy_left = to_gpu(Variable(torch.zeros(1, 1)))
                    if position > 0:
                        copy_right = torch.sum(selection_probs[:, :position],
                                               1)
                    else:
                        copy_right = to_gpu(Variable(torch.zeros(1, 1)))
                    select = selection_probs[:, position]

                    left = torch.squeeze(all_state_pairs[-1][position])
                    right = torch.squeeze(all_state_pairs[-1][position + 1])
                    composition_result = composition_results[position]
                    new_state_pair = copy_left.expand_as(left) * left \
                        + copy_right.expand_as(right) * right \
                        + select.unsqueeze(1).expand_as(composition_result) * composition_result
                    layer_state_pairs.append(new_state_pair)
            else:
                layer_state_pairs = composition_results

            all_state_pairs.append(layer_state_pairs)

        return all_state_pairs[-1][-1]