Exemple #1
0
    def __init__(self,
                 checkpoints,
                 relu_targets,
                 vgg_path,
                 device='/gpu:0',
                 ss_patch_size=3,
                 ss_stride=1):
        '''
            Args:
                checkpoints: List of trained decoder model checkpoint dirs
                relu_targets: List of relu target layers corresponding to decoder checkpoints
                vgg_path: Normalised VGG19 .t7 path
                device: String for device ID to load model onto
        '''
        self.ss_patch_size = ss_patch_size
        self.ss_stride = ss_stride

        graph = tf.get_default_graph()

        with graph.device(device):
            # Build the graph
            self.model = WCTModel(mode='test',
                                  relu_targets=relu_targets,
                                  vgg_path=vgg_path,
                                  ss_patch_size=self.ss_patch_size,
                                  ss_stride=self.ss_stride)

            self.content_input = self.model.content_input
            self.decoded_output = self.model.decoded_output
            # self.style_encoded = None

            config = tf.ConfigProto(allow_soft_placement=True)
            config.gpu_options.allow_growth = True
            self.sess = tf.Session(config=config)

            self.sess.run(tf.global_variables_initializer())

            # Load decoder vars one-by-one into the graph
            for relu_target, checkpoint_dir in zip(relu_targets, checkpoints):
                decoder_prefix = 'decoder_{}'.format(relu_target)
                relu_vars = [
                    v for v in tf.trainable_variables()
                    if decoder_prefix in v.name
                ]

                saver = tf.train.Saver(var_list=relu_vars)

                ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    print('Restoring vars for {} from checkpoint {}'.format(
                        relu_target, ckpt.model_checkpoint_path))
                    saver.restore(self.sess, ckpt.model_checkpoint_path)
                else:
                    raise Exception(
                        'No checkpoint found for target {} in dir {}'.format(
                            relu_target, checkpoint_dir))
Exemple #2
0
    def __init__(self,
                 checkpoints,
                 relu_targets,
                 vgg_path,
                 device='/gpu:0',
                 ss_patch_size=3,
                 ss_stride=1,
                 alpha=0.5,
                 beta=0.5):
        '''
            Args:
                checkpoints: List of trained decoder model checkpoint dirs
                relu_targets: List of relu target layers corresponding to decoder checkpoints
                vgg_path: Normalised VGG19 .t7 path
                device: String for device ID to load model onto
        '''
        self.ss_patch_size = ss_patch_size
        self.ss_stride = ss_stride

        # Build the graph
        self.model = WCTModel(relu_targets=relu_targets,
                              vgg_path=vgg_path,
                              alpha=alpha,
                              beta=beta)

        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True

        # Load decoder vars one-by-one into the graph
        for i, checkpoint_dir in enumerate(checkpoints):
            if os.path.exists(checkpoint_dir):
                checkpoint = tf.train.Checkpoint(model=self.model.decoders[i])
                checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
            else:
                raise Exception(
                    'No checkpoint found for target {} in dir {}'.format(
                        relu_targets[0], checkpoint_dir))
