Exemplo n.º 1
0
    def build_model(self, p):
        S = Input(p['input_shape'], name='input_state')
        A = Input((1,), name='input_action', dtype='int32')
        R = Input((1,), name='input_reward')
        T = Input((1,), name='input_terminate', dtype='int32')
        NS = Input(p['input_shape'], name='input_next_sate')

        self.Q_model = self.build_cnn_model(p)
        self.Q_old_model = self.build_cnn_model(p, False)  # Q hat in paper
        self.Q_old_model.set_weights(self.Q_model.get_weights())  # Q' = Q

        Q_S = self.Q_model(S)  # batch * actions
        Q_NS = disconnected_grad(self.Q_old_model(NS))  # disconnected gradient is not necessary

        y = R + p['discount'] * (1-T) * K.max(Q_NS, axis=1, keepdims=True)  # batch * 1

        action_mask = K.equal(Tht.arange(p['num_actions']).reshape((1, -1)), A.reshape((-1, 1)))
        output = K.sum(Q_S * action_mask, axis=1).reshape((-1, 1))
        loss = K.sum((output - y) ** 2)  # sum could also be mean()

        optimizer = adam(p['learning_rate'])
        params = self.Q_model.trainable_weights
        update = optimizer.get_updates(params, [], loss)

        self.training_func = K.function([S, A, R, T, NS], loss, updates=update)
        self.Q_func = K.function([S], Q_S)
Exemplo n.º 2
0
    def build_model(self, p):
        S = Input(p['input_shape'], name='input_state')
        A = Input((1, ), name='input_action', dtype='int32')
        R = Input((1, ), name='input_reward')
        T = Input((1, ), name='input_terminate', dtype='int32')
        NS = Input(p['input_shape'], name='input_next_sate')

        self.Q_model = self.build_cnn_model(p)
        self.Q_old_model = self.build_cnn_model(p, False)  # Q hat in paper
        self.Q_old_model.set_weights(self.Q_model.get_weights())  # Q' = Q

        Q_S = self.Q_model(S)  # batch * actions
        Q_NS = disconnected_grad(
            self.Q_old_model(NS))  # disconnected gradient is not necessary

        y = R + p['discount'] * (1 - T) * K.max(Q_NS, axis=1,
                                                keepdims=True)  # batch * 1

        action_mask = K.equal(
            Tht.arange(p['num_actions']).reshape((1, -1)), A.reshape((-1, 1)))
        output = K.sum(Q_S * action_mask, axis=1).reshape((-1, 1))
        loss = K.sum((output - y)**2)  # sum could also be mean()

        optimizer = adam(p['learning_rate'])
        params = self.Q_model.trainable_weights
        update = optimizer.get_updates(params, [], loss)

        self.training_func = K.function([S, A, R, T, NS], loss, updates=update)
        self.Q_func = K.function([S], Q_S)
Exemplo n.º 3
0
        figure[i * digit_size:(i + 1) * digit_size,
               j * digit_size:(j + 1) * digit_size] = digit

fig = plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()
fig.savefig('x_{}.png'.format(use_loss))

# data imputation
figure = np.zeros((digit_size * 3, digit_size * n))
x = x_test[:batch_size, :]
x_corupted = np.copy(x)
x_corupted[:, 300:400] = 0
x_encoded = vae.predict(x_corupted, batch_size=batch_size).reshape(
    (-1, digit_size, digit_size))
x = x.reshape((-1, digit_size, digit_size))
x_corupted = x_corupted.reshape((-1, digit_size, digit_size))
for i in range(n):
    xi = x[i]
    xi_c = x_corupted[i]
    xi_e = x_encoded[i]
    figure[:digit_size, i * digit_size:(i + 1) * digit_size] = xi
    figure[digit_size:2 * digit_size,
           i * digit_size:(i + 1) * digit_size] = xi_c
    figure[2 * digit_size:, i * digit_size:(i + 1) * digit_size] = xi_e

fig = plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()
fig.savefig('i_{}.png'.format(use_loss))
Exemplo n.º 4
0
results = []

