Esempio n. 1
0
def main(unused_argv):
    if not tf.gfile.IsDirectory(FLAGS.eval_dir):
        tf.gfile.MakeDirs(FLAGS.eval_dir)

    cfg, _ = get_named_config(FLAGS.model_cfg, FLAGS.model_cfg_overrides)

    # Load data
    with tf.name_scope("loader"):
        feat_dict = load_noteseqs(
            FLAGS.dataset_fp,
            cfg.eval_batch_size,
            cfg.eval_seq_len,
            max_discrete_times=cfg.data_max_discrete_times,
            max_discrete_velocities=cfg.data_max_discrete_velocities,
            augment_stretch_bounds=None,
            augment_transpose_bounds=None,
            randomize_chord_order=cfg.data_randomize_chord_order,
            repeat=False)

    # Build model
    with tf.variable_scope("phero_model"):
        model_dict = build_genie_model(
            feat_dict,
            cfg,
            cfg.eval_batch_size,
            cfg.eval_seq_len,
            is_training=False)
    genie_vars = tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope="phero_model")

    # Build gold model
    eval_gold = False
    if cfg.stp_emb_vq or cfg.stp_emb_iq:
        eval_gold = True
        with tf.variable_scope("phero_model", reuse=True):
            gold_feat_dict = {
                "midi_pitches": tf.placeholder(tf.int32, [1, None]),
                "velocities": tf.placeholder(tf.int32, [1, None]),
                "delta_times_int": tf.placeholder(tf.int32, [1, None])
            }
            gold_seq_maxlen = gold.gold_longest()
            gold_seq_varlens = tf.placeholder(tf.int32, [1])
            gold_buttons = tf.placeholder(tf.int32, [1, None])
            gold_model_dict = build_genie_model(
                gold_feat_dict,
                cfg,
                1,
                gold_seq_maxlen,
                is_training=False,
                seq_varlens=gold_seq_varlens)

        gold_encodings = gold_model_dict[
            "stp_emb_vq_discrete" if cfg.stp_emb_vq else "stp_emb_iq_discrete"]
        gold_mask = tf.sequence_mask(
            gold_seq_varlens, maxlen=gold_seq_maxlen, dtype=tf.float32)
        gold_diff = tf.cast(gold_buttons, tf.float32) - tf.cast(
            gold_encodings, tf.float32)
        gold_diff_l2 = tf.square(gold_diff)
        gold_diff_l1 = tf.abs(gold_diff)

        def weighted_avg(t, m): return tf.reduce_sum(t * m) / tf.reduce_sum(m)

        gold_diff_l2 = weighted_avg(gold_diff_l2, gold_mask)
        gold_diff_l1 = weighted_avg(gold_diff_l1, gold_mask)

        gold_diff_l2_placeholder = tf.placeholder(tf.float32, [None])
        gold_diff_l1_placeholder = tf.placeholder(tf.float32, [None])

    summary_name_to_batch_tensor = {}

    # Summarize quantized step embeddings
    if cfg.stp_emb_vq:
        summary_name_to_batch_tensor["codebook_perplexity"] = model_dict[
            "stp_emb_vq_codebook_ppl"]
        summary_name_to_batch_tensor["loss_vqvae"] = model_dict["stp_emb_vq_loss"]

    # Summarize integer-quantized step embeddings
    if cfg.stp_emb_iq:
        summary_name_to_batch_tensor["discrete_perplexity"] = model_dict[
            "stp_emb_iq_discrete_ppl"]
        summary_name_to_batch_tensor["iq_valid_p"] = model_dict[
            "stp_emb_iq_valid_p"]
        summary_name_to_batch_tensor["loss_iq_range"] = model_dict[
            "stp_emb_iq_range_penalty"]
        summary_name_to_batch_tensor["loss_iq_contour"] = model_dict[
            "stp_emb_iq_contour_penalty"]
        summary_name_to_batch_tensor["loss_iq_deviate"] = model_dict[
            "stp_emb_iq_deviate_penalty"]

    if cfg.stp_emb_vq or cfg.stp_emb_iq:
        summary_name_to_batch_tensor["contour_violation"] = model_dict[
            "contour_violation"]
        summary_name_to_batch_tensor["deviate_violation"] = model_dict[
            "deviate_violation"]

    # Summarize VAE sequence embeddings
    if cfg.seq_emb_vae:
        summary_name_to_batch_tensor["loss_kl"] = model_dict["seq_emb_vae_kl"]

    # Reconstruction loss
    summary_name_to_batch_tensor["loss_recons"] = model_dict["dec_recons_loss"]
    summary_name_to_batch_tensor["ppl_recons"] = tf.exp(
        model_dict["dec_recons_loss"])
    if cfg.dec_pred_velocity:
        summary_name_to_batch_tensor["loss_recons_velocity"] = model_dict[
            "dec_recons_velocity_loss"]
        summary_name_to_batch_tensor["ppl_recons_velocity"] = tf.exp(
            model_dict["dec_recons_velocity_loss"])

    # Create dataset summaries
    summaries = []
    summary_name_to_placeholder = {}
    for name in summary_name_to_batch_tensor:
        placeholder = tf.placeholder(tf.float32, [None])
        summary_name_to_placeholder[name] = placeholder
        summaries.append(tf.summary.scalar(name, tf.reduce_mean(placeholder)))
    if eval_gold:
        summary_name_to_placeholder["gold_diff_l2"] = gold_diff_l2_placeholder
        summaries.append(
            tf.summary.scalar("gold_diff_l2",
                              tf.reduce_mean(gold_diff_l2_placeholder)))
        summary_name_to_placeholder["gold_diff_l1"] = gold_diff_l1_placeholder
        summaries.append(
            tf.summary.scalar("gold_diff_l1",
                              tf.reduce_mean(gold_diff_l1_placeholder)))

    summaries = tf.summary.merge(summaries)
    summary_writer = tf.summary.FileWriter(FLAGS.eval_dir)

    # Create saver
    step = tf.train.get_or_create_global_step()
    saver = tf.train.Saver(genie_vars + [step], max_to_keep=None)

    def _eval_all(sess):
        """Gathers all metrics for a ckpt."""
        summaries = collections.defaultdict(list)

        if eval_gold:
            for midi_notes, buttons, seq_varlen in gold.gold_iterator([-6, 6]):
                gold_diff_l1_seq, gold_diff_l2_seq = sess.run(
                    [gold_diff_l1, gold_diff_l2], {
                        gold_feat_dict["midi_pitches"]:
                            midi_notes,
                        gold_feat_dict["delta_times_int"]:
                            np.ones_like(midi_notes) * 8,
                        gold_seq_varlens: [seq_varlen],
                        gold_buttons: buttons
                    })
                summaries["gold_diff_l1"].append(gold_diff_l1_seq)
                summaries["gold_diff_l2"].append(gold_diff_l2_seq)

        while True:
            try:
                batches = sess.run(summary_name_to_batch_tensor)
            except tf.errors.OutOfRangeError:
                break

            for name, scalar in batches.items():
                summaries[name].append(scalar)

        return summaries

    # Eval
    if FLAGS.ckpt_fp is None:
        ckpt_fp = None
        while True:
            latest_ckpt_fp = tf.train.latest_checkpoint(FLAGS.train_dir)

            if latest_ckpt_fp != ckpt_fp:
                print("Eval: {}".format(latest_ckpt_fp))

                with tf.Session() as sess:
                    sess.run(tf.local_variables_initializer())
                    saver.restore(sess, latest_ckpt_fp)

                    ckpt_summaries = _eval_all(sess)
                    ckpt_summaries, ckpt_step = sess.run(
                        [summaries, step],
                        feed_dict={
                            summary_name_to_placeholder[n]: v
                            for n, v in ckpt_summaries.items()
                        })
                    summary_writer.add_summary(ckpt_summaries, ckpt_step)

                    saver.save(
                        sess, os.path.join(FLAGS.eval_dir, "ckpt"), global_step=ckpt_step)

                print("Done")
                ckpt_fp = latest_ckpt_fp

            time.sleep(1)
    else:
        with tf.Session() as sess:
            sess.run(tf.local_variables_initializer())
            saver.restore(sess, FLAGS.ckpt_fp)

            ckpt_summaries = _eval_all(sess)
            ckpt_step = sess.run(step)

            print("-" * 80)
            print("Ckpt: {}".format(FLAGS.ckpt_fp))
            print("Step: {}".format(ckpt_step))
            for n, l in sorted(list(ckpt_summaries.items()), key=lambda x: x[0]):
                print("{}: {}".format(n, np.mean(l)))
Esempio n. 2
0
def saturation_matrix(B, p, v):
    toss = tf.cast(tf.random.uniform([B]) < p, tf.float32)
    s = tf.exp(toss * tf.random.normal([B], 0, tf.math.log(2.)))
    vv_t = tf.reshape(tf.transpose(v) @ v, [1, 4, 4])
    return vv_t + (tf.reshape(tf.eye(4), [1, 4, 4]) - vv_t) * tf.reshape(
        s, [B, 1, 1])
Esempio n. 3
0
 def forward(self, weighted_input):
     return 1.0 / (1.0 + tf.exp(-weighted_input))
 def objective(self, params, data=None, labels=None):
     x, y = tf.split(params[0], 2, axis=0)
     obj = (-20 * tf.exp(-0.2 * tf.sqrt(0.5 * (x**2 + y**2))) -
            tf.exp(0.5 * (tf.cos(2 * np.pi * x) + tf.cos(2 * np.pi * y))) +
            tf.exp(1.0) + 20.)
     return tf.squeeze(obj)
Esempio n. 5
0
    def _body(i, posterior, activation, center, masses):
        """Body of the EM while loop."""
        del activation
        beta = final_beta * (1 - tf.pow(0.95, tf.cast(i + 1, tf.float32)))
        # beta = final_beta
        # route: [outdim, height?, width?, batch, indim]
        vote_conf = posterior * input_activation
        # masses: [batch, 1, outdim, 1, height, width, 1, 1]
        masses = tf.reduce_sum(tf.reduce_sum(tf.reduce_sum(
            vote_conf, axis=1, keep_dims=True),
                                             axis=-1,
                                             keep_dims=True),
                               axis=-2,
                               keep_dims=True) + 0.0000001
        preactivate_unrolled = vote_conf * wx
        # center: [batch, 1, outdim, outatom, height, width]
        center = .9 * tf.reduce_sum(tf.reduce_sum(tf.reduce_sum(
            preactivate_unrolled, axis=1, keep_dims=True),
                                                  axis=-1,
                                                  keep_dims=True),
                                    axis=-2,
                                    keep_dims=True) / masses + .1 * center

        noise = (wx - center) * (wx - center)
        variance = min_var + tf.reduce_sum(tf.reduce_sum(tf.reduce_sum(
            vote_conf * noise, axis=1, keep_dims=True),
                                                         axis=-1,
                                                         keep_dims=True),
                                           axis=-2,
                                           keep_dims=True) / masses
        log_variance = tf.log(variance)
        p_i = -1 * tf.reduce_sum(log_variance, axis=3, keep_dims=True)
        log_2pi = tf.log(2 * math.pi)
        win = masses * (p_i - sigma_biases * num_out_atoms * (log_2pi + 1.0))
        logit = beta * (win - activation_biases * 5000)
        activation_update = tf.minimum(
            0.0, logit) - tf.log(1 + tf.exp(-tf.abs(logit)))
        # return activation, center
        log_det_sigma = -1 * p_i
        sigma_update = (num_out_atoms * log_2pi + log_det_sigma) / 2.0
        exp_update = tf.reduce_sum(noise / (2 * variance),
                                   axis=3,
                                   keep_dims=True)
        prior_update = activation_update - sigma_update - exp_update
        max_prior_update = tf.reduce_max(tf.reduce_max(tf.reduce_max(
            tf.reduce_max(prior_update, axis=-1, keep_dims=True),
            axis=-2,
            keep_dims=True),
                                                       axis=-3,
                                                       keep_dims=True),
                                         axis=-4,
                                         keep_dims=True)
        prior_normal = tf.add(prior_update, -1 * max_prior_update)
        prior_exp = tf.exp(prior_normal)
        t_prior = tf.transpose(prior_exp, [0, 1, 2, 3, 4, 6, 5, 7])
        c_prior = tf.reshape(t_prior, [-1, n * k, n * k, 1])
        pad_prior = tf.pad(c_prior,
                           [[0, 0], [(k - 1) * (k - 1), (k - 1) * (k - 1)],
                            [(k - 1) * (k - 1),
                             (k - 1) * (k - 1)], [0, 0]], 'CONSTANT')
        patch_prior = tf.extract_image_patches(images=pad_prior,
                                               ksizes=[1, k, k, 1],
                                               strides=[1, k, k, 1],
                                               rates=[1, k - 1, k - 1, 1],
                                               padding='VALID')
        sum_prior = tf.reduce_sum(patch_prior, axis=-1, keep_dims=True)
        sum_prior_patch = tf.extract_image_patches(images=sum_prior,
                                                   ksizes=[1, k, k, 1],
                                                   strides=[1, 1, 1, 1],
                                                   rates=[1, 1, 1, 1],
                                                   padding='VALID')
        sum_prior_reshape = tf.reshape(
            sum_prior_patch,
            [-1, input_dim, output_dim, 1, n, n, k, k]) + 0.0000001
        posterior = prior_exp / sum_prior_reshape
        return (posterior, logit, center, masses)
