예제 #1
0
def run(only_forward=False):
    logger = afs_safe_logger.ProtoLogger(
        log_path(FLAGS), print_formatter=create_log_formatter(
            True, False), write_proto=FLAGS.write_proto_to_log)
    header = pb.SpinnHeader()

    data_manager = get_data_manager(FLAGS.data_type)

    logger.Log("Flag Values:\n" +
               json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True))

    # Get Data and Embeddings
    vocabulary, initial_embeddings, training_data_iter, eval_iterators, training_data_length = \
        load_data_and_embeddings(FLAGS, data_manager, logger,
                                 FLAGS.training_data_path, FLAGS.eval_data_path)
    '''
    f = open("./vocab.txt", "w")
    for k in vocabulary:
        f.write("{0}\t{1}\n".format(k, vocabulary[k]))
    f.close()
    '''
    # Build model.
    vocab_size = len(vocabulary)
    num_classes = len(set(data_manager.LABEL_MAP.values()))

    model = init_model(
        FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager, header)
    epoch_length = int(training_data_length / FLAGS.batch_size)
    trainer = ModelTrainer(model, logger, epoch_length, vocabulary, FLAGS)    

    header.start_step = trainer.step
    header.start_time = int(time.time())

    # Do an evaluation-only run.
    logger.LogHeader(header)  # Start log_entry logging.
    if only_forward:
        log_entry = pb.SpinnEntry()
        for index, eval_set in enumerate(eval_iterators):
            log_entry.Clear()
            evaluate(
                FLAGS,
                model,
                eval_set,
                log_entry,
                logger,
                trainer,
                vocabulary,
                show_sample=True,
                eval_index=index)
            print(log_entry)
            logger.LogEntry(log_entry)
    else:
        train_loop(
            FLAGS,
            model,
            trainer,
            training_data_iter,
            eval_iterators,
            logger,
            vocabulary)
예제 #2
0
def train_loop(FLAGS, model, trainer, training_data_iter, eval_iterators,
               logger):
    # Accumulate useful statistics.
    A = Accumulator(maxlen=FLAGS.deque_length)

    # Train.
    logger.Log("Training.")

    # New Training Loop
    progress_bar = SimpleProgressBar(msg="Training",
                                     bar_length=60,
                                     enabled=FLAGS.show_progress_bar)
    progress_bar.step(i=0, total=FLAGS.statistics_interval_steps)

    log_entry = pb.SpinnEntry()
    for _ in range(trainer.step, FLAGS.training_steps):
        if (trainer.step -
                trainer.best_dev_step) > FLAGS.early_stopping_steps_to_wait:
            logger.Log('No improvement after ' +
                       str(FLAGS.early_stopping_steps_to_wait) +
                       ' steps. Stopping training.')
            break

        model.train()
        log_entry.Clear()
        log_entry.step = trainer.step
        should_log = False

        start = time.time()

        batch = get_batch(next(training_data_iter))
        X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch

        total_tokens = sum([(nt + 1) / 2
                            for nt in num_transitions_batch.reshape(-1)])

        # Reset cached gradients.
        trainer.optimizer_zero_grad()

        temperature = math.sin(
            math.pi / 2 +
            trainer.step / float(FLAGS.rl_confidence_interval) * 2 * math.pi)
        temperature = (temperature + 1) / 2

        # Confidence Penalty for Transition Predictions.
        if FLAGS.rl_confidence_penalty:
            epsilon = FLAGS.rl_epsilon * \
                math.exp(-trainer.step / float(FLAGS.rl_epsilon_decay))
            temp = 1 + \
                (temperature - .5) * FLAGS.rl_confidence_penalty * epsilon
            model.spinn.temperature = max(1e-3, temp)

        # Soft Wake/Sleep based on temperature.
        if FLAGS.rl_wake_sleep:
            model.rl_weight = temperature * FLAGS.rl_weight

        # Run model.
        output = model(X_batch,
                       transitions_batch,
                       y_batch,
                       use_internal_parser=FLAGS.use_internal_parser,
                       validate_transitions=FLAGS.validate_transitions)

        # Calculate class accuracy.
        target = torch.from_numpy(y_batch).long()

        # get the index of the max log-probability
        pred = output.data.max(1, keepdim=False)[1].cpu()

        class_acc = pred.eq(target).sum() / float(target.size(0))

        # Calculate class loss.
        xent_loss = nn.CrossEntropyLoss()(output,
                                          to_gpu(
                                              Variable(target,
                                                       volatile=False)))

        # Optionally calculate transition loss.
        transition_loss = model.transition_loss if hasattr(
            model, 'transition_loss') else None

        # Accumulate Total Loss Variable
        total_loss = 0.0
        total_loss += xent_loss
        if transition_loss is not None and model.optimize_transition_loss:
            total_loss += transition_loss
        aux_loss = auxiliary_loss(model)
        total_loss += aux_loss

        # Backward pass.
        total_loss.backward()

        # Hard Gradient Clipping
        nn.utils.clip_grad_norm([
            param for name, param in model.named_parameters()
            if name not in ["embed.embed.weight"]
        ], FLAGS.clipping_max_value)

        # Gradient descent step.
        trainer.optimizer_step()

        end = time.time()

        total_time = end - start

        train_accumulate(model, A, batch)
        A.add('class_acc', class_acc)
        A.add('total_tokens', total_tokens)
        A.add('total_time', total_time)

        train_rl_accumulate(model, A, batch)

        if trainer.step % FLAGS.statistics_interval_steps == 0:
            progress_bar.step(i=FLAGS.statistics_interval_steps,
                              total=FLAGS.statistics_interval_steps)
            progress_bar.finish()

            A.add('xent_cost', xent_loss.data[0])
            stats(model, trainer, A, log_entry)
            should_log = True

        if trainer.step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0:
            should_log = True
            model.train()
            model(X_batch,
                  transitions_batch,
                  y_batch,
                  use_internal_parser=FLAGS.use_internal_parser,
                  validate_transitions=FLAGS.validate_transitions)
            tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example(
            )

            model.eval()
            model(X_batch,
                  transitions_batch,
                  y_batch,
                  use_internal_parser=FLAGS.use_internal_parser,
                  validate_transitions=FLAGS.validate_transitions)
            ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example(
            )

            if model.use_sentence_pair and len(transitions_batch.shape) == 3:
                transitions_batch = np.concatenate(
                    [transitions_batch[:, :, 0], transitions_batch[:, :, 1]],
                    axis=0)

            # This could be done prior to running the batch for a tiny speed
            # boost.
            t_idxs = list(range(FLAGS.num_samples))
            random.shuffle(t_idxs)
            t_idxs = sorted(t_idxs[:FLAGS.num_samples])
            for t_idx in t_idxs:
                log = log_entry.rl_sampling.add()
                gold = transitions_batch[t_idx]
                pred_tr = tr_transitions_per_example[t_idx]
                pred_ev = ev_transitions_per_example[t_idx]
                strength_tr = sparks([1] + tr_strength[t_idx].tolist(),
                                     dec_str)
                strength_ev = sparks([1] + ev_strength[t_idx].tolist(),
                                     dec_str)
                _, crossing = evalb.crossing(gold, pred)
                log.t_idx = t_idx
                log.crossing = crossing
                log.gold_lb = "".join(map(str, gold))
                log.pred_tr = "".join(map(str, pred_tr))
                log.pred_ev = "".join(map(str, pred_ev))
                log.strg_tr = strength_tr[1:]
                log.strg_ev = strength_ev[1:]

        if trainer.step > 0 and trainer.step % FLAGS.eval_interval_steps == 0:
            should_log = True
            for index, eval_set in enumerate(eval_iterators):
                acc, _ = evaluate(FLAGS,
                                  model,
                                  eval_set,
                                  log_entry,
                                  logger,
                                  trainer,
                                  eval_index=index)
                if index == 0:
                    trainer.new_dev_accuracy(acc)

            progress_bar.reset()

        if trainer.step > FLAGS.ckpt_step and trainer.step % FLAGS.ckpt_interval_steps == 0:
            should_log = True
            trainer.checkpoint()

        if should_log:
            logger.LogEntry(log_entry)

        progress_bar.step(i=(trainer.step % FLAGS.statistics_interval_steps) +
                          1,
                          total=FLAGS.statistics_interval_steps)