for ii in xrange(args.epochs):
    t0 = time.time()
    loss, max_margin_loss = 0., 0.

    print("epoch: ", ii)
    for b in tqdm(xrange(batches_per_epoch)):
        sen_input = sen_gen.next()

        neg_input = neg_gen.next()
        #sen_input2 = sen_gen2.next()

        #sen_input = sen_input2[:,0:-1].reshape((args.batch_size, args.kstep, node_size))
        sen_input = sen_input.reshape((args.batch_size, args.kstep, node_size))
        neg_input = neg_input.reshape(
            (args.batch_size, args.neg_size, args.kstep, node_size))
        # lable_data = sen_input2[:,-1]
        #
        #
        # one_hot_lable = np.zeros((lable_data.shape[0], 17))
        # for i in range(lable_data.shape[0]):
        #     x = lable_data[i]
        #     one_hot_lable[i, x.astype(int)] = 1

        #batch_loss, batch_max_margin_loss = model_auto.train_on_batch([sen_input, neg_input, sen_input2[:,-1]], [np.ones((args.batch_size, 1)), np.ones((args.batch_size, 1))])
        #batch_loss, batch_max_margin_loss = model_auto.train_on_batch([sen_input, one_hot_lable],
        #                                                                 np.ones((args.batch_size, 1)))

        batch_loss, batch_max_margin_loss = model_auto.train_on_batch(
            [sen_input, neg_input], np.ones((args.batch_size, 1)))
        #batch_loss, batch_max_margin_loss = model_lable.train_on_batch(input_data2, one_hot_lable)
Exemplo n.º 5
0
    plt.yticks([])
    plt.show()
    fig.savefig('./fig/x_{}_latent_{}_ep_{}_n_{}.png'.format(
        use_loss, dim_latent, epoch_real, n))

#%% Visualization: Reconstruction

if plot_on:

    n = 20  #figure with 15x15 digits
    m = int(np.sqrt(dim_x))  #digit size

    figure = np.zeros((m * 2, m * n))
    x = x_test[:batch_size, :]
    x_recon = vae.predict(x, batch_size=batch_size).reshape((-1, m, m))
    x = x.reshape((-1, m, m))
    x_recon = x_recon.reshape((-1, m, m))
    for i in range(n):
        figure[:m, i * m:(i + 1) * m] = x[i]
        figure[m:2 * m, i * m:(i + 1) * m] = x_recon[i]

    fig = plt.figure(figsize=(10, 10))
    plt.imshow(figure, cmap='Greys_r')
    plt.title('Image Reconstruction')
    plt.xticks([])
    plt.yticks(m * np.array([.5, 1.5, 2.5]), ['Origin', 'Re-con'])
    fig.savefig('./fig/re_{}_latent_{}_ep_{}_n_{}.png'.format(
        use_loss, dim_latent, epoch_real, n))
    plt.show()

#%% Visualization: Image Imputation
Exemplo n.º 6
0
    samples_image.append(imgs)

with open('train_samples.pkl', 'wb') as f:
    pickle.dump(samples_image, f)

with open('train_samples.pkl', 'rb') as f:
    samples = pickle.load(f)
    # view_samples(-1, samples)
    epoch_idx = [0, 5, 10, 20, 40, 60, 80, 100, 150, 250]  # 一共300轮,不要越界
    show_imgs = []
    for i in epoch_idx:
        show_imgs.append(samples[i])

    # 指定图片形状
    rows, cols = 10, 25
    fig, axes = plt.subplots(figsize=(30, 12),
                             nrows=rows,
                             ncols=cols,
                             sharex=True,
                             sharey=True)

    idx = range(0, cfg.EPOCH_NUM, int(cfg.EPOCH_NUM / rows))

    for sample, ax_row in zip(show_imgs, axes):
        for img, ax in zip(sample[::int(len(sample) / cols)], ax_row):
            ax.imshow(img.reshape((28, 28)), cmap='Greys_r')
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)

    plt.show()