Esempio n. 6
0
    def build_model(self, hps):
        """Define model architecture."""
        if hps.is_training:
            self.global_step = tf.Variable(0,
                                           name='global_step',
                                           trainable=False)

        if hps.dec_model == 'lstm':
            cell_fn = rnn.LSTMCell
        elif hps.dec_model == 'layer_norm':
            cell_fn = rnn.LayerNormLSTMCell
        elif hps.dec_model == 'hyper':
            cell_fn = rnn.HyperLSTMCell
        else:
            assert False, 'please choose a respectable cell'

        if hps.enc_model == 'lstm':
            enc_cell_fn = rnn.LSTMCell
        elif hps.enc_model == 'layer_norm':
            enc_cell_fn = rnn.LayerNormLSTMCell
        elif hps.enc_model == 'hyper':
            enc_cell_fn = rnn.HyperLSTMCell
        else:
            assert False, 'please choose a respectable cell'

        use_recurrent_dropout = self.hps.use_recurrent_dropout
        use_input_dropout = self.hps.use_input_dropout
        use_output_dropout = self.hps.use_output_dropout

        cell = cell_fn(hps.dec_rnn_size,
                       use_recurrent_dropout=use_recurrent_dropout,
                       dropout_keep_prob=self.hps.recurrent_dropout_prob)

        if hps.conditional:  # vae mode:
            if hps.enc_model == 'hyper':
                self.enc_cell_fw = enc_cell_fn(
                    hps.enc_rnn_size,
                    use_recurrent_dropout=use_recurrent_dropout,
                    dropout_keep_prob=self.hps.recurrent_dropout_prob)
                self.enc_cell_bw = enc_cell_fn(
                    hps.enc_rnn_size,
                    use_recurrent_dropout=use_recurrent_dropout,
                    dropout_keep_prob=self.hps.recurrent_dropout_prob)
            else:
                self.enc_cell_fw = enc_cell_fn(
                    hps.enc_rnn_size,
                    use_recurrent_dropout=use_recurrent_dropout,
                    dropout_keep_prob=self.hps.recurrent_dropout_prob)
                self.enc_cell_bw = enc_cell_fn(
                    hps.enc_rnn_size,
                    use_recurrent_dropout=use_recurrent_dropout,
                    dropout_keep_prob=self.hps.recurrent_dropout_prob)

        # dropout:
        tf.logging.info('Input dropout mode = %s.', use_input_dropout)
        tf.logging.info('Output dropout mode = %s.', use_output_dropout)
        tf.logging.info('Recurrent dropout mode = %s.', use_recurrent_dropout)
        if use_input_dropout:
            tf.logging.info('Dropout to input w/ keep_prob = %4.4f.',
                            self.hps.input_dropout_prob)
            cell = contrib_rnn.DropoutWrapper(
                cell, input_keep_prob=self.hps.input_dropout_prob)
        if use_output_dropout:
            tf.logging.info('Dropout to output w/ keep_prob = %4.4f.',
                            self.hps.output_dropout_prob)
            cell = contrib_rnn.DropoutWrapper(
                cell, output_keep_prob=self.hps.output_dropout_prob)
        self.cell = cell

        self.sequence_lengths = tf.placeholder(dtype=tf.int32,
                                               shape=[self.hps.batch_size])
        self.input_data = tf.placeholder(
            dtype=tf.float32,
            shape=[self.hps.batch_size, self.hps.max_seq_len + 1, 5])

        # The target/expected vectors of strokes
        self.output_x = self.input_data[:, 1:self.hps.max_seq_len + 1, :]
        # vectors of strokes to be fed to decoder (same as above, but lagged behind
        # one step to include initial dummy value of (0, 0, 1, 0, 0))
        self.input_x = self.input_data[:, :self.hps.max_seq_len, :]

        # either do vae-bit and get z, or do unconditional, decoder-only
        if hps.conditional:  # vae mode:
            self.mean, self.presig = self.encoder(self.output_x,
                                                  self.sequence_lengths)
            self.sigma = tf.exp(self.presig /
                                2.0)  # sigma > 0. div 2.0 -> sqrt.
            eps = tf.random_normal((self.hps.batch_size, self.hps.z_size),
                                   0.0,
                                   1.0,
                                   dtype=tf.float32)
            self.batch_z = self.mean + tf.multiply(self.sigma, eps)
            # KL cost
            self.kl_cost = -0.5 * tf.reduce_mean(
                (1 + self.presig - tf.square(self.mean) - tf.exp(self.presig)))
            self.kl_cost = tf.maximum(self.kl_cost, self.hps.kl_tolerance)
            pre_tile_y = tf.reshape(self.batch_z,
                                    [self.hps.batch_size, 1, self.hps.z_size])
            overlay_x = tf.tile(pre_tile_y, [1, self.hps.max_seq_len, 1])
            actual_input_x = tf.concat([self.input_x, overlay_x], 2)
            self.initial_state = tf.nn.tanh(
                rnn.super_linear(self.batch_z,
                                 cell.state_size,
                                 init_w='gaussian',
                                 weight_start=0.001,
                                 input_size=self.hps.z_size))
        else:  # unconditional, decoder-only generation
            self.batch_z = tf.zeros((self.hps.batch_size, self.hps.z_size),
                                    dtype=tf.float32)
            self.kl_cost = tf.zeros([], dtype=tf.float32)
            actual_input_x = self.input_x
            self.initial_state = cell.zero_state(batch_size=hps.batch_size,
                                                 dtype=tf.float32)

        self.num_mixture = hps.num_mixture

        # TODO(deck): Better understand this comment.
        # Number of outputs is 3 (one logit per pen state) plus 6 per mixture
        # component: mean_x, stdev_x, mean_y, stdev_y, correlation_xy, and the
        # mixture weight/probability (Pi_k)
        n_out = (3 + self.num_mixture * 6)

        with tf.variable_scope('RNN'):
            output_w = tf.get_variable('output_w',
                                       [self.hps.dec_rnn_size, n_out])
            output_b = tf.get_variable('output_b', [n_out])

        # decoder module of sketch-rnn is below
        output, last_state = tf.nn.dynamic_rnn(
            cell,
            actual_input_x,
            initial_state=self.initial_state,
            time_major=False,
            swap_memory=True,
            dtype=tf.float32,
            scope='RNN')

        output = tf.reshape(output, [-1, hps.dec_rnn_size])
        output = tf.nn.xw_plus_b(output, output_w, output_b)
        self.final_state = last_state

        # NB: the below are inner functions, not methods of Model
        def tf_2d_normal(x1, x2, mu1, mu2, s1, s2, rho):
            """Returns result of eq # 24 of http://arxiv.org/abs/1308.0850."""
            norm1 = tf.subtract(x1, mu1)
            norm2 = tf.subtract(x2, mu2)
            s1s2 = tf.multiply(s1, s2)
            # eq 25
            z = (tf.square(tf.div(norm1, s1)) + tf.square(tf.div(norm2, s2)) -
                 2 * tf.div(tf.multiply(rho, tf.multiply(norm1, norm2)), s1s2))
            neg_rho = 1 - tf.square(rho)
            result = tf.exp(tf.div(-z, 2 * neg_rho))
            denom = 2 * np.pi * tf.multiply(s1s2, tf.sqrt(neg_rho))
            result = tf.div(result, denom)
            return result

        def get_lossfunc(z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr,
                         z_pen_logits, x1_data, x2_data, pen_data):
            """Returns a loss fn based on eq #26 of http://arxiv.org/abs/1308.0850."""
            # This represents the L_R only (i.e. does not include the KL loss term).

            result0 = tf_2d_normal(x1_data, x2_data, z_mu1, z_mu2, z_sigma1,
                                   z_sigma2, z_corr)
            epsilon = 1e-6
            # result1 is the loss wrt pen offset (L_s in equation 9 of
            # https://arxiv.org/pdf/1704.03477.pdf)
            result1 = tf.multiply(result0, z_pi)
            result1 = tf.reduce_sum(result1, 1, keep_dims=True)
            result1 = -tf.log(result1 + epsilon)  # avoid log(0)

            fs = 1.0 - pen_data[:, 2]  # use training data for this
            fs = tf.reshape(fs, [-1, 1])
            # Zero out loss terms beyond N_s, the last actual stroke
            result1 = tf.multiply(result1, fs)

            # result2: loss wrt pen state, (L_p in equation 9)
            result2 = tf.nn.softmax_cross_entropy_with_logits(
                labels=pen_data, logits=z_pen_logits)
            result2 = tf.reshape(result2, [-1, 1])
            if not self.hps.is_training:  # eval mode, mask eos columns
                result2 = tf.multiply(result2, fs)

            result = result1 + result2
            return result

        # below is where we need to do MDN (Mixture Density Network) splitting of
        # distribution params
        def get_mixture_coef(output):
            """Returns the tf slices containing mdn dist params."""
            # This uses eqns 18 -> 23 of http://arxiv.org/abs/1308.0850.
            z = output
            z_pen_logits = z[:, 0:3]  # pen states
            z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = tf.split(
                z[:, 3:], 6, 1)

            # process output z's into MDN parameters

            # softmax all the pi's and pen states:
            z_pi = tf.nn.softmax(z_pi)
            z_pen = tf.nn.softmax(z_pen_logits)

            # exponentiate the sigmas and also make corr between -1 and 1.
            z_sigma1 = tf.exp(z_sigma1)
            z_sigma2 = tf.exp(z_sigma2)
            z_corr = tf.tanh(z_corr)

            r = [
                z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen,
                z_pen_logits
            ]
            return r

        out = get_mixture_coef(output)
        [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen,
         o_pen_logits] = out

        self.pi = o_pi
        self.mu1 = o_mu1
        self.mu2 = o_mu2
        self.sigma1 = o_sigma1
        self.sigma2 = o_sigma2
        self.corr = o_corr
        self.pen_logits = o_pen_logits
        # pen state probabilities (result of applying softmax to self.pen_logits)
        self.pen = o_pen

        # reshape target data so that it is compatible with prediction shape
        target = tf.reshape(self.output_x, [-1, 5])
        [x1_data, x2_data, eos_data, eoc_data,
         cont_data] = tf.split(target, 5, 1)
        pen_data = tf.concat([eos_data, eoc_data, cont_data], 1)

        lossfunc = get_lossfunc(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr,
                                o_pen_logits, x1_data, x2_data, pen_data)

        self.r_cost = tf.reduce_mean(lossfunc)

        if self.hps.is_training:
            self.lr = tf.Variable(self.hps.learning_rate, trainable=False)
            optimizer = tf.train.AdamOptimizer(self.lr)

            self.kl_weight = tf.Variable(self.hps.kl_weight_start,
                                         trainable=False)
            self.cost = self.r_cost + self.kl_cost * self.kl_weight

            gvs = optimizer.compute_gradients(self.cost)
            g = self.hps.grad_clip
            capped_gvs = [(tf.clip_by_value(grad, -g, g), var)
                          for grad, var in gvs]
            self.train_op = optimizer.apply_gradients(
                capped_gvs, global_step=self.global_step, name='train_step')
Esempio n. 7
0
def _tf_lognormal(y, mean, logstd, logsqrttwopi):
    return -0.5 * ((y - mean) / tf.exp(logstd))**2 - logstd - logsqrttwopi
Esempio n. 8
0
def softmax(x, axis=-1):
    x = x - tf.reduce_max(x, axis=axis, keepdims=True)
    ex = tf.exp(x)
    return ex / tf.reduce_sum(ex, axis=axis, keepdims=True)
Esempio n. 9
0
                    dtype=tf.float64)
means = tf.Variable(initial_means, dtype=tf.float64)
covariances = tf.Variable(initial_covariances, dtype=tf.float64)
weights = tf.Variable(initial_weights, dtype=tf.float64)

# E-step: recomputing responsibilities with respect to the current parameter values
sq_distances = tf.squared_difference(tf.expand_dims(input, 0),
                                     tf.expand_dims(means, 1))
sum_sq_dist_times_inv_cov = tf.reduce_sum(
    sq_distances / tf.expand_dims(covariances, 1), 2)
log_coefficients = tf.expand_dims(
    ln2piD + tf.reduce_sum(tf.log(covariances), 1), 1)
log_components = -0.5 * (log_coefficients + sum_sq_dist_times_inv_cov)
log_weighted = log_components + tf.expand_dims(tf.log(weights), 1)
log_shift = tf.expand_dims(tf.reduce_max(log_weighted, 0), 0)
exp_log_shifted = tf.exp(log_weighted - log_shift)
exp_log_shifted_sum = tf.reduce_sum(exp_log_shifted, 0)
gamma = exp_log_shifted / exp_log_shifted_sum

# M-step: maximizing parameter values with respect to the computed responsibilities
gamma_sum = tf.reduce_sum(gamma, 1)
gamma_weighted = gamma / tf.expand_dims(gamma_sum, 1)
means_ = tf.reduce_sum(
    tf.expand_dims(input, 0) * tf.expand_dims(gamma_weighted, 2), 1)
distances_ = tf.squared_difference(tf.expand_dims(input, 0),
                                   tf.expand_dims(means_, 1))
covariances_ = tf.reduce_sum(distances_ * tf.expand_dims(gamma_weighted, 2), 1)
weights_ = gamma_sum / tf.cast(tf.shape(input)[0], tf.float64)

