class Texture_dataset_val(Dataset):
    def __init__(self, data_size, textures_path, max_region=10):
        self.data_size = data_size
        self.data = Data_loader(textures_path, 1, max_region)
        self.preload = []
        for i in range(self.data_size):
            x, y, x_ref = self.data.get_batch_data()
            x = x[0]
            y = y[0]
            x_ref = x_ref[0]
            x = np.swapaxes(x, 1, 2)
            x = np.swapaxes(x, 0, 1)
            y = np.swapaxes(y, 1, 2)
            y = np.swapaxes(y, 0, 1)
            x_ref = np.swapaxes(x_ref, 1, 2)
            x_ref = np.swapaxes(x_ref, 0, 1)
            x, y, x_ref = x.astype('float32'), y.astype(
                'float32'), x_ref.astype('float32')
            self.preload.append((x, y, x_ref))

    def __len__(self):
        return self.data_size

    def __getitem__(self, idx):
        return self.preload[idx]
class Texture_dataset_train(Dataset):
    def __init__(self, data_size, textures_path, max_region=10):
        self.data_size = data_size
        self.data = Data_loader(textures_path, 1, max_region)

    def __len__(self):
        return self.data_size

    def __getitem__(self, idx):
        x, y, x_ref = self.data.get_batch_data()
        x = x[0]
        y = y[0]
        x_ref = x_ref[0]
        x = np.swapaxes(x, 1, 2)
        x = np.swapaxes(x, 0, 1)
        y = np.swapaxes(y, 1, 2)
        y = np.swapaxes(y, 0, 1)
        x_ref = np.swapaxes(x_ref, 1, 2)
        x_ref = np.swapaxes(x_ref, 0, 1)
        x, y, x_ref = x.astype('float32'), y.astype('float32'), x_ref.astype(
            'float32')
        return x, y, x_ref
Beispiel #3
0
def train(data_path, batch_size, max_steps, eval_n_step, init_lr=1e-5):
	is_training = tf.placeholder(tf.bool)
	texture, ref, label, decode_mask = get_model(is_training)
	dl_train = Data_loader(data_path['train'], batch_size)
	dl_val = Data_loader(data_path['val'], batch_size)
	lr = tf.placeholder(tf.float32)
	learning_rate = init_lr
	# opt = tf.train.AdamOptimizer(learning_rate=lr)
	opt = tf.contrib.opt.NadamOptimizer(learning_rate=lr)

	# loss = total_loss(decode_mask, label)
	loss = tf.keras.backend.binary_crossentropy(label, decode_mask)
	loss = tf.reduce_mean(loss)
	print(decode_mask.shape, label.shape, loss.shape)
	# update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
	# with tf.control_dependencies(update_ops):
		# train_op = opt.minimize(loss=loss, global_step=tf.train.get_global_step())
	train_op = opt.minimize(loss)

	saver = tf.train.Saver()
	saver_best = tf.train.Saver()
	init = tf.global_variables_initializer()

	best_loss = np.inf
	loss_log = {'train':[], 'val':[]}
	no_best_cnt = 0
	with tf.Session() as sess:
		sess.run(init)
		# saver.restore(sess, '/fast_data/one_shot_texture_models/best_model0.51105') #55298
		cur_train_loss = 0
		tic = time.time()
		for i in range(max_steps):
			data = dl_train.get_batch_data()	# batch, mask, ref
			_, train_loss = sess.run([train_op, loss], 
				feed_dict={texture: data[0], ref:data[2], label: data[1], lr: learning_rate, is_training: True})
			stdout.write('\r%d, %.5f' % (i, train_loss))
			stdout.flush()
			# print(train_loss)
			cur_train_loss += train_loss

			if i % eval_n_step == 0:
				stdout.write('\r')
				stdout.flush()
				toc = time.time()
				# evaluate validation loss for 10 step
				val_loss = 0
				for _ in range(10):
					test_data = dl_val.get_batch_data()
					val_loss += sess.run(loss, 
						feed_dict={texture: test_data[0], ref:test_data[2], label: test_data[1], is_training: False})
				val_loss /= 10
				if val_loss < best_loss:
					best_loss = val_loss
					print('saving best model (%.5f)' % best_loss)
					saver_best.save(sess, '/fast_data/one_shot_texture_models/best_model%.5f' % best_loss)
					no_best_cnt = 0
				else:
					no_best_cnt += 1

				cur_train_loss /= eval_n_step

				print('%7d/%7d training loss: %.5f, validation loss: %.5f (%d sec)' % (i, max_steps, cur_train_loss, val_loss, toc - tic))
				loss_log['train'].append(cur_train_loss)
				loss_log['val'].append(val_loss)

				cur_train_loss = 0

				saver.save(sess, '/fast_data/one_shot_texture_models/model', global_step=i)

				# if no_best_cnt > 10:
				# 	learning_rate /= 2
				# 	print('setting leaning rate to:', learning_rate)
				# 	no_best_cnt = 0

				tic = time.time()

			if (i + 1) % 400 == 0:
				learning_rate /= 2
				print('setting leaning rate to:', learning_rate)

	np.save('train_log', loss_log['train'])
	np.save('val_log', loss_log['val'])