def main(unused_argv): if not tf.gfile.IsDirectory(FLAGS.train_dir): tf.gfile.MakeDirs(FLAGS.train_dir) cfg, cfg_summary = get_named_config(FLAGS.model_cfg, FLAGS.model_cfg_overrides) with tf.gfile.Open(os.path.join(FLAGS.train_dir, "cfg.txt"), "w") as f: f.write(cfg_summary) # Load data with tf.name_scope("loader"): feat_dict = load_noteseqs( FLAGS.dataset_fp, cfg.train_batch_size, cfg.train_seq_len, max_discrete_times=cfg.data_max_discrete_times, max_discrete_velocities=cfg.data_max_discrete_velocities, augment_stretch_bounds=cfg.train_augment_stretch_bounds, augment_transpose_bounds=cfg.train_augment_transpose_bounds, randomize_chord_order=cfg.data_randomize_chord_order, repeat=True) # Summarize data tf.summary.image( "piano_roll", util.discrete_to_piano_roll(util.demidify(feat_dict["midi_pitches"]), 88)) # Build model with tf.variable_scope("phero_model"): model_dict = build_genie_model(feat_dict, cfg, cfg.train_batch_size, cfg.train_seq_len, is_training=True) # Summarize quantized step embeddings if cfg.stp_emb_vq: tf.summary.scalar("codebook_perplexity", model_dict["stp_emb_vq_codebook_ppl"]) tf.summary.image( "genie", util.discrete_to_piano_roll( model_dict["stp_emb_vq_discrete"], cfg.stp_emb_vq_codebook_size, dilation=max(1, 88 // cfg.stp_emb_vq_codebook_size))) tf.summary.scalar("loss_vqvae", model_dict["stp_emb_vq_loss"]) # Summarize integer-quantized step embeddings if cfg.stp_emb_iq: tf.summary.scalar("discrete_perplexity", model_dict["stp_emb_iq_discrete_ppl"]) tf.summary.scalar("iq_valid_p", model_dict["stp_emb_iq_valid_p"]) tf.summary.image( "genie", util.discrete_to_piano_roll(model_dict["stp_emb_iq_discrete"], cfg.stp_emb_iq_nbins, dilation=max( 1, 88 // cfg.stp_emb_iq_nbins))) tf.summary.scalar("loss_iq_range", model_dict["stp_emb_iq_range_penalty"]) tf.summary.scalar("loss_iq_contour", model_dict["stp_emb_iq_contour_penalty"]) tf.summary.scalar("loss_iq_deviate", model_dict["stp_emb_iq_deviate_penalty"]) if cfg.stp_emb_vq or cfg.stp_emb_iq: tf.summary.scalar("contour_violation", model_dict["contour_violation"]) tf.summary.scalar("deviate_violation", model_dict["deviate_violation"]) # Summarize VAE sequence embeddings if cfg.seq_emb_vae: tf.summary.scalar("loss_kl", model_dict["seq_emb_vae_kl"]) # Summarize output tf.summary.image( "decoder_scores", util.discrete_to_piano_roll(model_dict["dec_recons_scores"], 88)) tf.summary.image( "decoder_preds", util.discrete_to_piano_roll(model_dict["dec_recons_preds"], 88)) if cfg.dec_pred_velocity: tf.summary.scalar("loss_recons_velocity", model_dict["dec_recons_velocity_loss"]) tf.summary.scalar("ppl_recons_velocity", tf.exp(model_dict["dec_recons_velocity_loss"])) # Reconstruction loss tf.summary.scalar("loss_recons", model_dict["dec_recons_loss"]) tf.summary.scalar("ppl_recons", tf.exp(model_dict["dec_recons_loss"])) # Build hybrid loss loss = model_dict["dec_recons_loss"] if cfg.stp_emb_vq and cfg.train_loss_vq_err_scalar > 0: loss += (cfg.train_loss_vq_err_scalar * model_dict["stp_emb_vq_loss"]) if cfg.stp_emb_iq and cfg.train_loss_iq_range_scalar > 0: loss += (cfg.train_loss_iq_range_scalar * model_dict["stp_emb_iq_range_penalty"]) if cfg.stp_emb_iq and cfg.train_loss_iq_contour_scalar > 0: loss += (cfg.train_loss_iq_contour_scalar * model_dict["stp_emb_iq_contour_penalty"]) if cfg.stp_emb_iq and cfg.train_loss_iq_deviate_scalar > 0: loss += (cfg.train_loss_iq_deviate_scalar * model_dict["stp_emb_iq_deviate_penalty"]) if cfg.seq_emb_vae and cfg.train_loss_vae_kl_scalar > 0: loss += (cfg.train_loss_vae_kl_scalar * model_dict["seq_emb_vae_kl"]) if cfg.dec_pred_velocity: loss += model_dict["dec_recons_velocity_loss"] tf.summary.scalar("loss", loss) # Construct optimizer opt = tf.train.AdamOptimizer(learning_rate=cfg.train_lr) train_op = opt.minimize(loss, global_step=tf.train.get_or_create_global_step()) # Train with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, save_checkpoint_secs=600, save_summaries_secs=FLAGS.summary_every_nsecs) as sess: while True: sess.run(train_op)
def main(unused_argv): if not tf.gfile.IsDirectory(FLAGS.train_dir): tf.gfile.MakeDirs(FLAGS.train_dir) cfg, cfg_summary = get_named_config(FLAGS.model_cfg, FLAGS.model_cfg_overrides) with tf.gfile.Open(os.path.join(FLAGS.train_dir, "cfg.txt"), "w") as f: f.write(cfg_summary) # Load data with tf.name_scope("loader"): feat_dict = load_noteseqs( FLAGS.dataset_fp, cfg.train_batch_size, cfg.train_seq_len, max_discrete_times=cfg.data_max_discrete_times, max_discrete_velocities=cfg.data_max_discrete_velocities, augment_stretch_bounds=cfg.train_augment_stretch_bounds, augment_transpose_bounds=cfg.train_augment_transpose_bounds, randomize_chord_order=cfg.data_randomize_chord_order, repeat=True) # Summarize data tf.summary.image( "piano_roll", util.discrete_to_piano_roll(util.demidify(feat_dict["midi_pitches"]), 88)) # Build model with tf.variable_scope("phero_model"): model_dict = build_genie_model( feat_dict, cfg, cfg.train_batch_size, cfg.train_seq_len, is_training=True) # Summarize quantized step embeddings if cfg.stp_emb_vq: tf.summary.scalar("codebook_perplexity", model_dict["stp_emb_vq_codebook_ppl"]) tf.summary.image( "genie", util.discrete_to_piano_roll( model_dict["stp_emb_vq_discrete"], cfg.stp_emb_vq_codebook_size, dilation=max(1, 88 // cfg.stp_emb_vq_codebook_size))) tf.summary.scalar("loss_vqvae", model_dict["stp_emb_vq_loss"]) # Summarize integer-quantized step embeddings if cfg.stp_emb_iq: tf.summary.scalar("discrete_perplexity", model_dict["stp_emb_iq_discrete_ppl"]) tf.summary.scalar("iq_valid_p", model_dict["stp_emb_iq_valid_p"]) tf.summary.image( "genie", util.discrete_to_piano_roll( model_dict["stp_emb_iq_discrete"], cfg.stp_emb_iq_nbins, dilation=max(1, 88 // cfg.stp_emb_iq_nbins))) tf.summary.scalar("loss_iq_range", model_dict["stp_emb_iq_range_penalty"]) tf.summary.scalar("loss_iq_contour", model_dict["stp_emb_iq_contour_penalty"]) tf.summary.scalar("loss_iq_deviate", model_dict["stp_emb_iq_deviate_penalty"]) if cfg.stp_emb_vq or cfg.stp_emb_iq: tf.summary.scalar("contour_violation", model_dict["contour_violation"]) tf.summary.scalar("deviate_violation", model_dict["deviate_violation"]) # Summarize VAE sequence embeddings if cfg.seq_emb_vae: tf.summary.scalar("loss_kl", model_dict["seq_emb_vae_kl"]) # Summarize output tf.summary.image( "decoder_scores", util.discrete_to_piano_roll(model_dict["dec_recons_scores"], 88)) tf.summary.image( "decoder_preds", util.discrete_to_piano_roll(model_dict["dec_recons_preds"], 88)) if cfg.dec_pred_velocity: tf.summary.scalar("loss_recons_velocity", model_dict["dec_recons_velocity_loss"]) tf.summary.scalar("ppl_recons_velocity", tf.exp(model_dict["dec_recons_velocity_loss"])) # Reconstruction loss tf.summary.scalar("loss_recons", model_dict["dec_recons_loss"]) tf.summary.scalar("ppl_recons", tf.exp(model_dict["dec_recons_loss"])) # Build hybrid loss loss = model_dict["dec_recons_loss"] if cfg.stp_emb_vq and cfg.train_loss_vq_err_scalar > 0: loss += (cfg.train_loss_vq_err_scalar * model_dict["stp_emb_vq_loss"]) if cfg.stp_emb_iq and cfg.train_loss_iq_range_scalar > 0: loss += ( cfg.train_loss_iq_range_scalar * model_dict["stp_emb_iq_range_penalty"]) if cfg.stp_emb_iq and cfg.train_loss_iq_contour_scalar > 0: loss += ( cfg.train_loss_iq_contour_scalar * model_dict["stp_emb_iq_contour_penalty"]) if cfg.stp_emb_iq and cfg.train_loss_iq_deviate_scalar > 0: loss += ( cfg.train_loss_iq_deviate_scalar * model_dict["stp_emb_iq_deviate_penalty"]) if cfg.seq_emb_vae and cfg.train_loss_vae_kl_scalar > 0: loss += (cfg.train_loss_vae_kl_scalar * model_dict["seq_emb_vae_kl"]) if cfg.dec_pred_velocity: loss += model_dict["dec_recons_velocity_loss"] tf.summary.scalar("loss", loss) # Construct optimizer opt = tf.train.AdamOptimizer(learning_rate=cfg.train_lr) train_op = opt.minimize( loss, global_step=tf.train.get_or_create_global_step()) # Train with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, save_checkpoint_secs=600, save_summaries_secs=FLAGS.summary_every_nsecs) as sess: while True: sess.run(train_op)
def build_genie_model(feat_dict, cfg, batch_size, seq_len, is_training=True, seq_varlens=None, dtype=tf.float32): """Builds a Piano Genie model. Args: feat_dict: Dictionary containing input tensors. cfg: Configuration object. batch_size: Number of items in batch. seq_len: Length of each batch item. is_training: Set to False for evaluation. seq_varlens: If not None, a tensor with the batch sequence lengths. dtype: Model weight type. Returns: A dict containing tensors for relevant model config. """ out_dict = {} # Parse features pitches = util.demidify(feat_dict["midi_pitches"]) velocities = feat_dict["velocities"] pitches_scalar = ((tf.cast(pitches, tf.float32) / 87.) * 2.) - 1. # Create sequence lens if is_training and cfg.train_randomize_seq_len: seq_lens = tf.random_uniform( [batch_size], minval=cfg.train_seq_len_min, maxval=seq_len + 1, dtype=tf.int32) stp_varlen_mask = tf.sequence_mask( seq_lens, maxlen=seq_len, dtype=tf.float32) elif seq_varlens is not None: seq_lens = seq_varlens stp_varlen_mask = tf.sequence_mask( seq_varlens, maxlen=seq_len, dtype=tf.float32) else: seq_lens = tf.ones([batch_size], dtype=tf.int32) * seq_len stp_varlen_mask = None # Encode if (cfg.stp_emb_unconstrained or cfg.stp_emb_vq or cfg.stp_emb_iq or cfg.seq_emb_unconstrained or cfg.seq_emb_vae or cfg.lor_emb_unconstrained): # Build encoder features enc_feats = [] if cfg.enc_pitch_scalar: enc_feats.append(tf.expand_dims(pitches_scalar, axis=-1)) else: enc_feats.append(tf.one_hot(pitches, 88)) if "delta_times_int" in cfg.enc_aux_feats: enc_feats.append( tf.one_hot(feat_dict["delta_times_int"], cfg.data_max_discrete_times + 1)) if "velocities" in cfg.enc_aux_feats: enc_feats.append( tf.one_hot(velocities, cfg.data_max_discrete_velocities + 1)) enc_feats = tf.concat(enc_feats, axis=2) with tf.variable_scope("encoder"): enc_stp, enc_seq = simple_lstm_encoder( enc_feats, seq_lens, rnn_celltype=cfg.rnn_celltype, rnn_nlayers=cfg.rnn_nlayers, rnn_nunits=cfg.rnn_nunits, rnn_bidirectional=cfg.enc_rnn_bidirectional, dtype=dtype) latents = [] # Step embeddings (single vector per timestep) if cfg.stp_emb_unconstrained: with tf.variable_scope("stp_emb_unconstrained"): stp_emb_unconstrained = tf.layers.dense( enc_stp, cfg.stp_emb_unconstrained_embedding_dim) out_dict["stp_emb_unconstrained"] = stp_emb_unconstrained latents.append(stp_emb_unconstrained) # Quantized step embeddings with VQ-VAE if cfg.stp_emb_vq: import sonnet as snt # pylint:disable=g-import-not-at-top,import-outside-toplevel with tf.variable_scope("stp_emb_vq"): with tf.variable_scope("pre_vq"): # pre_vq_encoding is tf.float32 of [batch_size, seq_len, embedding_dim] pre_vq_encoding = tf.layers.dense(enc_stp, cfg.stp_emb_vq_embedding_dim) with tf.variable_scope("quantizer"): assert stp_varlen_mask is None vq_vae = snt.nets.VectorQuantizer( embedding_dim=cfg.stp_emb_vq_embedding_dim, num_embeddings=cfg.stp_emb_vq_codebook_size, commitment_cost=cfg.stp_emb_vq_commitment_cost) vq_vae_output = vq_vae(pre_vq_encoding, is_training=is_training) stp_emb_vq_quantized = vq_vae_output["quantize"] stp_emb_vq_discrete = tf.reshape( tf.argmax(vq_vae_output["encodings"], axis=1, output_type=tf.int32), [batch_size, seq_len]) stp_emb_vq_codebook = tf.transpose(vq_vae.embeddings) out_dict["stp_emb_vq_quantized"] = stp_emb_vq_quantized out_dict["stp_emb_vq_discrete"] = stp_emb_vq_discrete out_dict["stp_emb_vq_loss"] = vq_vae_output["loss"] out_dict["stp_emb_vq_codebook"] = stp_emb_vq_codebook out_dict["stp_emb_vq_codebook_ppl"] = vq_vae_output["perplexity"] latents.append(stp_emb_vq_quantized) # This tensor retrieves continuous embeddings from codebook. It should # *never* be used during training. out_dict["stp_emb_vq_quantized_lookup"] = tf.nn.embedding_lookup( stp_emb_vq_codebook, stp_emb_vq_discrete) # Integer-quantized step embeddings with straight-through if cfg.stp_emb_iq: with tf.variable_scope("stp_emb_iq"): with tf.variable_scope("pre_iq"): # pre_iq_encoding is tf.float32 of [batch_size, seq_len] pre_iq_encoding = tf.layers.dense(enc_stp, 1)[:, :, 0] def iqst(x, n): """Integer quantization with straight-through estimator.""" eps = 1e-7 s = float(n - 1) xp = tf.clip_by_value((x + 1) / 2.0, -eps, 1 + eps) xpp = tf.round(s * xp) xppp = 2 * (xpp / s) - 1 return xpp, x + tf.stop_gradient(xppp - x) with tf.variable_scope("quantizer"): # Pass rounded vals to decoder w/ straight-through estimator stp_emb_iq_discrete_f, stp_emb_iq_discrete_rescaled = iqst( pre_iq_encoding, cfg.stp_emb_iq_nbins) stp_emb_iq_discrete = tf.cast(stp_emb_iq_discrete_f + 1e-4, tf.int32) stp_emb_iq_discrete_f = tf.cast(stp_emb_iq_discrete, tf.float32) stp_emb_iq_quantized = tf.expand_dims( stp_emb_iq_discrete_rescaled, axis=2) # Determine which elements round to valid indices stp_emb_iq_inrange = tf.logical_and( tf.greater_equal(pre_iq_encoding, -1), tf.less_equal(pre_iq_encoding, 1)) stp_emb_iq_inrange_mask = tf.cast(stp_emb_iq_inrange, tf.float32) stp_emb_iq_valid_p = weighted_avg(stp_emb_iq_inrange_mask, stp_varlen_mask) # Regularize to encourage encoder to output in range stp_emb_iq_range_penalty = weighted_avg( tf.square(tf.maximum(tf.abs(pre_iq_encoding) - 1, 0)), stp_varlen_mask) # Regularize to correlate latent finite differences to input stp_emb_iq_dlatents = pre_iq_encoding[:, 1:] - pre_iq_encoding[:, :-1] if cfg.stp_emb_iq_contour_dy_scalar: stp_emb_iq_dnotes = pitches_scalar[:, 1:] - pitches_scalar[:, :-1] else: stp_emb_iq_dnotes = tf.cast(pitches[:, 1:] - pitches[:, :-1], tf.float32) if cfg.stp_emb_iq_contour_exp == 1: power_func = tf.identity elif cfg.stp_emb_iq_contour_exp == 2: power_func = tf.square else: raise NotImplementedError() if cfg.stp_emb_iq_contour_comp == "product": comp_func = tf.multiply elif cfg.stp_emb_iq_contour_comp == "quotient": comp_func = lambda x, y: tf.divide(x, y + 1e-6) else: raise NotImplementedError() stp_emb_iq_contour_penalty = weighted_avg( power_func( tf.maximum( cfg.stp_emb_iq_contour_margin - comp_func( stp_emb_iq_dnotes, stp_emb_iq_dlatents), 0)), None if stp_varlen_mask is None else stp_varlen_mask[:, 1:]) # Regularize to maintain note consistency stp_emb_iq_note_held = tf.cast( tf.equal(pitches[:, 1:] - pitches[:, :-1], 0), tf.float32) if cfg.stp_emb_iq_deviate_exp == 1: power_func = tf.abs elif cfg.stp_emb_iq_deviate_exp == 2: power_func = tf.square if stp_varlen_mask is None: mask = stp_emb_iq_note_held else: mask = stp_varlen_mask[:, 1:] * stp_emb_iq_note_held stp_emb_iq_deviate_penalty = weighted_avg( power_func(stp_emb_iq_dlatents), mask) # Calculate perplexity of discrete encoder posterior if stp_varlen_mask is None: mask = stp_emb_iq_inrange_mask else: mask = stp_varlen_mask * stp_emb_iq_inrange_mask stp_emb_iq_discrete_oh = tf.one_hot(stp_emb_iq_discrete, cfg.stp_emb_iq_nbins) stp_emb_iq_avg_probs = weighted_avg( stp_emb_iq_discrete_oh, mask, axis=[0, 1], expand_mask=True) stp_emb_iq_discrete_ppl = tf.exp(-tf.reduce_sum( stp_emb_iq_avg_probs * tf.log(stp_emb_iq_avg_probs + 1e-10))) out_dict["stp_emb_iq_quantized"] = stp_emb_iq_quantized out_dict["stp_emb_iq_discrete"] = stp_emb_iq_discrete out_dict["stp_emb_iq_valid_p"] = stp_emb_iq_valid_p out_dict["stp_emb_iq_range_penalty"] = stp_emb_iq_range_penalty out_dict["stp_emb_iq_contour_penalty"] = stp_emb_iq_contour_penalty out_dict["stp_emb_iq_deviate_penalty"] = stp_emb_iq_deviate_penalty out_dict["stp_emb_iq_discrete_ppl"] = stp_emb_iq_discrete_ppl latents.append(stp_emb_iq_quantized) # This tensor converts discrete values to continuous. # It should *never* be used during training. out_dict["stp_emb_iq_quantized_lookup"] = tf.expand_dims( 2. * (stp_emb_iq_discrete_f / (cfg.stp_emb_iq_nbins - 1.)) - 1., axis=2) # Sequence embedding (single vector per sequence) if cfg.seq_emb_unconstrained: with tf.variable_scope("seq_emb_unconstrained"): seq_emb_unconstrained = tf.layers.dense( enc_seq, cfg.seq_emb_unconstrained_embedding_dim) out_dict["seq_emb_unconstrained"] = seq_emb_unconstrained seq_emb_unconstrained = tf.stack([seq_emb_unconstrained] * seq_len, axis=1) latents.append(seq_emb_unconstrained) # Sequence embeddings (variational w/ reparameterization trick) if cfg.seq_emb_vae: with tf.variable_scope("seq_emb_vae"): seq_emb_vae = tf.layers.dense(enc_seq, cfg.seq_emb_vae_embedding_dim * 2) mean = seq_emb_vae[:, :cfg.seq_emb_vae_embedding_dim] stddev = 1e-6 + tf.nn.softplus( seq_emb_vae[:, cfg.seq_emb_vae_embedding_dim:]) seq_emb_vae = mean + stddev * tf.random_normal( tf.shape(mean), 0, 1, dtype=dtype) kl = tf.reduce_mean(0.5 * tf.reduce_sum( tf.square(mean) + tf.square(stddev) - tf.log(1e-8 + tf.square(stddev)) - 1, axis=1)) out_dict["seq_emb_vae"] = seq_emb_vae out_dict["seq_emb_vae_kl"] = kl seq_emb_vae = tf.stack([seq_emb_vae] * seq_len, axis=1) latents.append(seq_emb_vae) # Low-rate embeddings if cfg.lor_emb_unconstrained: assert seq_len % cfg.lor_emb_n == 0 with tf.variable_scope("lor_emb_unconstrained"): # Downsample step embeddings rnn_embedding_dim = int(enc_stp.get_shape()[-1]) enc_lor = tf.reshape(enc_stp, [ batch_size, seq_len // cfg.lor_emb_n, cfg.lor_emb_n * rnn_embedding_dim ]) lor_emb_unconstrained = tf.layers.dense( enc_lor, cfg.lor_emb_unconstrained_embedding_dim) out_dict["lor_emb_unconstrained"] = lor_emb_unconstrained # Upsample lo-rate embeddings for decoding lor_emb_unconstrained = tf.expand_dims(lor_emb_unconstrained, axis=2) lor_emb_unconstrained = tf.tile(lor_emb_unconstrained, [1, 1, cfg.lor_emb_n, 1]) lor_emb_unconstrained = tf.reshape( lor_emb_unconstrained, [batch_size, seq_len, cfg.lor_emb_unconstrained_embedding_dim]) latents.append(lor_emb_unconstrained) # Build decoder features dec_feats = latents if cfg.dec_autoregressive: # Retrieve pitch numbers curr_pitches = pitches last_pitches = curr_pitches[:, :-1] last_pitches = tf.pad( last_pitches, [[0, 0], [1, 0]], constant_values=-1) # Prepend <SOS> token out_dict["dec_last_pitches"] = last_pitches dec_feats.append(tf.one_hot(last_pitches + 1, 89)) if cfg.dec_pred_velocity: curr_velocities = velocities last_velocities = curr_velocities[:, :-1] last_velocities = tf.pad(last_velocities, [[0, 0], [1, 0]]) dec_feats.append( tf.one_hot(last_velocities, cfg.data_max_discrete_velocities + 1)) if "delta_times_int" in cfg.dec_aux_feats: dec_feats.append( tf.one_hot(feat_dict["delta_times_int"], cfg.data_max_discrete_times + 1)) if "velocities" in cfg.dec_aux_feats: assert not cfg.dec_pred_velocity dec_feats.append( tf.one_hot(feat_dict["velocities"], cfg.data_max_discrete_velocities + 1)) assert dec_feats dec_feats = tf.concat(dec_feats, axis=2) # Decode with tf.variable_scope("decoder"): dec_stp, dec_initial_state, dec_final_state = simple_lstm_decoder( dec_feats, seq_lens, batch_size, rnn_celltype=cfg.rnn_celltype, rnn_nlayers=cfg.rnn_nlayers, rnn_nunits=cfg.rnn_nunits) with tf.variable_scope("pitches"): dec_recons_logits = tf.layers.dense(dec_stp, 88) dec_recons_loss = weighted_avg( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=dec_recons_logits, labels=pitches), stp_varlen_mask) out_dict["dec_initial_state"] = dec_initial_state out_dict["dec_final_state"] = dec_final_state out_dict["dec_recons_logits"] = dec_recons_logits out_dict["dec_recons_scores"] = tf.nn.softmax(dec_recons_logits, axis=-1) out_dict["dec_recons_preds"] = tf.argmax( dec_recons_logits, output_type=tf.int32, axis=-1) out_dict["dec_recons_midi_preds"] = util.remidify( out_dict["dec_recons_preds"]) out_dict["dec_recons_loss"] = dec_recons_loss if cfg.dec_pred_velocity: with tf.variable_scope("velocities"): dec_recons_velocity_logits = tf.layers.dense( dec_stp, cfg.data_max_discrete_velocities + 1) dec_recons_velocity_loss = weighted_avg( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=dec_recons_velocity_logits, labels=velocities), stp_varlen_mask) out_dict["dec_recons_velocity_logits"] = dec_recons_velocity_logits out_dict["dec_recons_velocity_loss"] = dec_recons_velocity_loss # Stats if cfg.stp_emb_vq or cfg.stp_emb_iq: discrete = out_dict[ "stp_emb_vq_discrete" if cfg.stp_emb_vq else "stp_emb_iq_discrete"] dx = pitches[:, 1:] - pitches[:, :-1] dy = discrete[:, 1:] - discrete[:, :-1] contour_violation = tf.reduce_mean(tf.cast(tf.less(dx * dy, 0), tf.float32)) dx_hold = tf.equal(dx, 0) deviate_violation = weighted_avg( tf.cast(tf.not_equal(dy, 0), tf.float32), tf.cast(dx_hold, tf.float32)) out_dict["contour_violation"] = contour_violation out_dict["deviate_violation"] = deviate_violation return out_dict
def build_genie_model(feat_dict, cfg, batch_size, seq_len, is_training=True, seq_varlens=None, dtype=tf.float32): """Builds a Piano Genie model. Args: feat_dict: Dictionary containing input tensors. cfg: Configuration object. batch_size: Number of items in batch. seq_len: Length of each batch item. is_training: Set to False for evaluation. seq_varlens: If not None, a tensor with the batch sequence lengths. dtype: Model weight type. Returns: A dict containing tensors for relevant model config. """ out_dict = {} # Parse features pitches = util.demidify(feat_dict["midi_pitches"]) velocities = feat_dict["velocities"] pitches_scalar = ((tf.cast(pitches, tf.float32) / 87.) * 2.) - 1. # Create sequence lens if is_training and cfg.train_randomize_seq_len: seq_lens = tf.random_uniform( [batch_size], minval=cfg.train_seq_len_min, maxval=seq_len + 1, dtype=tf.int32) stp_varlen_mask = tf.sequence_mask( seq_lens, maxlen=seq_len, dtype=tf.float32) elif seq_varlens is not None: seq_lens = seq_varlens stp_varlen_mask = tf.sequence_mask( seq_varlens, maxlen=seq_len, dtype=tf.float32) else: seq_lens = tf.ones([batch_size], dtype=tf.int32) * seq_len stp_varlen_mask = None # Encode if (cfg.stp_emb_unconstrained or cfg.stp_emb_vq or cfg.stp_emb_iq or cfg.seq_emb_unconstrained or cfg.seq_emb_vae or cfg.lor_emb_unconstrained): # Build encoder features enc_feats = [] if cfg.enc_pitch_scalar: enc_feats.append(tf.expand_dims(pitches_scalar, axis=-1)) else: enc_feats.append(tf.one_hot(pitches, 88)) if "delta_times_int" in cfg.enc_aux_feats: enc_feats.append( tf.one_hot(feat_dict["delta_times_int"], cfg.data_max_discrete_times + 1)) if "velocities" in cfg.enc_aux_feats: enc_feats.append( tf.one_hot(velocities, cfg.data_max_discrete_velocities + 1)) enc_feats = tf.concat(enc_feats, axis=2) with tf.variable_scope("encoder"): enc_stp, enc_seq = simple_lstm_encoder( enc_feats, seq_lens, rnn_celltype=cfg.rnn_celltype, rnn_nlayers=cfg.rnn_nlayers, rnn_nunits=cfg.rnn_nunits, rnn_bidirectional=cfg.enc_rnn_bidirectional, dtype=dtype) latents = [] # Step embeddings (single vector per timestep) if cfg.stp_emb_unconstrained: with tf.variable_scope("stp_emb_unconstrained"): stp_emb_unconstrained = tf.layers.dense( enc_stp, cfg.stp_emb_unconstrained_embedding_dim) out_dict["stp_emb_unconstrained"] = stp_emb_unconstrained latents.append(stp_emb_unconstrained) # Quantized step embeddings with VQ-VAE if cfg.stp_emb_vq: with tf.variable_scope("stp_emb_vq"): with tf.variable_scope("pre_vq"): # pre_vq_encoding is tf.float32 of [batch_size, seq_len, embedding_dim] pre_vq_encoding = tf.layers.dense(enc_stp, cfg.stp_emb_vq_embedding_dim) with tf.variable_scope("quantizer"): assert stp_varlen_mask is None vq_vae = snt.nets.VectorQuantizer( embedding_dim=cfg.stp_emb_vq_embedding_dim, num_embeddings=cfg.stp_emb_vq_codebook_size, commitment_cost=cfg.stp_emb_vq_commitment_cost) vq_vae_output = vq_vae(pre_vq_encoding, is_training=is_training) stp_emb_vq_quantized = vq_vae_output["quantize"] stp_emb_vq_discrete = tf.reshape( tf.argmax(vq_vae_output["encodings"], axis=1, output_type=tf.int32), [batch_size, seq_len]) stp_emb_vq_codebook = tf.transpose(vq_vae.embeddings) out_dict["stp_emb_vq_quantized"] = stp_emb_vq_quantized out_dict["stp_emb_vq_discrete"] = stp_emb_vq_discrete out_dict["stp_emb_vq_loss"] = vq_vae_output["loss"] out_dict["stp_emb_vq_codebook"] = stp_emb_vq_codebook out_dict["stp_emb_vq_codebook_ppl"] = vq_vae_output["perplexity"] latents.append(stp_emb_vq_quantized) # This tensor retrieves continuous embeddings from codebook. It should # *never* be used during training. out_dict["stp_emb_vq_quantized_lookup"] = tf.nn.embedding_lookup( stp_emb_vq_codebook, stp_emb_vq_discrete) # Integer-quantized step embeddings with straight-through if cfg.stp_emb_iq: with tf.variable_scope("stp_emb_iq"): with tf.variable_scope("pre_iq"): # pre_iq_encoding is tf.float32 of [batch_size, seq_len] pre_iq_encoding = tf.layers.dense(enc_stp, 1)[:, :, 0] def iqst(x, n): """Integer quantization with straight-through estimator.""" eps = 1e-7 s = float(n - 1) xp = tf.clip_by_value((x + 1) / 2.0, -eps, 1 + eps) xpp = tf.round(s * xp) xppp = 2 * (xpp / s) - 1 return xpp, x + tf.stop_gradient(xppp - x) with tf.variable_scope("quantizer"): # Pass rounded vals to decoder w/ straight-through estimator stp_emb_iq_discrete_f, stp_emb_iq_discrete_rescaled = iqst( pre_iq_encoding, cfg.stp_emb_iq_nbins) stp_emb_iq_discrete = tf.cast(stp_emb_iq_discrete_f + 1e-4, tf.int32) stp_emb_iq_discrete_f = tf.cast(stp_emb_iq_discrete, tf.float32) stp_emb_iq_quantized = tf.expand_dims( stp_emb_iq_discrete_rescaled, axis=2) # Determine which elements round to valid indices stp_emb_iq_inrange = tf.logical_and( tf.greater_equal(pre_iq_encoding, -1), tf.less_equal(pre_iq_encoding, 1)) stp_emb_iq_inrange_mask = tf.cast(stp_emb_iq_inrange, tf.float32) stp_emb_iq_valid_p = weighted_avg(stp_emb_iq_inrange_mask, stp_varlen_mask) # Regularize to encourage encoder to output in range stp_emb_iq_range_penalty = weighted_avg( tf.square(tf.maximum(tf.abs(pre_iq_encoding) - 1, 0)), stp_varlen_mask) # Regularize to correlate latent finite differences to input stp_emb_iq_dlatents = pre_iq_encoding[:, 1:] - pre_iq_encoding[:, :-1] if cfg.stp_emb_iq_contour_dy_scalar: stp_emb_iq_dnotes = pitches_scalar[:, 1:] - pitches_scalar[:, :-1] else: stp_emb_iq_dnotes = tf.cast(pitches[:, 1:] - pitches[:, :-1], tf.float32) if cfg.stp_emb_iq_contour_exp == 1: power_func = tf.identity elif cfg.stp_emb_iq_contour_exp == 2: power_func = tf.square else: raise NotImplementedError() if cfg.stp_emb_iq_contour_comp == "product": comp_func = tf.multiply elif cfg.stp_emb_iq_contour_comp == "quotient": comp_func = lambda x, y: tf.divide(x, y + 1e-6) else: raise NotImplementedError() stp_emb_iq_contour_penalty = weighted_avg( power_func( tf.maximum( cfg.stp_emb_iq_contour_margin - comp_func( stp_emb_iq_dnotes, stp_emb_iq_dlatents), 0)), None if stp_varlen_mask is None else stp_varlen_mask[:, 1:]) # Regularize to maintain note consistency stp_emb_iq_note_held = tf.cast( tf.equal(pitches[:, 1:] - pitches[:, :-1], 0), tf.float32) if cfg.stp_emb_iq_deviate_exp == 1: power_func = tf.abs elif cfg.stp_emb_iq_deviate_exp == 2: power_func = tf.square if stp_varlen_mask is None: mask = stp_emb_iq_note_held else: mask = stp_varlen_mask[:, 1:] * stp_emb_iq_note_held stp_emb_iq_deviate_penalty = weighted_avg( power_func(stp_emb_iq_dlatents), mask) # Calculate perplexity of discrete encoder posterior if stp_varlen_mask is None: mask = stp_emb_iq_inrange_mask else: mask = stp_varlen_mask * stp_emb_iq_inrange_mask stp_emb_iq_discrete_oh = tf.one_hot(stp_emb_iq_discrete, cfg.stp_emb_iq_nbins) stp_emb_iq_avg_probs = weighted_avg( stp_emb_iq_discrete_oh, mask, axis=[0, 1], expand_mask=True) stp_emb_iq_discrete_ppl = tf.exp(-tf.reduce_sum( stp_emb_iq_avg_probs * tf.log(stp_emb_iq_avg_probs + 1e-10))) out_dict["stp_emb_iq_quantized"] = stp_emb_iq_quantized out_dict["stp_emb_iq_discrete"] = stp_emb_iq_discrete out_dict["stp_emb_iq_valid_p"] = stp_emb_iq_valid_p out_dict["stp_emb_iq_range_penalty"] = stp_emb_iq_range_penalty out_dict["stp_emb_iq_contour_penalty"] = stp_emb_iq_contour_penalty out_dict["stp_emb_iq_deviate_penalty"] = stp_emb_iq_deviate_penalty out_dict["stp_emb_iq_discrete_ppl"] = stp_emb_iq_discrete_ppl latents.append(stp_emb_iq_quantized) # This tensor converts discrete values to continuous. # It should *never* be used during training. out_dict["stp_emb_iq_quantized_lookup"] = tf.expand_dims( 2. * (stp_emb_iq_discrete_f / (cfg.stp_emb_iq_nbins - 1.)) - 1., axis=2) # Sequence embedding (single vector per sequence) if cfg.seq_emb_unconstrained: with tf.variable_scope("seq_emb_unconstrained"): seq_emb_unconstrained = tf.layers.dense( enc_seq, cfg.seq_emb_unconstrained_embedding_dim) out_dict["seq_emb_unconstrained"] = seq_emb_unconstrained seq_emb_unconstrained = tf.stack([seq_emb_unconstrained] * seq_len, axis=1) latents.append(seq_emb_unconstrained) # Sequence embeddings (variational w/ reparameterization trick) if cfg.seq_emb_vae: with tf.variable_scope("seq_emb_vae"): seq_emb_vae = tf.layers.dense(enc_seq, cfg.seq_emb_vae_embedding_dim * 2) mean = seq_emb_vae[:, :cfg.seq_emb_vae_embedding_dim] stddev = 1e-6 + tf.nn.softplus( seq_emb_vae[:, cfg.seq_emb_vae_embedding_dim:]) seq_emb_vae = mean + stddev * tf.random_normal( tf.shape(mean), 0, 1, dtype=dtype) kl = tf.reduce_mean(0.5 * tf.reduce_sum( tf.square(mean) + tf.square(stddev) - tf.log(1e-8 + tf.square(stddev)) - 1, axis=1)) out_dict["seq_emb_vae"] = seq_emb_vae out_dict["seq_emb_vae_kl"] = kl seq_emb_vae = tf.stack([seq_emb_vae] * seq_len, axis=1) latents.append(seq_emb_vae) # Low-rate embeddings if cfg.lor_emb_unconstrained: assert seq_len % cfg.lor_emb_n == 0 with tf.variable_scope("lor_emb_unconstrained"): # Downsample step embeddings rnn_embedding_dim = int(enc_stp.get_shape()[-1]) enc_lor = tf.reshape(enc_stp, [ batch_size, seq_len // cfg.lor_emb_n, cfg.lor_emb_n * rnn_embedding_dim ]) lor_emb_unconstrained = tf.layers.dense( enc_lor, cfg.lor_emb_unconstrained_embedding_dim) out_dict["lor_emb_unconstrained"] = lor_emb_unconstrained # Upsample lo-rate embeddings for decoding lor_emb_unconstrained = tf.expand_dims(lor_emb_unconstrained, axis=2) lor_emb_unconstrained = tf.tile(lor_emb_unconstrained, [1, 1, cfg.lor_emb_n, 1]) lor_emb_unconstrained = tf.reshape( lor_emb_unconstrained, [batch_size, seq_len, cfg.lor_emb_unconstrained_embedding_dim]) latents.append(lor_emb_unconstrained) # Build decoder features dec_feats = latents if cfg.dec_autoregressive: # Retrieve pitch numbers curr_pitches = pitches last_pitches = curr_pitches[:, :-1] last_pitches = tf.pad( last_pitches, [[0, 0], [1, 0]], constant_values=-1) # Prepend <SOS> token out_dict["dec_last_pitches"] = last_pitches dec_feats.append(tf.one_hot(last_pitches + 1, 89)) if cfg.dec_pred_velocity: curr_velocities = velocities last_velocities = curr_velocities[:, :-1] last_velocities = tf.pad(last_velocities, [[0, 0], [1, 0]]) dec_feats.append( tf.one_hot(last_velocities, cfg.data_max_discrete_velocities + 1)) if "delta_times_int" in cfg.dec_aux_feats: dec_feats.append( tf.one_hot(feat_dict["delta_times_int"], cfg.data_max_discrete_times + 1)) if "velocities" in cfg.dec_aux_feats: assert not cfg.dec_pred_velocity dec_feats.append( tf.one_hot(feat_dict["velocities"], cfg.data_max_discrete_velocities + 1)) assert dec_feats dec_feats = tf.concat(dec_feats, axis=2) # Decode with tf.variable_scope("decoder"): dec_stp, dec_initial_state, dec_final_state = simple_lstm_decoder( dec_feats, seq_lens, batch_size, rnn_celltype=cfg.rnn_celltype, rnn_nlayers=cfg.rnn_nlayers, rnn_nunits=cfg.rnn_nunits) with tf.variable_scope("pitches"): dec_recons_logits = tf.layers.dense(dec_stp, 88) dec_recons_loss = weighted_avg( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=dec_recons_logits, labels=pitches), stp_varlen_mask) out_dict["dec_initial_state"] = dec_initial_state out_dict["dec_final_state"] = dec_final_state out_dict["dec_recons_logits"] = dec_recons_logits out_dict["dec_recons_scores"] = tf.nn.softmax(dec_recons_logits, axis=-1) out_dict["dec_recons_preds"] = tf.argmax( dec_recons_logits, output_type=tf.int32, axis=-1) out_dict["dec_recons_midi_preds"] = util.remidify( out_dict["dec_recons_preds"]) out_dict["dec_recons_loss"] = dec_recons_loss if cfg.dec_pred_velocity: with tf.variable_scope("velocities"): dec_recons_velocity_logits = tf.layers.dense( dec_stp, cfg.data_max_discrete_velocities + 1) dec_recons_velocity_loss = weighted_avg( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=dec_recons_velocity_logits, labels=velocities), stp_varlen_mask) out_dict["dec_recons_velocity_logits"] = dec_recons_velocity_logits out_dict["dec_recons_velocity_loss"] = dec_recons_velocity_loss # Stats if cfg.stp_emb_vq or cfg.stp_emb_iq: discrete = out_dict[ "stp_emb_vq_discrete" if cfg.stp_emb_vq else "stp_emb_iq_discrete"] dx = pitches[:, 1:] - pitches[:, :-1] dy = discrete[:, 1:] - discrete[:, :-1] contour_violation = tf.reduce_mean(tf.cast(tf.less(dx * dy, 0), tf.float32)) dx_hold = tf.equal(dx, 0) deviate_violation = weighted_avg( tf.cast(tf.not_equal(dy, 0), tf.float32), tf.cast(dx_hold, tf.float32)) out_dict["contour_violation"] = contour_violation out_dict["deviate_violation"] = deviate_violation return out_dict