示例#1
0
def evaluate(args, utterance_data, interaction_data, model):
    accuracy, _, _ = utterance_accuracy(model,
                                        utterance_data,
                                        fsa_builder=TangramsFSA,
                                        logfile=args.results_file)
    print(accuracy)
    accuracy = interaction_accuracy(model,
                                    interaction_data,
                                    fsa_builder=TangramsFSA)
    print(accuracy)
示例#2
0
def evaluate(args, utterance_data, interaction_data, model):
    accuracy, _, _ = utterance_accuracy(
        model,
        utterance_data,
        fsa_builder=AlchemyFSA,
        logfile=args.results_file + "_utt.log")
    print(accuracy)
    accuracy = interaction_accuracy(
        model,
        interaction_data,
        fsa_builder=AlchemyFSA,
        logfile=args.results_file + "_int.log")
    print(accuracy)
示例#3
0
def do_one_epoch(model,
                 train_set,
                 val_set,
                 val_ints,
                 fsa_builder,
                 args,
                 epoch,
                 trainer):
    """ Performs one epoch of update to the model.

    Inputs:
        model (SconeModel): Model to update.
        train_set (list of Example): Examples to train on.
        val_set (list of Example): Examples to compute held out accuracy on.
        fsa_builder (lambda WorldState : ExecutionFSA): Creates an FSA from a
            world state.
        args (kwargs)

    Returns:
        float, the validation accuracy computed on the val_set after the epoch
    """
    epoch_loss = 0.0
    batches = chunks(train_set, args.batch_size)
    for batch in tqdm.tqdm(batches,
                           desc='Training epoch %d' % (epoch),
                           ncols=80):
        epoch_loss += do_one_batch(batch, model, trainer, fsa_builder)
    print('Epoch mean loss: %.4f' % (epoch_loss / len(batches)))
    # At the end of each epoch, run on the validation data.
    val_loss = get_set_loss(model, val_set, fsa_builder, args.batch_size)
    val_accuracy, _, val_token_accuracy = utterance_accuracy(
        model, val_set, fsa_builder, '%s/val-epoch%d.log' %
        (args.logdir, epoch))
    val_int_acc = interaction_accuracy(model, val_ints, fsa_builder,
                    '%s/val-int-epoch-%d.log' % (args.logdir, epoch))
    print(
        'Validation: loss=%.4f, accuracy=%.4f, int_acc=%.4f, token_acc=%.4f' %
        (val_loss, val_accuracy, val_int_acc, val_token_accuracy))
    return val_token_accuracy, val_int_acc
示例#4
0
def train_and_evaluate(model,
                       train_set,
                       val_set,
                       val_interactions,
                       dev_set,
                       dev_interactions,
                       args,
                       fsa_builder=None):
    model.set_dropout(args.supervised_dropout)
    training_subset = train_set
    if args.supervised_ratio < 1.0 or args.supervised_amount > 0:
        train_ids = [example.id[:-1] for example in train_set]
        random.shuffle(train_ids)
        if args.supervised_ratio < 1.0:
            ids_to_use = train_ids[:int(len(train_ids) * args.supervised_ratio)]
        elif args.supervised_amount > 0:
            ids_to_use = train_ids[:args.supervised_amount]
        training_subset = [example for example in train_set if example.id[:-1] in ids_to_use]
    print("Training with " + str(len(training_subset)) + " examples")
    training(model,
             training_subset,
             val_set,
             val_interactions,
             args,
             fsa_builder=fsa_builder)
    model.save_params(args.logdir + "/supervised_model.dy")
    dev_accuracy, _, _ = utterance_accuracy(
        model,
        dev_set,
        fsa_builder=fsa_builder,
        logfile=args.logdir + "/dev_utterances.log")
    dev_interaction_accuracy = interaction_accuracy(
        model,
        dev_interactions,
        fsa_builder=fsa_builder,
        logfile=args.logdir + "/dev_interactions.log")
    print("Development accuracy: %.4f (single), %.4f (interaction)" %
          (dev_accuracy, dev_interaction_accuracy))