# applying prior to the computed covariances
covariances_ *= tf.expand_dims(gamma_sum, 1)
Esempio n. 10
0
def compress(args):
    """Compresses an image, or a batch of images of the same shape in npy format."""
    from configs import get_eval_batch_size

    if args.input_file.endswith('.npy'):
        # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3]
        X = np.load(args.input_file)
    else:
        # Load input image and add batch dimension.
        from PIL import Image
        x = np.asarray(Image.open(args.input_file).convert('RGB'))
        X = x[None, ...]

    num_images = int(X.shape[0])
    img_num_pixels = int(np.prod(X.shape[1:-1]))
    X = X.astype('float32')
    X /= 255.

    eval_batch_size = get_eval_batch_size(img_num_pixels)
    dataset = tf.data.Dataset.from_tensor_slices(X)
    dataset = dataset.batch(batch_size=eval_batch_size)
    # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator
    # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on
    # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like
    # sess.run([op1, op2, ...]).
    # x = dataset.make_one_shot_iterator().get_next()
    x_next = dataset.make_one_shot_iterator().get_next()

    x_ph = x = tf.placeholder(
        'float32',
        (None, *X.shape[1:]))  # keep a reference around for feed_dict

    #### BEGIN build compression graph ####
    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    entropy_bottleneck = tfc.EntropyBottleneck()

    # Initial values for optimization
    y_init = analysis_transform(x)
    z_init = hyper_analysis_transform(y_init)

    y = tf.placeholder('float32', y_init.shape)
    y_tilde = y + tf.random.uniform(tf.shape(y), -0.5, 0.5)

    z = tf.placeholder('float32', z_init.shape)
    # sample z_tilde from q(z_tilde|x) = q(z_tilde|h_a(g_a(x))), and compute the pdf of z_tilde under the flexible prior
    # p(z_tilde) ("z_likelihoods")
    z_tilde, z_likelihoods = entropy_bottleneck(z, training=True)
    z_hat = entropy_bottleneck._quantize(
        z, 'dequantize')  # rounded (with median centering)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive
    # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
    y_shape = tf.shape(y_tilde)
    mu = mu[:, :y_shape[1], :y_shape[2], :]
    sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde)
    # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    y_likelihoods = conditional_bottleneck._likelihood(
        y_tilde)  # p(\tilde y | \tilde z)
    if conditional_bottleneck.likelihood_bound > 0:
        likelihood_bound = conditional_bottleneck.likelihood_bound
        y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound)
    y_hat = conditional_bottleneck._quantize(
        y, 'dequantize')  # rounded (with mean centering)

    x_tilde = synthesis_transform(y_tilde)
    x_shape = tf.shape(x)
    x_tilde = x_tilde[:, :x_shape[1], :x_shape[
        2], :]  # crop reconstruction to have the same shape as input

    # Total number of bits divided by number of pixels.
    # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y)
    axes_except_batch = list(range(1, len(x.shape)))  # should be [1,2,3]
    y_bpp = tf.reduce_sum(-tf.log(y_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    z_bpp = tf.reduce_sum(-tf.log(z_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    eval_bpp = y_bpp + z_bpp  # shape (N,)
    train_bpp = tf.reduce_mean(eval_bpp)

    # Mean squared error across pixels.
    train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde))
    # Multiply by 255^2 to correct for rescaling.
    # float_train_mse = train_mse
    # psnr = - 10 * (tf.log(float_train_mse) / np.log(10))  # float MSE computed on float images
    train_mse *= 255**2

    # The rate-distortion cost.
    if args.lmbda < 0:
        args.lmbda = float(args.runname.split('lmbda=')[1].split('-')
                           [0])  # re-use the lmbda as used for training
        print(
            'Defaulting lmbda (mse coefficient) to %g as used in model training.'
            % args.lmbda)
    if args.lmbda > 0:
        rd_loss = args.lmbda * train_mse + train_bpp
    else:
        rd_loss = train_bpp
    rd_gradients = tf.gradients(rd_loss, [y, z])

    # Bring both images back to 0..255 range, for evaluation only.
    x *= 255
    x_tilde = tf.clip_by_value(x_tilde, 0, 1)
    x_tilde = tf.round(x_tilde * 255)

    mse = tf.reduce_mean(tf.squared_difference(x, x_tilde),
                         axis=axes_except_batch)  # shape (N,)
    psnr = tf.image.psnr(x_tilde, x, 255)  # shape (N,)
    msssim = tf.image.ssim_multiscale(x_tilde, x, 255)  # shape (N,)
    msssim_db = -10 * tf.log(1 - msssim) / np.log(10)  # shape (N,)

    with tf.Session() as sess:
        # Load the latest model checkpoint, get compression stats
        save_dir = os.path.join(args.checkpoint_dir, args.runname)
        latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir)
        tf.train.Saver().restore(sess, save_path=latest)
        eval_fields = [
            'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp',
            'est_z_bpp'
        ]
        eval_tensors = [mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp]
        all_results_arrs = {key: []
                            for key in eval_fields
                            }  # append across all batches

        log_itv = 100
        if save_opt_record:
            log_itv = 10
        rd_lr = 0.005
        rd_opt_its = 2000
        from adam import Adam

        batch_idx = 0
        while True:
            try:
                x_val = sess.run(x_next)
                x_feed_dict = {x_ph: x_val}
                # 1. Perform R-D optimization conditioned on ground truth x
                print('----RD Optimization----')
                y_cur, z_cur = sess.run([y_init, z_init],
                                        feed_dict=x_feed_dict)  # np arrays
                adam_optimizer = Adam(lr=rd_lr)
                opt_record = {
                    'its': [],
                    'rd_loss': [],
                    'rd_loss_after_rounding': []
                }
                for it in range(rd_opt_its):
                    grads, obj, mse_, train_bpp_, psnr_ = sess.run(
                        [rd_gradients, rd_loss, train_mse, train_bpp, psnr],
                        feed_dict={
                            y: y_cur,
                            z: z_cur,
                            **x_feed_dict
                        })
                    y_cur, z_cur = adam_optimizer.update([y_cur, z_cur], grads)
                    if it % log_itv == 0 or it + 1 == rd_opt_its:
                        psnr_ = psnr_.mean()
                        if args.verbose:
                            y_hat_, z_hat_ = sess.run([y_hat, z_hat],
                                                      feed_dict={
                                                          y: y_cur,
                                                          z: z_cur
                                                      })
                            bpp_after_rounding, psnr_after_rounding, rd_loss_after_rounding = sess.run(
                                [train_bpp, psnr, rd_loss],
                                feed_dict={
                                    y_tilde: y_hat_,
                                    z_tilde: z_hat_,
                                    **x_feed_dict
                                })
                            psnr_after_rounding = psnr_after_rounding.mean()
                            print(
                                'it=%d, rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f\t after rounding: rd_loss=%.4f, bpp=%.4f psnr=%.4f'
                                % (it, obj, mse_, train_bpp_, psnr_,
                                   rd_loss_after_rounding, bpp_after_rounding,
                                   psnr_after_rounding))
                            opt_record['rd_loss_after_rounding'].append(
                                rd_loss_after_rounding)

                        else:
                            print(
                                'it=%d, rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f'
                                % (it, obj, mse_, train_bpp_, psnr_))
                        opt_record['its'].append(it)
                        opt_record['rd_loss'].append(obj)

                print()

                # this is the latents we end up transmitting
                y_hat_, z_hat_ = sess.run([y_hat, z_hat],
                                          feed_dict={
                                              y: y_cur,
                                              z: z_cur
                                          })

                # If requested, transform the quantized image back and measure performance.
                eval_arrs = sess.run(eval_tensors,
                                     feed_dict={
                                         y_tilde: y_hat_,
                                         z_tilde: z_hat_,
                                         **x_feed_dict
                                     })
                for field, arr in zip(eval_fields, eval_arrs):
                    all_results_arrs[field] += arr.tolist()

                batch_idx += 1

            except tf.errors.OutOfRangeError:
                break

        for field in eval_fields:
            all_results_arrs[field] = np.asarray(all_results_arrs[field])

        input_file = os.path.basename(args.input_file)
        results_dict = all_results_arrs
        trained_script_name = args.runname.split('-')[0]
        script_name = os.path.splitext(os.path.basename(__file__))[
            0]  # current script name, without extension

        # save RD evaluation results
        prefix = 'rd'
        save_file = '%s-%s-input=%s.npz' % (prefix, args.runname, input_file)
        if script_name != trained_script_name:
            save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % (
                prefix, script_name, args.lmbda, args.runname, input_file)
        np.savez(os.path.join(args.results_dir, save_file), **results_dict)

        if save_opt_record:
            # save optimization record
            prefix = 'opt'
            save_file = '%s-%s-input=%s.npz' % (prefix, args.runname,
                                                input_file)
            if script_name != trained_script_name:
                save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % (
                    prefix, script_name, args.lmbda, args.runname, input_file)
            np.savez(os.path.join(args.results_dir, save_file), **opt_record)

        for field in eval_fields:
            arr = all_results_arrs[field]
            print('Avg {}: {:0.4f}'.format(field, arr.mean()))
def multihead_invertible_1x1_conv_np(name, x, x_mask, multihead_split, inverse,
                                     dtype):
    """Multi-head 1X1 convolution on x."""
    batch_size, length, n_channels_all = common_layers.shape_list(x)
    assert n_channels_all % 32 == 0
    n_channels = 32
    n_1x1_heads = n_channels_all // n_channels

    def get_init_np():
        """Initializer function for multihead 1x1 parameters using numpy."""
        results = []
        for _ in range(n_1x1_heads):
            random_matrix = np.random.rand(n_channels, n_channels)
            np_w = scipy.linalg.qr(random_matrix)[0].astype("float32")
            np_p, np_l, np_u = scipy.linalg.lu(np_w)
            np_s = np.diag(np_u)
            np_sign_s = np.sign(np_s)[np.newaxis, :]
            np_log_s = np.log(np.abs(np_s))[np.newaxis, :]
            np_u = np.triu(np_u, k=1)
            results.append(
                np.concatenate([np_p, np_l, np_u, np_sign_s, np_log_s],
                               axis=0))
        return tf.convert_to_tensor(np.stack(results, axis=0))

    def get_mask_init():
        ones = tf.ones([n_1x1_heads, n_channels, n_channels], dtype=dtype)
        l_mask = tf.matrix_band_part(ones, -1, 0) - tf.matrix_band_part(
            ones, 0, 0)
        u_mask = tf.matrix_band_part(ones, 0, -1) - tf.matrix_band_part(
            ones, 0, 0)
        return tf.stack([l_mask, u_mask], axis=0)

    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        params = tf.get_variable("params",
                                 initializer=get_init_np,
                                 dtype=dtype)
        mask_params = tf.get_variable("mask_params",
                                      initializer=get_mask_init,
                                      dtype=dtype,
                                      trainable=False)

        p = tf.stop_gradient(params[:, :n_channels, :])
        l = params[:, n_channels:2 * n_channels, :]
        u = params[:, 2 * n_channels:3 * n_channels, :]
        sign_s = tf.stop_gradient(params[:, 3 * n_channels, :])
        log_s = params[:, 3 * n_channels + 1, :]

        l_mask = mask_params[0]
        u_mask = mask_params[1]

        l_diag = l * l_mask + (tf.eye(
            n_channels, n_channels, [n_1x1_heads], dtype=dtype))
        u_diag = u * u_mask + (tf.matrix_diag(sign_s * tf.exp(log_s)))
        w = tf.matmul(p, tf.matmul(l_diag, u_diag))

        if multihead_split == "a":
            x = tf.reshape(x, [batch_size, length, n_channels, n_1x1_heads])
            x = tf.transpose(x, [3, 0, 1, 2])
        elif multihead_split == "c":
            x = tf.reshape(x, [batch_size, length, n_1x1_heads, n_channels])
            x = tf.transpose(x, [2, 0, 1, 3])
        else:
            raise ValueError("Multihead split not supported.")
        # [n_1x1_heads, batch_size, length, n_channels]

        if not inverse:
            # [n_1x1_heads, 1, n_channels, n_channels]
            x = tf.matmul(x, w[:, tf.newaxis, :, :])
        else:
            w_inv = tf.matrix_inverse(w)
            x = tf.matmul(x, w_inv[:, tf.newaxis, :, :])

        if multihead_split == "a":
            x = tf.transpose(x, [1, 2, 3, 0])
            x = tf.reshape(x, [batch_size, length, n_channels * n_1x1_heads])
        elif multihead_split == "c":
            x = tf.transpose(x, [1, 2, 0, 3])
            x = tf.reshape(x, [batch_size, length, n_1x1_heads * n_channels])
        else:
            raise ValueError("Multihead split not supported.")

        x_length = tf.reduce_sum(x_mask, -1)
        logabsdet = x_length * tf.reduce_sum(log_s)
        if inverse:
            logabsdet *= -1
    return x, logabsdet
Esempio n. 12
0
def exp2(x: TfExpressionEx) -> TfExpression:
    """Exponent in base 2."""
    with tf.name_scope("Exp2"):
        return tf.exp(x * np.float32(np.log(2.0)))
    def __init__(self, args):
        self.args = args
        dense = tf.layers.dense

        inputs = tf.placeholder(shape=(args.batch_size, None),
                                dtype=tf.int32,
                                name='inputs')
        time_inputs = tf.placeholder(shape=(args.batch_size, None),
                                     dtype=tf.int32,
                                     name='time_inputs')
        mask = tf.placeholder(shape=(args.batch_size, None),
                              dtype=tf.float32,
                              name='inputs_mask')
        seq_length = tf.placeholder(shape=args.batch_size,
                                    dtype=tf.float32,
                                    name='seq_length')

        self.s_inputs = s_inputs = tf.placeholder(shape=args.batch_size,
                                                  dtype=tf.int32,
                                                  name='s_inputs')
        self.d_inputs = d_inputs = tf.placeholder(shape=args.batch_size,
                                                  dtype=tf.int32,
                                                  name='d_inputs')

        self.input_form = [inputs, time_inputs, mask, seq_length]

        decoder_inputs = tf.concat(
            [tf.zeros(shape=(args.batch_size, 1), dtype=tf.int32), inputs],
            axis=1)
        decoder_targets = tf.concat(
            [inputs,
             tf.zeros(shape=(args.batch_size, 1), dtype=tf.int32)],
            axis=1)
        decoder_mask = tf.concat(
            [mask,
             tf.zeros(shape=(args.batch_size, 1), dtype=tf.float32)],
            axis=1)

        x_size = out_size = args.map_size[0] * args.map_size[1]
        self.embeddings = embeddings = tf.Variable(tf.random_uniform(
            [x_size, args.x_latent_size], -1.0, 1.0),
                                                   dtype=tf.float32)
        self.encoder_inputs_embedded = encoder_inputs_embedded = tf.nn.embedding_lookup(
            embeddings, inputs)
        self.decoder_inputs_embedded = decoder_inputs_embedded = tf.nn.embedding_lookup(
            embeddings, decoder_inputs)

        self.time_embeddings = time_embeddings = tf.Variable(tf.random_uniform(
            [49, args.x_latent_size], -1.0, 1.0),
                                                             dtype=tf.float32)
        self.encoder_time_inputs_embedded = encoder_time_inputs_embedded = tf.nn.embedding_lookup(
            time_embeddings, time_inputs)

        time_mean = tf.reduce_mean(encoder_time_inputs_embedded, axis=1)
        mu_c_delta = dense(time_mean,
                           args.mem_num * args.rnn_size,
                           activation=None)
        mu_c_delta = tf.reduce_mean(mu_c_delta, axis=0)
        mu_c_delta = tf.reshape(mu_c_delta, [args.mem_num, args.rnn_size])

        log_sigma_sq_c_delta = dense(time_mean,
                                     args.mem_num * args.rnn_size,
                                     activation=None)
        log_sigma_sq_c_delta = tf.reduce_mean(log_sigma_sq_c_delta, axis=0)
        log_sigma_sq_c_delta = tf.reshape(log_sigma_sq_c_delta,
                                          [args.mem_num, args.rnn_size])

        with tf.variable_scope("encoder"):
            encoder_cell = tf.nn.rnn_cell.GRUCell(args.rnn_size)
            _, encoder_final_state = tf.nn.dynamic_rnn(
                encoder_cell,
                encoder_inputs_embedded,
                sequence_length=seq_length,
                dtype=tf.float32,
            )

        with tf.variable_scope("clusters"):
            mu_c = tf.get_variable("mu_c", [args.mem_num, args.rnn_size],
                                   initializer=tf.random_uniform_initializer(
                                       0.0, 1.0))
            log_sigma_sq_c = tf.get_variable(
                "sigma_sq_c", [args.mem_num, args.rnn_size],
                initializer=tf.constant_initializer(0.0),
                trainable=False)
            log_pi_prior = tf.get_variable(
                "log_pi_prior",
                args.mem_num,
                initializer=tf.constant_initializer(0.0),
                trainable=False)
            pi_prior = tf.nn.softmax(log_pi_prior)

            init_mu_c = tf.placeholder(shape=(args.mem_num, args.rnn_size),
                                       dtype=tf.float32,
                                       name='init_mu_c')
            init_sigma_c = tf.placeholder(shape=(args.mem_num, args.rnn_size),
                                          dtype=tf.float32,
                                          name='init_sigma_c')
            init_pi = tf.placeholder(shape=args.mem_num,
                                     dtype=tf.float32,
                                     name='init_pi')
            self.cluster_init = [init_mu_c, init_sigma_c, init_pi]

            self.init_mu_c_op = tf.assign(mu_c, init_mu_c)
            self.init_sigma_c_op = tf.assign(log_sigma_sq_c, init_sigma_c)
            self.init_pi_op = tf.assign(log_pi_prior, init_pi)

            self.mu_c = mu_c
            self.sigma_c = log_sigma_sq_c
            self.pi = pi_prior

            mu_c += mu_c_delta
            log_sigma_sq_c += log_sigma_sq_c_delta
            stack_mu_c = tf.stack([mu_c] * args.batch_size, axis=0)
            stack_log_sigma_sq_c = tf.stack([log_sigma_sq_c] * args.batch_size,
                                            axis=0)

        with tf.variable_scope("latent"):
            mu_z = dense(encoder_final_state, args.rnn_size,
                         activation=None)  # shape=(128, 256)
            log_sigma_sq_z = dense(encoder_final_state,
                                   args.rnn_size,
                                   activation=None)  # shape=(128, 256)

            eps_z = tf.random_normal(shape=tf.shape(log_sigma_sq_z),
                                     mean=0,
                                     stddev=1,
                                     dtype=tf.float32)
            z = mu_z + tf.sqrt(tf.exp(log_sigma_sq_z)) * eps_z

            stack_mu_z = tf.stack([mu_z] * args.mem_num, axis=1)
            stack_log_sigma_sq_z = tf.stack([log_sigma_sq_z] * args.mem_num,
                                            axis=1)
            stack_z = tf.stack([z] * args.mem_num, axis=1)
            self.batch_post_embedded = z

        with tf.variable_scope("sd_attention"):
            s_embeddings = tf.Variable(tf.random_uniform(
                [x_size, args.rnn_size], -1.0, 1.0),
                                       dtype=tf.float32)
            d_embeddings = tf.Variable(tf.random_uniform(
                [x_size, args.rnn_size], -1.0, 1.0),
                                       dtype=tf.float32)
            s = tf.nn.embedding_lookup(s_embeddings, s_inputs)
            d = tf.nn.embedding_lookup(d_embeddings, d_inputs)
            sd = tf.concat([s, d], axis=1)
            hsd1 = dense(sd, args.rnn_size, activation=tf.nn.relu)
            sd_logits = dense(hsd1, args.mem_num, activation=tf.nn.relu)
            sd_att = tf.nn.softmax(sd_logits)

        # for batch_latent_loss
        with tf.variable_scope("attention"):
            att_logits = -tf.reduce_sum(
                tf.square(stack_z - stack_mu_c) / tf.exp(stack_log_sigma_sq_c),
                axis=-1)
            att = tf.nn.softmax(att_logits) + 1e-10
            self.batch_att = att

        def generation(h):
            with tf.variable_scope("generation", reuse=tf.AUTO_REUSE):
                with tf.variable_scope("decoder"):
                    decoder_init_state = h
                    decoder_cell = tf.nn.rnn_cell.GRUCell(args.rnn_size)
                    # decoder_outputs.shape=(128, None, 256)
                    decoder_outputs, _ = tf.nn.dynamic_rnn(
                        decoder_cell,
                        decoder_inputs_embedded,
                        initial_state=decoder_init_state,
                        sequence_length=seq_length,
                        dtype=tf.float32,
                    )
                with tf.variable_scope("outputs"):
                    out_w = tf.get_variable(
                        "out_w", [out_size, args.rnn_size], tf.float32,
                        tf.random_normal_initializer(stddev=0.02))
                    out_b = tf.get_variable(
                        "out_b", [out_size],
                        tf.float32,
                        initializer=tf.constant_initializer(0.0))

                    batch_rec_loss = tf.reduce_mean(
                        decoder_mask * tf.reshape(
                            tf.nn.sampled_softmax_loss(
                                weights=out_w,
                                biases=out_b,
                                labels=tf.reshape(decoder_targets,
                                                  [-1, 1]),  # shape=(None, 1)
                                inputs=tf.reshape(
                                    decoder_outputs,
                                    [-1, args.rnn_size]),  # shape=(None, 256)
                                num_sampled=args.neg_size,
                                num_classes=out_size),
                            [args.batch_size, -1]),
                        axis=-1)
                    target_out_w = tf.nn.embedding_lookup(
                        out_w, decoder_targets)  # shape=(128, None, 256)
                    target_out_b = tf.nn.embedding_lookup(
                        out_b, decoder_targets)  # shape=(128, None)
                    batch_likelihood = tf.reduce_mean(
                        decoder_mask * tf.log_sigmoid(
                            tf.reduce_sum(decoder_outputs * target_out_w, -1) +
                            target_out_b),
                        axis=-1,
                        name="batch_likelihood")

                    batch_latent_loss = 0.5 * tf.reduce_sum(
                        att * tf.reduce_mean(
                            stack_log_sigma_sq_c + tf.exp(stack_log_sigma_sq_z)
                            / tf.exp(stack_log_sigma_sq_c) +
                            tf.square(stack_mu_z - stack_mu_c) /
                            tf.exp(stack_log_sigma_sq_c),
                            axis=-1),
                        axis=-1) - 0.5 * tf.reduce_mean(1 + log_sigma_sq_z,
                                                        axis=-1)
                    batch_cate_loss = tf.reduce_mean(
                        tf.reduce_mean(att, axis=0) *
                        tf.log(tf.reduce_mean(att, axis=0)))
                return batch_rec_loss, batch_latent_loss, batch_cate_loss, batch_likelihood

        if args.eval:
            sd_z = tf.matmul(
                tf.one_hot(tf.argmax(sd_att, axis=-1),
                           depth=args.mem_num,
                           axis=-1), mu_c)
            # sd_z = tf.matmul(tf.one_hot(tf.argmax(att-1e-10, axis=-1), depth=args.mem_num, axis=-1), mu_c)
            results = generation(sd_z)
            self.batch_likelihood = results[-1]
        else:
            results = generation(z)
            self.batch_likelihood = results[-1]
            self.rec_loss = rec_loss = tf.reduce_mean(results[0])
            self.latent_loss = latent_loss = tf.reduce_mean(results[1])
            self.cate_loss = cate_loss = results[2]

            self.sd_loss = sd_loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits_v2(labels=att,
                                                           logits=sd_logits))

            self.loss = loss = rec_loss + latent_loss + 0.1 * cate_loss
            self.pretrain_loss = pretrain_loss = rec_loss

            all_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            sd_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                        scope='sd_attention')
            cluster_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                             scope='clusters')
            vae_vars = list(set(all_vars) - set(sd_vars) - set(cluster_vars))

            self.pretrain_op = tf.train.AdamOptimizer(
                args.learning_rate).minimize(pretrain_loss, var_list=vae_vars)
            self.train_op = tf.train.AdamOptimizer(
                args.learning_rate).minimize(loss, var_list=vae_vars)
            self.sd_train_op = tf.train.AdamOptimizer(
                args.learning_rate).minimize(sd_loss, var_list=sd_vars)

        saver = tf.train.Saver(tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES),
                               max_to_keep=100)
        self.save, self.restore = saver.save, saver.restore