예제 #3
0
def run(only_forward=False):
    logger = afs_safe_logger.ProtoLogger(log_path(FLAGS),
                                         print_formatter=create_log_formatter(True, False),
                                         write_proto=FLAGS.write_proto_to_log)
    header = pb.SpinnHeader()

    data_manager = get_data_manager(FLAGS.data_type)

    logger.Log("Flag Values:\n" +
               json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True))
    flags_dict = sorted(list(FLAGS.FlagValuesDict().items()))
    for k, v in flags_dict:
        flag = header.flags.add()
        flag.key = k
        flag.value = str(v)

    # Get Data and Embeddings
    vocabulary, initial_embeddings, training_data_iter, eval_iterators = \
        load_data_and_embeddings(FLAGS, data_manager, logger,
                                 FLAGS.training_data_path, FLAGS.eval_data_path)

    # Build model.
    vocab_size = len(vocabulary)
    num_classes = len(set(data_manager.LABEL_MAP.values()))

    model, optimizer, trainer = init_model(
        FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager, header)

    standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name)
    best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True)

    # Load checkpoint if available.
    if FLAGS.load_best and os.path.isfile(best_checkpoint_path):
        logger.Log("Found best checkpoint, restoring.")
        step, best_dev_error = trainer.load(best_checkpoint_path)
        logger.Log(
            "Resuming at step: {} with best dev accuracy: {}".format(
                step, 1. - best_dev_error))
    elif os.path.isfile(standard_checkpoint_path):
        logger.Log("Found checkpoint, restoring.")
        step, best_dev_error = trainer.load(standard_checkpoint_path)
        logger.Log(
            "Resuming at step: {} with best dev accuracy: {}".format(
                step, 1. - best_dev_error))
    else:
        assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint."
        step = 0
        best_dev_error = 1.0
    header.start_step = step
    header.start_time = int(time.time())

    # GPU support.
    the_gpu.gpu = FLAGS.gpu
    if FLAGS.gpu >= 0:
        model.cuda()
    else:
        model.cpu()
    recursively_set_device(optimizer.state_dict(), FLAGS.gpu)

    # Debug
    def set_debug(self):
        self.debug = FLAGS.debug
    model.apply(set_debug)

    # Do an evaluation-only run.
    logger.LogHeader(header)  # Start log_entry logging.
    if only_forward:
        log_entry = pb.SpinnEntry()
        for index, eval_set in enumerate(eval_iterators):
            log_entry.Clear()
            evaluate(FLAGS, model, data_manager, eval_set, log_entry, logger, step, vocabulary, show_sample=True, eval_index=index)
            print(log_entry)
            logger.LogEntry(log_entry)
    else:
        train_loop(FLAGS, data_manager, model, optimizer, trainer,
                   training_data_iter, eval_iterators, logger, step, best_dev_error, vocabulary)
