예제 #1
0
def evaluate(model='ram', width=60, n_distractors=4, N=10):
    """Tests performance of trained model on larger/noisier mnist images."""

    # data
    mnist = input_data.read_data_sets('MNIST_data', one_hot=False)

    # set parameters
    #config.loc_std          = 1e-10
    config.num_glimpses = FLAGS.num_glimpses
    config.n_patches = FLAGS.n_patches
    config.use_context = FLAGS.use_context
    config.convnet = FLAGS.convnet
    config.sensor_size = config.glimpse_size**2 * config.n_patches * config.num_channels
    config.N = mnist.train.num_examples  # number of training examples

    config.new_size = width
    config.n_distractors = n_distractors

    # init model
    print('\n-- Model: {} --'.format(model))
    print('Setting samplding SD to {:.4e}'.format(config.loc_std))
    if model == 'ram':
        net = RAM(config)
    elif model == 'dram':
        net = DRAM(config)
    elif model == 'dram_loc':
        net = DRAMl(config)
    else:
        print('Unknown model {}'.format(model))
        exit()

    net.load(FLAGS.load)  # restore
    net.count_params()

    #params = net.return_params(['context_network/conv0/w:0'])
    #net.plot_filters(params[0], fname=FLAGS.plot_dir + '.pdf')
    #exit()

    if FLAGS.visualize:

        # create plot for current parameters
        plot_dir = os.path.join(FLAGS.load, FLAGS.plot_dir)
        if not os.path.exists(plot_dir):
            os.mkdir(plot_dir)

        task = {
            'variant': FLAGS.task,
            'width': width,
            'n_distractors': n_distractors
        }
        net.visualize(data=mnist,
                      task=task,
                      config=config,
                      N=N,
                      plot_dir=plot_dir)

    # evaluate
    #test, val = net.evaluate(data=mnist, task=FLAGS.task)

    return test, val
예제 #2
0
    elif FLAGS.model == 'dram':
        print('\n\n\nTraining DRAM\n\n\n')
        model = DRAM(config, logdir=FLAGS.logdir)
    elif FLAGS.model == 'dram_loc':
        print('\n\n\nTraining DRAM with location ground truth\n\n\n')
        model = DRAMl(config, logdir=FLAGS.logdir)
    else:
        print(('Unknown model {}'.format(FLAGS.model)))
        exit()

    # load if specified
    if FLAGS.load is not None:
        model.load(FLAGS.load)
        model.visualize(config=[],
                        data=mnist,
                        task={
                            'variant': 'cluttered',
                            'width': 60,
                            'n_distractors': 4
                        },
                        plot_dir='.',
                        N=10,
                        seed=None)
    # display # parameters
    model.count_params()

    # train
    model.train(mnist, FLAGS.task)

    model.evaluate(data=mnist, task=FLAGS.task)
예제 #3
0
def evaluate_numglimpses(model='dram_loc', visualize=False, N=10):
    """Tests performance of trained model on larger/noisier mnist images."""

    # data
    mnist = input_data.read_data_sets('MNIST_data', one_hot=False)

    # set parameters
    n_glimpses = [1, 2, 3, 4, 5, 6, 7, 8]
    n_reps = N
    width, noise = 100, 4

    RESULTS = {}

    for n in n_glimpses:

        # set parameters
        config.num_glimpses = n
        config.n_patches = FLAGS.n_patches
        config.use_context = FLAGS.use_context
        config.convnet = FLAGS.convnet

        config.sensor_size = config.glimpse_size**2 * config.n_patches * config.num_channels
        config.N = mnist.train.num_examples  # number of training examples

        config.new_size = width
        config.n_distractors = noise

        # init model
        print('\n-- Model: {} --'.format(model))
        print('Setting samplding SD to {:.4e}'.format(config.loc_std))
        tf.reset_default_graph()
        if model == 'ram':
            net = RAM(config)
        elif model == 'dram':
            net = DRAM(config)
        elif model == 'dram_loc':
            net = DRAMl(config)
        else:
            print('Unknown model {}'.format(model))
            exit()
        net.load(FLAGS.load)  # restore

        if FLAGS.visualize:
            n_reps = 1

            # create plot for current parameters
            subfolder = os.path.join(FLAGS.load, FLAGS.plot_dir)
            if not os.path.exists(subfolder):
                os.mkdir(subfolder)
            plot_dir = os.path.join(
                subfolder, 'w={}_n_distractors={}'.format(width, noise))
            if not os.path.exists(plot_dir):
                os.mkdir(plot_dir)

            task = {
                'variant': FLAGS.task,
                'width': width,
                'n_distractors': noise
            }
            net.visualize(data=mnist,
                          task=task,
                          config=config,
                          plot_dir=plot_dir,
                          N=N)

        # evaluate (n_reps) times
        acc, _ = evaluate_repeatedly(ram=net,
                                     data=mnist,
                                     task=FLAGS.task,
                                     n_reps=n_reps)
        print(acc)

        # store results
        RESULTS[n] = acc

    # save dictionary
    with open(
            os.path.join(
                FLAGS.load,
                'glimpses{}_results.pickle'.format(FLAGS.num_glimpses)),
            'wb') as handle:
        pickle.dump(RESULTS, handle, protocol=pickle.HIGHEST_PROTOCOL)