Esempio n. 14
0
def cross_batch_softmax(logits, cross_blocks_eq_mask, num_replicas=None):
    """Computes softmax across the whole (global) batch.

  The computations are independent with respect to the 3rd, innermost dimension.
  In case of the span prediction, the size of this dimension is K=2, which
  corresponds to beginings and ends of annotations.

  Args:
    logits: <float32>[batch_size, seq_len, K] Tensor of logits.
    cross_blocks_eq_mask: <float32>[batch_size, global_batch_size] The mask
      which indicates which samples in the batch have the same block IDs.
    num_replicas: Optional[int]. If provided the function performs computations
      over the global (multi-devices) batch. Should be equal to the number of
      devices.

  Returns:
      probs: <float32>[batch_size, seq_len, K]
  """
    # (1) Apply max-trick to improve softmax numerical stability.

    # [batch_size, K]
    max_logits_per_sample = tf.math.reduce_max(logits, axis=1)
    if num_replicas:
        # [global_batch_size, K]
        max_logits_per_sample = tpu_utils.cross_replica_concat(
            tensor=max_logits_per_sample,
            num_replicas=num_replicas,
            name='max_logits_per_sample_concat')
    # [1, global_batch_size, K]
    max_logits_per_sample = tf.expand_dims(max_logits_per_sample, 0)

    # [batch_size, global_batch_size, 1]
    one_minus_one_mask = 2 * tf.expand_dims(cross_blocks_eq_mask, 2) - 1
    # [batch_size, global_batch_size, K]
    masked_max_logits_per_sample = tf.minimum(max_logits_per_sample,
                                              one_minus_one_mask * np.inf)
    # [batch_size, K]
    max_logits_per_sample = tf.reduce_max(masked_max_logits_per_sample, axis=1)

    # [batch_size, seq_len, K]
    logits -= tf.expand_dims(max_logits_per_sample, 1)

    # (2) Take exponent
    unnormalized_probs = tf.exp(logits)

    # (3) Compute softmax's denominator (normalization constant)

    # [batch_size, K]
    softmax_denominator_per_sample = tf.math.reduce_sum(unnormalized_probs,
                                                        axis=1)
    if num_replicas:
        # [global_batch_size, K]
        softmax_denominator_per_sample = tpu_utils.cross_replica_concat(
            tensor=softmax_denominator_per_sample,
            num_replicas=num_replicas,
            name='softmax_denominator_per_sample_concat')

    # [batch_size, K]
    softmax_denominator_per_sample = tf.matmul(cross_blocks_eq_mask,
                                               softmax_denominator_per_sample)

    # (4) Compute probabilities

    # [batch_size, seq_len, K]
    probs = unnormalized_probs / tf.expand_dims(softmax_denominator_per_sample,
                                                1)
    return probs
Esempio n. 15
0
def log_sum_exp(xs):
    """Computes the log sum exp value of a tensor."""
    maxes = tf.reduce_max(xs, keep_dims=True)
    xs -= maxes
    return tf.squeeze(maxes, [-1]) + tf.log(tf.reduce_sum(tf.exp(xs), -1))
