def get_model(sess, image_shape=(80, 160, 3), gf_dim=64, df_dim=64, batch_size=64, name="transition", gpu=0): K.set_session(sess) checkpoint_dir = './outputs/results_' + name with tf.variable_scope(name): # sizes ch = image_shape[2] rows = [image_shape[0]/i for i in [16, 8, 4, 2, 1]] cols = [image_shape[1]/i for i in [16, 8, 4, 2, 1]] G = autoencoder.generator(7*(time+out_leng-1), gf_dim, ch, rows, cols) G.compile("sgd", "mse") E = autoencoder.encoder(batch_size*(time+out_leng), df_dim, ch, rows, cols) E.compile("sgd", "mse") G.trainable = False E.trainable = False # nets T = transition(batch_size) T.compile("sgd", "mse") t_vars = T.trainable_weights print "T.shape: ", T.output_shape Img = Input(batch_shape=(batch_size, time+out_leng,) + image_shape) I = K.reshape(Img, (batch_size*(time+out_leng),)+image_shape) code = E(I)[0] code = K.reshape(code, (batch_size, time+out_leng, z_dim)) target = code[:, 1:, :] inp = code[:, :time, :] out = T(inp) G_dec = G(K.reshape(out[:7, :, :], (-1, z_dim))) # costs loss = tf.reduce_mean(tf.square(target - out)) print "Transition variables:" for v in t_vars: print v.name t_optim = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(loss, var_list=t_vars) tf.initialize_all_variables().run() # summaries sum_loss = tf.scalar_summary("loss", loss) sum_e_mean = tf.histogram_summary("e_mean", code) sum_out = tf.histogram_summary("out", out) sum_dec = tf.image_summary("E", G_dec) # saver saver = tf.train.Saver() t_sum = tf.merge_summary([sum_e_mean, sum_out, sum_dec, sum_loss]) writer = tf.train.SummaryWriter("/tmp/logs/"+name, sess.graph) # functions def train_d(images, z, counter, sess=sess): return 0, 0, 0 def train_g(images, z, counter, sess=sess): outputs = [loss, G_dec, t_sum, t_optim] outs = sess.run(outputs, feed_dict={Img: images, K.learning_phase(): 1}) gl, samples, sums = outs[:3] writer.add_summary(sums, counter) images = images.reshape((-1, 80, 160, 3))[:64] samples = samples.reshape((-1, 80, 160, 3))[:64] return gl, samples, images def f_load(): try: return load(sess, saver, checkpoint_dir, name) except: print("Loading weights via Keras") T.load_weights(checkpoint_dir+"/T_weights.keras") def f_save(step): save(sess, saver, checkpoint_dir, step, name) T.save_weights(checkpoint_dir+"/T_weights.keras", True) def sampler(z, x): video = np.zeros((128, 80, 160, 3)) print "Sampling..." for i in range(128): print i x = x.reshape((-1, 80, 160, 3)) # code = E.predict(x, batch_size=batch_size*(time+1))[0] code = sess.run([E(I)[0]], feed_dict={I: x, K.learning_phase(): 1})[0] code = code.reshape((batch_size, time+out_leng, z_dim)) inp = code[:, :time] outs = T.predict(inp, batch_size=batch_size) # imgs = G.predict(out, batch_size=batch_size) imgs = sess.run([G_dec], feed_dict={out: outs, K.learning_phase(): 1})[0] video[i] = imgs[0] x = x.reshape((batch_size, time+out_leng, 80, 160, 3)) x[0, :-1] = x[0, 1:] x[0, -1] = imgs[0] video = video.reshape((batch_size, 2, 80, 160, 3)) return video[:, 0], video[:, 1] G.load_weights(G_file_path) E.load_weights(E_file_path) return train_g, train_d, sampler, f_save, f_load, [G, E, T]
def get_model(sess, image_shape=(80, 160, 3), gf_dim=64, df_dim=64, batch_size=64, name="transition", gpu=0): K.set_session(sess) checkpoint_dir = './results_' + name with tf.variable_scope(name): # sizes ch = image_shape[2] rows = [image_shape[0]/i for i in [16, 8, 4, 2, 1]] cols = [image_shape[1]/i for i in [16, 8, 4, 2, 1]] G = autoencoder.generator(batch_size*out_leng, gf_dim, ch, rows, cols) G.compile("sgd", "mse") E = autoencoder.encoder(batch_size*(time+out_leng), df_dim, ch, rows, cols) E.compile("sgd", "mse") G.trainable = False E.trainable = False # nets T = transition(batch_size) T.compile("sgd", "mse") t_vars = T.trainable_weights print "T.shape: ", T.output_shape Img = Input(batch_shape=(batch_size, time+out_leng,) + image_shape) Z = Input(batch_shape=(batch_size, time+out_leng, 2)) # controls signal I = K.reshape(Img, (batch_size*(time+out_leng),)+image_shape) code = E(I)[0] code = K.reshape(code, (batch_size, time+out_leng, z_dim)) inp = K.concatenate([Z, code], axis=2) target = code[:, time:, :] out = T(inp) G_dec = G(K.reshape(out, (batch_size*out_leng, z_dim))) # costs loss = tf.reduce_mean(tf.square(target - out)) print "Transition variables:" for v in t_vars: print v.name t_optim = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(loss, var_list=t_vars) tf.initialize_all_variables().run() # summaries sum_loss = tf.scalar_summary("loss", loss) sum_e_mean = tf.histogram_summary("e_mean", code) sum_out = tf.histogram_summary("out", out) sum_dec = tf.image_summary("E", G_dec) # saver saver = tf.train.Saver() t_sum = tf.merge_summary([sum_e_mean, sum_out, sum_dec, sum_loss]) writer = tf.train.SummaryWriter("/tmp/logs/"+name, sess.graph) # functions def train_d(images, z, counter, sess=sess): return 0, 0, 0 def train_g(images, z, counter, sess=sess): outputs = [loss, G_dec, t_sum, t_optim] outs = sess.run(outputs, feed_dict={Img: images, Z: z, K.learning_phase(): 1}) gl, samples, sums = outs[:3] writer.add_summary(sums, counter) images = images[:, time:].reshape((-1, 80, 160, 3))[:64] samples = samples.reshape((-1, 80, 160, 3))[:64] return gl, samples, images def f_load(): try: return load(sess, saver, checkpoint_dir, name) except: print("Loading weights via Keras") T.load_weights(checkpoint_dir+"/T_weights.keras") def f_save(step): save(sess, saver, checkpoint_dir, step, name) T.save_weights(checkpoint_dir+"/T_weights.keras", True) def sampler(z, x): video = np.zeros((128, 80, 160, 3)) print "Sampling..." for i in range(128): print i x = x.reshape((-1, 80, 160, 3)) # code = E.predict(x, batch_size=batch_size*(time+1))[0] code = sess.run([E(I)[0]], feed_dict={I: x, Z: z, K.learning_phase(): 1})[0] code = code.reshape((batch_size, time+out_leng, z_dim)) inp = np.concatenate([z, code], axis=-1) outs = T.predict(inp, batch_size=batch_size) # [:, :out_leng, :] # imgs = G.predict(out, batch_size=batch_size) imgs = sess.run([G_dec], feed_dict={out: outs, Z: z, K.learning_phase(): 1})[0] video[i] = imgs[0] x = x.reshape((batch_size, time+out_leng, 80, 160, 3)) x[0, :-1] = x[0, 1:] x[0, -1] = imgs[0] z[0, :-1] = z[0, 1:] video = video.reshape((batch_size, 2, 80, 160, 3)) return video[:, 0], video[:, 1] G.load_weights(G_file_path) E.load_weights(E_file_path) return train_g, train_d, sampler, f_save, f_load, [G, E, T]