예제 #4
0
def train_loop(FLAGS, data_manager, model, optimizer, trainer,
               training_data_iter, eval_iterators, logger, step, best_dev_error, vocabulary):
    # Accumulate useful statistics.
    A = Accumulator(maxlen=FLAGS.deque_length)

    # Checkpoint paths.
    standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name)
    best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True)

    # Build log format strings.
    model.train()
    X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = get_batch(
        training_data_iter.next())
    model(X_batch, transitions_batch, y_batch,
          use_internal_parser=FLAGS.use_internal_parser,
          validate_transitions=FLAGS.validate_transitions,
          pyramid_temperature_multiplier=1.0,
          example_lengths=num_transitions_batch
          )

    # Train.
    logger.Log("Training.")

    # New Training Loop
    progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar)
    progress_bar.step(i=0, total=FLAGS.statistics_interval_steps)

    log_entry = pb.SpinnEntry()
    for step in range(step, FLAGS.training_steps):
        model.train()
        log_entry.Clear()
        log_entry.step = step
        should_log = False

        start = time.time()

        batch = get_batch(training_data_iter.next())
        X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch

        total_tokens = sum([(nt + 1) / 2 for nt in num_transitions_batch.reshape(-1)])

        # Reset cached gradients.
        optimizer.zero_grad()

        if FLAGS.model_type in ["Pyramid", "ChoiPyramid"]:
            pyramid_temperature_multiplier = FLAGS.pyramid_temperature_decay_per_10k_steps ** (
                step / 10000.0)
            if FLAGS.pyramid_temperature_cycle_length > 0.0:
                min_temp = 1e-5
                pyramid_temperature_multiplier *= (math.cos((step) /
                                                            FLAGS.pyramid_temperature_cycle_length) + 1 + min_temp) / 2
        else:
            pyramid_temperature_multiplier = None

        # Run model.
        output = model(X_batch, transitions_batch, y_batch,
                       use_internal_parser=FLAGS.use_internal_parser,
                       validate_transitions=FLAGS.validate_transitions,
                       pyramid_temperature_multiplier=pyramid_temperature_multiplier,
                       example_lengths=num_transitions_batch
                       )

        # Normalize output.
        logits = F.log_softmax(output)

        # Calculate class accuracy.
        target = torch.from_numpy(y_batch).long()

        # get the index of the max log-probability
        pred = logits.data.max(1, keepdim=False)[1].cpu()

        class_acc = pred.eq(target).sum() / float(target.size(0))

        # Calculate class loss.
        xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False)))

        # Optionally calculate transition loss.
        transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None

        # Extract L2 Cost
        l2_loss = get_l2_loss(model, FLAGS.l2_lambda) if FLAGS.use_l2_loss else None

        # Accumulate Total Loss Variable
        total_loss = 0.0
        total_loss += xent_loss
        if l2_loss is not None:
            total_loss += l2_loss
        if transition_loss is not None and model.optimize_transition_loss:
            total_loss += transition_loss
        aux_loss = auxiliary_loss(model)
        total_loss += aux_loss
        # Backward pass.
        total_loss.backward()

        # Hard Gradient Clipping
        clip = FLAGS.clipping_max_value
        for p in model.parameters():
            if p.requires_grad:
                p.grad.data.clamp_(min=-clip, max=clip)

        # Learning Rate Decay
        if FLAGS.actively_decay_learning_rate:
            optimizer.lr = FLAGS.learning_rate * \
                (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0))

        # Gradient descent step.
        optimizer.step()

        end = time.time()

        total_time = end - start

        train_accumulate(model, data_manager, A, batch)
        A.add('class_acc', class_acc)
        A.add('total_tokens', total_tokens)
        A.add('total_time', total_time)

        if step % FLAGS.statistics_interval_steps == 0:
            A.add('xent_cost', xent_loss.data[0])
            A.add('l2_cost', l2_loss.data[0])
            stats(model, optimizer, A, step, log_entry)
            should_log = True
            progress_bar.finish()

        if step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0:
            should_log = True
            model.train()
            model(X_batch, transitions_batch, y_batch,
                  use_internal_parser=FLAGS.use_internal_parser,
                  validate_transitions=FLAGS.validate_transitions,
                  pyramid_temperature_multiplier=pyramid_temperature_multiplier,
                  example_lengths=num_transitions_batch
                  )
            tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example()

            model.eval()
            model(X_batch, transitions_batch, y_batch,
                  use_internal_parser=FLAGS.use_internal_parser,
                  validate_transitions=FLAGS.validate_transitions,
                  pyramid_temperature_multiplier=pyramid_temperature_multiplier,
                  example_lengths=num_transitions_batch
                  )
            ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example()

            if model.use_sentence_pair and len(transitions_batch.shape) == 3:
                transitions_batch = np.concatenate([
                    transitions_batch[:, :, 0], transitions_batch[:, :, 1]], axis=0)

            # This could be done prior to running the batch for a tiny speed boost.
            t_idxs = range(FLAGS.num_samples)
            random.shuffle(t_idxs)
            t_idxs = sorted(t_idxs[:FLAGS.num_samples])
            for t_idx in t_idxs:
                log = log_entry.rl_sampling.add()
                gold = transitions_batch[t_idx]
                pred_tr = tr_transitions_per_example[t_idx]
                pred_ev = ev_transitions_per_example[t_idx]
                strength_tr = sparks([1] + tr_strength[t_idx].tolist(), dec_str)
                strength_ev = sparks([1] + ev_strength[t_idx].tolist(), dec_str)
                _, crossing = evalb.crossing(gold, pred_ev)
                log.t_idx = t_idx
                log.crossing = crossing
                log.gold_lb = "".join(map(str, gold))
                log.pred_tr = "".join(map(str, pred_tr))
                log.pred_ev = "".join(map(str, pred_ev))
                log.strg_tr = strength_tr[1:].encode('utf-8')
                log.strg_ev = strength_ev[1:].encode('utf-8')

        if step > 0 and step % FLAGS.eval_interval_steps == 0:
            should_log = True
            for index, eval_set in enumerate(eval_iterators):
                acc, tacc = evaluate(FLAGS, model, data_manager, eval_set, log_entry, logger, step,
                                     show_sample=(
                                         step %
                                         FLAGS.sample_interval_steps == 0), vocabulary=vocabulary, eval_index=index)
                if FLAGS.ckpt_on_best_dev_error and index == 0 and (
                        1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step:
                    best_dev_error = 1 - acc
                    logger.Log("Checkpointing with new best dev accuracy of %f" % acc)  # TODO: This mixes information across dev sets. Fix.
                    trainer.save(best_checkpoint_path, step, best_dev_error)
            progress_bar.reset()

        if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0:
            should_log = True
            logger.Log("Checkpointing.")
            trainer.save(standard_checkpoint_path, step, best_dev_error)

        if should_log:
            logger.LogEntry(log_entry)

        progress_bar.step(i=(step % FLAGS.statistics_interval_steps) + 1,
                          total=FLAGS.statistics_interval_steps)
예제 #5
0
def run(only_forward=False):
    logger = afs_safe_logger.ProtoLogger(log_path(FLAGS),
                                         print_formatter=create_log_formatter(
                                             True, False),
                                         write_proto=FLAGS.write_proto_to_log)
    header = pb.SpinnHeader()

    data_manager = get_data_manager(FLAGS.data_type)

    logger.Log("Flag Values:\n" +
               json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True))

    flags_dict = sorted(list(FLAGS.FlagValuesDict().items()))
    for k, v in flags_dict:
        flag = header.flags.add()
        flag.key = k
        flag.value = str(v)

    if not FLAGS.expanded_eval_only_mode:
        # Get Data and Embeddings for training
        preprocessed_data_path = os.path.join(
            FLAGS.ckpt_path,
            'allnli_preprocessed_data_prpn-{}_train-{:d}-valid-{:d}_batch-{:d}_dist-{}.dat'
            .format(FLAGS.prpn_name, FLAGS.seq_length, FLAGS.eval_seq_length,
                    FLAGS.batch_size, FLAGS.tree_joint))
        if os.path.isfile(preprocessed_data_path):
            print 'Reading dumped preprocessed data'
            vocabulary, initial_embeddings, picked_train_iter_pack, eval_iterators = cPickle.load(
                open(preprocessed_data_path, "rb"))
        else:
            vocabulary, initial_embeddings, picked_train_iter_pack, eval_iterators = \
                load_data_and_embeddings(FLAGS, data_manager, logger,
                                        FLAGS.training_data_path, FLAGS.eval_data_path,
                                        )
            print 'Dumping data'
            cPickle.dump(
                (vocabulary, initial_embeddings, picked_train_iter_pack,
                 list(eval_iterators)), open(preprocessed_data_path, 'wb'))
            print 'Dumping done'
        train_sources, train_batches = picked_train_iter_pack

        def unpack_pickled_train_iter(sources, batches):
            '''
            '''
            num_batches = len(batches)
            idx = -1
            order = range(num_batches)
            random.shuffle(order)

            while True:
                idx += 1
                if idx >= num_batches:
                    # Start another epoch.
                    num_batches = len(batches)
                    idx = 0
                    order = range(num_batches)
                    random.shuffle(order)
                batch_indices = batches[order[idx]]
                # yield tuple(source[batch_indices] for source in sources if source is not None)
                yield tuple(
                    source[batch_indices] if source is not None else None
                    for source in
                    sources)  # for gumbel tree model, the dist will be None

        training_data_iter = unpack_pickled_train_iter(train_sources,
                                                       train_batches)

    else:
        # Get Data and Embeddings for test only
        vocabulary, initial_embeddings, training_data_iter, eval_iterators = \
            load_data_and_embeddings(FLAGS, data_manager, logger,
                                    FLAGS.training_data_path, FLAGS.eval_data_path,
                                    )

    # Build model.
    vocab_size = len(vocabulary)
    num_classes = len(set(data_manager.LABEL_MAP.values()))

    model, optimizer, trainer = init_model(FLAGS, logger, initial_embeddings,
                                           vocab_size, num_classes,
                                           data_manager, header)

    standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path,
                                                   FLAGS.experiment_name)
    best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path,
                                               FLAGS.experiment_name,
                                               best=True)
    best_parsing_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path,
                                                       FLAGS.experiment_name,
                                                       best=True,
                                                       parsing=True)
    sl_checkpoint_path = get_checkpoint_path_for_sl(FLAGS.ckpt_path,
                                                    FLAGS.experiment_name,
                                                    step=FLAGS.load_sl_step)

    # Load checkpoint if available.
    if FLAGS.customize_ckpt:
        customize_ckpt_path = FLAGS.customize_ckpt_path
        logger.Log("Found pretrained customized checkpoint, restoring.")
        step, best_dev_error, best_dev_f1_error = trainer.load(
            customize_ckpt_path,
            cpu=FLAGS.gpu < 0,
            continue_train=FLAGS.continue_train)
        best_dev_step = 0
    elif FLAGS.load_best:
        if FLAGS.test_type == 'classification' and os.path.isfile(
                best_checkpoint_path):
            logger.Log("Found best classification checkpoint, restoring.")
            step, best_dev_error, dev_f1_error = trainer.load(
                best_checkpoint_path, cpu=FLAGS.gpu < 0)
            logger.Log(
                "Resuming at step: {} best dev accuracy: {} with dev f1: {}".
                format(step, 1. - best_dev_error, 1. - dev_f1_error))
            step = 0
            best_dev_step = 0
            best_dev_f1_error = dev_f1_error
        elif os.path.isfile(best_parsing_checkpoint_path):
            logger.Log("Found best parsing checkpoint, restoring.")
            step, dev_error, best_dev_f1_error = trainer.load(
                best_parsing_checkpoint_path, cpu=FLAGS.gpu < 0)
            logger.Log(
                "Resuming at step: {} best f1: {} with dev accuracy: {}".
                format(step, 1. - best_dev_f1_error, 1. - dev_error))
        else:
            raise ValueError('Can\'t find the best checkpoint.')
    elif FLAGS.load_sl:
        logger.Log(
            "Found pretrained SL checkpoint at step {:d}, restoring.".format(
                FLAGS.load_sl_step))
        step, best_dev_error, best_dev_f1_error = trainer.load(
            standard_checkpoint_path,
            cpu=FLAGS.gpu < 0,
            continue_train=FLAGS.continue_train)
        best_dev_step = 0
    elif os.path.isfile(standard_checkpoint_path):
        logger.Log("Found checkpoint, restoring.")
        step, best_dev_error, best_dev_f1_error = trainer.load(
            standard_checkpoint_path, cpu=FLAGS.gpu < 0)
        logger.Log(
            "Resuming at step: {} previously best dev accuracy: {} and previously best f1: {}"
            .format(step, 1. - best_dev_error, 1. - best_dev_f1_error))
    else:
        assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint."
        step = 0
        best_dev_error = 1.0
        best_dev_step = 0
        best_dev_f1_error = 1.0  # for best parsing checkpoint
    header.start_step = step
    header.start_time = int(time.time())

    # # Right-branching trick.
    # DefaultUniformInitializer(model.binary_tree_lstm.comp_query.weight)
    # set temperature
    model.binary_tree_lstm.temperature_param.data = torch.Tensor([[0.2]])

    # GPU support.
    the_gpu.gpu = FLAGS.gpu
    if FLAGS.gpu >= 0:
        model.cuda()
    else:
        model.cpu()
    recursively_set_device(optimizer.state_dict(), FLAGS.gpu)

    # Debug
    def set_debug(self):
        self.debug = FLAGS.debug

    model.apply(set_debug)

    # Do an evaluation-only run.
    logger.LogHeader(header)  # Start log_entry logging.
    if only_forward:
        log_entry = pb.SpinnEntry()
        for index, eval_set in enumerate(eval_iterators):
            log_entry.Clear()
            evaluate(FLAGS,
                     model,
                     data_manager,
                     eval_set,
                     log_entry,
                     logger,
                     step,
                     vocabulary,
                     show_sample=True,
                     eval_index=index)
            print(log_entry)
            logger.LogEntry(log_entry)
    else:
        best_dev_step = 0
        train_loop(FLAGS, data_manager, model, optimizer, trainer,
                   training_data_iter, eval_iterators, logger, step,
                   best_dev_error, best_dev_step, best_dev_f1_error,
                   vocabulary)