Esempio n. 16
0
def main(unused_argv):
    if not tf.gfile.IsDirectory(FLAGS.train_dir):
        tf.gfile.MakeDirs(FLAGS.train_dir)

    cfg, cfg_summary = get_named_config(FLAGS.model_cfg,
                                        FLAGS.model_cfg_overrides)
    with tf.gfile.Open(os.path.join(FLAGS.train_dir, "cfg.txt"), "w") as f:
        f.write(cfg_summary)

    # Load data
    with tf.name_scope("loader"):
        feat_dict = load_noteseqs(
            FLAGS.dataset_fp,
            cfg.train_batch_size,
            cfg.train_seq_len,
            max_discrete_times=cfg.data_max_discrete_times,
            max_discrete_velocities=cfg.data_max_discrete_velocities,
            augment_stretch_bounds=cfg.train_augment_stretch_bounds,
            augment_transpose_bounds=cfg.train_augment_transpose_bounds,
            randomize_chord_order=cfg.data_randomize_chord_order,
            repeat=True)

    # Summarize data
    tf.summary.image(
        "piano_roll",
        util.discrete_to_piano_roll(util.demidify(feat_dict["midi_pitches"]),
                                    88))

    # Build model
    with tf.variable_scope("phero_model"):
        model_dict = build_genie_model(feat_dict,
                                       cfg,
                                       cfg.train_batch_size,
                                       cfg.train_seq_len,
                                       is_training=True)

    # Summarize quantized step embeddings
    if cfg.stp_emb_vq:
        tf.summary.scalar("codebook_perplexity",
                          model_dict["stp_emb_vq_codebook_ppl"])
        tf.summary.image(
            "genie",
            util.discrete_to_piano_roll(
                model_dict["stp_emb_vq_discrete"],
                cfg.stp_emb_vq_codebook_size,
                dilation=max(1, 88 // cfg.stp_emb_vq_codebook_size)))
        tf.summary.scalar("loss_vqvae", model_dict["stp_emb_vq_loss"])

    # Summarize integer-quantized step embeddings
    if cfg.stp_emb_iq:
        tf.summary.scalar("discrete_perplexity",
                          model_dict["stp_emb_iq_discrete_ppl"])
        tf.summary.scalar("iq_valid_p", model_dict["stp_emb_iq_valid_p"])
        tf.summary.image(
            "genie",
            util.discrete_to_piano_roll(model_dict["stp_emb_iq_discrete"],
                                        cfg.stp_emb_iq_nbins,
                                        dilation=max(
                                            1, 88 // cfg.stp_emb_iq_nbins)))
        tf.summary.scalar("loss_iq_range",
                          model_dict["stp_emb_iq_range_penalty"])
        tf.summary.scalar("loss_iq_contour",
                          model_dict["stp_emb_iq_contour_penalty"])
        tf.summary.scalar("loss_iq_deviate",
                          model_dict["stp_emb_iq_deviate_penalty"])

    if cfg.stp_emb_vq or cfg.stp_emb_iq:
        tf.summary.scalar("contour_violation", model_dict["contour_violation"])
        tf.summary.scalar("deviate_violation", model_dict["deviate_violation"])

    # Summarize VAE sequence embeddings
    if cfg.seq_emb_vae:
        tf.summary.scalar("loss_kl", model_dict["seq_emb_vae_kl"])

    # Summarize output
    tf.summary.image(
        "decoder_scores",
        util.discrete_to_piano_roll(model_dict["dec_recons_scores"], 88))
    tf.summary.image(
        "decoder_preds",
        util.discrete_to_piano_roll(model_dict["dec_recons_preds"], 88))
    if cfg.dec_pred_velocity:
        tf.summary.scalar("loss_recons_velocity",
                          model_dict["dec_recons_velocity_loss"])
        tf.summary.scalar("ppl_recons_velocity",
                          tf.exp(model_dict["dec_recons_velocity_loss"]))

    # Reconstruction loss
    tf.summary.scalar("loss_recons", model_dict["dec_recons_loss"])
    tf.summary.scalar("ppl_recons", tf.exp(model_dict["dec_recons_loss"]))

    # Build hybrid loss
    loss = model_dict["dec_recons_loss"]
    if cfg.stp_emb_vq and cfg.train_loss_vq_err_scalar > 0:
        loss += (cfg.train_loss_vq_err_scalar * model_dict["stp_emb_vq_loss"])
    if cfg.stp_emb_iq and cfg.train_loss_iq_range_scalar > 0:
        loss += (cfg.train_loss_iq_range_scalar *
                 model_dict["stp_emb_iq_range_penalty"])
    if cfg.stp_emb_iq and cfg.train_loss_iq_contour_scalar > 0:
        loss += (cfg.train_loss_iq_contour_scalar *
                 model_dict["stp_emb_iq_contour_penalty"])
    if cfg.stp_emb_iq and cfg.train_loss_iq_deviate_scalar > 0:
        loss += (cfg.train_loss_iq_deviate_scalar *
                 model_dict["stp_emb_iq_deviate_penalty"])
    if cfg.seq_emb_vae and cfg.train_loss_vae_kl_scalar > 0:
        loss += (cfg.train_loss_vae_kl_scalar * model_dict["seq_emb_vae_kl"])
    if cfg.dec_pred_velocity:
        loss += model_dict["dec_recons_velocity_loss"]
    tf.summary.scalar("loss", loss)

    # Construct optimizer
    opt = tf.train.AdamOptimizer(learning_rate=cfg.train_lr)
    train_op = opt.minimize(loss,
                            global_step=tf.train.get_or_create_global_step())

    # Train
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=FLAGS.train_dir,
            save_checkpoint_secs=600,
            save_summaries_secs=FLAGS.summary_every_nsecs) as sess:
        while True:
            sess.run(train_op)
Esempio n. 17
0
def get_sampling_probability(hparams, is_training):
    """Returns the sampling probability as a tensor based on the hparams.

  Supports three sampling schedules (`hparams.sampling_schedule`):
    constant: `hparams.sampling_rate` is the sampling probability. Must be in
      the interval [0, 1].
    exponential: `hparams.sampling_rate` is the base of the decay exponential.
      Must be in the interval (0, 1). Larger values imply a slower increase in
      sampling.
    inverse_sigmoid: `hparams.sampling_rate` is in the interval [1, inf).
      Larger values imply a slower increase in sampling.

  A constant value of 0 is returned if `hparams.sampling_schedule` is undefined.

  If not training and a non-0 sampling schedule is defined, a constant value of
  1 is returned since this is assumed to be a test/eval job associated with a
  scheduled sampling trainer.

  Args:
    hparams: An HParams object containing model hyperparameters.
    is_training: Whether or not the model is being used for training.

  Raises:
    ValueError: On an invalid `sampling_schedule` or `sampling_rate` hparam.
  """
    if (not hasattr(hparams, 'sampling_schedule')
            or not hparams.sampling_schedule
            or (hparams.sampling_schedule == 'constant'
                and hparams.sampling_rate == 0)):
        return tf.constant(0.0)

    if not is_training:
        # This is likely an eval/test job associated with a training job using
        # scheduled sampling.
        tf.logging.warning(
            'Setting non-training sampling schedule from %s:%f to constant:1.0.',
            hparams.sampling_schedule, hparams.sampling_rate)
        hparams.sampling_schedule = 'constant'
        hparams.sampling_rate = 1.0

    schedule = hparams.sampling_schedule
    rate = hparams.sampling_rate
    step = tf.to_float(tf.train.get_global_step())

    if schedule == 'constant':
        if not 0 <= rate <= 1:
            raise ValueError(
                '`constant` sampling rate must be in the interval [0, 1]. Got %f.'
                % rate)
        sampling_probability = tf.to_float(rate)
    elif schedule == 'inverse_sigmoid':
        if rate < 1:
            raise ValueError(
                '`inverse_sigmoid` sampling rate must be at least 1. Got %f.' %
                rate)
        k = tf.to_float(rate)
        sampling_probability = 1.0 - k / (k + tf.exp(step / k))
    elif schedule == 'exponential':
        if not 0 < rate < 1:
            raise ValueError(
                '`exponential` sampling rate must be in the interval (0, 1). Got %f.'
                % hparams.sampling_rate)
        k = tf.to_float(rate)
        sampling_probability = 1.0 - tf.pow(k, step)
    else:
        raise ValueError('Invalid `sampling_schedule`: %s' % schedule)
    tf.summary.scalar('sampling_probability', sampling_probability)
    return sampling_probability
def coordinates_to_heatmap(y_grid,
                           x_grid,
                           y_coordinates,
                           x_coordinates,
                           sigma,
                           channel_onehot,
                           channel_weights=None):
    """Returns the heatmap targets from a set of point coordinates.

  This function maps a set of point coordinates to the output heatmap image
  applied using a Gaussian kernel. Note that this function be can used by both
  object detection and keypoint estimation tasks. For object detection, the
  "channel" refers to the object class. For keypoint estimation, the "channel"
  refers to the number of keypoint types.

  Args:
    y_grid: A 2D tensor with shape [height, width] which contains the grid
      y-coordinates given in the (output) image dimensions.
    x_grid: A 2D tensor with shape [height, width] which contains the grid
      x-coordinates given in the (output) image dimensions.
    y_coordinates: A 1D tensor with shape [num_instances] representing the
      y-coordinates of the instances in the output space coordinates.
    x_coordinates: A 1D tensor with shape [num_instances] representing the
      x-coordinates of the instances in the output space coordinates.
    sigma: A 1D tensor with shape [num_instances] representing the standard
      deviation of the Gaussian kernel to be applied to the point.
    channel_onehot: A 2D tensor with shape [num_instances, num_channels]
      representing the one-hot encoded channel labels for each point.
    channel_weights: A 1D tensor with shape [num_instances] corresponding to the
      weight of each instance.

  Returns:
    heatmap: A tensor of size [height, width, num_channels] representing the
      heatmap. Output (height, width) match the dimensions of the input grids.
  """
    num_instances, num_channels = (
        shape_utils.combined_static_and_dynamic_shape(channel_onehot))

    x_grid = tf.expand_dims(x_grid, 2)
    y_grid = tf.expand_dims(y_grid, 2)
    # The raw center coordinates in the output space.
    x_diff = x_grid - tf.math.floor(x_coordinates)
    y_diff = y_grid - tf.math.floor(y_coordinates)
    squared_distance = x_diff**2 + y_diff**2

    gaussian_map = tf.exp(-squared_distance / (2 * sigma * sigma))

    reshaped_gaussian_map = tf.expand_dims(gaussian_map, axis=-1)
    reshaped_channel_onehot = tf.reshape(channel_onehot,
                                         (1, 1, num_instances, num_channels))
    gaussian_per_box_per_class_map = (reshaped_gaussian_map *
                                      reshaped_channel_onehot)

    if channel_weights is not None:
        reshaped_weights = tf.reshape(channel_weights,
                                      (1, 1, num_instances, 1))
        gaussian_per_box_per_class_map *= reshaped_weights

    # Take maximum along the "instance" dimension so that all per-instance
    # heatmaps of the same class are merged together.
    heatmap = tf.reduce_max(gaussian_per_box_per_class_map, axis=2)

    # Maximum of an empty tensor is -inf, the following is to avoid that.
    heatmap = tf.maximum(heatmap, 0)

    return heatmap
Esempio n. 19
0
def real_svg_top(body_output,
                 unused_targets,
                 model_hparams,
                 unused_vocab_size,
                 hard=False):
    """Applies the Mixture Density Network on top of the LSTM outputs.

  Args:
    body_output: outputs from LSTM with shape [batch, seqlen, 1, hidden_size]
    unused_targets: what the ground truth SVG outputted should be (unused).
    model_hparams: hyper-parameters, should include num_mixture,
      mix_temperature, and gauss_temperature.
    unused_vocab_size: unused
    hard: whether to force predict mode functionality, or return all MDN
      components

  Returns:
    The MDN output. Could be shape [batch, seqlen, 1, 10] if in predict mode
      (or hard=True) or shape [batch, seqlen, 1, 4 + 6 * num_mix * 3], in train.
  """
    # mixture of gaussians for 6 args plus 4 extra states for cmds
    num_mix = model_hparams.num_mixture
    nout = 4 + 6 * num_mix * 3

    # the 'hard' option is meant to be used if 'top' is called within body
    with tf.variable_scope('real_top', reuse=tf.AUTO_REUSE):
        ret = tf.layers.dense(body_output, nout, name='top')
        batch_size = common_layers.shape_list(ret)[0]

        if hard or model_hparams.mode == tf.estimator.ModeKeys.PREDICT:
            temperature = model_hparams.mix_temperature

            # apply temperature, do softmax
            command = tf.identity(ret[:, :, :, :4]) / temperature
            command = tf.exp(command -
                             tf.reduce_max(command, axis=[-1], keepdims=True))
            command = command / tf.reduce_sum(
                command, axis=[-1], keepdims=True)

            # sample from the given probs, this is the same as get_pi_idx,
            # and already returns not soft prob
            command = tf.distributions.Categorical(probs=command).sample()
            # this is now [batch, seq, 1], need to make it one_hot
            command = tf.one_hot(command, 4)

            arguments = ret[:, :, :, 4:]
            # args are [batch, seq, 1, 6*3*num_mix]. want [batch * seq * 6, 3*num_mix]
            arguments = tf.reshape(arguments, [-1, 3 * num_mix])

            out_logmix, out_mean, out_logstd = _get_mdn_coef(arguments)
            # these are [batch*seq*6, num_mix]

            # apply temp to logmix
            out_logmix = tf.identity(out_logmix) / temperature
            out_logmix = tf.exp(
                out_logmix -
                tf.reduce_max(out_logmix, axis=[-1], keepdims=True))
            out_logmix = out_logmix / tf.reduce_sum(
                out_logmix, axis=[-1], keepdims=True)
            # get_pi_idx
            out_logmix = tf.distributions.Categorical(
                probs=out_logmix).sample()
            # should now be [batch*seq*6, 1]
            out_logmix = tf.cast(out_logmix, tf.int32)
            out_logmix = tf.reshape(out_logmix, [-1])
            # prepare for gather
            out_logmix = tf.stack([tf.range(tf.size(out_logmix)), out_logmix],
                                  axis=-1)

            chosen_mean = tf.gather_nd(out_mean, out_logmix)
            chosen_logstd = tf.gather_nd(out_logstd, out_logmix)

            # sample!!
            rand_gaussian = (tf.random.normal(tf.shape(chosen_mean)) *
                             tf.sqrt(model_hparams.gauss_temperature))
            arguments = chosen_mean + tf.exp(chosen_logstd) * rand_gaussian
            arguments = tf.reshape(arguments, [batch_size, -1, 1, 6])

            # concat with the command we picked!
            ret = tf.concat([command, arguments], axis=-1)

    return ret
Esempio n. 20
0
def main():
    print("Local rank: ", hvd.local_rank(), hvd.size())

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    if hvd.rank() == 0:
        if not osp.exists(logdir):
            os.makedirs(logdir)
        logger = TensorBoardOutputFormat(logdir)
    else:
        logger = None

    LABEL = None
    print("Loading data...")
    if FLAGS.dataset == 'cifar10':
        dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale)
        test_dataset = Cifar10(train=False, rescale=FLAGS.rescale)
        channel_num = 3

        X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)

        if FLAGS.large_model:
            model = ResNet32Large(num_channels=channel_num,
                                  num_filters=128,
                                  train=True)
        elif FLAGS.larger_model:
            model = ResNet32Larger(num_channels=channel_num, num_filters=128)
        elif FLAGS.wider_model:
            model = ResNet32Wider(num_channels=channel_num, num_filters=192)
        else:
            model = ResNet32(num_channels=channel_num, num_filters=128)

    elif FLAGS.dataset == 'imagenet':
        dataset = Imagenet(train=True)
        test_dataset = Imagenet(train=False)
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)

        model = ResNet32Wider(num_channels=channel_num, num_filters=256)

    elif FLAGS.dataset == 'imagenetfull':
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)

        model = ResNet128(num_channels=channel_num, num_filters=64)

    elif FLAGS.dataset == 'mnist':
        dataset = Mnist(rescale=FLAGS.rescale)
        test_dataset = dataset
        channel_num = 1
        X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)

        model = MnistNet(num_channels=channel_num,
                         num_filters=FLAGS.num_filters)

    elif FLAGS.dataset == 'dsprites':
        dataset = DSprites(cond_shape=FLAGS.cond_shape,
                           cond_size=FLAGS.cond_size,
                           cond_pos=FLAGS.cond_pos,
                           cond_rot=FLAGS.cond_rot)
        test_dataset = dataset
        channel_num = 1

        X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)

        if FLAGS.dpos_only:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.dsize_only:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
        elif FLAGS.drot_only:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.cond_size:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
        elif FLAGS.cond_shape:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
        elif FLAGS.cond_pos:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.cond_rot:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        else:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)

        model = DspritesNet(num_channels=channel_num,
                            num_filters=FLAGS.num_filters,
                            cond_size=FLAGS.cond_size,
                            cond_shape=FLAGS.cond_shape,
                            cond_pos=FLAGS.cond_pos,
                            cond_rot=FLAGS.cond_rot)

    print("Done loading...")

    if FLAGS.dataset == "imagenetfull":
        # In the case of full imagenet, use custom_tensorflow dataloader
        data_loader = TFImagenetLoader('train',
                                       FLAGS.batch_size,
                                       hvd.rank(),
                                       hvd.size(),
                                       rescale=FLAGS.rescale)
    else:
        data_loader = DataLoader(dataset,
                                 batch_size=FLAGS.batch_size,
                                 num_workers=FLAGS.data_workers,
                                 drop_last=True,
                                 shuffle=True)

    batch_size = FLAGS.batch_size

    weights = [model.construct_weights('context_0')]

    Y = tf.placeholder(shape=(None), dtype=tf.int32)

    # Varibles to run in training
    X_SPLIT = tf.split(X, FLAGS.num_gpus)
    X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
    LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
    LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
    LABEL_SPLIT_INIT = list(LABEL_SPLIT)
    tower_grads = []
    tower_gen_grads = []
    x_mod_list = []

    optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999)
    optimizer = hvd.DistributedOptimizer(optimizer)

    for j in range(FLAGS.num_gpus):

        if FLAGS.model_cclass:
            ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus
            label_tensor = tf.Variable(tf.convert_to_tensor(np.reshape(
                np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)),
                (FLAGS.batch_size * 10, 10)),
                                                            dtype=tf.float32),
                                       trainable=False,
                                       dtype=tf.float32)
            x_split = tf.tile(
                tf.reshape(X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)),
                (1, 10, 1, 1, 1))
            x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3))
            energy_pos = model.forward(x_split,
                                       weights[0],
                                       label=label_tensor,
                                       stop_at_grad=False)

            energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10))
            energy_partition_est = tf.reduce_logsumexp(energy_pos_full,
                                                       axis=1,
                                                       keepdims=True)
            uniform = tf.random_uniform(tf.shape(energy_pos_full))
            label_tensor = tf.argmax(-energy_pos_full -
                                     tf.log(-tf.log(uniform)) -
                                     energy_partition_est,
                                     axis=1)
            label = tf.one_hot(label_tensor, 10, dtype=tf.float32)
            label = tf.Print(label, [label_tensor, energy_pos_full])
            LABEL_SPLIT[j] = label
            energy_pos = tf.concat(energy_pos, axis=0)
        else:
            energy_pos = [
                model.forward(X_SPLIT[j],
                              weights[0],
                              label=LABEL_POS_SPLIT[j],
                              stop_at_grad=False)
            ]
            energy_pos = tf.concat(energy_pos, axis=0)

        print("Building graph...")
        x_mod = x_orig = X_NOISE_SPLIT[j]

        x_grads = []

        energy_negs = []
        loss_energys = []

        energy_negs.extend([
            model.forward(tf.stop_gradient(x_mod),
                          weights[0],
                          label=LABEL_SPLIT[j],
                          stop_at_grad=False,
                          reuse=True)
        ])
        eps_begin = tf.zeros(1)

        steps = tf.constant(0)
        c = lambda i, x: tf.less(i, FLAGS.num_steps)

        def langevin_step(counter, x_mod):
            x_mod = x_mod + tf.random_normal(
                tf.shape(x_mod),
                mean=0.0,
                stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale)

            energy_noise = energy_start = tf.concat([
                model.forward(x_mod,
                              weights[0],
                              label=LABEL_SPLIT[j],
                              reuse=True,
                              stop_at_grad=False,
                              stop_batch=True)
            ],
                                                    axis=0)

            x_grad, label_grad = tf.gradients(FLAGS.temperature * energy_noise,
                                              [x_mod, LABEL_SPLIT[j]])
            energy_noise_old = energy_noise

            lr = FLAGS.step_lr

            if FLAGS.proj_norm != 0.0:
                if FLAGS.proj_norm_type == 'l2':
                    x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
                elif FLAGS.proj_norm_type == 'li':
                    x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm,
                                              FLAGS.proj_norm)
                else:
                    print("Other types of projection are not supported!!!")
                    assert False

            # Clip gradient norm for now
            if FLAGS.hmc:
                # Step size should be tuned to get around 65% acceptance
                def energy(x):
                    return FLAGS.temperature * \
                        model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True)

                x_last = hmc(x_mod, 15., 10, energy)
            else:
                x_last = x_mod - (lr) * x_grad

            x_mod = x_last
            x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale)

            counter = counter + 1

            return counter, x_mod

        steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod))

        energy_eval = model.forward(x_mod,
                                    weights[0],
                                    label=LABEL_SPLIT[j],
                                    stop_at_grad=False,
                                    reuse=True)
        x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0]
        x_grads.append(x_grad)

        energy_negs.append(
            model.forward(tf.stop_gradient(x_mod),
                          weights[0],
                          label=LABEL_SPLIT[j],
                          stop_at_grad=False,
                          reuse=True))

        test_x_mod = x_mod

        temp = FLAGS.temperature

        energy_neg = energy_negs[-1]
        x_off = tf.reduce_mean(
            tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))

        loss_energy = model.forward(x_mod,
                                    weights[0],
                                    reuse=True,
                                    label=LABEL,
                                    stop_grad=True)

        print("Finished processing loop construction ...")

        target_vars = {}

        if FLAGS.cclass or FLAGS.model_cclass:
            label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
            label_prob = label_sum / tf.reduce_sum(label_sum)
            label_ent = -tf.reduce_sum(
                label_prob * tf.math.log(label_prob + 1e-7))
        else:
            label_ent = tf.zeros(1)

        target_vars['label_ent'] = label_ent

        if FLAGS.train:

            if FLAGS.objective == 'logsumexp':
                pos_term = temp * energy_pos
                energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
                coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced))
                norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
                pos_loss = tf.reduce_mean(temp * energy_pos)
                neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant
                loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
            elif FLAGS.objective == 'cd':
                pos_loss = tf.reduce_mean(temp * energy_pos)
                neg_loss = -tf.reduce_mean(temp * energy_neg)
                loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
            elif FLAGS.objective == 'softplus':
                loss_ml = FLAGS.ml_coeff * \
                    tf.nn.softplus(temp * (energy_pos - energy_neg))

            loss_total = tf.reduce_mean(loss_ml)

            if not FLAGS.zero_kl:
                loss_total = loss_total + tf.reduce_mean(loss_energy)

            loss_total = loss_total + \
                FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))))

            print("Started gradient computation...")
            gvs = optimizer.compute_gradients(loss_total)
            gvs = [(k, v) for (k, v) in gvs if k is not None]

            print("Applying gradients...")

            tower_grads.append(gvs)

            print("Finished applying gradients.")

            target_vars['loss_ml'] = loss_ml
            target_vars['total_loss'] = loss_total
            target_vars['loss_energy'] = loss_energy
            target_vars['weights'] = weights
            target_vars['gvs'] = gvs

        target_vars['X'] = X
        target_vars['Y'] = Y
        target_vars['LABEL'] = LABEL
        target_vars['LABEL_POS'] = LABEL_POS
        target_vars['X_NOISE'] = X_NOISE
        target_vars['energy_pos'] = energy_pos
        target_vars['energy_start'] = energy_negs[0]

        if len(x_grads) >= 1:
            target_vars['x_grad'] = x_grads[-1]
            target_vars['x_grad_first'] = x_grads[0]
        else:
            target_vars['x_grad'] = tf.zeros(1)
            target_vars['x_grad_first'] = tf.zeros(1)

        target_vars['x_mod'] = x_mod
        target_vars['x_off'] = x_off
        target_vars['temp'] = temp
        target_vars['energy_neg'] = energy_neg
        target_vars['test_x_mod'] = test_x_mod
        target_vars['eps_begin'] = eps_begin

    if FLAGS.train:
        grads = average_gradients(tower_grads)
        train_op = optimizer.apply_gradients(grads)
        target_vars['train_op'] = train_op

    config = tf.ConfigProto()

    if hvd.size() > 1:
        config.gpu_options.visible_device_list = str(hvd.local_rank())

    sess = tf.Session(config=config)

    saver = loader = tf.train.Saver(max_to_keep=30,
                                    keep_checkpoint_every_n_hours=6)

    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Model has a total of {} parameters".format(total_parameters))

    sess.run(tf.global_variables_initializer())

    resume_itr = 0

    if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0:
        model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter
        # saver.restore(sess, model_file)
        optimistic_restore(sess, model_file)

    sess.run(hvd.broadcast_global_variables(0))
    print("Initializing variables...")

    print("Start broadcast")
    print("End broadcast")

    if FLAGS.train:
        print("Training phase")
        train(target_vars, saver, sess, logger, data_loader, resume_itr,
              logdir)
    print("Testing phase")
    test(target_vars, saver, sess, logger, data_loader)
 def objective(self, params, data=None, labels=None):
     x, y = tf.split(params[0], 2, axis=0)
     obj = tf.log(
         tf.exp(x + 3. * y - 0.1) + tf.exp(x - 3. * y - 0.1) +
         tf.exp(-x - 0.1) + 1.0)
     return tf.squeeze(obj)
