Пример #1
0
def plot_generated(trainer):
    # Get a batch from the train set
    train_set_one_epoch = ArrayIterator(dataset.train, batch_size, shuffle=False)
    gen_series = trainer.predict(train_set_one_epoch, num_batches=1)
    train_set_one_epoch.reset()

    # Get an example from the batch
    gen_series = gen_series[4]

    if args.backward:
        # If args.backward is set, the autoencoder would have produced the input sequence in reverse.
        # We flip it again to match the true series
        gen_series = gen_series[::-1, :]

    true_series = next(train_set_one_epoch)['X'][4]

    # Plot the true and generated values of each series
    ncols = int(np.ceil((dataset.n_sensors + dataset.n_operating_modes) * 1.0 // 3))
    fig, ax = plt.subplots(ncols, 3)
    fig.set_figheight(20)
    fig.set_figwidth(10)

    for i in range(dataset.n_operating_modes):
        plt.subplot(ncols, 3, i + 1)
        if i == 0:
            plt.plot(true_series[:, i], label="true", color="blue")
        else:
            plt.plot(true_series[:, i], color="blue")
        if i == 0:
            plt.plot(gen_series[:, i], label="gen", color="red")
        else:
            plt.plot(gen_series[:, i], color="red")
        plt.title("Operating mode {}".format(i + 1))

    for i in range(dataset.n_sensors):
        plt.subplot(ncols, 3, dataset.n_operating_modes + i + 1)
        plt.plot(true_series[:, dataset.n_operating_modes + i], color="blue")
        plt.plot(gen_series[:, dataset.n_operating_modes + i], color="red")
        plt.title("Sensor {}".format(i + 1))
    fig.legend()

    plt.tight_layout()
    fig.savefig(os.path.join(args.results_dir, "generated_series.png"))
Пример #2
0
            error = np.mean(data['answer'].argmax(axis=1) != preds)
            test_error.append(error)

        val_cost_str = "Epoch {}: validation_cost {}, validation_error {}".format(
            e, np.mean(test_loss), np.mean(test_error))
        print(val_cost_str)
        if args.save_log:
            with open(log_file, 'a') as f:
                f.write(val_cost_str + '\n')

        # Shuffle training set and reset the others
        shuf_idx = np.random.permutation(
            range(train_set.data_arrays['memory'].shape[0]))
        train_set.data_arrays = {k: v[shuf_idx]
                                 for k, v in train_set.data_arrays.items()}
        train_set.reset()
        dev_set.reset()

    print('Training Complete.')

    if args.interactive:
        interactive_loop(interactive_computation, babi)

    if args.test:
        # Final evaluation on test set
        test_loss = []
        test_error = []
        for idx, data in enumerate(test_set):
            test_output = loss_computation(data)
            test_loss.append(np.sum(test_output['test_cross_ent_loss']))
            preds = np.argmax(test_output['test_preds'], axis=1)
Пример #3
0
            for idx, data in enumerate(test_set):
                test_output = eval_computation(data)
                test_loss.append(np.sum(test_output['test_cross_ent_loss']))
                preds = np.argmax(test_output['test_preds'], axis=0)
                error = np.mean(data['answer'] != preds)
                test_error.append(error)
            print("Epoch {}, Test_loss {}, test_batch_error {}".format(
                e, np.mean(test_loss), np.mean(test_error)))
            # Shuffle training set and reset the others
            shuf_idx = np.random.permutation(
                range(train_set.data_arrays['query'].shape[0]))
            train_set.data_arrays = {
                k: v[shuf_idx]
                for k, v in train_set.data_arrays.items()
            }
            train_set.reset()
            test_set.reset()

            if (model_file is not None and e % 50 == 0):
                print('Saving model to: ', model_file)
                weight_saver.save(filename=model_file)
else:
    print('Loading saved model')
    with closing(ngt.make_transformer()) as transformer:
        eval_computation = make_bound_computation(transformer, eval_outputs,
                                                  inputs)
        if args.interactive:
            interactive_computation = make_bound_computation(
                transformer, interactive_outputs, inputs)
        weight_saver.setup_restore(transformer=transformer,
                                   computation=eval_outputs,