예제 #6
0
def run(only_forward=False):
    logger = afs_safe_logger.ProtoLogger(log_path(FLAGS))
    header = pb.SpinnHeader()

    data_manager = get_data_manager(FLAGS.data_type)

    logger.Log("Flag Values:\n" +
               json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True))
    flags_dict = sorted(list(FLAGS.FlagValuesDict().items()))
    for k, v in flags_dict:
        flag = header.flags.add()
        flag.key = k
        flag.value = str(v)

    # Get Data and Embeddings
    vocabulary, initial_embeddings, training_data_iter, eval_iterators = \
        load_data_and_embeddings(FLAGS, data_manager, logger,
                                 FLAGS.training_data_path, FLAGS.eval_data_path)

    # Build model.
    vocab_size = len(vocabulary)
    num_classes = len(data_manager.LABEL_MAP)

    model, optimizer, trainer = init_model(FLAGS, logger, initial_embeddings,
                                           vocab_size, num_classes,
                                           data_manager, header)

    # Checking if experiment with petrurbation id 0 has a checkpoint
    perturbation_name = FLAGS.experiment_name + "_p" + '0'
    best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path,
                                               perturbation_name,
                                               best=True)
    standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path,
                                                   perturbation_name,
                                                   best=False)

    ckpt_names = []
    if os.path.isfile(best_checkpoint_path):
        logger.Log("Found best checkpoints, they will be restored.")
        ckpt_names = get_pert_names(best=True)
    elif os.path.isfile(standard_checkpoint_path):
        logger.Log("Found standard checkpoints, they will be restored.")
        ckpt_names = get_pert_names(best=False)
    else:
        assert not only_forward, "Can't run an eval-only run without best checkpoints. Supply best checkpoint(s)."
        true_step = 0
        best_dev_error = 1.0
        reload_ev_step = 0
    header.start_step = step
    header.start_time = int(time.time())
    header.model_label = perturbation_name

    # GPU support.
    the_gpu.gpu = FLAGS.gpu
    if FLAGS.gpu >= 0:
        model.cuda()
    else:
        model.cpu()
    recursively_set_device(optimizer.state_dict(), FLAGS.gpu)

    # Debug
    def set_debug(self):
        self.debug = FLAGS.debug

    model.apply(set_debug)

    logger.LogHeader(header)  # Start log_entry logging.

    # Do an evaluation-only run.
    if only_forward:
        assert len(
            ckpt_names
        ) == 0, "Can not run forward pass without best checkpoints supplied."
        log_entry = pb.SpinnEntry()

        restore_queue = mp.Queue()
        processes_restore = []
        while ckpt_names:
            pert_name = ckpt_names.pop()
            path = os.path.join(FLAGS.ckpt_path, pert_name)
            name = pert_name.replace('.ckpt_best', '')
            p_restore = mp.Process(target=restore,
                                   args=(logger, trainer, restore_queue, FLAGS,
                                         name, path))
            p_restore.start()
            processes_restore.append(p_restore)
        assert len(ckpt_names) == 0
        results = [restore_queue.get() for p in processes_restore]
        reload_ev_step = results[0][0]

        while all_models:
            p_checkpoint = all_models.pop()
            p_model = p_checkpoint[2]
            true_step = p_checkpoint[1]
            for index, eval_set in enumerate(eval_iterators):
                log_entry.Clear()
                evaluate(FLAGS, p_model, data_manager, eval_set, log_entry,
                         true_step, vocabulary)
                print(log_entry)
                logger.LogEntry(log_entry)

    else:
        # Restore model, i.e. perturbation spawns, from best checkpoint, if it exists, or standard checkpoint.
        # Get dev-set accuracies so we can select which models to use for the next evolution step.
        if len(ckpt_names) != 0:
            logger.Log("Restoring models from best or standard checkpoints")
            processes_restore = []
            restore_queue = mp.Queue()
            while ckpt_names:
                pert_name = ckpt_names.pop()
                path = os.path.join(FLAGS.ckpt_path, pert_name)
                name = pert_name.replace('.ckpt_best', '')
                p_restore = mp.Process(target=restore,
                                       args=(logger, trainer, restore_queue,
                                             FLAGS, name, path))
                p_restore.start()
                processes_restore.append(p_restore)
            assert len(ckpt_names) == 0
            results = [restore_queue.get() for p in processes_restore]
            reload_ev_step = results[0][0] + 1  # the next evolution step

        else:
            id_ = "B"
            chosen_models = [(reload_ev_step, true_step, id_, best_dev_error)]
            base = True  # This is the "base" model
            results = []

        for ev_step in range(reload_ev_step, FLAGS.es_steps):
            logger.Log("Evolution step: %i" % ev_step)

            # Choose root models for next generation using dev-set accuracy
            if len(results) != 0:
                base = False
                chosen_models = []
                acc_order = [
                    i[0] for i in sorted(enumerate(results),
                                         key=lambda x: x[1][3],
                                         reverse=True)
                ]
                for i in range(FLAGS.es_num_episodes):
                    id_ = acc_order[i]
                    logger.Log(
                        "Picking model %s to perturb for next evolution step."
                        % results[id_][2])
                    chosen_models.append(results[id_])

            # Flush results from previous generatrion
            results = []
            processes = []
            queue = mp.Queue()
            all_seeds, all_models = [], []
            all_steps = []
            all_dev_errs = []
            for chosen_model in chosen_models:
                perturbation_id = chosen_model[2]
                random_seed, models = generate_seeds_and_models(
                    trainer, model, perturbation_id, base=base)
                for i in range(len(models)):
                    all_seeds.append(random_seed)
                    all_steps.append(chosen_model[1])
                    all_dev_errs.append(chosen_model[3])
                all_models += models
            assert len(all_seeds) == len(all_models)
            assert len(all_steps) == len(all_seeds)

            perturbation_id = 0
            while all_models:
                perturbed_model = all_models.pop()
                true_step = all_steps.pop()
                best_dev_error = all_dev_errs.pop()
                p = mp.Process(target=rollout,
                               args=(queue, perturbed_model, FLAGS,
                                     data_manager, model, optimizer, trainer,
                                     training_data_iter, eval_iterators,
                                     logger, true_step, best_dev_error,
                                     perturbation_id, ev_step))
                p.start()
                processes.append(p)
                perturbation_id += 1
            assert len(all_models) == 0, "All models where not trained!"

            # Run processes in queue and
            results = [queue.get() for p in processes]

            # Check to ensure the correct number of models where trained and saved
            if ev_step == 0:
                assert len(results) == FLAGS.es_num_episodes
            else:
                assert len(results) == FLAGS.es_num_episodes**2