示例#5
0
def main():
    """Actually does the training."""
    args = interpret_args()
    train_data, dev_data, val_data, test_data, in_vocab, _ = \
        load_data('../data/alchemy/train_sequences.json',
                  '../data/alchemy/dev_sequences.json',
                  '../data/alchemy/test_sequences.json',
                  AlchemyState,
                  args)
    train, train_interactions = train_data
    dev, dev_interactions = dev_data
    val, val_interactions = val_data
    test, test_interactions = test_data

    model = ConstrainedContextSeq2SeqEmbeddings(in_vocab,
                                                (["push", "pop"],
                                                 [str(x+1) for x in range(NUM_BEAKERS)],
                                                 list(COLORS)),
                                                AlchemyStateEncoder,
                                                valid_action_fn,
                                                args)

    train.sort(key=lambda x: (x.turn, len(x.utterance)))

    # Train.
    # Using development set for validation -- currently only to pick the best
    # model.
    if args.evaluate:
        model.load_params(args.saved_model)
        split_to_eval = None
        interactions_to_eval = None
        if args.evaluate_split == "train":
            split_to_eval = train
            interactions_to_eval = train_interactions
        elif args.evaluate_split == "val":
            split_to_eval = val
            interactions_to_eval = val_interactions
        elif args.evaluate_split == "dev":
            split_to_eval = dev
            interactions_to_eval = dev_interactions
        elif args.evaluate_split == "test":
            split_to_eval = test
            interactions_to_eval = test_interactions
        else:
            raise ValueError("Unexpected split name " + args.evaluate_split)
        evaluate(args, split_to_eval, interactions_to_eval, model)

    if args.evaluate_attention:
        model.load_params(args.saved_model)
        for example in dev:
            attention_analysis(model, example, AlchemyFSA, name = args.logdir + "/attention/" + example.id)
    if args.supervised:
        train_and_evaluate(model,
                           train,
                           val,
                           val_interactions,
                           dev,
                           dev_interactions,
                           args,
                           fsa_builder=AlchemyFSA)
    if args.rl:
        if args.pretrained:
            model.load_params(args.logdir + "/supervised_model.dy")
#            train_acc = utterance_accuracy(model,
#                                           train,
#                                           fsa_builder=AlchemyFSA,
#                                           syntax_restricted=args.syntax_restricted,
#                                           logfile=args.logdir + "supervised_train.log")
#            print('Training accuracy: ' + str(train_acc) + " (single)")

        model.set_dropout(args.rl_dropout)
        reinforcement_learning(model,
                               train,
                               val,
                               val_interactions,
                               args.logdir,
                               AlchemyFSA,
                               reward_with_shaping,
                               model.compute_entropy,
                               args,
                               epochs=args.max_epochs,
                               batch_size=20,
                               single_head=False,
                               explore_with_fsa=False)

        # Save final model.
        model.save_params(args.logdir + '/model-final.dy')

    # Test results.
    if test:
        test_accuracy, _, _ = utterance_accuracy(
            model, test, AlchemyFSA, args.logdir + '/test.log')
        test_interaction_accuracy = interaction_accuracy(
            model, test_interactions, AlchemyFSA, args.logdir + '/test.interactions.log')
        print('Test accuracy: %.4f (single), %.4f (interaction)' %
              (test_accuracy, test_interaction_accuracy))
