Пример #1
0
def eval(opts, data=None):
    # generate and evaluate a test set for analysis
    print('eval start')
    save_path = opts.save_path
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    print('graph start')
    tf.reset_default_graph()
    if data:
        X, Y, N, vel = data
    else:
        X, Y, N, vel = inputs.create_inputs(opts)

    opts.n_inputs = X.shape[0]
    opts.batch_size = opts.n_inputs
    X_pl, Y_pl, N_pl = create_placeholders(X.shape[-1], Y.shape[-1], opts.rnn_size, X.shape[1])
    train_iter, next_element = create_tf_dataset(X_pl, Y_pl, N_pl, opts.batch_size, shuffle=False)

    print('rnn start')
    model = RNN(next_element, opts, training=False)

    save_name = opts.activity_name
    print('[*] Testing')
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(train_iter.initializer, feed_dict={X_pl: X, Y_pl: Y, N_pl: N})
        # sess.run(train_iter.initializer, feed_dict={X_pl: X, Y_pl: Y})
        print('loading saved')
        model.load()
        save_activity(model, X, Y, N, save_path, save_name)
Пример #2
0
def train(opts):
    """
    :param inputs: n x t x d input matrix
    :param labels: n x t x d label matrix
    :return:
    """
    n_epoch = opts.epoch
    save_path = opts.save_path
    n_batch_per_epoch = opts.n_input // opts.batch_size

    with tf.Graph().as_default() as graph:
        if not os.path.exists(opts.save_path):
            os.makedirs(opts.save_path)

        X, Y = inputs.create_inputs(opts)
        X_pl, Y_pl = create_placeholders(opts)
        train_iter, next_element = create_tf_dataset(X_pl, Y_pl,
                                                     opts.batch_size)
        model = RNN(next_element[0], next_element[1], opts, training=True)

        logger = defaultdict(list)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            sess.run(train_iter.initializer, feed_dict={X_pl: X, Y_pl: Y})
            if opts.load_checkpoint:
                model.load()
            else:
                rnn_helper.initialize_weights(opts)

            for ep in range(n_epoch):
                for b in range(n_batch_per_epoch):
                    cur_loss, xe_loss, weight_loss, activity_loss, _ = sess.run(
                        [
                            model.total_loss, model.xe_loss, model.weight_loss,
                            model.activity_loss, model.train_op
                        ])

                if (ep % 1 == 0 and ep > 0):  #save to loss file
                    logger['epoch'] = ep
                    logger['loss'].append(cur_loss)
                    logger['xe_loss'].append(xe_loss)
                    logger['activity_loss'].append(activity_loss)
                    logger['weight_loss'].append(weight_loss)
                if (ep % 25 == 0 and ep > 0):  #display in terminal
                    print(
                        '[*] Epoch %d  total_loss=%.2f xe_loss=%.2f a_loss=%.2f, w_loss=%.2f'
                        % (ep, cur_loss, xe_loss, activity_loss, weight_loss))

            #save latest
            model.save()
            utils.save_parameters(opts,
                                  os.path.join(save_path, opts.parameter_name))
            model.save_weights()
            save_activity(model, next_element[0], next_element[1])
            with open(os.path.join(save_path, opts.log_name + '.pkl'),
                      'wb') as f:
                pkl.dump(logger, f)
    fig, ax = plt.subplots(rc[0], rc[1])
    state = inputs[:rc[0]]
    labels = labels[:rc[0]]
    i = 0
    for batch in zip(state, labels):
        for d in batch:
            plot_ix = np.unravel_index(i, rc)
            cur_ax = ax[plot_ix]
            adjust(cur_ax)
            plt.sca(cur_ax)
            # cbarBoo = True if i %2==1 else True
            plt.imshow(d, cmap='RdBu_r', vmin=-.3, vmax=.3)
            plt.xticks([0, 19])
            plt.yticks([0, 49])
            cb = plt.colorbar()
            cb.set_ticks([0, .3])
            i += 1
    path = os.path.join('./lab_meeting/images', 'input_stationary')
    plt.savefig(path + '.png', dpi=300)


