def __init__(self, checkpoints, relu_targets, vgg_path, device='/gpu:0', ss_patch_size=3, ss_stride=1): ''' Args: checkpoints: List of trained decoder model checkpoint dirs relu_targets: List of relu target layers corresponding to decoder checkpoints vgg_path: Normalised VGG19 .t7 path device: String for device ID to load model onto ''' self.ss_patch_size = ss_patch_size self.ss_stride = ss_stride graph = tf.get_default_graph() with graph.device(device): # Build the graph self.model = WCTModel(mode='test', relu_targets=relu_targets, vgg_path=vgg_path, ss_patch_size=self.ss_patch_size, ss_stride=self.ss_stride) self.content_input = self.model.content_input self.decoded_output = self.model.decoded_output # self.style_encoded = None config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True self.sess = tf.Session(config=config) self.sess.run(tf.global_variables_initializer()) # Load decoder vars one-by-one into the graph for relu_target, checkpoint_dir in zip(relu_targets, checkpoints): decoder_prefix = 'decoder_{}'.format(relu_target) relu_vars = [ v for v in tf.trainable_variables() if decoder_prefix in v.name ] saver = tf.train.Saver(var_list=relu_vars) ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: print('Restoring vars for {} from checkpoint {}'.format( relu_target, ckpt.model_checkpoint_path)) saver.restore(self.sess, ckpt.model_checkpoint_path) else: raise Exception( 'No checkpoint found for target {} in dir {}'.format( relu_target, checkpoint_dir))
def __init__(self, checkpoints, relu_targets, vgg_path, device='/gpu:0', ss_patch_size=3, ss_stride=1, alpha=0.5, beta=0.5): ''' Args: checkpoints: List of trained decoder model checkpoint dirs relu_targets: List of relu target layers corresponding to decoder checkpoints vgg_path: Normalised VGG19 .t7 path device: String for device ID to load model onto ''' self.ss_patch_size = ss_patch_size self.ss_stride = ss_stride # Build the graph self.model = WCTModel(relu_targets=relu_targets, vgg_path=vgg_path, alpha=alpha, beta=beta) config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True # Load decoder vars one-by-one into the graph for i, checkpoint_dir in enumerate(checkpoints): if os.path.exists(checkpoint_dir): checkpoint = tf.train.Checkpoint(model=self.model.decoders[i]) checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir)) else: raise Exception( 'No checkpoint found for target {} in dir {}'.format( relu_targets[0], checkpoint_dir))
def train(): batch_shape = (args.batch_size, 256, 256, 3) with tf.Graph().as_default(): tf.logging.set_verbosity(tf.logging.INFO) ### Setup data loading queue queue_input_content = tf.placeholder(tf.float32, shape=batch_shape) queue_input_val = tf.placeholder(tf.float32, shape=batch_shape) queue = tf.FIFOQueue(capacity=100, dtypes=[tf.float32, tf.float32], shapes=[[256, 256, 3], [256, 256, 3]]) enqueue_op = queue.enqueue_many([queue_input_content, queue_input_val]) dequeue_op = queue.dequeue() content_batch_op, val_batch_op = tf.train.batch( dequeue_op, batch_size=args.batch_size, capacity=100) def enqueue(sess): content_images = batch_gen(args.content_path, batch_shape) val_images = batch_gen(args.val_path, batch_shape) while True: content_batch = next(content_images) val_batch = next(val_images) sess.run(enqueue_op, feed_dict={ queue_input_content: content_batch, queue_input_val: val_batch }) ### Build the model graph and train/summary ops model = WCTModel(mode='train', relu_targets=[args.relu_target], vgg_path=args.vgg_path, batch_size=args.batch_size, feature_weight=args.feature_weight, pixel_weight=args.pixel_weight, tv_weight=args.tv_weight, learning_rate=args.learning_rate, lr_decay=args.lr_decay).encoder_decoders[0] saver = tf.train.Saver(max_to_keep=None) config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: enqueue_thread = threading.Thread(target=enqueue, args=[sess]) enqueue_thread.isDaemon() enqueue_thread.start() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord, sess=sess) log_path = args.log_path if args.log_path is not None else os.path.join( args.checkpoint, 'log') summary_writer = tf.summary.FileWriter(log_path, sess.graph) sess.run(tf.global_variables_initializer()) def load_latest(): if os.path.exists(os.path.join(args.checkpoint, 'checkpoint')): print("Restoring checkpoint") saver.restore(sess, tf.train.latest_checkpoint(args.checkpoint)) load_latest() for iteration in range(args.max_iter): start = time.time() content_batch = sess.run(content_batch_op) fetches = { 'train': model.train_op, 'global_step': model.global_step, # 'summary': model.summary_op, 'lr': model.learning_rate, 'feature_loss': model.feature_loss, 'pixel_loss': model.pixel_loss, 'tv_loss': model.tv_loss } feed_dict = {model.content_input: content_batch} try: results = sess.run(fetches, feed_dict=feed_dict) except Exception as e: print(e) import IPython IPython.embed() # Sometimes training NaNs out and has to be restarted. In that case, reload the last checkpoint and resume. print( "Exception encountered, re-loading latest checkpoint") load_latest() continue ### Run a val batch and log the summaries if iteration % args.summary_iter == 0: val_batch = sess.run(val_batch_op) summary = sess.run( model.summary_op, feed_dict={model.content_input: val_batch}) summary_writer.add_summary(summary, results['global_step']) ### Save checkpoint if iteration % args.save_iter == 0: save_path = saver.save( sess, os.path.join(args.checkpoint, 'model.ckpt'), results['global_step']) print("Model saved in file: %s" % save_path) ### Log training stats print( "Step: {} LR: {:.7f} Feature: {:.5f} Pixel: {:.5f} TV: {:.5f} Time: {:.5f}" .format(results['global_step'], results['lr'], results['feature_loss'], results['pixel_loss'], results['tv_loss'], time.time() - start)) # Last save save_path = saver.save(sess, os.path.join(args.checkpoint, 'model.ckpt'), results['global_step']) print("Model saved in file: %s" % save_path)
def train(): batch_shape = (args.batch_size, 256, 256, 3) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) model = WCTModel(relu_targets=args.relu_target, vgg_path=args.vgg_path) config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True content_images = batch_gen(args.content_path, batch_shape) loss_object = tf.keras.losses.MeanSquaredError() optimizer = tf.keras.optimizers.Adam() test_loss = tf.keras.metrics.Mean() train_loss = tf.keras.metrics.Mean() checkpoint = tf.train.Checkpoint(model=model.decoders[0]) manager = tf.train.CheckpointManager(checkpoint, args.checkpoint, max_to_keep=1) checkpoint.restore(manager.latest_checkpoint) @tf.function def train_step(images, labels): with tf.GradientTape() as tape: predictions, decoded_encoded, content_encoded = model( images, training=True) pixel_loss = loss_object(labels, predictions) feature_loss = loss_object(content_encoded, decoded_encoded) loss = pixel_loss + feature_loss gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) train_loss(loss) @tf.function def test_step(images, labels): predictions, decoded_encoded, content_encoded = model(images, training=True) t_pixel_loss = loss_object(labels, predictions) t_feature_loss = loss_object(content_encoded, decoded_encoded) t_loss = t_pixel_loss + t_feature_loss test_loss(t_loss) for iteration in range(args.max_iter): start = time.time() batch = next(content_images) train_step(batch, batch) test_batch = next(content_images) elapsed = time.time() - start print(f"Iteration {iteration} took {elapsed:.4}s") if iteration % 5 == 0: test_step(test_batch, test_batch) print(f"After iteration {iteration}:") print(f"Test loss: {float(test_loss.result())}") print(f"Train loss: {float(train_loss.result())}") # Reset the metrics for the next epoch train_loss.reset_states() test_loss.reset_states() # Last save manager.save() print(f"Model saved in file: {manager.latest_checkpoint}")