Exemple #3
0
def train():
    batch_shape = (args.batch_size, 256, 256, 3)

    with tf.Graph().as_default():
        tf.logging.set_verbosity(tf.logging.INFO)

        ### Setup data loading queue
        queue_input_content = tf.placeholder(tf.float32, shape=batch_shape)
        queue_input_val = tf.placeholder(tf.float32, shape=batch_shape)
        queue = tf.FIFOQueue(capacity=100,
                             dtypes=[tf.float32, tf.float32],
                             shapes=[[256, 256, 3], [256, 256, 3]])
        enqueue_op = queue.enqueue_many([queue_input_content, queue_input_val])
        dequeue_op = queue.dequeue()
        content_batch_op, val_batch_op = tf.train.batch(
            dequeue_op, batch_size=args.batch_size, capacity=100)

        def enqueue(sess):
            content_images = batch_gen(args.content_path, batch_shape)
            val_images = batch_gen(args.val_path, batch_shape)
            while True:
                content_batch = next(content_images)
                val_batch = next(val_images)

                sess.run(enqueue_op,
                         feed_dict={
                             queue_input_content: content_batch,
                             queue_input_val: val_batch
                         })

        ### Build the model graph and train/summary ops
        model = WCTModel(mode='train',
                         relu_targets=[args.relu_target],
                         vgg_path=args.vgg_path,
                         batch_size=args.batch_size,
                         feature_weight=args.feature_weight,
                         pixel_weight=args.pixel_weight,
                         tv_weight=args.tv_weight,
                         learning_rate=args.learning_rate,
                         lr_decay=args.lr_decay).encoder_decoders[0]

        saver = tf.train.Saver(max_to_keep=None)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            enqueue_thread = threading.Thread(target=enqueue, args=[sess])
            enqueue_thread.isDaemon()
            enqueue_thread.start()
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord, sess=sess)

            log_path = args.log_path if args.log_path is not None else os.path.join(
                args.checkpoint, 'log')
            summary_writer = tf.summary.FileWriter(log_path, sess.graph)

            sess.run(tf.global_variables_initializer())

            def load_latest():
                if os.path.exists(os.path.join(args.checkpoint, 'checkpoint')):
                    print("Restoring checkpoint")
                    saver.restore(sess,
                                  tf.train.latest_checkpoint(args.checkpoint))

            load_latest()

            for iteration in range(args.max_iter):
                start = time.time()

                content_batch = sess.run(content_batch_op)

                fetches = {
                    'train': model.train_op,
                    'global_step': model.global_step,
                    # 'summary':      model.summary_op,
                    'lr': model.learning_rate,
                    'feature_loss': model.feature_loss,
                    'pixel_loss': model.pixel_loss,
                    'tv_loss': model.tv_loss
                }

                feed_dict = {model.content_input: content_batch}

                try:
                    results = sess.run(fetches, feed_dict=feed_dict)
                except Exception as e:
                    print(e)
                    import IPython
                    IPython.embed()
                    # Sometimes training NaNs out and has to be restarted. In that case, reload the last checkpoint and resume.
                    print(
                        "Exception encountered, re-loading latest checkpoint")
                    load_latest()
                    continue

                ### Run a val batch and log the summaries
                if iteration % args.summary_iter == 0:
                    val_batch = sess.run(val_batch_op)
                    summary = sess.run(
                        model.summary_op,
                        feed_dict={model.content_input: val_batch})
                    summary_writer.add_summary(summary, results['global_step'])

                ### Save checkpoint
                if iteration % args.save_iter == 0:
                    save_path = saver.save(
                        sess, os.path.join(args.checkpoint, 'model.ckpt'),
                        results['global_step'])
                    print("Model saved in file: %s" % save_path)

                ### Log training stats
                print(
                    "Step: {}  LR: {:.7f}  Feature: {:.5f}  Pixel: {:.5f}  TV: {:.5f}  Time: {:.5f}"
                    .format(results['global_step'], results['lr'],
                            results['feature_loss'], results['pixel_loss'],
                            results['tv_loss'],
                            time.time() - start))

            # Last save
            save_path = saver.save(sess,
                                   os.path.join(args.checkpoint, 'model.ckpt'),
                                   results['global_step'])
            print("Model saved in file: %s" % save_path)
Exemple #4
0
def train():
    batch_shape = (args.batch_size, 256, 256, 3)
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    model = WCTModel(relu_targets=args.relu_target, vgg_path=args.vgg_path)

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    content_images = batch_gen(args.content_path, batch_shape)

    loss_object = tf.keras.losses.MeanSquaredError()
    optimizer = tf.keras.optimizers.Adam()
    test_loss = tf.keras.metrics.Mean()
    train_loss = tf.keras.metrics.Mean()
    checkpoint = tf.train.Checkpoint(model=model.decoders[0])
    manager = tf.train.CheckpointManager(checkpoint,
                                         args.checkpoint,
                                         max_to_keep=1)
    checkpoint.restore(manager.latest_checkpoint)

    @tf.function
    def train_step(images, labels):
        with tf.GradientTape() as tape:
            predictions, decoded_encoded, content_encoded = model(
                images, training=True)
            pixel_loss = loss_object(labels, predictions)
            feature_loss = loss_object(content_encoded, decoded_encoded)
            loss = pixel_loss + feature_loss
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        train_loss(loss)

    @tf.function
    def test_step(images, labels):
        predictions, decoded_encoded, content_encoded = model(images,
                                                              training=True)
        t_pixel_loss = loss_object(labels, predictions)
        t_feature_loss = loss_object(content_encoded, decoded_encoded)
        t_loss = t_pixel_loss + t_feature_loss

        test_loss(t_loss)

    for iteration in range(args.max_iter):
        start = time.time()

        batch = next(content_images)
        train_step(batch, batch)

        test_batch = next(content_images)
        elapsed = time.time() - start
        print(f"Iteration {iteration} took {elapsed:.4}s")

        if iteration % 5 == 0:
            test_step(test_batch, test_batch)

            print(f"After iteration {iteration}:")
            print(f"Test loss:  {float(test_loss.result())}")
            print(f"Train loss:  {float(train_loss.result())}")

            # Reset the metrics for the next epoch
            train_loss.reset_states()
            test_loss.reset_states()

    # Last save
    manager.save()
    print(f"Model saved in file: {manager.latest_checkpoint}")