Esempio n. 22
0
def focal_loss(logits, targets, alpha, gamma, normalizer):
    """Compute the focal loss between `logits` and the golden `target` values.

  Focal loss = -(1-pt)^gamma * log(pt)
  where pt is the probability of being classified to the true class.

  Args:
    logits: A float32 tensor of size
      [batch, height_in, width_in, num_predictions].
    targets: A float32 tensor of size
      [batch, height_in, width_in, num_predictions].
    alpha: A float32 scalar multiplying alpha to the loss from positive examples
      and (1-alpha) to the loss from negative examples.
    gamma: A float32 scalar modulating loss from hard and easy examples.
    normalizer: A float32 scalar normalizes the total loss from all examples.

  Returns:
    loss: A float32 Tensor of size [batch, height_in, width_in, num_predictions]
      representing normalized loss on the prediction map.
  """
    with tf.name_scope('focal_loss'):
        positive_label_mask = tf.equal(targets, 1.0)
        cross_entropy = (tf.nn.sigmoid_cross_entropy_with_logits(
            labels=targets, logits=logits))
        # Below are comments/derivations for computing modulator.
        # For brevity, let x = logits,  z = targets, r = gamma, and p_t = sigmod(x)
        # for positive samples and 1 - sigmoid(x) for negative examples.
        #
        # The modulator, defined as (1 - P_t)^r, is a critical part in focal loss
        # computation. For r > 0, it puts more weights on hard examples, and less
        # weights on easier ones. However if it is directly computed as (1 - P_t)^r,
        # its back-propagation is not stable when r < 1. The implementation here
        # resolves the issue.
        #
        # For positive samples (labels being 1),
        #    (1 - p_t)^r
        #  = (1 - sigmoid(x))^r
        #  = (1 - (1 / (1 + exp(-x))))^r
        #  = (exp(-x) / (1 + exp(-x)))^r
        #  = exp(log((exp(-x) / (1 + exp(-x)))^r))
        #  = exp(r * log(exp(-x)) - r * log(1 + exp(-x)))
        #  = exp(- r * x - r * log(1 + exp(-x)))
        #
        # For negative samples (labels being 0),
        #    (1 - p_t)^r
        #  = (sigmoid(x))^r
        #  = (1 / (1 + exp(-x)))^r
        #  = exp(log((1 / (1 + exp(-x)))^r))
        #  = exp(-r * log(1 + exp(-x)))
        #
        # Therefore one unified form for positive (z = 1) and negative (z = 0)
        # samples is:
        #      (1 - p_t)^r = exp(-r * z * x - r * log(1 + exp(-x))).
        neg_logits = -1.0 * logits
        modulator = tf.exp(gamma * targets * neg_logits -
                           gamma * tf.log1p(tf.exp(neg_logits)))
        loss = modulator * cross_entropy
        weighted_loss = tf.where(positive_label_mask, alpha * loss,
                                 (1.0 - alpha) * loss)
        weighted_loss /= normalizer
    return weighted_loss