예제 #7
0
def run(only_forward=False):
    logger = afs_safe_logger.ProtoLogger(log_path(FLAGS),
                                         print_formatter=create_log_formatter(
                                             True, False),
                                         write_proto=FLAGS.write_proto_to_log)

    header = pb.SpinnHeader()

    data_manager = get_data_manager(FLAGS.data_type)

    logger.Log("Flag Values:\n" +
               json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True))
    flags_dict = sorted(list(FLAGS.FlagValuesDict().items()))
    for k, v in flags_dict:
        flag = header.flags.add()
        flag.key = k
        flag.value = str(v)

    # Get Data and Embeddings
    vocabulary, initial_embeddings, training_data_iter, eval_iterators = \
        load_data_and_embeddings(FLAGS, data_manager, logger,
                                 FLAGS.training_data_path, FLAGS.eval_data_path)

    # Build model.
    vocab_size = len(vocabulary)
    num_classes = len(data_manager.LABEL_MAP)

    model, optimizer, trainer = init_model(FLAGS, logger, initial_embeddings,
                                           vocab_size, num_classes,
                                           data_manager, header)

    # Checking if experiment with petrurbation id 0 has a checkpoint
    perturbation_name = FLAGS.experiment_name + "_p" + '0'
    best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path,
                                               perturbation_name,
                                               best=True)
    standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path,
                                                   perturbation_name,
                                                   best=False)

    ckpt_names = []
    if os.path.isfile(best_checkpoint_path):
        logger.Log("Found best checkpoints, they will be restored.")
        ckpt_names = get_pert_names(best=True)
    elif os.path.isfile(standard_checkpoint_path):
        logger.Log("Found standard checkpoints, they will be restored.")
        ckpt_names = get_pert_names(best=False)
    else:
        assert not only_forward, "Can't run an eval-only run without best checkpoints. Supply best checkpoint(s)."
        true_step = 0
        best_dev_error = 1.0
        best_dev_step = 0
        reload_ev_step = 0

    if FLAGS.mirror:
        true_num_episodes = FLAGS.es_num_episodes * 2
    else:
        true_num_episodes = FLAGS.es_num_episodes

    # GPU support.
    the_gpu.gpu = FLAGS.gpu
    if FLAGS.gpu >= 0:
        model.cuda()
    else:
        model.cpu()
    recursively_set_device(optimizer.state_dict(), FLAGS.gpu)

    # Debug
    def set_debug(self):
        self.debug = FLAGS.debug

    model.apply(set_debug)

    logger.LogHeader(header)  # Start log_entry logging.

    # Do an evaluation-only run.
    if only_forward:
        assert len(
            ckpt_names
        ) != 0, "Can not run forward pass without best checkpoints supplied."
        log_entry = pb.SpinnEntry()

        restore_queue = mp.Queue()
        processes_restore = []
        while ckpt_names:
            pert_name = ckpt_names.pop()
            path = os.path.join(FLAGS.ckpt_path, pert_name)
            name = pert_name.replace('.ckpt_best', '')
            p_restore = mp.Process(target=restore,
                                   args=(logger, trainer, restore_queue, FLAGS,
                                         name, path))
            p_restore.start()
            processes_restore.append(p_restore)
        assert len(ckpt_names) == 0

        results = [restore_queue.get() for p in processes_restore]
        assert results != 0

        acc_order = [
            i[0] for i in sorted(enumerate(results), key=lambda x: x[1][3])
        ]
        best_id = acc_order[0]
        best_name = FLAGS.experiment_name + "_p" + str(best_id)
        best_path = os.path.join(FLAGS.ckpt_path, best_name + ".ckpt_best")
        ev_step, true_step, dev_error, best_dev_step = trainer.load(
            best_path, cpu=FLAGS.gpu < 0)

        print "Picking best perturbation/model %s to run evaluation, with best dev accuracy of %f" % (
            best_name, 1. - dev_error)

        for index, eval_set in enumerate(eval_iterators):
            log_entry.Clear()
            evaluate(FLAGS,
                     model,
                     eval_set,
                     log_entry,
                     true_step,
                     vocabulary,
                     show_sample=True,
                     eval_index=index)
            print(log_entry)
            logger.LogEntry(log_entry)

    # Train the model.
    else:
        # Restore model, i.e. perturbation spawns, from best checkpoint.
        # Get dev-set accuracies so we can select which models to use for the
        # next evolution step.
        if len(ckpt_names) != 0:
            logger.Log("Restoring models from best  checkpoints")
            processes_restore = []
            restore_queue = mp.Queue()
            while ckpt_names:
                pert_name = ckpt_names.pop()
                path = os.path.join(FLAGS.ckpt_path, pert_name)
                name = pert_name.replace('.ckpt_best', '')
                p_restore = mp.Process(target=restore,
                                       args=(logger, trainer, restore_queue,
                                             FLAGS, name, path))
                p_restore.start()
                processes_restore.append(p_restore)
            assert len(ckpt_names) == 0
            results = [restore_queue.get() for p in processes_restore]
            reload_ev_step = results[0][0] + 1  # the next evolution step

        else:
            id_ = "B"
            chosen_models = [(reload_ev_step, true_step, id_, best_dev_error,
                              best_dev_step)]
            base = True  # This is the "base" model
            results = []

        for ev_step in range(reload_ev_step, FLAGS.es_steps):
            logger.Log("Evolution step: %i" % ev_step)

            # Downsample dev-set for evaluation runs during training
            eval_iterators_ = []
            if FLAGS.eval_sample_size is not None:
                for file in eval_iterators:
                    eval_filename = eval_iterators[0][0]
                    eval_batches = eval_iterators[0][1]
                    full = len(eval_batches)
                    subsample = int(full * FLAGS.eval_sample_size)
                    eval_batches = random.sample(eval_batches, subsample)
                    eval_iterators_.append((eval_filename, eval_batches))
            else:
                eval_iterators_ = eval_iterators

            # Choose root models for next generation using dev-set accuracy
            if len(results) != 0:
                base = False
                chosen_models = []
                acc_order = [
                    i[0]
                    for i in sorted(enumerate(results), key=lambda x: x[1][3])
                ]
                for i in range(FLAGS.es_num_roots):
                    id_ = acc_order[i]
                    logger.Log(
                        "Picking model %s to perturb for next evolution step."
                        % results[id_][2])
                    chosen_models.append(results[id_])

                # Early stopping based on current best model
                best_current = chosen_models[0]
                best_current_step = best_current[1]  # true_step
                best_current_dev_step = best_current[4]  # best_dev_step
                if (best_current_step - best_current_dev_step
                    ) > FLAGS.early_stopping_steps_to_wait:
                    logger.Log('No improvement after ' +
                               str(FLAGS.early_stopping_steps_to_wait) +
                               ' steps. Stopping training.')
                    break

            # Flush results from previous generatrion
            results = []
            processes = []
            queue = mp.Queue()
            all_seeds, all_models, all_roots, all_steps, all_dev_errs, all_best_dev_steps = (
                [] for i in range(6))
            for chosen_model in chosen_models:
                perturbation_id = chosen_model[2]
                random_seed, models, true_step, best_dev_step = generate_seeds_and_models(
                    trainer, model, perturbation_id, base=base)
                for i in range(len(models)):
                    all_seeds.append(random_seed)
                    all_steps.append(true_step)
                    all_dev_errs.append(chosen_model[3])
                    all_roots.append(perturbation_id)
                    all_best_dev_steps.append(best_dev_step)
                all_models += models
            assert len(all_seeds) == len(all_models)
            assert len(all_steps) == len(all_seeds)

            perturbation_id = 0
            j = 0
            while all_models:
                perturbed_model = all_models.pop()
                true_step = all_steps.pop()
                best_dev_error = all_dev_errs.pop()
                root_id = all_roots.pop()
                best_dev_step = all_best_dev_steps.pop()
                p = mp.Process(
                    target=rollout,
                    args=(queue, perturbed_model, FLAGS, model, optimizer,
                          trainer, training_data_iter, eval_iterators_, logger,
                          true_step, best_dev_error, perturbation_id, ev_step,
                          header, root_id, vocabulary, best_dev_step))
                p.start()
                processes.append(p)
                perturbation_id += 1
                j += 1
            assert len(all_models) == 0, "All models where not trained!"

            for p in processes:
                p.join()

            results = [queue.get() for p in processes]

            # Check to ensure the correct number of models where trained and saved
            if ev_step == 0:
                assert len(results) == true_num_episodes
            else:
                assert len(results) == true_num_episodes * FLAGS.es_num_roots
