def test_model_train_step(self):
        FLAGS(['', '--feat_model=fast'])
        feat_model = scm.make_feat_model([32, 32, 3])

        for sample_size in [None, 64]:
            sc_model = scm.SCModel(feat_model,
                                   sample_size=sample_size,
                                   loss_warmup=0)
            sc_model.compile(None,
                             'adam',
                             loss={
                                 'style': [
                                     tf.keras.losses.MeanSquaredError(),
                                     tf.keras.losses.MeanSquaredError()
                                 ]
                             })
            # Random uniform doesn't support uint8
            x = tf.random.uniform([1, 32, 32, 3], maxval=255, dtype=tf.int32)
            y = tf.random.uniform([1, 32, 32, 3], maxval=255, dtype=tf.int32)
            feats = {
                'style': [
                    tf.random.uniform([1, 16, 16, 3]),
                    tf.random.uniform([1, 8, 8, 3])
                ],
                'content': [
                    tf.random.uniform([1, 16, 16, 3]),
                    tf.random.uniform([1, 8, 8, 3])
                ]
            }
            metrics = sc_model.train_step(((x, y), feats))
            self.assertIsInstance(metrics, dict)
 def test_model_call(self):
     FLAGS(['', '--feat_model=fast'])
     feat_model = scm.make_feat_model([32, 32, 3])
     sc_model = scm.SCModel(feat_model, sample_size=None, loss_warmup=0)
     # Random uniform doesn't support uint8
     x = tf.random.uniform([1, 32, 32, 3], maxval=255, dtype=tf.int32)
     y = tf.random.uniform([1, 32, 32, 3], maxval=255, dtype=tf.int32)
     _ = sc_model((x, y))
    def test_model_no_warmup(self):
        FLAGS(['', '--feat_model=fast'])
        feat_model = scm.make_feat_model([32, 32, 3])
        sc_model = scm.SCModel(feat_model, sample_size=None, loss_warmup=0)
        sc_model.compile(None,
                         'adam',
                         loss={
                             'style': [
                                 tf.keras.losses.MeanSquaredError(),
                                 tf.keras.losses.MeanSquaredError()
                             ]
                         })

        # Initial alpha value should be 0
        alpha = sc_model.get_loss_warmup_alpha()
        tf.debugging.assert_equal(tf.ones_like(alpha), alpha)
    def test_model_warmup(self):
        FLAGS(['', '--feat_model=fast'])
        feat_model = scm.make_feat_model([32, 32, 3])
        sc_model = scm.SCModel(feat_model, sample_size=None, loss_warmup=100)
        sc_model.compile(None,
                         'adam',
                         loss={
                             'style': [
                                 tf.keras.losses.MeanSquaredError(),
                                 tf.keras.losses.MeanSquaredError()
                             ]
                         })

        # Initial alpha value should be 0
        alpha = sc_model.get_loss_warmup_alpha()
        tf.debugging.assert_equal(tf.zeros_like(alpha), alpha)

        # Linear warmup to 1
        x = tf.random.uniform([1, 32, 32, 3], maxval=255, dtype=tf.int32)
        y = tf.random.uniform([1, 32, 32, 3], maxval=255, dtype=tf.int32)
        feats = {
            'style': [
                tf.random.uniform([1, 16, 16, 3]),
                tf.random.uniform([1, 8, 8, 3])
            ],
            'content': [
                tf.random.uniform([1, 16, 16, 3]),
                tf.random.uniform([1, 8, 8, 3])
            ]
        }
        _ = sc_model.train_step(((x, y), feats))

        alpha = sc_model.get_loss_warmup_alpha()
        tf.debugging.assert_equal(0.01 * tf.ones_like(alpha), alpha)

        # Max value as 1
        for _ in range(200):
            _ = sc_model.train_step(((x, y), feats))

        alpha = sc_model.get_loss_warmup_alpha()
        tf.debugging.assert_equal(tf.ones_like(alpha), alpha)
def main(argv):
    del argv  # Unused.

    # Setup
    strategy, loss_dir = setup()

    # Load style/content image
    logging.info('loading images')
    style_image, content_image = load_sc_images()

    # Create the style-content model
    logging.info('making style-content model')
    image_shape = style_image.shape[1:]
    with strategy.scope():
        raw_feat_model = scm.make_feat_model(image_shape)
        sc_model = scm.SCModel(raw_feat_model, FLAGS.sample_size,
                               FLAGS.loss_warmup)

        # Configure the model to the style and content images
        sc_model.configure(style_image, content_image)

    # Plot the feature model structure
    tf.keras.utils.plot_model(sc_model.feat_model, './out/feat_model.jpg')

    # Get the style and content features
    raw_feats_dict = raw_feat_model((style_image, content_image),
                                    training=False)
    feats_dict = sc_model.feat_model((style_image, content_image),
                                     training=False)

    # Make the dataset
    ds = make_dataset(strategy, (style_image, content_image), feats_dict)

    # Log distribution statistics of the style image
    log_feat_distribution(raw_feats_dict, 'raw layer average style moments')
    log_feat_distribution(feats_dict, 'projected layer average style moments')

    # Plot the gram matrices
    plot_layer_grams(raw_feats_dict, feats_dict, filepath='./out/gram.jpg')

    # Reset gen image and recompile
    sc_model.reinit_gen_image()
    compile_sc_model(strategy,
                     sc_model,
                     FLAGS.loss,
                     with_metrics=FLAGS.train_metrics)

    # Style transfer
    logging.info(f'loss function: {FLAGS.loss}')
    train(sc_model, ds, loss_dir)

    # Save the images to disk
    gen_image = sc_model.get_gen_image()
    for filename, image in [('style.jpg', style_image),
                            ('content.jpg', content_image),
                            (f'{FLAGS.loss}.jpg', gen_image)]:
        tf.keras.preprocessing.image.save_img(f'{loss_dir}/{filename}',
                                              tf.squeeze(image, 0))
    logging.info(f'images saved to {loss_dir}')

    # Sanity evaluation
    logging.info('evaluating on projected features')
    compile_sc_model(strategy, sc_model, FLAGS.loss, with_metrics=True)
    sc_model.evaluate(ds, steps=1, return_dict=True)

    # Metrics
    logs_df = pd.read_csv(f'{loss_dir}/logs.csv')

    logging.info('evaluating on raw features')
    orig_feat_model = sc_model.feat_model
    sc_model.feat_model = raw_feat_model
    compile_sc_model(strategy, sc_model, FLAGS.loss, with_metrics=True)
    all_raw_metrics = sc_model.evaluate(ds, steps=1, return_dict=True)
    all_raw_metrics = pd.Series(all_raw_metrics)
    for metric in ['_mean', '_var', '_covar', '_gram', '_skew', '_wass']:
        raw_metrics = all_raw_metrics.filter(like=metric)
        raw_metrics[f'total{metric}_loss'] = raw_metrics.sum()
        filepath = f'{loss_dir}/raw_metrics.csv'
        raw_metrics.to_csv(filepath, mode='a', header=False)
        with open(filepath, mode='a') as f:
            f.write('\n')
    sc_model.feat_model = orig_feat_model

    plot_loss(logs_df, path=f'{loss_dir}/plots.jpg')
    logging.info(f'metrics saved to {loss_dir}')