Example #1
0
    def test_save_sup_load_rl(self):
        pass

        model_to_save = MockModel(spinn.pyramid.Pyramid, default_args())
        opt_to_save = optim.SGD(model_to_save.parameters(), lr=0.1)
        trainer_to_save = ModelTrainer(model_to_save, opt_to_save)

        model_to_load = MockModel(spinn.pyramid.Pyramid, default_args())
        opt_to_load = optim.SGD(model_to_load.parameters(), lr=0.1)
        trainer_to_load = ModelTrainer(model_to_load, opt_to_load)

        # Save to and load from temporary file.
        temp = tempfile.NamedTemporaryFile()
        trainer_to_save.save(temp.name, 0, 0)
        trainer_to_load.load(temp.name)

        compare_models(model_to_save, model_to_load)

        # Cleanup temporary file.
        temp.close()
Example #2
0
    def test_save_sup_load_rl(self):
        scalar = 11
        other_scalar = 0

        model_to_save = MockModel(spinn.fat_stack.BaseModel, default_args())
        opt_to_save = optim.SGD(model_to_save.parameters(), lr=0.1)
        trainer_to_save = ModelTrainer(model_to_save, opt_to_save)

        model_to_load = MockModel(spinn.rl_spinn.BaseModel, default_args())
        opt_to_load = optim.SGD(model_to_load.parameters(), lr=0.1)
        trainer_to_load = ModelTrainer(model_to_load, opt_to_load)

        # Save to and load from temporary file.
        temp = tempfile.NamedTemporaryFile()
        trainer_to_save.save(temp.name, 0, 0)
        trainer_to_load.load(temp.name)

        compare_models(model_to_save, model_to_load)

        # Cleanup temporary file.
        temp.close()
Example #3
0
def init_model(
        FLAGS,
        logger,
        initial_embeddings,
        vocab_size,
        num_classes,
        data_manager,
        logfile_header=None):
    # Choose model.
    logger.Log("Building model.")
    if FLAGS.model_type == "CBOW":
        build_model = spinn.cbow.build_model
    elif FLAGS.model_type == "RNN":
        build_model = spinn.plain_rnn.build_model
    elif FLAGS.model_type == "SPINN":
        build_model = spinn.spinn_core_model.build_model
    elif FLAGS.model_type == "RLSPINN":
        build_model = spinn.rl_spinn.build_model
    elif FLAGS.model_type == "Pyramid":
        build_model = spinn.pyramid.build_model
    else:
        raise NotImplementedError

    # Input Encoder.
    context_args = Args()
    context_args.reshape_input = lambda x, batch_size, seq_length: x
    context_args.reshape_context = lambda x, batch_size, seq_length: x
    context_args.input_dim = FLAGS.word_embedding_dim

    if FLAGS.encode == "projection":
        encoder = Linear()(FLAGS.word_embedding_dim, FLAGS.model_dim)
    elif FLAGS.encode == "gru":
        context_args.reshape_input = lambda x, batch_size, seq_length: x.view(
            batch_size, seq_length, -1)
        context_args.reshape_context = lambda x, batch_size, seq_length: x.view(
            batch_size * seq_length, -1)
        context_args.input_dim = FLAGS.model_dim
        encoder = EncodeGRU(FLAGS.word_embedding_dim, FLAGS.model_dim,
                            num_layers=FLAGS.encode_num_layers,
                            bidirectional=FLAGS.encode_bidirectional,
                            reverse=FLAGS.encode_reverse)
    elif FLAGS.encode == "attn":
        context_args.reshape_input = lambda x, batch_size, seq_length: x.view(
            batch_size, seq_length, -1)
        context_args.reshape_context = lambda x, batch_size, seq_length: x.view(
            batch_size * seq_length, -1)
        context_args.input_dim = FLAGS.model_dim
        encoder = IntraAttention(FLAGS.word_embedding_dim, FLAGS.model_dim)
    elif FLAGS.encode == "pass":
        def encoder(x): return x
    else:
        raise NotImplementedError

    context_args.encoder = encoder

    # Composition Function.
    composition_args = Args()
    composition_args.lateral_tracking = FLAGS.lateral_tracking
    composition_args.tracking_ln = FLAGS.tracking_ln
    composition_args.use_tracking_in_composition = FLAGS.use_tracking_in_composition
    composition_args.size = FLAGS.model_dim
    composition_args.tracker_size = FLAGS.tracking_lstm_hidden_dim
    composition_args.use_internal_parser = FLAGS.use_internal_parser
    composition_args.transition_weight = FLAGS.transition_weight
    composition_args.wrap_items = lambda x: torch.cat(x, 0)
    composition_args.extract_h = lambda x: x
    composition_args.extract_c = None

    composition_args.detach = FLAGS.transition_detach
    composition_args.evolution = FLAGS.evolution

    if FLAGS.reduce == "treelstm":
        assert FLAGS.model_dim % 2 == 0, 'model_dim must be an even number.'
        if FLAGS.model_dim != FLAGS.word_embedding_dim:
            print('If you are setting different hidden layer and word '
                  'embedding sizes, make sure you specify an encoder')
        composition_args.wrap_items = lambda x: bundle(x)
        composition_args.extract_h = lambda x: x.h
        composition_args.extract_c = lambda x: x.c
        composition_args.size = FLAGS.model_dim / 2
        composition = ReduceTreeLSTM(FLAGS.model_dim / 2,
                                     tracker_size=FLAGS.tracking_lstm_hidden_dim,
                                     use_tracking_in_composition=FLAGS.use_tracking_in_composition,
                                     composition_ln=FLAGS.composition_ln)
    elif FLAGS.reduce == "tanh":
        class ReduceTanh(nn.Module):
            def forward(self, lefts, rights, tracking=None):
                batch_size = len(lefts)
                ret = torch.cat(lefts, 0) + F.tanh(torch.cat(rights, 0))
                return torch.chunk(ret, batch_size, 0)
        composition = ReduceTanh()
    elif FLAGS.reduce == "treegru":
        composition = ReduceTreeGRU(FLAGS.model_dim,
                                    FLAGS.tracking_lstm_hidden_dim,
                                    FLAGS.use_tracking_in_composition)
    else:
        raise NotImplementedError

    composition_args.composition = composition

    model = build_model(data_manager, initial_embeddings, vocab_size,
                        num_classes, FLAGS, context_args, composition_args)

    # Build optimizer.
    if FLAGS.optimizer_type == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=FLAGS.learning_rate,
                               betas=(0.9, 0.999), eps=1e-08)
    elif FLAGS.optimizer_type == "RMSprop":
        optimizer = optim.RMSprop(model.parameters(), lr=FLAGS.learning_rate, eps=1e-08)
    else:
        raise NotImplementedError

    # Build trainer.
    if FLAGS.evolution:
        trainer = ModelTrainer_ES(model, optimizer)
    else:
        trainer = ModelTrainer(model, optimizer)

    # Print model size.
    logger.Log("Architecture: {}".format(model))
    if logfile_header:
        logfile_header.model_architecture = str(model)
    total_params = sum([reduce(lambda x, y: x * y, w.size(), 1.0) for w in model.parameters()])
    logger.Log("Total params: {}".format(total_params))
    if logfile_header:
        logfile_header.total_params = int(total_params)

    return model, optimizer, trainer