def exp2(x):
    with tf.name_scope('Exp2'):
        return tf.exp(x * np.float32(np.log(2.0)))
Esempio n. 24
0
def ppo_policy_loss(neg_logprobs_old,
                    actions,
                    advantages,
                    dist_new,
                    policy_gradient_enable=False,
                    mcts_sampling=False,
                    clipping_coeff=0.2,
                    mcts_clipping_coeff=0.9,
                    tanh_action_clipping=False):
    """Use the formula in PPO baseline for calculating policy loss.

  paper: https://arxiv.org/abs/1707.06347

  Args:
    neg_logprobs_old: old negative log of probability.
    actions: actions from old policy.
    advantages: advantages from old policy.
    dist_new: the latest trained policy distribution.
    policy_gradient_enable: if True, vanilla policy gradient with advantage
      is used.
    mcts_sampling: If True, the data samples are generated with MCTS sampling.
    clipping_coeff: the coefficient used to clip the probability ratio.
    mcts_clipping_coeff: the coefficient used to clip the probability ration,
      when the data are sampled using MCTS.
    tanh_action_clipping: if True, performs tanh action clipping. Enabling tanh
      action clipping bound the actions to [-1, 1].
      Paper --> https://arxiv.org/pdf/1801.01290.pdf

  Returns:
    policy_loss: policy loss.
  """
    neg_logprobs_new = dist_new.negative_log_prob(actions)

    current_clipping_coeff = tf.cond(tf.equal(mcts_sampling, True),
                                     lambda: tf.constant(mcts_clipping_coeff),
                                     lambda: tf.constant(clipping_coeff))

    # Calculate correction for logprob if tanh clipping is enabled
    # A mechanism for clipping the actions between [-1., 1.]
    # paper: https://arxiv.org/pdf/1801.01290.pdf
    if tanh_action_clipping:
        logprobs_correction = tf.reduce_sum(tf.log(1 - tf.tanh(actions)**2 +
                                                   1e-6),
                                            axis=1)
        neg_logprobs_new = neg_logprobs_new + logprobs_correction

    p_ratio = tf.exp(neg_logprobs_old - neg_logprobs_new, name='ratio')

    if policy_gradient_enable:
        pg_losses = advantages * neg_logprobs_new
        pg_loss = tf.reduce_mean(pg_losses, name='policy_loss')
    else:  # using PPO formulat to calculate policy loss
        # Defining Loss = - J is equivalent to max J
        pg_losses = -advantages * p_ratio
        pg_losses2 = -advantages * tf.clip_by_value(
            p_ratio, 1. - current_clipping_coeff, 1. + current_clipping_coeff)
        pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2),
                                 name='policy_loss')
    # KL between new and old policy
    approxkl = .5 * tf.reduce_mean(
        tf.square(neg_logprobs_new - neg_logprobs_old))
    # Which fraction of policy ratios get clipped
    clipfrac = tf.reduce_mean(
        tf.to_float(tf.greater(tf.abs(p_ratio - 1.), current_clipping_coeff)))

    return pg_loss, approxkl, clipfrac, p_ratio
Esempio n. 25
0
        def _body(i, posterior, center, wx, activation_biases, sigma_biases,
                  input_activation, tile_filter):
            """Body of EM while loop."""
            tf.logging.info('  Wx: %s', wx)

            beta = final_beta * (1 - tf.pow(0.95, tf.cast(i + 1, tf.float32)))

            posterior = tf.Print(posterior, [
                layer_name, i, h, ih,
                tf.reduce_min(posterior),
                tf.reduce_max(posterior)
            ],
                                 message='posterior')
            # route: [outdim, height?, width?, batch, indim]
            with tf.name_scope('vote_conf'):
                vote_conf = posterior * input_activation
                vote_conf = tf.maximum(vote_conf, 0.0)

            # masses: [batch, 1, outdim, 1, height, width, 1, 1]
            with tf.name_scope('masses'):
                masses = tf.reduce_sum(vote_conf,
                                       axis=[1, -1, -2],
                                       keepdims=True,
                                       name='masses_calculation') + 0.0000001
            with tf.name_scope('preactivate_unrolled'):
                preactivate_unrolled = vote_conf * wx

            # center: [batch, 1, outdim, outatom, height, width]
            with tf.name_scope('center'):
                center = .9 * tf.reduce_sum(
                    preactivate_unrolled, axis=[1, -1, -2],
                    keepdims=True) / masses + .1 * center

            # Rematerialization to save GPU memory. (+22ms/-1.6GB)
            # @tf.contrib.layers.recompute_grad
            def compute_noise_and_variance(wx, center, vote_conf, masses):
                noise = tf.squared_difference(wx, center)
                variance = min_var + tf.reduce_sum(
                    vote_conf * noise,
                    axis=[1, -1, -2],
                    keepdims=True,
                    name='variance_calculation') / masses
                return noise, variance

            with tf.name_scope('compute_noise_and_variance'):
                noise, variance = compute_noise_and_variance(
                    wx, center, vote_conf, masses)

            with tf.name_scope('win'):
                log_variance = tf.log(variance)
                p_i = -1 * tf.reduce_sum(log_variance, axis=3, keepdims=True)
                log_2pi = tf.log(2 * math.pi)
                sigma_b = tf.log(sigma_biases * sigma_biases + min_var)
                win = masses * (p_i - num_out_atoms *
                                (sigma_b + log_2pi + 1.0))
            with tf.name_scope('logit'):
                logit = beta * (win - activation_biases * 50 * num_out_atoms)
            with tf.name_scope('activation_update'):
                activation_update = tf.minimum(
                    0.0, logit) - tf.log(1 + tf.exp(-tf.abs(logit)))
            with tf.name_scope('sigma_update'):
                log_det_sigma = -1 * p_i
                sigma_update = (num_out_atoms * log_2pi + log_det_sigma) / 2.0
            with tf.name_scope('exp_update'):
                exp_update = tf.reduce_sum(noise / (2 * variance),
                                           axis=3,
                                           keep_dims=True)
            prior_update = tf.subtract(activation_update - sigma_update,
                                       exp_update,
                                       name='prior_update_sub')
            max_prior_update = tf.reduce_max(prior_update,
                                             axis=[2, 3, 4, 5, 6, 7],
                                             keepdims=True,
                                             name='max_prior_opdate')
            prior_normal = tf.add(prior_update, -1 * max_prior_update)
            prior_exp = tf.exp(prior_normal)
            prior_exp_out = tf.reduce_sum(prior_exp,
                                          axis=2,
                                          keepdims=True,
                                          name='prior_exp_out')
            prior_exp_reshape = tf.reshape(prior_exp_out, [-1, h, h, k * k],
                                           name='prior_exp_reshape')

            sum_prior = tf.nn.conv2d_transpose(prior_exp_reshape,
                                               tile_filter,
                                               output_shape=[b * c, ih, ih, 1],
                                               strides=[1, s, s, 1],
                                               padding='VALID')
            sum_prior = tf.maximum(1e-6, sum_prior)

            sum_prior_patch = utils.kernel_tile(sum_prior,
                                                k,
                                                s,
                                                1,
                                                name='sum_prior_patch')

            with utils.maybe_jit_scope(), tf.name_scope('posterior'):
                sum_prior_reshape = tf.reshape(
                    sum_prior_patch, [-1, input_dim, 1, 1, h, h, k, k])
                posterior = prior_exp / sum_prior_reshape

            return (i + 1, posterior, logit, center, masses)
Esempio n. 26
0
centers = tf.cast(tf.lin_space(rbf_low, rbf_high, rbf_count), FLOAT_TYPE)

# In[23]:

# r : [N, 3]
r = tf.placeholder(FLOAT_TYPE, shape=(4, 3))

# rij : [N, N, 3]
rij = utils.difference_matrix(r)

# dij : [N, N]
dij = utils.distance_matrix(r)

# rbf : [N, N, rbf_count]
gamma = 1. / rbf_spacing
rbf = tf.exp(-gamma * tf.square(tf.expand_dims(dij, axis=-1) - centers))

layer_dims = [1, 4, 4, 4]
num_layers = len(layer_dims) - 1

# embed : [N, layer1_dim, 1]
with tf.variable_scope(None, "embed"):
    embed = layers.self_interaction_layer_without_biases(
        tf.ones(shape=(4, 1, 1)), layer_dims[0])

input_tensor_list = {0: [embed]}

for layer, layer_dim in enumerate(layer_dims[1:]):
    with tf.variable_scope(None,
                           'layer' + str(layer),
                           values=[input_tensor_list]):
