def test_model_train_step(self): FLAGS(['', '--feat_model=fast']) feat_model = scm.make_feat_model([32, 32, 3]) for sample_size in [None, 64]: sc_model = scm.SCModel(feat_model, sample_size=sample_size, loss_warmup=0) sc_model.compile(None, 'adam', loss={ 'style': [ tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError() ] }) # Random uniform doesn't support uint8 x = tf.random.uniform([1, 32, 32, 3], maxval=255, dtype=tf.int32) y = tf.random.uniform([1, 32, 32, 3], maxval=255, dtype=tf.int32) feats = { 'style': [ tf.random.uniform([1, 16, 16, 3]), tf.random.uniform([1, 8, 8, 3]) ], 'content': [ tf.random.uniform([1, 16, 16, 3]), tf.random.uniform([1, 8, 8, 3]) ] } metrics = sc_model.train_step(((x, y), feats)) self.assertIsInstance(metrics, dict)
def test_model_call(self): FLAGS(['', '--feat_model=fast']) feat_model = scm.make_feat_model([32, 32, 3]) sc_model = scm.SCModel(feat_model, sample_size=None, loss_warmup=0) # Random uniform doesn't support uint8 x = tf.random.uniform([1, 32, 32, 3], maxval=255, dtype=tf.int32) y = tf.random.uniform([1, 32, 32, 3], maxval=255, dtype=tf.int32) _ = sc_model((x, y))
def test_model_no_warmup(self): FLAGS(['', '--feat_model=fast']) feat_model = scm.make_feat_model([32, 32, 3]) sc_model = scm.SCModel(feat_model, sample_size=None, loss_warmup=0) sc_model.compile(None, 'adam', loss={ 'style': [ tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError() ] }) # Initial alpha value should be 0 alpha = sc_model.get_loss_warmup_alpha() tf.debugging.assert_equal(tf.ones_like(alpha), alpha)
def test_model_warmup(self): FLAGS(['', '--feat_model=fast']) feat_model = scm.make_feat_model([32, 32, 3]) sc_model = scm.SCModel(feat_model, sample_size=None, loss_warmup=100) sc_model.compile(None, 'adam', loss={ 'style': [ tf.keras.losses.MeanSquaredError(), tf.keras.losses.MeanSquaredError() ] }) # Initial alpha value should be 0 alpha = sc_model.get_loss_warmup_alpha() tf.debugging.assert_equal(tf.zeros_like(alpha), alpha) # Linear warmup to 1 x = tf.random.uniform([1, 32, 32, 3], maxval=255, dtype=tf.int32) y = tf.random.uniform([1, 32, 32, 3], maxval=255, dtype=tf.int32) feats = { 'style': [ tf.random.uniform([1, 16, 16, 3]), tf.random.uniform([1, 8, 8, 3]) ], 'content': [ tf.random.uniform([1, 16, 16, 3]), tf.random.uniform([1, 8, 8, 3]) ] } _ = sc_model.train_step(((x, y), feats)) alpha = sc_model.get_loss_warmup_alpha() tf.debugging.assert_equal(0.01 * tf.ones_like(alpha), alpha) # Max value as 1 for _ in range(200): _ = sc_model.train_step(((x, y), feats)) alpha = sc_model.get_loss_warmup_alpha() tf.debugging.assert_equal(tf.ones_like(alpha), alpha)
def main(argv): del argv # Unused. # Setup strategy, loss_dir = setup() # Load style/content image logging.info('loading images') style_image, content_image = load_sc_images() # Create the style-content model logging.info('making style-content model') image_shape = style_image.shape[1:] with strategy.scope(): raw_feat_model = scm.make_feat_model(image_shape) sc_model = scm.SCModel(raw_feat_model, FLAGS.sample_size, FLAGS.loss_warmup) # Configure the model to the style and content images sc_model.configure(style_image, content_image) # Plot the feature model structure tf.keras.utils.plot_model(sc_model.feat_model, './out/feat_model.jpg') # Get the style and content features raw_feats_dict = raw_feat_model((style_image, content_image), training=False) feats_dict = sc_model.feat_model((style_image, content_image), training=False) # Make the dataset ds = make_dataset(strategy, (style_image, content_image), feats_dict) # Log distribution statistics of the style image log_feat_distribution(raw_feats_dict, 'raw layer average style moments') log_feat_distribution(feats_dict, 'projected layer average style moments') # Plot the gram matrices plot_layer_grams(raw_feats_dict, feats_dict, filepath='./out/gram.jpg') # Reset gen image and recompile sc_model.reinit_gen_image() compile_sc_model(strategy, sc_model, FLAGS.loss, with_metrics=FLAGS.train_metrics) # Style transfer logging.info(f'loss function: {FLAGS.loss}') train(sc_model, ds, loss_dir) # Save the images to disk gen_image = sc_model.get_gen_image() for filename, image in [('style.jpg', style_image), ('content.jpg', content_image), (f'{FLAGS.loss}.jpg', gen_image)]: tf.keras.preprocessing.image.save_img(f'{loss_dir}/{filename}', tf.squeeze(image, 0)) logging.info(f'images saved to {loss_dir}') # Sanity evaluation logging.info('evaluating on projected features') compile_sc_model(strategy, sc_model, FLAGS.loss, with_metrics=True) sc_model.evaluate(ds, steps=1, return_dict=True) # Metrics logs_df = pd.read_csv(f'{loss_dir}/logs.csv') logging.info('evaluating on raw features') orig_feat_model = sc_model.feat_model sc_model.feat_model = raw_feat_model compile_sc_model(strategy, sc_model, FLAGS.loss, with_metrics=True) all_raw_metrics = sc_model.evaluate(ds, steps=1, return_dict=True) all_raw_metrics = pd.Series(all_raw_metrics) for metric in ['_mean', '_var', '_covar', '_gram', '_skew', '_wass']: raw_metrics = all_raw_metrics.filter(like=metric) raw_metrics[f'total{metric}_loss'] = raw_metrics.sum() filepath = f'{loss_dir}/raw_metrics.csv' raw_metrics.to_csv(filepath, mode='a', header=False) with open(filepath, mode='a') as f: f.write('\n') sc_model.feat_model = orig_feat_model plot_loss(logs_df, path=f'{loss_dir}/plots.jpg') logging.info(f'metrics saved to {loss_dir}')