示例#6
0
def reinforcement_learning(model,
                           train_set,
                           val_set,
                           val_interactions,
                           log_dir,
                           fsa_builder,
                           reward_fn,
                           entropy_function,
                           args,
                           batch_size=1,
                           epochs=20,
                           single_head=True,
                           explore_with_fsa=False):
    """Performs training with exploration.

    Inputs:
        model (Model): Model to train.
        train_set (list of Examples): The set of training examples.
        val_set (list of Examples): The set of validation examples.
        val_interactions (list of Interactions): Full interactions for validation.
        log_dir (str): Location to log.

    """
    trainer = dy.RMSPropTrainer(model.get_params())
    trainer.set_clip_threshold(1)

    mode = get_rl_mode(args.rl_mode)

    best_val_accuracy = 0.0
    best_val_reward = -float('inf')
    best_model = None

    try:
        from pycrayon import CrayonClient
        crayon = CrayonClient(hostname="localhost")
        experiment = crayon.create_experiment(log_dir)
    except ValueError or ImportError:
        print(
            "If you want to use Crayon, please use `pip install pycrayon` to install it. "
        )
        experiment = None

    num_batches = 0
    train_file = open(os.path.join(log_dir, "train.log"), "w")

    patience = args.patience
    countdown = patience

    for epoch in range(epochs):
        random.shuffle(train_set)
        batches = chunks(train_set, batch_size)

        num_examples = 0
        num_tokens = 0
        num_tokens_zero = 0
        progbar = progressbar.ProgressBar(maxval=len(batches),
                                          widgets=[
                                              "Epoch " + str(epoch),
                                              progressbar.Bar('=', '[', ']'),
                                              ' ',
                                              progressbar.Percentage(), ' ',
                                              progressbar.ETA()
                                          ])
        progbar.start()
        for i, batch in enumerate(batches):
            dy.renew_cg()

            prob_seqs, predictions = model.sample_sequences(
                batch,
                length=args.sample_length_limit,
                training=True,
                fsa_builder=fsa_builder)

            batch_entropy_sum = dy.inputTensor([0.])
            batch_rewards = []
            processed_predictions = []

            train_file.write("--- NEW BATCH # " + str(num_batches) + " ---\n")
            action_probabilities = {}
            for action in model.output_action_vocabulary:
                if action != BEG:
                    action_probabilities[action] = []

            for example, prob_seq, prediction in zip(batch, prob_seqs,
                                                     predictions):
                # Get reward (and other evaluation information)
                prediction = process_example(example, prediction, prob_seq,
                                             reward_fn, entropy_function,
                                             model, args, fsa_builder)
                for distribution in prob_seq:
                    action_probability = model.action_probabilities(
                        distribution)
                    for action, prob_exp in action_probability.items():
                        action_probabilities[action].append(prob_exp)

                batch_rewards.extend(prediction.reward_expressions)
                batch_entropy_sum += dy.esum(prediction.entropies)
                processed_predictions.append(prediction)

                num_examples += 1

            # Now backpropagate given these rewards
            batch_action_probabilities = {}
            for action, prob_exps in action_probabilities.items():
                batch_action_probabilities[action] = dy.esum(prob_exps) / len(
                    batch_rewards)

            num_reward_exps = len(batch_rewards)
            loss = dy.esum(batch_rewards)
            if args.entropy_coefficient > 0:
                loss += args.entropy_coefficient * batch_entropy_sum
            loss = -loss / num_reward_exps
            loss.backward()
            try:
                trainer.update()
            except RuntimeError as r:
                print(loss.npvalue())
                for lookup_param in model._pc.lookup_parameters_list():
                    print(lookup_param.name())
                    print(lookup_param.grad_as_array())
                for param in model._pc.parameters_list():
                    print(param.name())
                    print(param.grad_as_array())
                print(r)
                exit()

            # Calculate metrics
            stop_tok = (EOS if single_head else (EOS, NO_ARG, NO_ARG))
            per_token_metrics = compute_metrics(processed_predictions,
                                                num_reward_exps,
                                                ["entropy", "reward"], args)
            gold_token_metrics = compute_metrics(
                processed_predictions,
                sum([len(ex.actions) for ex in batch]) + len(batch),
                ["gold_probability"],
                args,
                model=model)
            per_example_metrics = compute_metrics(
                processed_predictions,
                len(batch), [
                    "distance", "completion", "invalid", "num_tokens",
                    "prefix_length"
                ],
                args,
                model=model)

            for prediction in processed_predictions:
                train_file.write(str(prediction) + "\n")
            train_file.write("=====\n")
            log_metrics({"loss": loss.npvalue()[0]}, train_file, experiment,
                        num_batches)
            log_metrics(per_token_metrics, train_file, experiment, num_batches)
            log_metrics(gold_token_metrics, train_file, experiment,
                        num_batches)
            log_metrics(per_example_metrics, train_file, experiment,
                        num_batches)
            train_file.flush()

            num_batches += 1
            progbar.update(i)

        progbar.finish()
        train_acc, _, _ = utterance_accuracy(model,
                                             train_set,
                                             fsa_builder=fsa_builder,
                                             logfile=log_dir + "/rl-train" +
                                             str(epoch) + ".log")
        val_acc, val_reward, _ = utterance_accuracy(
            model,
            val_set,
            fsa_builder=fsa_builder,
            logfile=log_dir + "/rl-val-" + str(epoch) + ".log",
            args=args,
            reward_function=reward_fn)

        val_int_acc = interaction_accuracy(model,
                                           val_interactions,
                                           fsa_builder=fsa_builder,
                                           logfile=log_dir + "/rl-val-int-" +
                                           str(epoch) + ".log")

        log_metrics(
            {
                "train_accuracy": train_acc,
                "validation_accuracy": val_acc,
                "validation_int_acc": val_int_acc,
                "validation_reward": val_reward,
                "countdown": countdown
            }, train_file, experiment, num_batches)
        if experiment is not None:
            experiment.to_zip(
                os.path.join(log_dir, "crayon-" + str(epoch) + ".zip"))
        model_file_name = log_dir + "/model-rl-epoch" + str(epoch) + ".dy"
        model.save_params(model_file_name)
        if val_int_acc > best_val_accuracy or best_model is None:
            best_model = model_file_name
            best_val_accuracy = val_int_acc

        if val_reward > best_val_reward:
            patience *= 1.005
            countdown = patience
            best_val_reward = val_reward
        else:
            countdown -= 1
        if countdown <= 0:
            print("Patience ran out -- stopping")
            break
    train_file.close()
    print('Loading parameters from best model: %s' % (best_model))
    model.load_params(best_model)
    model.save_params(log_dir + "/best_rl_model.dy")
    print(train_set[0])
    print(
        model.generate(train_set[0].utterance, train_set[0].initial_state,
                       train_set[0].history)[0])