Esempio n. 27
0
def geometry_augment(images, p=1, PADDING="REFLECT"):
    B, real_H, real_W, C = images.shape.as_list()
    #pad_size=real_H//2
    #mirror_pad_images=tf.pad(images, [[0,0],[pad_size, pad_size],[pad_size, pad_size], [0, 0]],'REFLECT') #gradient wrt the padded region is unimplemented on TPUs. Constant padding can result in (possibly non-square) cutout regions at the margins of the final output, which may be desirable.
    #constant_pad_images=tf.pad(images, [[0,0],[pad_size, pad_size],[pad_size, pad_size], [0, 0]],'CONSTANT')
    #images=constant_pad_images#tf.stop_gradient(mirror_pad_images-constant_pad_images)+constant_pad_images
    #images=tf.compat.v1.image.resize(images, [real_H*4,real_W*4])
    #images=tf.stop_gradient(mirror_pad_images-constant_pad_images)+constant_pad_images
    images = upsample(images)
    B, H, W, C = images.shape.as_list()
    DIM = H
    XDIM = DIM % 2  #fix for size 331
    m = tf.reshape(tf.eye(3), [1, 3, 3])
    #Transform the doubled coordinates back to the original coordinates (multiply by S^-1)
    m = get_scale_mat(tf.ones([B]) / 2, tf.ones([B]) / 2) @ m
    #Horizontal flip
    toss = tf.cast(tf.random.uniform([B]) < p, tf.float32)
    m = get_scale_mat(tf.ones([B]), tf.ones([B]) - 2 * toss) @ m
    #90-degree rotation
    toss = tf.reshape(
        tf.random.categorical(
            tf.tile(tf.math.log([[1 - p + p / 4, p / 4, p / 4, p / 4]]),
                    [B, 1]), 1), [B])
    rad1 = tf.cast(toss, tf.float32) * np.pi / 2
    m = get_rot_mat(rad1) @ m
    #Integer translation
    toss = tf.cast(tf.random.uniform([B]) < p, tf.float32)
    x_offset = tf.cast(
        tf.round(toss * tf.random.uniform([B], -.125, .125) * real_W),
        tf.float32)
    y_offset = tf.cast(
        tf.round(toss * tf.random.uniform([B], -.125, .125) * real_H),
        tf.float32)
    m = get_shift_mat(y_offset, x_offset) @ m
    #Isotropic scaling
    toss = tf.cast(tf.random.uniform([B]) < p, tf.float32)
    scale = tf.exp(toss * tf.random.normal([B], 0, 0.2 * tf.math.log(2.)))
    m = get_scale_mat(scale, scale) @ m
    #Pre-rotation
    p_rot = 1. - tf.sqrt(1. - p)
    toss = tf.cast(tf.random.uniform([B]) < p_rot, tf.float32)
    rad2 = toss * tf.random.uniform([B], -np.pi, np.pi)
    m = get_rot_mat(rad2) @ m
    #Anisotropic scaling
    toss = tf.cast(tf.random.uniform([B]) < p, tf.float32)
    x_scale = tf.exp(toss * tf.random.normal([B], 0, 0.2 * tf.math.log(2.)))
    y_scale = 1. / x_scale
    m = get_scale_mat(y_scale, x_scale) @ m
    #Post-rotation
    p_rot = 1. - tf.sqrt(1. - p)
    toss = tf.cast(tf.reshape(tf.random.uniform([B]) < p_rot, [B]), tf.float32)
    rad3 = toss * tf.random.uniform([B], -np.pi, np.pi)
    m = get_rot_mat(rad3) @ m
    #Fractional translation
    toss = tf.cast(tf.random.uniform([B]) < p, tf.float32)
    x_offset = toss * tf.random.normal([B], 0, .125) * real_W
    y_offset = toss * tf.random.normal([B], 0, .125) * real_H
    m = get_shift_mat(y_offset, x_offset) @ m
    # Transform to the coordinates of the upsampled images (multiply by S)
    m = get_scale_mat(tf.ones([B]) * 2, tf.ones([B]) * 2) @ m

    # LIST DESTINATION PIXEL INDICES
    #x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM ) # TF1.15 or above
    x = tf.reshape(tf.transpose([tf.range(DIM // 2, -DIM // 2, -1)] * DIM),
                   [-1])
    y = tf.tile(tf.range(-DIM // 2, DIM // 2), [DIM])
    z = tf.ones([DIM * DIM], dtype='int32')
    idx = tf.stack([x, y, z])

    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = tf.linalg.matmul(m, tf.cast(
        idx, dtype='float32'))  # shape = (batch_size, 3, DIM ** 2)
    idx2 = K.cast(idx2, dtype='int32')  # shape = (batch_size, 3, DIM ** 2)
    if PADDING == "REFLECT":  #TPU compatible reflection padding, based on https://www.kaggle.com/psaikko/augmentations-with-reflect-padding
        idx2 = tf.stack([DIM // 2 - idx2[:, 0, ], DIM // 2 + idx2[:, 1, ]],
                        axis=1)
        # Identify out-of-bounds positions
        bounds_mask = tf.math.logical_or(tf.math.less(idx2, 0),
                                         tf.math.greater(idx2, DIM - 1))
        # Compute mirrored positions
        mirror_idxs = tf.math.subtract(DIM - 1,
                                       tf.math.floormod(idx2, DIM - 1))
        idx2 = tf.where(bounds_mask, mirror_idxs, idx2)
        idx3 = tf.stack([idx2[:, 0, ], idx2[:, 1, ]], axis=1)
    else:
        #idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
        idx3 = tf.stack([DIM // 2 - idx2[:, 0, ], DIM // 2 + idx2[:, 1, ]],
                        axis=1)
    # shape = (batch_size, DIM ** 2, 3)
    d = tf.gather_nd(images, tf.transpose(idx3, perm=[0, 2, 1]), batch_dims=1)
    d = tf.reshape(d, (B, DIM, DIM, 3))
    #d=tf.compat.v1.image.resize(d, [real_H*2,real_W*2])
    d = tf.nn.avg_pool(d,
                       ksize=[1, 2, 2, 1],
                       strides=[1, 2, 2, 1],
                       padding='SAME',
                       data_format='NHWC')
    # shape = (batch_size, DIM, DIM, 3)
    images = tf.reshape(d, (B, DIM // 2, DIM // 2, 3))
    #images=images[:,(real_H)//2:(3*real_H)//2,(real_W)//2:(3*real_W)//2,:]
    return images
Esempio n. 28
0
    def __init__(self,
                 sess,
                 model,
                 batch_size=1,
                 confidence=CONFIDENCE,
                 targeted=False,
                 learning_rate=LEARNING_RATE,
                 binary_search_steps=BINARY_SEARCH_STEPS,
                 max_iterations=MAX_ITERATIONS,
                 abort_early=ABORT_EARLY,
                 initial_const=INITIAL_CONST,
                 boxmin=0,
                 boxmax=1,
                 epsilon=0.3):

        image_size, num_channels, num_labels = model.image_size, model.num_channels, model.num_labels
        self.sess = sess
        self.TARGETED = targeted
        self.LEARNING_RATE = learning_rate
        self.MAX_ITERATIONS = max_iterations
        self.BINARY_SEARCH_STEPS = binary_search_steps
        self.ABORT_EARLY = abort_early
        self.CONFIDENCE = confidence
        self.initial_const = initial_const
        self.batch_size = batch_size

        self.repeat = binary_search_steps >= 10

        self.I_KNOW_WHAT_I_AM_DOING_AND_WANT_TO_OVERRIDE_THE_PRESOFTMAX_CHECK = False

        shape = (batch_size, image_size, image_size, num_channels)

        self.uninitialized_vars = []
        for var in tf.all_variables():
            try:
                self.sess.run(var)
            except tf.errors.FailedPreconditionError:
                self.uninitialized_vars.append(var)

        # these are variables to be more efficient in sending data to tf
        self.timg = tf.Variable(np.zeros(shape), dtype=tf.float32)
        self.tlab = tf.Variable(np.zeros((batch_size, num_labels)),
                                dtype=tf.float32)
        self.const = tf.Variable(np.ones(batch_size), dtype=tf.float32)

        # and here's what we use to assign them
        self.assign_timg = tf.placeholder(tf.float32, shape)
        self.assign_tlab = tf.placeholder(tf.float32, (batch_size, num_labels))
        self.assign_const = tf.placeholder(tf.float32, [batch_size])

        # the variable we're going to optimize over
        #modifier = tf.Variable(np.zeros(shape,dtype=np.float32),name='modifier')
        modifier = tf.Variable(np.random.uniform(-epsilon, epsilon,
                                                 shape).astype('float32'),
                               name='modifier')
        self.modifier = tf.get_variable(
            'modifier',
            shape,
            trainable=True,
            constraint=lambda x: tf.clip_by_value(x, -epsilon, epsilon))

        # the resulting image, tanh'd to keep bounded from boxmin to boxmax
        self.boxmul = (boxmax - boxmin) / 2.
        self.boxplus = (boxmin + boxmax) / 2.
        self.newimg = tf.clip_by_value(self.modifier + self.timg, 0, 1)
        '''
        matrix = tf.random_normal([500,image_size*image_size*num_channels,128],0.,1.0/tf.sqrt(128.))
        
        self.x_batch = standardized( tf.keras.backend.dot(tf.reshape(self.timg,[image_size*image_size*num_channels]),matrix))
        self.y_batch = standardized( tf.keras.backend.dot(tf.reshape(self.newimg,[image_size*image_size*num_channels]),matrix))
        '''

        tconv = tf.transpose(model.conv1(self.timg), perm=[3, 1, 2, 0])
        nconv = tf.transpose(model.conv1(self.newimg), perm=[3, 1, 2, 0])
        T_xy, T_x_y = MiNetwork(tconv, nconv)
        #T_xy , T_x_y = MiNetwork(self.x_batch,self.y_batch)
        self.MI = tf.reduce_mean(T_xy, axis=0) - tf.log(
            tf.reduce_mean(tf.exp(T_x_y)))

        real = model.predict(self.timg, self.tlab)
        self.real = real
        #fake = tf.reduce_max(tf.abs(self.timg - self.recon))
        fake = model.predict(self.newimg, self.tlab)
        self.fake = fake
        self.loss1 = tf.maximum(0.0, fake - real)

        # sum up the losses
        self.loss1_1 = tf.reduce_sum(self.const * self.loss1)

        #self.loss = self.loss1_1 - self.MI
        self.loss = self.loss1_1 + self.MI

        # Setup the adam optimizer and keep track of variables we're creating
        start_vars = set(x.name for x in tf.global_variables())
        m_var = [var for var in tf.global_variables() if 'M_' in var.name]
        #optimizer = tf.train.AdamOptimizer(self.LEARNING_RATE)
        optimizer = tf.train.GradientDescentOptimizer(self.LEARNING_RATE)
        self.mi_train = tf.train.AdamOptimizer(0.000015).minimize(
            -self.MI, var_list=m_var)
        self.train = optimizer.minimize(self.loss, var_list=[self.modifier])
        end_vars = tf.global_variables()
        new_vars = [x for x in end_vars if x.name not in start_vars]

        # these are the variables to initialize when we run
        self.setup = []
        self.lamda = []
        self.setup.append(self.timg.assign(self.assign_timg))
        self.setup.append(self.tlab.assign(self.assign_tlab))
        self.lamda.append(self.const.assign(self.assign_const))

        self.init = tf.variables_initializer(var_list=[self.modifier] +
                                             new_vars)
        self.mi_init = tf.variables_initializer(var_list=m_var)
Esempio n. 29
0
 def forward(self, weighted_input):
     return 2.0 / (1.0 + tf.exp(-2 * weighted_input)) - 1.0
Esempio n. 30
0
    def __init__(self, args):
        self.args = args
        dense = tf.layers.dense

        # inputs/mask.shape=(128, None)  'None' in shape means any number  seq_length.shape=(128,)
        inputs = tf.placeholder(shape=(args.batch_size, None),
                                dtype=tf.int32,
                                name='inputs')
        time_inputs = tf.placeholder(shape=(args.batch_size, None),
                                     dtype=tf.int32,
                                     name='time_inputs')
        mask = tf.placeholder(shape=(args.batch_size, None),
                              dtype=tf.float32,
                              name='inputs_mask')
        seq_length = tf.placeholder(shape=args.batch_size,
                                    dtype=tf.float32,
                                    name='seq_length')

        self.input_form = [inputs, time_inputs, mask, seq_length]

        # all shape=(128, None)
        encoder_inputs = inputs
        decoder_inputs = tf.concat(
            [tf.zeros(shape=(args.batch_size, 1), dtype=tf.int32), inputs],
            axis=1)
        decoder_targets = tf.concat(
            [inputs,
             tf.zeros(shape=(args.batch_size, 1), dtype=tf.int32)],
            axis=1)
        decoder_mask = tf.concat(
            [mask,
             tf.zeros(shape=(args.batch_size, 1), dtype=tf.float32)],
            axis=1)

        x_size = out_size = args.map_size[0] * args.map_size[1]
        # embeddings.shape=(16900, 32)  tf.random_uniform(shape, minval=0, maxval=None, ...)
        # x_latent_size is the input embedding size = 32
        embeddings = tf.Variable(tf.random_uniform(
            [x_size, args.x_latent_size], -1.0, 1.0),
                                 dtype=tf.float32)
        # tf.nn.embedding_lookup(params, ids, ...)  Looks up ids in a list of embedding tensors.
        # shape=(128, None, 32)
        encoder_inputs_embedded = tf.nn.embedding_lookup(
            embeddings, encoder_inputs)
        decoder_inputs_embedded = tf.nn.embedding_lookup(
            embeddings, decoder_inputs)

        time_embeddings = tf.Variable(tf.random_uniform(
            [49, args.x_latent_size], -1.0, 1.0),
                                      dtype=tf.float32)
        encoder_time_inputs_embedded = tf.nn.embedding_lookup(
            time_embeddings, time_inputs)

        time_mean = tf.reduce_mean(encoder_time_inputs_embedded, axis=1)
        mu_delta = dense(time_mean, args.rnn_size, activation=None)
        log_sigma_sq_delta = dense(time_mean, args.rnn_size, activation=None)

        with tf.variable_scope("encoder"):
            # create a GRUCell  output_size = state_size = 256
            encoder_cell = tf.nn.rnn_cell.GRUCell(args.rnn_size)

            # tf.compat.v1.nn.dynamic_rnn(cell, inputs, ...) = keras.layers.RNN(cell)
            # returns (outputs, state)
            # 'outputs' is a tensor of shape [batch_size, max_time, cell_output_size]
            # 'state' is a tensor of shape [batch_size, cell_state_size] = (128, 256)
            _, encoder_final_state = tf.nn.dynamic_rnn(
                encoder_cell,
                encoder_inputs_embedded,
                sequence_length=seq_length,
                dtype=tf.float32,
            )

        # tf.compat.v1.get_variable(name, shape=None, dtype=None,
        #                           initializer=None, ...)
        mu_w = tf.get_variable("mu_w", [args.rnn_size, args.rnn_size],
                               tf.float32,
                               tf.random_normal_initializer(stddev=0.02))
        mu_b = tf.get_variable("mu_b", [args.rnn_size], tf.float32,
                               tf.constant_initializer(0.0))
        sigma_w = tf.get_variable("sigma_w", [args.rnn_size, args.rnn_size],
                                  tf.float32,
                                  tf.random_normal_initializer(stddev=0.02))
        sigma_b = tf.get_variable("sigma_b", [args.rnn_size], tf.float32,
                                  tf.constant_initializer(0.0))

        # all shape=(128, 256)
        mu = tf.matmul(encoder_final_state, mu_w) + mu_b + mu_delta
        log_sigma_sq = tf.matmul(encoder_final_state,
                                 sigma_w) + sigma_b + log_sigma_sq_delta
        eps = tf.random_normal(shape=tf.shape(log_sigma_sq),
                               mean=0,
                               stddev=1,
                               dtype=tf.float32)

        if args.eval:
            # z = tf.zeros(shape=(args.batch_size, args.rnn_size), dtype=tf.float32)
            z = mu_delta
        else:
            # Re-parameterization trick
            z = mu + tf.sqrt(tf.exp(log_sigma_sq)) * eps

        self.batch_post_embedded = z

        with tf.variable_scope("decoder"):
            decoder_cell = tf.nn.rnn_cell.GRUCell(args.rnn_size)
            decoder_init_state = z
            decoder_outputs, _ = tf.nn.dynamic_rnn(
                decoder_cell,
                decoder_inputs_embedded,
                initial_state=decoder_init_state,
                sequence_length=seq_length,
                dtype=tf.float32,
            )

        # out_size = 16900
        out_w = tf.get_variable("out_w", [out_size, args.rnn_size], tf.float32,
                                tf.random_normal_initializer(stddev=0.02))
        out_b = tf.get_variable("out_b", [out_size], tf.float32,
                                tf.constant_initializer(0.0))
        # tf.reduce_mean(input_tensor, axis=None, ...)  Reduces input_tensor to mean value along the given axis.
        # tf.reshape(tensor, shape, name=None)  Reshape the tensor into given shape, -1 indicates calculated value.
        # tf.nn.sampled_softmax_loss()  A fast way to train softmax classifier, usually an underestimate (for training only).
        batch_rec_loss = tf.reduce_mean(
            decoder_mask * tf.reshape(
                tf.nn.sampled_softmax_loss(
                    weights=out_w,
                    biases=out_b,
                    labels=tf.reshape(decoder_targets, [-1, 1]),
                    inputs=tf.reshape(decoder_outputs, [-1, args.rnn_size]),
                    num_sampled=args.neg_size,
                    num_classes=out_size), [args.batch_size, -1]),
            axis=-1  # reduce to mean along the last dimension
        )
        batch_latent_loss = -0.5 * tf.reduce_sum(
            1 + log_sigma_sq - tf.square(mu) - tf.exp(log_sigma_sq), axis=1)

        self.rec_loss = rec_loss = tf.reduce_mean(batch_rec_loss)
        self.latent_loss = latent_loss = tf.reduce_mean(batch_latent_loss)

        self.loss = loss = tf.reduce_mean([rec_loss, latent_loss])
        self.train_op = tf.train.AdamOptimizer(
            args.learning_rate).minimize(loss)

        target_out_w = tf.nn.embedding_lookup(out_w, decoder_targets)
        target_out_b = tf.nn.embedding_lookup(out_b, decoder_targets)

        self.batch_likelihood = tf.reduce_mean(decoder_mask * tf.log_sigmoid(
            tf.reduce_sum(decoder_outputs * target_out_w, -1) + target_out_b),
                                               axis=-1,
                                               name="batch_likelihood")

        # save/restore variables to/from checkpoints, max_to_keep = max #recent checkpoint files to keep.
        saver = tf.train.Saver(tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES),
                               max_to_keep=10)
        self.save, self.restore = saver.save, saver.restore