Example #4
0
def run(only_forward=False):
    logger = afs_safe_logger.Logger(os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log")

    # Select data format.
    data_manager = get_data_manager(FLAGS.data_type)

    logger.Log("Flag Values:\n" + json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True))

    # Load the data.
    raw_training_data, vocabulary = data_manager.load_data(
        FLAGS.training_data_path, FLAGS.lowercase)

    # Load the eval data.
    raw_eval_sets = []
    if FLAGS.eval_data_path:
        for eval_filename in FLAGS.eval_data_path.split(":"):
            raw_eval_data, _ = data_manager.load_data(eval_filename, FLAGS.lowercase)
            raw_eval_sets.append((eval_filename, raw_eval_data))

    # Prepare the vocabulary.
    if not vocabulary:
        logger.Log("In open vocabulary mode. Using loaded embeddings without fine-tuning.")
        train_embeddings = False
        vocabulary = util.BuildVocabulary(
            raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA)
    else:
        logger.Log("In fixed vocabulary mode. Training embeddings.")
        train_embeddings = True

    # Load pretrained embeddings.
    if FLAGS.embedding_data_path:
        logger.Log("Loading vocabulary with " + str(len(vocabulary))
                   + " words from " + FLAGS.embedding_data_path)
        initial_embeddings = util.LoadEmbeddingsFromText(
            vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path)
    else:
        initial_embeddings = None

    # Trim dataset, convert token sequences to integer sequences, crop, and
    # pad.
    logger.Log("Preprocessing training data.")
    training_data = util.PreprocessDataset(
        raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger,
        sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
        for_rnn=sequential_only())
    training_data_iter = util.MakeTrainingIterator(
        training_data, FLAGS.batch_size, FLAGS.smart_batching, FLAGS.use_peano,
        sentence_pair_data=data_manager.SENTENCE_PAIR_DATA)

    # Preprocess eval sets.
    eval_iterators = []
    for filename, raw_eval_set in raw_eval_sets:
        logger.Log("Preprocessing eval data: " + filename)
        eval_data = util.PreprocessDataset(
            raw_eval_set, vocabulary,
            FLAGS.eval_seq_length if FLAGS.eval_seq_length is not None else FLAGS.seq_length,
            data_manager, eval_mode=True, logger=logger,
            sentence_pair_data=data_manager.SENTENCE_PAIR_DATA,
            for_rnn=sequential_only())
        eval_it = util.MakeEvalIterator(eval_data,
            FLAGS.batch_size, FLAGS.eval_data_limit, bucket_eval=FLAGS.bucket_eval,
            shuffle=FLAGS.shuffle_eval, rseed=FLAGS.shuffle_eval_seed)
        eval_iterators.append((filename, eval_it))

    # Build model.
    vocab_size = len(vocabulary)
    num_classes = len(data_manager.LABEL_MAP)

    model, optimizer, trainer = init_model(FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager)

    # Build trainer.
    trainer = ModelTrainer(model, optimizer)

    standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name)
    best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True)

    # Load checkpoint if available.
    if FLAGS.load_best and os.path.isfile(best_checkpoint_path):
        logger.Log("Found best checkpoint, restoring.")
        step, best_dev_error = trainer.load(best_checkpoint_path)
        logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error))
    elif os.path.isfile(standard_checkpoint_path):
        logger.Log("Found checkpoint, restoring.")
        step, best_dev_error = trainer.load(standard_checkpoint_path)
        logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error))
    else:
        assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint."
        step = 0
        best_dev_error = 1.0

    # GPU support.
    the_gpu.gpu = FLAGS.gpu
    if FLAGS.gpu >= 0:
        model.cuda()
    else:
        model.cpu()
    recursively_set_device(optimizer.state_dict(), the_gpu.gpu)

    # Debug
    def set_debug(self):
        self.debug = FLAGS.debug
    model.apply(set_debug)

    # Accumulate useful statistics.
    A = Accumulator(maxlen=FLAGS.deque_length)

    # Do an evaluation-only run.
    if only_forward:
        for index, eval_set in enumerate(eval_iterators):
            acc = evaluate(model, eval_set, logger, step, vocabulary)
    else:
        # Build log format strings.
        model.train()
        X_batch, transitions_batch, y_batch, num_transitions_batch = get_batch(training_data_iter.next())[:4]
        model(X_batch, transitions_batch, y_batch,
                use_internal_parser=FLAGS.use_internal_parser,
                validate_transitions=FLAGS.validate_transitions
                )

        train_str = train_format(model)
        logger.Log("Train-Format: {}".format(train_str))
        train_extra_str = train_extra_format(model)
        logger.Log("Train-Extra-Format: {}".format(train_extra_str))

         # Train
        logger.Log("Training.")

        # New Training Loop
        progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar)
        progress_bar.step(i=0, total=FLAGS.statistics_interval_steps)

        for step in range(step, FLAGS.training_steps):
            model.train()

            start = time.time()

            batch = get_batch(training_data_iter.next())
            X_batch, transitions_batch, y_batch, num_transitions_batch = batch[:4]

            total_tokens = sum([(nt+1)/2 for nt in num_transitions_batch.reshape(-1)])

            # Reset cached gradients.
            optimizer.zero_grad()

            if FLAGS.model_type == "RLSPINN":
                model.spinn.epsilon = FLAGS.rl_epsilon * math.exp(-step/FLAGS.rl_epsilon_decay)

            # Run model.
            output = model(X_batch, transitions_batch, y_batch,
                use_internal_parser=FLAGS.use_internal_parser,
                validate_transitions=FLAGS.validate_transitions
                )

            # Normalize output.
            logits = F.log_softmax(output)

            # Calculate class accuracy.
            target = torch.from_numpy(y_batch).long()
            pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability
            class_acc = pred.eq(target).sum() / float(target.size(0))

            # Calculate class loss.
            xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False)))

            # Optionally calculate transition loss.
            transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None

            # Extract L2 Cost
            l2_loss = l2_cost(model, FLAGS.l2_lambda) if FLAGS.use_l2_cost else None

            # Accumulate Total Loss Variable
            total_loss = 0.0
            total_loss += xent_loss
            if l2_loss is not None:
                total_loss += l2_loss
            if transition_loss is not None and model.optimize_transition_loss:
                total_loss += transition_loss
            total_loss += auxiliary_loss(model)

            # Backward pass.
            total_loss.backward()

            # Hard Gradient Clipping
            clip = FLAGS.clipping_max_value
            for p in model.parameters():
                if p.requires_grad:
                    p.grad.data.clamp_(min=-clip, max=clip)

            # Learning Rate Decay
            if FLAGS.actively_decay_learning_rate:
                optimizer.lr = FLAGS.learning_rate * (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0))

            # Gradient descent step.
            optimizer.step()

            end = time.time()

            total_time = end - start

            train_accumulate(model, data_manager, A, batch)
            A.add('class_acc', class_acc)
            A.add('total_tokens', total_tokens)
            A.add('total_time', total_time)

            if step % FLAGS.statistics_interval_steps == 0:
                progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)
                progress_bar.finish()

                A.add('xent_cost', xent_loss.data[0])
                A.add('l2_cost', l2_loss.data[0])
                stats_args = train_stats(model, optimizer, A, step)

                logger.Log(train_str.format(**stats_args))
                logger.Log(train_extra_str.format(**stats_args))

            if step > 0 and step % FLAGS.eval_interval_steps == 0:
                for index, eval_set in enumerate(eval_iterators):
                    acc = evaluate(model, eval_set, logger, step)
                    if FLAGS.ckpt_on_best_dev_error and index == 0 and (1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step:
                        best_dev_error = 1 - acc
                        logger.Log("Checkpointing with new best dev accuracy of %f" % acc)
                        trainer.save(best_checkpoint_path, step, best_dev_error)
                progress_bar.reset()

            if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0:
                logger.Log("Checkpointing.")
                trainer.save(standard_checkpoint_path, step, best_dev_error)

            progress_bar.step(i=step % FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)