示例#7
0
def main():
    """Actually does the training."""
    args = interpret_args()
    train_data, dev_data, val_data, test_data, in_vocab, _ = \
        load_data('../tangram/train_sequences.json',
                  '../tangram/dev_sequences.json',
                  '../tangram/test_sequences.json',
                  TangramsState,
                  args,
                  sort=False)
    train, train_interactions = train_data
    dev, dev_interactions = dev_data
    val, val_interactions = val_data
    test, test_interactions = test_data

    model = ConstrainedContextSeq2SeqEmbeddings(
        in_vocab,
        (["insert", "remove"], [str(x + 1)
                                for x in range(NUM_POSITIONS)], list(SHAPES)),
        TangramsStateEncoder, valid_action_fn, args)

    train.sort(key=lambda x: (x.turn, len(x.utterance)))

    # Train.
    # Using development set for validation -- currently only to pick the best
    # model.
    if args.evaluate:
        model.load_params(args.saved_model)
        split_to_eval = None
        interactions_to_eval = None
        if args.evaluate_split == "train":
            split_to_eval = train
            interactions_to_eval = train_interactions
        elif args.evaluate_split == "val":
            split_to_eval = val
            interactions_to_eval = val_interactions
        elif args.evaluate_split == "dev":
            split_to_eval = dev
            interactions_to_eval = dev_interactions
        elif args.evaluate_split == "test":
            split_to_eval = test
            interactions_to_eval = test_interactions
        else:
            raise ValueError("Unexpected split name " + args.evaluate_split)
        evaluate(args, split_to_eval, interactions_to_eval, model)

    if args.evaluate_attention:
        model.load_params(args.saved_model)
        for example in dev:
            attention_analysis(model,
                               example,
                               TangramsFSA,
                               name=args.logdir + "/attention/" + example.id)
    if args.supervised:
        train_and_evaluate(model,
                           train,
                           val,
                           val_interactions,
                           dev,
                           dev_interactions,
                           args,
                           fsa_builder=TangramsFSA)
    if args.rl:
        if args.pretrained:
            model.load_params(args.logdir + "/supervised_model.dy")

        model.set_dropout(args.rl_dropout)
        reinforcement_learning(model,
                               train,
                               val,
                               val_interactions,
                               args.logdir,
                               TangramsFSA,
                               reward_with_shaping,
                               model.compute_entropy,
                               TANGRAMS_ACTION_DIST,
                               args,
                               epochs=200,
                               batch_size=20,
                               single_head=False,
                               explore_with_fsa=False)

        # Save final model.
        model.save_params(args.logdir + '/model-final.dy')

    # Test results.
    if test:
        test_accuracy, _, _ = utterance_accuracy(model, test, TangramsFSA,
                                                 args.logdir + '/test.log')
        test_interaction_accuracy = interaction_accuracy(
            model, test_interactions, TangramsFSA,
            args.logdir + '/test.interactions.log')
        print('Test accuracy: %.4f (single), %.4f (interaction)' %
              (test_accuracy, test_interaction_accuracy))