Beispiel #1
0
with tf.variable_scope('optimizer'):
    d_optim = tf.train.AdamOptimizer(learning_rate= 1e-4 ).minimize(train_d_loss, var_list=d_param)
    g_optim = tf.train.AdamOptimizer(learning_rate= 1e-4 ).minimize(train_g_loss, var_list=g_param)
    saver   = tf.train.Saver()

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)

with tf.Session() as sess:
    if Is_train == False:# for testing the trained model
        saver.restore(sess,'./GANmodel.ckpt')
        writer = tf.summary.FileWriter('./graphs', sess.graph)
    if Is_train == True:# for training the model
        training_data = data_loader_special.load_data_wrapper() #uploading the data
        tvals = np.repeat(np.linspace(0.1,1.9,n_temps),10000)
        c = list(zip(training_data,tvals))
        random.shuffle(c) # pairing and shuffling the data and temeperature
        training_data, tvals = zip(*c)
        print(len(training_data),len(tvals))

        m = tf.placeholder(tf.float32,[datapoints, lattice_size, lattice_size,1])
        n = tf.placeholder(tf.float32,[datapoints,lattice_size+2,lattice_size+2,1])
        b = tf.placeholder(tf.float32,[datapoints, n_z])
        # Uploading the data to prevent memeory issues prefetch and batching
        dataset = tf.data.Dataset.from_tensor_slices((m,n,b))
        dataset = dataset.prefetch(buffer_size=100)
        dataset = dataset.batch(batch_size)
        iterator = dataset.make_initializable_iterator()
        next = iterator.get_next()
Beispiel #2
0
# 4. Update weights
g_param = tf.trainable_variables(scope='Generator')
d_param = tf.trainable_variables(scope='Discriminator')
print(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))

with tf.name_scope('optimizer'):
    d_optim = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(
        train_d_loss, var_list=d_param)
    g_optim = tf.train.AdamOptimizer(learning_rate=1e-3).minimize(
        train_g_loss, var_list=g_param)
    saver = tf.train.Saver()

with tf.Session() as sess:
    # saver.restore(sess,'./GANmodel.ckpt')
    # writer = tf.summary.FileWriter('./graphs', sess.graph)
    training_data = data_loader_special.load_data_wrapper()
    tvals = np.repeat(np.linspace(0.1, 2.0, 32), 10000)
    c = list(zip(training_data, tvals))
    random.shuffle(c)
    training_data, tvals = zip(*c)
    print(len(training_data), len(tvals))
    m = tf.placeholder(tf.float32, [datapoints, 128])
    n = tf.placeholder(tf.float32, [datapoints, 1])
    dataset = tf.data.Dataset.from_tensor_slices((m, n))
    dataset = dataset.prefetch(buffer_size=1000)
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_initializable_iterator()
    next = iterator.get_next()
    print("============< WARNING >===============")
    sess.run(tf.global_variables_initializer())
    print("==========< Model DELETED >===========")