예제 #8
0
def train_loop(FLAGS, data_manager, model, optimizer, trainer,
               training_data_iter, eval_iterators, logger, step, best_dev_error):
    # Accumulate useful statistics.
    A = Accumulator(maxlen=FLAGS.deque_length)

    # Checkpoint paths.
    standard_checkpoint_path = get_checkpoint_path(
        FLAGS.ckpt_path, FLAGS.experiment_name)
    best_checkpoint_path = get_checkpoint_path(
        FLAGS.ckpt_path, FLAGS.experiment_name, best=True)

    # Build log format strings.
    model.train()
    X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = get_batch(
        training_data_iter.next())
    model(X_batch, transitions_batch, y_batch,
          use_internal_parser=FLAGS.use_internal_parser,
          validate_transitions=FLAGS.validate_transitions
          )

    # Train.
    logger.Log("Training.")

    # New Training Loop
    progress_bar = SimpleProgressBar(
        msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar)
    progress_bar.step(i=0, total=FLAGS.statistics_interval_steps)

    log_entry = pb.SpinnEntry()
    for step in range(step, FLAGS.training_steps):
        model.train()
        log_entry.Clear()
        log_entry.step = step
        should_log = False

        start = time.time()

        batch = get_batch(training_data_iter.next())
        X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch

        total_tokens = sum(
            [(nt + 1) / 2 for nt in num_transitions_batch.reshape(-1)])

        # Reset cached gradients.
        optimizer.zero_grad()

        epsilon = FLAGS.rl_epsilon * math.exp(-step / FLAGS.rl_epsilon_decay)

        # Epsilon Greedy w. Decay.
        model.spinn.epsilon = epsilon

        # Confidence Penalty for Transition Predictions.
        temperature = math.sin(math.pi / 2 + step /
                               float(FLAGS.rl_confidence_interval) * 2 * math.pi)
        temperature = (temperature + 1) / 2

        if FLAGS.rl_confidence_penalty:
            temp = 1 + \
                (temperature - .5) * FLAGS.rl_confidence_penalty * epsilon
            model.spinn.temperature = max(1e-3, temp)

        # Soft Wake/Sleep based on temperature.
        if FLAGS.rl_wake_sleep:
            model.rl_weight = temperature * FLAGS.rl_weight

        # Run model.
        output = model(X_batch, transitions_batch, y_batch,
                       use_internal_parser=FLAGS.use_internal_parser,
                       validate_transitions=FLAGS.validate_transitions
                       )

        # Normalize output.
        logits = F.log_softmax(output)

        # Calculate class accuracy.
        target = torch.from_numpy(y_batch).long()
        pred = logits.data.max(1)[
            1].cpu()  # get the index of the max log-probability
        class_acc = pred.eq(target).sum() / float(target.size(0))

        # Calculate class loss.
        xent_loss = nn.NLLLoss()(
            logits, to_gpu(Variable(target, volatile=False)))

        # Optionally calculate transition loss.
        transition_loss = model.transition_loss if hasattr(
            model, 'transition_loss') else None

        # Extract L2 Cost
        l2_loss = get_l2_loss(
            model, FLAGS.l2_lambda) if FLAGS.use_l2_loss else None

        # Accumulate Total Loss Variable
        total_loss = 0.0
        total_loss += xent_loss
        if l2_loss is not None:
            total_loss += l2_loss
        if transition_loss is not None and model.optimize_transition_loss:
            total_loss += transition_loss
        aux_loss = auxiliary_loss(model)
        total_loss += aux_loss

        # Backward pass.
        total_loss.backward()

        # Hard Gradient Clipping
        clip = FLAGS.clipping_max_value
        for p in model.parameters():
            if p.requires_grad:
                p.grad.data.clamp_(min=-clip, max=clip)

        # Learning Rate Decay
        if FLAGS.actively_decay_learning_rate:
            optimizer.lr = FLAGS.learning_rate * \
                (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0))

        # Gradient descent step.
        optimizer.step()

        end = time.time()

        total_time = end - start

        train_accumulate(model, data_manager, A, batch)
        A.add('class_acc', class_acc)
        A.add('total_tokens', total_tokens)
        A.add('total_time', total_time)

        train_rl_accumulate(model, data_manager, A, batch)

        if step % FLAGS.statistics_interval_steps == 0 \
                or step % FLAGS.metrics_interval_steps == 0:
            if step % FLAGS.statistics_interval_steps == 0:
                progress_bar.step(i=FLAGS.statistics_interval_steps,
                                  total=FLAGS.statistics_interval_steps)
                progress_bar.finish()

            A.add('xent_cost', xent_loss.data[0])
            A.add('l2_cost', l2_loss.data[0])
            stats(model, optimizer, A, step, log_entry)

        if step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0:
            should_log = True
            model.train()
            model(X_batch, transitions_batch, y_batch,
                  use_internal_parser=FLAGS.use_internal_parser,
                  validate_transitions=FLAGS.validate_transitions
                  )
            tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example(
            )

            model.eval()
            model(X_batch, transitions_batch, y_batch,
                  use_internal_parser=FLAGS.use_internal_parser,
                  validate_transitions=FLAGS.validate_transitions
                  )
            ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example(
            )

            if model.use_sentence_pair and len(transitions_batch.shape) == 3:
                transitions_batch = np.concatenate([
                    transitions_batch[:, :, 0], transitions_batch[:, :, 1]], axis=0)

            # This could be done prior to running the batch for a tiny speed
            # boost.
            t_idxs = range(FLAGS.num_samples)
            random.shuffle(t_idxs)
            t_idxs = sorted(t_idxs[:FLAGS.num_samples])
            for t_idx in t_idxs:
                log = log_entry.rl_sampling.add()
                gold = transitions_batch[t_idx]
                pred_tr = tr_transitions_per_example[t_idx]
                pred_ev = ev_transitions_per_example[t_idx]
                strength_tr = sparks(
                    [1] + tr_strength[t_idx].tolist(), dec_str)
                strength_ev = sparks(
                    [1] + ev_strength[t_idx].tolist(), dec_str)
                _, crossing = evalb.crossing(gold, pred)

                log.t_idx = t_idx
                log.crossing = crossing
                log.gold_lb = "".join(map(str, gold))
                log.pred_tr = "".join(map(str, pred_tr))
                log.pred_ev = "".join(map(str, pred_ev))
                log.strg_tr = strength_tr[1:].encode('utf-8')
                log.strg_ev = strength_ev[1:].encode('utf-8')

        if step > 0 and step % FLAGS.eval_interval_steps == 0:
            should_log = True
            for index, eval_set in enumerate(eval_iterators):
                acc, tacc = evaluate(
                    FLAGS, model, data_manager, eval_set, log_entry, step)
                if FLAGS.ckpt_on_best_dev_error and index == 0 and (
                        1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step:
                    best_dev_error = 1 - acc
                    logger.Log(
                        "Checkpointing with new best dev accuracy of %f" % acc)
                    trainer.save(best_checkpoint_path, step, best_dev_error)
            progress_bar.reset()

        if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0:
            should_log = True
            logger.Log("Checkpointing.")
            trainer.save(standard_checkpoint_path, step, best_dev_error)

        log_level = afs_safe_logger.ProtoLogger.INFO
        if not should_log and step % FLAGS.metrics_interval_steps == 0:
            # Log to file, but not to stderr.
            should_log = True
            log_level = afs_safe_logger.ProtoLogger.DEBUG

        if should_log:
            logger.LogEntry(log_entry, level=log_level)

        progress_bar.step(i=step % FLAGS.statistics_interval_steps,
                          total=FLAGS.statistics_interval_steps)
예제 #9
0
def run(only_forward=False):
    logger = afs_safe_logger.ProtoLogger(log_path(FLAGS),
                                         print_formatter=create_log_formatter(
                                             True, False),
                                         write_proto=FLAGS.write_proto_to_log)
    header = pb.SpinnHeader()

    data_manager = get_data_manager(FLAGS.data_type)

    logger.Log("Flag Values:\n" +
               json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True))

    # Get Data and Embeddings
    vocabulary, initial_embeddings, training_data_iter, eval_iterators, training_data_length, target_vocabulary = \
        load_data_and_embeddings(FLAGS, data_manager, logger,
                                 "", FLAGS.eval_data_path)

    # Build model.
    vocab_size = len(vocabulary)
    if FLAGS.data_type != "mt":
        num_classes = len(set(data_manager.LABEL_MAP.values()))
    else:
        num_classes = None

    model = init_model(FLAGS,
                       logger,
                       initial_embeddings,
                       vocab_size,
                       num_classes,
                       data_manager,
                       header,
                       target_vocabulary=target_vocabulary)
    time_to_wait_to_lower_lr = min(
        10000, int(training_data_length / FLAGS.batch_size))

    trainer = ModelTrainer(model, logger, time_to_wait_to_lower_lr, vocabulary,
                           FLAGS)

    header.start_step = trainer.step
    header.start_time = int(time.time())

    # Do an evaluation-only run.
    logger.LogHeader(header)  # Start log_entry logging.
    if only_forward:
        log_entry = pb.SpinnEntry()
        for index, eval_set in enumerate(eval_iterators):
            log_entry.Clear()
            evaluate(FLAGS,
                     model,
                     eval_set,
                     log_entry,
                     logger,
                     trainer,
                     vocabulary,
                     show_sample=True,
                     eval_index=index,
                     target_vocabulary=target_vocabulary)
            print(log_entry)
            logger.LogEntry(log_entry)
    else:
        train_loop(FLAGS, model, trainer, training_data_iter, eval_iterators,
                   logger, vocabulary, target_vocabulary)