if __name__ == '__main__':
    stationary = config.stationary_input_config()
    non_stationary = config.non_stationary_input_config()
    stationary.time_steps = 50
    non_stationary.time_steps = 50
    x, y = inputs.create_inputs(stationary)
    plot_stationary_inputs(x, y, stationary)
    x, y = inputs.create_inputs(non_stationary)
    plot_moving_inputs(x, y, non_stationary)
Пример #4
0
 def __init__(self, opts):
     X, Y = inputs.create_inputs(opts)
     self.X = X
     self.Y = Y
Пример #5
0
def train(opts, seed=False):
    """
    :param inputs: n x t x d input matrix
    :param labels: n x t x d label matrix
    :return:
    """
    n_epoch = opts.epoch
    print('n epoch', n_epoch)
    save_path = opts.save_path
    n_batch_per_epoch = opts.n_input // opts.batch_size
    tf.reset_default_graph()
    # with tf.Graph().as_default() as graph:
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    if seed:
        seednum = 2
        np.random.seed(seednum)
        tf.set_random_seed(seednum)

    X, Y, N, _ = inputs.create_inputs(opts)
    X_pl, Y_pl, N_pl = create_placeholders(X.shape[-1], Y.shape[-1], opts.rnn_size, X.shape[1])
    train_iter, next_element = create_tf_dataset(X_pl, Y_pl, N_pl, opts.batch_size)
    model = RNN(next_element, opts, training=True)

    logger = defaultdict(list)  # return an empty list for keys not present, set those keys to a value of empty list
    print('starting')
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(train_iter.initializer, feed_dict={X_pl: X, Y_pl: Y, N_pl: N})
        if opts.load_checkpoint:
            model.load()

        t = time.perf_counter()
        for ep in range(n_epoch):
            for b in range(n_batch_per_epoch):
                cur_loss, error_loss, weight_loss, activity_loss, states, _ = sess.run(
                    [model.total_loss, model.error_loss, model.weight_loss,
                     model.activity_loss, model.states, model.train_op])
                # grads_and_vars = sess.run(model.grad)[0]
                # grad, var = grads_and_vars
                # print(grad)
                # print(Whh,'\n',Wxh,'\n',Wout)
                assert not np.isnan(error_loss), "Error is NaN, retry"

            if (ep % 1 == 0 and ep>0):  # save to loss file
                logger['epoch'] = ep
                logger['loss'].append(cur_loss)
                logger['error_loss'].append(error_loss)
                logger['activity_loss'].append(activity_loss)
                logger['weight_loss'].append(weight_loss)

            # if (ep+1) % 25 == 0: #display in terminal
            # if (ep+1) % 10 == 0:
            #     print('[*] Epoch %d  total_loss=%.2f error_loss=%.2f a_loss=%.2f, w_loss=%.2f'
            #           % (ep+1, cur_loss, error_loss, activity_loss, weight_loss))
            print('[*] Epoch %d  total_loss=%.2f error_loss=%.2f a_loss=%.2f, w_loss=%.2f'
                  % (ep, cur_loss, error_loss, activity_loss, weight_loss))
            # Whh, Wxh, Wout = sess.run([model.Whh, model.Wxh, model.Wout])
            # print(Whh)
            # print(states[0,:10,:])
            tnew = time.perf_counter()
            print(f'{tnew - t} seconds elapsed')
            t = tnew

        model.save(save_path)
        model.save_weights(save_path)
        with open(os.path.join(save_path, opts.log_name + '.pkl'), 'wb') as f:
            pkl.dump(logger, f)

    data = {'X': X, 'Y': Y, 'N': N}
    train_path = os.path.join(save_path, 'training_set.pkl')
    with open(train_path, 'wb') as f:
        pkl.dump(data, f)
    save_name = os.path.join(save_path, opts.parameter_name)
    utils.save_parameters(opts, save_name)
    return opts, save_name