예제 #10
0
def train_loop(FLAGS, model, trainer, training_data_iter, eval_iterators,
               logger, vocabulary, target_vocabulary):
    # Accumulate useful statistics.
    A = Accumulator(maxlen=FLAGS.deque_length)

    # Train.
    logger.Log("Training.")

    # New Training Loop
    progress_bar = SimpleProgressBar(msg="Training",
                                     bar_length=60,
                                     enabled=FLAGS.show_progress_bar)
    progress_bar.step(i=0, total=FLAGS.statistics_interval_steps)
    rl_only = False
    log_entry = pb.SpinnEntry()
    for _ in range(trainer.step, FLAGS.training_steps):
        if FLAGS.rl_alternate and trainer.step % 1000 == 0 and trainer.step > 0:
            rl_only = not rl_only
            if rl_only:
                logger.Log('Switching training mode: RL only.')
            else:
                logger.Log('Switching training mode: MT only.')
        if (trainer.step -
                trainer.best_dev_step) > FLAGS.early_stopping_steps_to_wait:
            logger.Log('No improvement after ' +
                       str(FLAGS.early_stopping_steps_to_wait) +
                       ' steps. Stopping training.')
            break

        model.train()
        log_entry.Clear()
        log_entry.step = trainer.step
        should_log = False

        start = time.time()

        batch = get_batch(next(training_data_iter))
        X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch

        total_tokens = sum([(nt + 1) / 2
                            for nt in num_transitions_batch.reshape(-1)])

        # Reset cached gradients.
        trainer.optimizer_zero_grad()

        temperature = math.sin(
            math.pi / 2 +
            trainer.step / float(FLAGS.rl_confidence_interval) * 2 * math.pi)
        temperature = (temperature + 1) / 2

        # Confidence Penalty for Transition Predictions.
        if FLAGS.rl_confidence_penalty:
            epsilon = FLAGS.rl_epsilon * \
                math.exp(-trainer.step / float(FLAGS.rl_epsilon_decay))
            temp = 1 + \
                (temperature - .5) * FLAGS.rl_confidence_penalty * epsilon
            model.spinn.temperature = max(1e-3, temp)

        # Soft Wake/Sleep based on temperature.
        if FLAGS.rl_wake_sleep:
            model.rl_weight = temperature * FLAGS.rl_weight

        # Run model.
        output, trg, attention, mask = model(
            X_batch,
            transitions_batch,
            y_batch,
            use_internal_parser=FLAGS.use_internal_parser,
            validate_transitions=FLAGS.validate_transitions,
            example_lengths=num_transitions_batch)

        criterion = nn.NLLLoss()
        batch_size = len(y_batch)
        trg_seq_len = trg.shape[0]
        mt_loss = 0.0
        if rl_only == False:
            num_classes = output.shape[-1]
            mask = to_gpu(mask)

            for i in range(trg_seq_len):
                mt_loss += criterion(
                    output[i, :].index_select(0, mask[i].nonzero().squeeze(1)),
                    trg[i].index_select(0,
                                        mask[i].nonzero().squeeze(1)).view(-1))
        elif FLAGS.rl_alternate:
            model.policy_loss = 0.0
            model.value_loss = 0.0
        # Optionally calculate transition loss.
        mt_loss = mt_loss / trg_seq_len
        model.transition_loss = model.encoder.transition_loss if hasattr(
            model.encoder, 'transition_loss') else None
        transition_loss = model.transition_loss if hasattr(
            model, 'transition_loss') else None
        model.mt_loss = mt_loss

        # Accumulate Total Loss Variable
        total_loss = 0.0
        total_loss += mt_loss
        if transition_loss is not None and model.encoder.optimize_transition_loss:
            model.optimize_transition_loss = model.encoder.optimize_transition_loss
            total_loss += transition_loss
        aux_loss = auxiliary_loss(model)
        total_loss += aux_loss[0]

        # Backward pass.
        total_loss.backward()

        # Hard Gradient Clipping
        nn.utils.clip_grad_norm_([
            param for name, param in model.named_parameters()
            if name not in ["embed.embed.weight"]
        ], FLAGS.clipping_max_value)

        # Gradient descent step.
        trainer.optimizer_step()
        bb = list(model.parameters())[-1].clone()
        end = time.time()

        total_time = end - start

        train_accumulate(model, A, batch)
        A.add('total_tokens', total_tokens)
        A.add('total_time', total_time)
        A.add('mt_loss', float(mt_loss))

        train_rl_accumulate(model, A, batch)

        if trainer.step % FLAGS.statistics_interval_steps == 0:
            stats(model, trainer, A, log_entry)
            should_log = True
            progress_bar.finish()

        if trainer.step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0:
            should_log = True
            model.train()
            model(X_batch,
                  transitions_batch,
                  y_batch,
                  use_internal_parser=FLAGS.use_internal_parser,
                  validate_transitions=FLAGS.validate_transitions,
                  example_lengths=num_transitions_batch)
            tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example(
            )

            model.eval()
            model(X_batch,
                  transitions_batch,
                  y_batch,
                  use_internal_parser=FLAGS.use_internal_parser,
                  validate_transitions=FLAGS.validate_transitions,
                  example_lengths=num_transitions_batch)
            ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example(
            )

            if model.use_sentence_pair and len(transitions_batch.shape) == 3:
                transitions_batch = np.concatenate(
                    [transitions_batch[:, :, 0], transitions_batch[:, :, 1]],
                    axis=0)

            # This could be done prior to running the batch for a tiny speed
            # boost.
            t_idxs = list(range(FLAGS.num_samples))
            random.shuffle(t_idxs)
            t_idxs = sorted(t_idxs[:FLAGS.num_samples])
            for t_idx in t_idxs:
                log = log_entry.rl_sampling.add()
                gold = transitions_batch[t_idx]
                pred_tr = tr_transitions_per_example[t_idx]
                pred_ev = ev_transitions_per_example[t_idx]
                strength_tr = sparks([1] + tr_strength[t_idx].tolist(),
                                     dec_str)
                strength_ev = sparks([1] + ev_strength[t_idx].tolist(),
                                     dec_str)
                _, crossing = evalb.crossing(gold, pred_ev)
                log.t_idx = t_idx
                log.crossing = crossing
                log.gold_lb = "".join(map(str, gold))
                log.pred_tr = "".join(map(str, pred_tr))
                log.pred_ev = "".join(map(str, pred_ev))
                log.strg_tr = strength_tr[1:]
                log.strg_ev = strength_ev[1:]

        if trainer.step > 0 and trainer.step % FLAGS.eval_interval_steps == 0:
            should_log = True
            for index, eval_set in enumerate(eval_iterators):
                acc, _ = evaluate(
                    FLAGS,
                    model,
                    eval_set,
                    log_entry,
                    logger,
                    trainer,
                    show_sample=(trainer.step %
                                 FLAGS.sample_interval_steps == 0),
                    vocabulary=vocabulary,
                    eval_index=index,
                    target_vocabulary=target_vocabulary)
                if index == 0:
                    trainer.new_dev_accuracy(acc)
            progress_bar.reset()

        if trainer.step > FLAGS.ckpt_step and trainer.step % FLAGS.ckpt_interval_steps == 0:
            should_log = True
            trainer.checkpoint()

        if should_log:
            logger.LogEntry(log_entry)

        progress_bar.step(i=(trainer.step % FLAGS.statistics_interval_steps) +
                          1,
                          total=FLAGS.statistics_interval_steps)