def encoder(self, batch, sequence_lengths):
        tf.logging.info('model-encoder部分')
        unused_outputs, last_states = tf.nn.bidirectional_dynamic_rnn(
            self.enc_cell_fw,
            self.enc_cell_bw,
            batch,
            sequence_length=sequence_lengths,
            time_major=False,
            swap_memory=True,
            dtype=tf.float32,
            scope='ENC_RNN')

        last_state_fw, last_state_bw = last_states
        last_h_fw = self.enc_cell_fw.get_output(
            last_state_fw)  #得到正向RNN的最后的隐层输出
        last_h_bw = self.enc_cell_bw.get_output(
            last_state_bw)  #得到反向RNN的最后的隐层输出
        last_h = tf.concat([last_h_fw, last_h_bw],
                           1)  #bi-direction rnn 的最后隐层结果 h
        #===============================================================================
        mu = rnn.super_linear(
            last_h,
            self.hps.z_size,
            input_size=self.hps.enc_rnn_size * 2,  # bi-dir, so x2
            scope='ENC_RNN_mu',
            init_w='gaussian',
            weight_start=0.001)
        presig = rnn.super_linear(
            last_h,
            self.hps.z_size,
            input_size=self.hps.enc_rnn_size * 2,  # bi-dir, so x2
            scope='ENC_RNN_sigma',
            init_w='gaussian',
            weight_start=0.001)
        return mu, presig  #128维的encoder最终输出
示例#2
0
    def cnn_encoder(self, img_x):
        with tf.variable_scope('ENC_CNN'):
            # high-pass filter
            x = self.high_pass_filtering(img_x)  # [N, 48, 48, 1]

            # 6 conv layers
            x = self.conv_2d('conv1', x, filter_size=2, out_filters=4, strides=self.stride_arr(2))  # [N, 24, 24, 4]
            x = tf.nn.relu(x)
            x = self.conv_2d('conv2', x, filter_size=2, out_filters=4, strides=self.stride_arr(1))  # [N, 24, 24, 4]
            x = tf.nn.relu(x)
            x = self.conv_2d('conv3', x, filter_size=2, out_filters=8, strides=self.stride_arr(2))  # [N, 12, 12, 8]
            x = tf.nn.relu(x)
            x = self.conv_2d('conv4', x, filter_size=2, out_filters=8, strides=self.stride_arr(1))  # [N, 12, 12, 8]
            x = tf.nn.relu(x)
            x = self.conv_2d('conv5', x, filter_size=2, out_filters=8, strides=self.stride_arr(2))  # [N, 6, 6, 8]
            x = tf.nn.relu(x)
            x = self.conv_2d('conv6', x, filter_size=2, out_filters=8, strides=self.stride_arr(1))  # [N, 6, 6, 8]
            x = tf.tanh(x)

            x = tf.reshape(x, shape=[x.shape[0], -1])  # [N, 6 * 6 * 8]

            mu = rnn.super_linear(
                x,
                self.hps.z_size,
                scope='ENC_CNN_mu',
                init_w='gaussian',
                weight_start=0.001)
            presig = rnn.super_linear(
                x,
                self.hps.z_size,
                scope='ENC_CNN_sigma',
                init_w='gaussian',
                weight_start=0.001)
            return mu, presig
示例#3
0
    def encoder(self, batch, sequence_lengths):
        """Define the bi-directional encoder module of sketch-rnn."""
        unused_outputs, last_states = tf.nn.bidirectional_dynamic_rnn(
            self.enc_cell_fw,
            self.enc_cell_bw,
            batch,
            sequence_length=sequence_lengths,
            time_major=False,
            swap_memory=True,
            dtype=tf.float32,
            scope='ENC_RNN')

        last_state_fw, last_state_bw = last_states
        last_h_fw = self.enc_cell_fw.get_output(last_state_fw)
        last_h_bw = self.enc_cell_bw.get_output(last_state_bw)
        last_h = tf.concat([last_h_fw, last_h_bw], 1)
        mu = rnn.super_linear(
            last_h,
            self.hps.z_size,
            input_size=self.hps.enc_rnn_size * 2,  # bi-dir, so x2
            scope='ENC_RNN_mu',
            init_w='gaussian',
            weight_start=0.001)
        presig = rnn.super_linear(
            last_h,
            self.hps.z_size,
            input_size=self.hps.enc_rnn_size * 2,  # bi-dir, so x2
            scope='ENC_RNN_sigma',
            init_w='gaussian',
            weight_start=0.001)
        return mu, presig
示例#4
0
    def latent_space(self, reuse):
        with tf.variable_scope("latent_space", reuse=reuse):
            mu = rnn.super_linear(
                self.last_h,
                self.hps.z_size,
                input_size=self.hps.enc_rnn_size * 2 +
                self.expr_num,  # bi-dir, so x2
                scope='ENC_RNN_mu',
                init_w='gaussian',
                weight_start=0.001)

            presig = rnn.super_linear(
                self.last_h,
                self.hps.z_size,
                input_size=self.hps.enc_rnn_size * 2 +
                self.expr_num,  # bi-dir, so x2
                scope='ENC_RNN_sigma',
                init_w='gaussian',
                weight_start=0.001)
            self.mean, self.presig = mu, presig
            self.sigma = tf.exp(self.presig /
                                2.0)  # sigma > 0. div 2.0 -> sqrt.
            eps = tf.random_normal((self.hps.batch_size, self.hps.z_size),
                                   0.0,
                                   1.0,
                                   dtype=tf.float32)
            self.batch_z = self.mean + tf.multiply(self.sigma, eps)
    def encoder(self, batch, sequence_lengths):
        """Define the encoder module of sketch-rnn."""
        mu = 0
        presig = 0

        # adding CNN option
        if self.hps.enc_CNN:
            last_h = cnn.cnn_model(batch)
            mu = rnn.super_linear(last_h,
                                  self.hps.z_size,
                                  scope='ENC_CNN_mu',
                                  init_w='gaussian',
                                  weight_start=0.001)
            presig = rnn.super_linear(last_h,
                                      self.hps.z_size,
                                      scope='ENC_CNN_sigma',
                                      init_w='gaussian',
                                      weight_start=0.001)

        # bi-directional RNN
        else:
            unused_outputs, last_states = tf.nn.bidirectional_dynamic_rnn(  #returns forward and backward tensors
                self.enc_cell_fw,
                self.enc_cell_bw,
                batch,
                sequence_length=sequence_lengths,
                time_major=False,
                swap_memory=True,
                dtype=tf.float32,
                scope='ENC_RNN')

            last_state_fw, last_state_bw = last_states  #split the tuple

            #gets outputs of the final states
            last_h_fw = self.enc_cell_fw.get_output(last_state_fw)
            last_h_bw = self.enc_cell_bw.get_output(last_state_bw)
            last_h = tf.concat([last_h_fw, last_h_bw], 1)

            mu = rnn.super_linear(
                last_h,
                self.hps.z_size,
                input_size=self.hps.enc_rnn_size * 2,  # bi-dir, so x2
                scope='ENC_RNN_mu',
                init_w='gaussian',
                weight_start=0.001)
            presig = rnn.super_linear(
                last_h,
                self.hps.z_size,
                input_size=self.hps.enc_rnn_size * 2,  # bi-dir, so x2
                scope='ENC_RNN_sigma',
                init_w='gaussian',
                weight_start=0.001)

        return mu, presig
示例#6
0
 def get_mu_sig(self, encoded_h):
     input_size = int(encoded_h.shape[-1])
     mu = rnn.super_linear(encoded_h,
                           self.hps.z_size,
                           input_size=input_size,
                           scope='ENC_mu',
                           init_w='gaussian',
                           weight_start=0.001)
     presig = rnn.super_linear(encoded_h,
                               self.hps.z_size,
                               input_size=input_size,
                               scope='ENC_sigma',
                               init_w='gaussian',
                               weight_start=0.001)
     return mu, presig
示例#7
0
 def cnn_encoder(self, goal_batch):
     image_input = goal_batch  # tf.stack([current_batch, goal_batch], axis=3)
     flattened_conv = tf.layers.Flatten()(
         conv_model.make_model(image_input))
     mu = rnn.super_linear(flattened_conv,
                           self.hps.z_size,
                           scope='ENC_RNN_mu',
                           init_w='gaussian',
                           weight_start=0.001)
     presig = rnn.super_linear(flattened_conv,
                               self.hps.z_size,
                               scope='ENC_RNN_sigma',
                               init_w='gaussian',
                               weight_start=0.001)
     return mu, presig
示例#8
0
 def encoder(self, batch):
     z = rnn.super_linear(batch,
                          self.hps.z_size,
                          input_size=4096,
                          scope='latent_z',
                          init_w='gaussian',
                          weight_start=0.001)
     return z
示例#9
0
    def encoder(self, batch, label, sequence_lengths, conditioned=True):
        """Define the bi-directional encoder module of sketch-rnn."""
        unused_outputs, last_states = tf.nn.bidirectional_dynamic_rnn(
            self.enc_cell_fw,
            self.enc_cell_bw,
            batch,
            sequence_length=sequence_lengths,
            time_major=False,
            swap_memory=True,
            dtype=tf.float32,
            scope='ENC_RNN')

        last_state_fw, last_state_bw = last_states
        last_h_fw = self.enc_cell_fw.get_output(last_state_fw)
        last_h_bw = self.enc_cell_bw.get_output(last_state_bw)
        if not conditioned:
            last_h_fw = tf.zeros_like(last_h_fw)
            last_h_bw = tf.zeros_like(last_h_bw)
        label_embedding = tf.nn.embedding_lookup(
            self.input_label_embedding_matrix, label)
        last_h = tf.concat([last_h_fw, last_h_bw, label_embedding], 1)
        mu = rnn.super_linear(
            last_h,
            self.hps.z_size,
            input_size=self.hps.enc_rnn_size * 2 +
            self.hps.label_embedding_dim,  # bi-dir, so x2
            scope='ENC_RNN_mu',
            init_w='gaussian',
            weight_start=0.001)
        presig = rnn.super_linear(
            last_h,
            self.hps.z_size,
            input_size=self.hps.enc_rnn_size * 2 +
            self.hps.label_embedding_dim,  # bi-dir, so x2
            scope='ENC_RNN_sigma',
            init_w='gaussian',
            weight_start=0.001)
        return mu, presig
示例#10
0
    def get_decoder_inputs(self, encoded_h, is_seq=True, name_scope=None):
        mean, presig = self.get_mu_sig(encoded_h)
        sigma = tf.exp(presig / 2.0)  # sigma > 0. div 2.0 -> sqrt.
        eps = tf.random_normal((self.hps.batch_size, self.hps.z_size),
                               0.0,
                               1.0,
                               dtype=tf.float32)
        batch_z = mean + tf.multiply(sigma, eps)  # [N, z_size]

        if not is_seq:
            return batch_z, mean, presig

        pre_tile_y = tf.reshape(batch_z,
                                [self.hps.batch_size, 1, self.hps.z_size])
        overlay_x = tf.tile(
            pre_tile_y,
            [1, self.hps.max_seq_len, 1])  # [N, max_seq_len, z_size]
        actual_input_x = tf.concat([self.input_x, overlay_x], 2)

        initial_state = tf.nn.tanh(
            rnn.super_linear(batch_z,
                             self.dec_cell.state_size,
                             init_w='gaussian',
                             weight_start=0.001,
                             input_size=self.hps.z_size))

        if name_scope == 'p2s':
            # print('p2s seq decoder')
            self.initial_state_p2s = initial_state
            return batch_z, self.initial_state_p2s, actual_input_x, mean, presig
        elif name_scope == 's2s':
            # print('s2s seq decoder')
            self.initial_state_s2s = initial_state
            return batch_z, self.initial_state_s2s, actual_input_x, mean, presig
        else:
            raise Exception('Unknown name_scope', name_scope)
示例#11
0
    def build_model(self, hps):
        if hps.is_training:
            self.global_step = tf.Variable(0,
                                           name='global_step',
                                           trainable=False)

        if hps.dec_model == 'lstm':
            cell_fn = rnn.LSTMCell
        elif hps.dec_model == 'layer_norm':
            cell_fn = rnn.LayerNormLSTMCell
        elif hps.dec_model == 'hyper':
            cell_fn = rnn.HyperLSTMCell
        else:
            assert False, 'please choose a respectable cell'

        if hps.enc_model == 'lstm':
            enc_cell_fn = rnn.LSTMCell  #设置编码器为LSTM
        elif hps.enc_model == 'layer_norm':
            enc_cell_fn = rnn.LayerNormLSTMCell
        elif hps.enc_model == 'hyper':
            enc_cell_fn = rnn.HyperLSTMCell
        else:
            assert False, 'please choose a respectable cell'

        #dropout技巧
        use_recurrent_dropout = self.hps.use_recurrent_dropout
        use_input_dropout = self.hps.use_input_dropout
        use_output_dropout = self.hps.use_output_dropout

        cell = cell_fn(
            hps.dec_rnn_size,  #512
            use_recurrent_dropout=use_recurrent_dropout,
            dropout_keep_prob=self.hps.recurrent_dropout_prob)

        if hps.conditional:
            if hps.enc_model == 'hyper':
                self.enc_cell_fw = enc_cell_fn(
                    hps.enc_rnn_size,
                    use_recurrent_dropout=use_recurrent_dropout,
                    dropout_keep_prob=self.hps.recurrent_dropout_prob)
                self.enc_cell_bw = enc_cell_fn(
                    hps.enc_rnn_size,
                    use_recurrent_dropout=use_recurrent_dropout,
                    dropout_keep_prob=self.hps.recurrent_dropout_prob)
            else:  #enc_model = lstm
                self.enc_cell_fw = enc_cell_fn(  #双向RNN的正向
                    hps.enc_rnn_size,
                    use_recurrent_dropout=use_recurrent_dropout,
                    dropout_keep_prob=self.hps.recurrent_dropout_prob)
                self.enc_cell_bw = enc_cell_fn(  #双向RNN的反向
                    hps.enc_rnn_size,
                    use_recurrent_dropout=use_recurrent_dropout,
                    dropout_keep_prob=self.hps.recurrent_dropout_prob)

        # dropout:
        tf.logging.info('Input dropout mode = %s.', use_input_dropout)  #false
        tf.logging.info('Output dropout mode = %s.',
                        use_output_dropout)  #false
        tf.logging.info('Recurrent dropout mode = %s.',
                        use_recurrent_dropout)  #true
        if use_input_dropout:
            tf.logging.info('Dropout to input w/ keep_prob = %4.4f.',
                            self.hps.input_dropout_prob)
            cell = tf.contrib.rnn.DropoutWrapper(
                cell, input_keep_prob=self.hps.input_dropout_prob)
        if use_output_dropout:
            tf.logging.info('Dropout to output w/ keep_prob = %4.4f.',
                            self.hps.output_dropout_prob)
            cell = tf.contrib.rnn.DropoutWrapper(
                cell, output_keep_prob=self.hps.output_dropout_prob)

        #==============================decode !!!============================
        tf.logging.info('model- decoder 部分')
        self.cell = cell
        self.sequence_lengths = tf.placeholder(dtype=tf.int32,
                                               shape=[self.hps.batch_size
                                                      ])  #batch 大小
        #包含了起始符的decoder输入
        self.source_input = tf.placeholder(
            dtype=tf.float32,
            shape=[self.hps.batch_size, self.hps.max_seq_len + 1, 5])
        self.target_input = tf.placeholder(
            dtype=tf.float32,
            shape=[self.hps.batch_size, self.hps.max_seq_len + 1, 5])

        tf.logging.info('model- encoder的输入')
        self.encoder_input_x = self.source_input[:,
                                                 1:self.hps.max_seq_len + 1, :]
        tf.logging.info('model- decoder的输入和标签')
        self.output_x = self.target_input[:, 1:self.hps.max_seq_len + 1, :]
        self.decoder_input_source_x = self.source_input[:, :self.hps.
                                                        max_seq_len, :]
        self.decoder_input_target_x = self.target_input[:, :self.hps.
                                                        max_seq_len, :]

        # 如果condition=true,输入加入隐含变量z
        if hps.conditional:
            self.mean, self.presig = self.encoder(self.encoder_input_x,
                                                  self.sequence_lengths)
            self.sigma = tf.exp(self.presig / 2.0)
            eps = tf.random_normal((self.hps.batch_size, self.hps.z_size),
                                   0.0,
                                   1.0,
                                   dtype=tf.float32)
            self.batch_z = self.mean + tf.multiply(self.sigma, eps)
            #self.kl_cost = -0.5 * tf.reduce_mean(
            # (1 + self.presig - tf.square(self.mean) - tf.exp(self.presig)))
            #self.kl_cost = tf.maximum(self.kl_cost, self.hps.kl_tolerance)
            #输入数据,得到隐含变量h
            pre_tile_y = tf.reshape(self.batch_z,
                                    [self.hps.batch_size, 1, self.hps.z_size])
            overlay_x = tf.tile(pre_tile_y, [1, self.hps.max_seq_len, 1])
            part_input_x = tf.concat([self.decoder_input_target_x, overlay_x],
                                     2)
            actual_input_x = tf.concat(
                [self.decoder_input_source_x, part_input_x], 2)
            # decoder每一时刻的输入为soure 和 target输入的组合 按咧拼接’
            self.initial_state = tf.nn.tanh(
                rnn.super_linear(self.batch_z,
                                 cell.state_size,
                                 init_w='gaussian',
                                 weight_start=0.001,
                                 input_size=self.hps.z_size))
        # unconditional, decoder-only generation
        else:
            self.batch_z = tf.zeros((self.hps.batch_size, self.hps.z_size),
                                    dtype=tf.float32)
            actual_input_x = self.input_x
            self.initial_state = cell.zero_state(batch_size=hps.batch_size,
                                                 dtype=tf.float32)

        tf.logging.info(
            'model- 开始高斯混合模型采样了======================================')
        self.num_mixture = hps.num_mixture  #混合高斯模型的高斯个数 20
        n_out = (3 + self.num_mixture * 6)
        #解码器输出y的维度为 5M + M + 3,分别表示高斯函数参数、权重、(p1,p2,p3)
        with tf.variable_scope('RNN'):
            output_w = tf.get_variable('output_w',
                                       [self.hps.dec_rnn_size, n_out])
            output_b = tf.get_variable('output_b', [n_out])

        # decoder module of sketch-rnn is below
        output, last_state = tf.nn.dynamic_rnn(
            cell,
            actual_input_x,
            initial_state=self.initial_state,
            time_major=False,
            swap_memory=True,
            dtype=tf.float32,
            scope='RNN')

        output = tf.reshape(output, [-1, hps.dec_rnn_size])
        output = tf.nn.xw_plus_b(output, output_w, output_b)  #全连接层,接n_out个神经元
        self.final_state = last_state

        # x1 x2分别是坐标x轴和y轴的偏移量,result为二维正态分布的概率密度
        def tf_2d_normal(x1, x2, mu1, mu2, s1, s2, rho):
            tf.logging.info('model- 根据那篇文章的公式计算二维正态分布的概率密度')
            norm1 = tf.subtract(x1, mu1)
            norm2 = tf.subtract(x2, mu2)
            s1s2 = tf.multiply(s1, s2)
            z = (tf.square(tf.div(norm1, s1)) + tf.square(tf.div(norm2, s2)) -
                 2 * tf.div(tf.multiply(rho, tf.multiply(norm1, norm2)), s1s2))
            neg_rho = 1 - tf.square(rho)
            result = tf.exp(tf.div(-z, 2 * neg_rho))
            denom = 2 * np.pi * tf.multiply(s1s2, tf.sqrt(neg_rho))
            result = tf.div(result, denom)
            return result

        #计算重构误差,求上面概率密度的对数
        def get_lossfunc(z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr,
                         z_pen_logits, x1_data, x2_data, pen_data):
            tf.logging.info('model- 计算重构误差')
            #采样(x,y)数据的误差
            result0 = tf_2d_normal(x1_data, x2_data, z_mu1, z_mu2, z_sigma1,
                                   z_sigma2, z_corr)
            epsilon = 1e-6
            result1 = tf.multiply(result0, z_pi)  #每一个高斯模型乘上相应的权重
            result1 = tf.reduce_sum(result1, 1, keep_dims=True)  #对所有求和
            result1 = -tf.log(result1 + epsilon)  # avoid log(0)
            fs = 1.0 - pen_data[:, 2]  #如果最后一个点为1,则为终点概率为0,反之,为1.
            fs = tf.reshape(fs, [-1, 1])  #将batch * 1 的数据转为 1* batch的数据
            result1 = tf.multiply(result1, fs)

            # result2: loss wrt pen state, (L_p in equation 9)
            result2 = tf.nn.softmax_cross_entropy_with_logits(  #Lp就是求交叉熵
                labels=pen_data, logits=z_pen_logits)
            result2 = tf.reshape(result2, [-1, 1])
            if not self.hps.is_training:  # eval mode, mask eos columns
                result2 = tf.multiply(result2, fs)

            result = result1 + result2
            return result

        # below is where we need to do MDN (Mixture Density Network) splitting of
        # distribution params
        def get_mixture_coef(output):
            tf.logging.info('model- 将decoder的网络输出切分成高斯混合模型的参数')
            """Returns the tf slices containing mdn dist params."""
            #根据decoder输出的6M+8的值,构建混合高斯模型
            z = output
            z_pen_logits = z[:, 0:3]  # pen states z的前三个值为(p1,p2,p3)
            #剩下的6个分别为权重pi,miu(x),miu(y), z_sigma1, z_sigma2, z_corr
            z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = tf.split(
                z[:, 3:], 6, 1)

            # process output z's into MDN paramters

            # softmax all the pi's and pen states:
            z_pi = tf.nn.softmax(z_pi)
            z_pen = tf.nn.softmax(z_pen_logits)
            #对(p1,p2,p3)和pi的值做做logit处理,是的其都为正,且加起来为1

            # exponentiate the sigmas and also make corr between -1 and 1.
            z_sigma1 = tf.exp(z_sigma1)
            z_sigma2 = tf.exp(z_sigma2)
            z_corr = tf.tanh(z_corr)

            r = [
                z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen,
                z_pen_logits
            ]
            return r

        tf.logging.info('model- 调用切分输出函数')
        out = get_mixture_coef(output)
        [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen,
         o_pen_logits] = out

        self.pi = o_pi
        self.mu1 = o_mu1
        self.mu2 = o_mu2
        self.sigma1 = o_sigma1
        self.sigma2 = o_sigma2
        self.corr = o_corr
        self.pen_logits = o_pen_logits
        self.pen = o_pen

        # 重构误差
        target = tf.reshape(self.output_x, [-1, 5])  #目标输出(x,y ,p1,p2,p3)
        [x1_data, x2_data, p1, p2, p3] = tf.split(target, 5, 1)
        pen_data = tf.concat([p1, p2, p3], 1)
        lossfunc = get_lossfunc(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr,
                                o_pen_logits, x1_data, x2_data, pen_data)
        self.r_cost = tf.reduce_mean(lossfunc)

        # identify loss
        sourcess = tf.reshape(self.encoder_input_x,
                              [-1, 5])  #网络输入(x,y ,p1,p2,p3)
        [s1_data, s2_data, sp1, sp2, sp3] = tf.split(sourcess, 5, 1)
        spen_data = tf.concat([sp1, sp2, sp3], 1)
        identyfunc = get_lossfunc(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2,
                                  o_corr, o_pen_logits, s1_data, s2_data,
                                  spen_data)
        self.il_cost = tf.reduce_mean(identyfunc)

        if self.hps.is_training:
            tf.logging.info('model- 选择学习率和误差函数')
            self.lr = tf.Variable(self.hps.learning_rate, trainable=False)
            optimizer = tf.train.AdamOptimizer(self.lr)  #使用ADAM优化方式
            #self.kl_weight = tf.Variable(self.hps.kl_weight_start, trainable=False)
            self.il_weight = tf.Variable(self.hps.il_weight_start,
                                         trainable=False)
            self.cost = self.r_cost + self.il_cost * self.il_weight
            gvs = optimizer.compute_gradients(self.cost)
            g = self.hps.grad_clip  # 1.0
            capped_gvs = [(tf.clip_by_value(grad, -g, g), var)
                          for grad, var in gvs]
            self.train_op = optimizer.apply_gradients(
                capped_gvs, global_step=self.global_step, name='train_step')
示例#12
0
  def encoder(self, batch, sequence_lengths):
    """Define the bi-directional encoder module of sketch-rnn."""
    unused_outputs, last_states = tf.nn.bidirectional_dynamic_rnn(
        self.enc_cell_fw,
        self.enc_cell_bw,
        batch,
        sequence_length=sequence_lengths,
        time_major=False,
        swap_memory=True,
        dtype=tf.float32,
        scope='ENC_RNN')

    last_state_fw, last_state_bw = last_states
    last_h_fw = self.enc_cell_fw.get_output(last_state_fw)
    last_h_bw = self.enc_cell_bw.get_output(last_state_bw)
    last_h = tf.concat([last_h_fw, last_h_bw], 1)
    mu = rnn.super_linear(
        last_h,
        self.hps.z_size,
        input_size=self.hps.enc_rnn_size * 2,  # bi-dir, so x2
        scope='ENC_RNN_mu',
        init_w='gaussian',
        weight_start=0.001)
    presig = rnn.super_linear(
        last_h,
        self.hps.z_size,
        input_size=self.hps.enc_rnn_size * 2,  # bi-dir, so x2
        scope='ENC_RNN_sigma',
        init_w='gaussian',
        weight_start=0.001)

    # Will most likely have to change the size. Maybe create an array of the
    # super linear terms? But either way will have to separate into u, b, w.
    u_array = []
    w_array = []
    b_array = []
    for i in range(self.hps.num_flows):
      u = rnn.super_linear(
          last_h,
          # self.hps.num_flows * (self.hps.z_size * 2 + 1),
          (self.hps.z_size),
          input_size=self.hps.enc_rnn_size * 2,  # bi-dir, so x2
          scope='ENC_RNN_u_' + str(i),
          init_w='gaussian',
          weight_start=0.001)
      u_array.append(u)
      w = rnn.super_linear(
          last_h,
          # self.hps.num_flows * (self.hps.z_size * 2 + 1),
          (self.hps.z_size),
          input_size=self.hps.enc_rnn_size * 2,  # bi-dir, so x2
          scope='ENC_RNN_w_' + str(i),
          init_w='gaussian',
          weight_start=0.001)
      w_array.append(w)
      b = rnn.super_linear(
          last_h,
          # self.hps.num_flows * (self.hps.z_size * 2 + 1),
          1,
          input_size=self.hps.enc_rnn_size * 2,  # bi-dir, so x2
          scope='ENC_RNN_b_' + str(i),
          init_w='gaussian',
          weight_start=0.001)
      b_array.append(b)

    return mu, presig, u_array, w_array, b_array
示例#13
0
  def build_model(self, hps):
    """Define model architecture."""
    if hps.is_training:
      self.global_step = tf.Variable(0, name='global_step', trainable=False)

    if hps.dec_model == 'lstm':
      cell_fn = rnn.LSTMCell
    elif hps.dec_model == 'layer_norm':
      cell_fn = rnn.LayerNormLSTMCell
    elif hps.dec_model == 'hyper':
      cell_fn = rnn.HyperLSTMCell
    else:
      assert False, 'please choose a respectable cell'

    if hps.enc_model == 'lstm':
      enc_cell_fn = rnn.LSTMCell
    elif hps.enc_model == 'layer_norm':
      enc_cell_fn = rnn.LayerNormLSTMCell
    elif hps.enc_model == 'hyper':
      enc_cell_fn = rnn.HyperLSTMCell
    else:
      assert False, 'please choose a respectable cell'

    use_recurrent_dropout = self.hps.use_recurrent_dropout
    use_input_dropout = self.hps.use_input_dropout
    use_output_dropout = self.hps.use_output_dropout

    cell = cell_fn(
        hps.dec_rnn_size,
        use_recurrent_dropout=use_recurrent_dropout,
        dropout_keep_prob=self.hps.recurrent_dropout_prob)

    if hps.conditional:  # vae mode:
      if hps.enc_model == 'hyper':
        self.enc_cell_fw = enc_cell_fn(
            hps.enc_rnn_size,
            use_recurrent_dropout=use_recurrent_dropout,
            dropout_keep_prob=self.hps.recurrent_dropout_prob)
        self.enc_cell_bw = enc_cell_fn(
            hps.enc_rnn_size,
            use_recurrent_dropout=use_recurrent_dropout,
            dropout_keep_prob=self.hps.recurrent_dropout_prob)
      else:
        self.enc_cell_fw = enc_cell_fn(
            hps.enc_rnn_size,
            use_recurrent_dropout=use_recurrent_dropout,
            dropout_keep_prob=self.hps.recurrent_dropout_prob)
        self.enc_cell_bw = enc_cell_fn(
            hps.enc_rnn_size,
            use_recurrent_dropout=use_recurrent_dropout,
            dropout_keep_prob=self.hps.recurrent_dropout_prob)

    # dropout:
    tf.logging.info('Input dropout mode = %s.', use_input_dropout)
    tf.logging.info('Output dropout mode = %s.', use_output_dropout)
    tf.logging.info('Recurrent dropout mode = %s.', use_recurrent_dropout)
    if use_input_dropout:
      tf.logging.info('Dropout to input w/ keep_prob = %4.4f.',
                      self.hps.input_dropout_prob)
      cell = tf.contrib.rnn.DropoutWrapper(
          cell, input_keep_prob=self.hps.input_dropout_prob)
    if use_output_dropout:
      tf.logging.info('Dropout to output w/ keep_prob = %4.4f.',
                      self.hps.output_dropout_prob)
      cell = tf.contrib.rnn.DropoutWrapper(
          cell, output_keep_prob=self.hps.output_dropout_prob)
    self.cell = cell

    self.sequence_lengths = tf.placeholder(
        dtype=tf.int32, shape=[self.hps.batch_size])
    self.input_data = tf.placeholder(
        dtype=tf.float32,
        shape=[self.hps.batch_size, self.hps.max_seq_len + 1, 5])

    # The target/expected vectors of strokes
    self.output_x = self.input_data[:, 1:self.hps.max_seq_len + 1, :]
    # vectors of strokes to be fed to decoder (same as above, but lagged behind
    # one step to include initial dummy value of (0, 0, 1, 0, 0))
    self.input_x = self.input_data[:, :self.hps.max_seq_len, :]

    # either do vae-bit and get z, or do unconditional, decoder-only
    if hps.conditional:  # vae mode:
      self.mean, self.presig, self.us, self.ws, self.bs = self.encoder(self.output_x,
                                            self.sequence_lengths)
      self.sigma = tf.exp(self.presig / 2.0)  # sigma > 0. div 2.0 -> sqrt.
      eps = tf.random_normal(
          (self.hps.batch_size, self.hps.z_size), 0.0, 1.0, dtype=tf.float32)
      self.batch_z = self.mean + tf.multiply(self.sigma, eps)

      # Apply normalizing flow
      self.sum_log_det_jacobian = 0.
      h_prime = lambda x: 1 - tf.tanh(x) ** 2  # Derivative of nonlinearity h above

      for i in range(self.hps.num_flows):
        psi = h_prime(tf.expand_dims(tf.reduce_sum(
          self.batch_z * self.ws[i], -1), -1) + self.bs[i]) * self.ws[i]
        self.sum_log_det_jacobian += tf.log(tf.abs(1 + tf.reduce_sum(psi * self.us[i], -1)))

        self.batch_z = self.planar_flow(self.batch_z, self.ws[i], self.us[i], self.bs[i])

      # KL cost
      self.kl_cost = -0.5 * tf.reduce_mean(
          (1 + self.presig - tf.square(self.mean) - tf.exp(self.presig)))
      self.kl_cost -= tf.reduce_mean(self.sum_log_det_jacobian)
      self.kl_cost = tf.maximum(self.kl_cost, self.hps.kl_tolerance)
      pre_tile_y = tf.reshape(self.batch_z,
                              [self.hps.batch_size, 1, self.hps.z_size])
      overlay_x = tf.tile(pre_tile_y, [1, self.hps.max_seq_len, 1])
      actual_input_x = tf.concat([self.input_x, overlay_x], 2)
      self.initial_state = tf.nn.tanh(
          rnn.super_linear(
              self.batch_z,
              cell.state_size,
              init_w='gaussian',
              weight_start=0.001,
              input_size=self.hps.z_size))
    else:  # unconditional, decoder-only generation
      self.batch_z = tf.zeros(
          (self.hps.batch_size, self.hps.z_size), dtype=tf.float32)
      self.kl_cost = tf.zeros([], dtype=tf.float32)
      actual_input_x = self.input_x
      self.initial_state = cell.zero_state(
          batch_size=hps.batch_size, dtype=tf.float32)

    self.num_mixture = hps.num_mixture

    # TODO(deck): Better understand this comment.
    # Number of outputs is 3 (one logit per pen state) plus 6 per mixture
    # component: mean_x, stdev_x, mean_y, stdev_y, correlation_xy, and the
    # mixture weight/probability (Pi_k)
    n_out = (3 + self.num_mixture * 6)

    with tf.variable_scope('RNN'):
      output_w = tf.get_variable('output_w', [self.hps.dec_rnn_size, n_out])
      output_b = tf.get_variable('output_b', [n_out])

    # decoder module of sketch-rnn is below
    output, last_state = tf.nn.dynamic_rnn(
        cell,
        actual_input_x,
        initial_state=self.initial_state,
        time_major=False,
        swap_memory=True,
        dtype=tf.float32,
        scope='RNN')

    output = tf.reshape(output, [-1, hps.dec_rnn_size])
    output = tf.nn.xw_plus_b(output, output_w, output_b)
    self.final_state = last_state

    # NB: the below are inner functions, not methods of Model
    def tf_2d_normal(x1, x2, mu1, mu2, s1, s2, rho):
      """Returns result of eq # 24 of http://arxiv.org/abs/1308.0850."""
      norm1 = tf.subtract(x1, mu1)
      norm2 = tf.subtract(x2, mu2)
      s1s2 = tf.multiply(s1, s2)
      # eq 25
      z = (tf.square(tf.div(norm1, s1)) + tf.square(tf.div(norm2, s2)) -
           2 * tf.div(tf.multiply(rho, tf.multiply(norm1, norm2)), s1s2))
      neg_rho = 1 - tf.square(rho)
      result = tf.exp(tf.div(-z, 2 * neg_rho))
      denom = 2 * np.pi * tf.multiply(s1s2, tf.sqrt(neg_rho))
      result = tf.div(result, denom)
      return result

    def get_lossfunc(z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr,
                     z_pen_logits, x1_data, x2_data, pen_data):
      """Returns a loss fn based on eq #26 of http://arxiv.org/abs/1308.0850."""
      # This represents the L_R only (i.e. does not include the KL loss term).

      result0 = tf_2d_normal(x1_data, x2_data, z_mu1, z_mu2, z_sigma1, z_sigma2,
                             z_corr)
      epsilon = 1e-6
      # result1 is the loss wrt pen offset (L_s in equation 9 of
      # https://arxiv.org/pdf/1704.03477.pdf)
      result1 = tf.multiply(result0, z_pi)
      result1 = tf.reduce_sum(result1, 1, keep_dims=True)
      result1 = -tf.log(result1 + epsilon)  # avoid log(0)

      fs = 1.0 - pen_data[:, 2]  # use training data for this
      fs = tf.reshape(fs, [-1, 1])
      # Zero out loss terms beyond N_s, the last actual stroke
      result1 = tf.multiply(result1, fs)

      # result2: loss wrt pen state, (L_p in equation 9)
      result2 = tf.nn.softmax_cross_entropy_with_logits(
          labels=pen_data, logits=z_pen_logits)
      result2 = tf.reshape(result2, [-1, 1])
      if not self.hps.is_training:  # eval mode, mask eos columns
        result2 = tf.multiply(result2, fs)

      result = result1 + result2
      return result

    # below is where we need to do MDN (Mixture Density Network) splitting of
    # distribution params
    def get_mixture_coef(output):
      """Returns the tf slices containing mdn dist params."""
      # This uses eqns 18 -> 23 of http://arxiv.org/abs/1308.0850.
      z = output
      z_pen_logits = z[:, 0:3]  # pen states
      z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = tf.split(z[:, 3:], 6, 1)

      # process output z's into MDN parameters

      # softmax all the pi's and pen states:
      z_pi = tf.nn.softmax(z_pi)
      z_pen = tf.nn.softmax(z_pen_logits)

      # exponentiate the sigmas and also make corr between -1 and 1.
      z_sigma1 = tf.exp(z_sigma1)
      z_sigma2 = tf.exp(z_sigma2)
      z_corr = tf.tanh(z_corr)

      r = [z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen, z_pen_logits]
      return r

    out = get_mixture_coef(output)
    [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits] = out

    self.pi = o_pi
    self.mu1 = o_mu1
    self.mu2 = o_mu2
    self.sigma1 = o_sigma1
    self.sigma2 = o_sigma2
    self.corr = o_corr
    self.pen_logits = o_pen_logits
    # pen state probabilities (result of applying softmax to self.pen_logits)
    self.pen = o_pen

    # reshape target data so that it is compatible with prediction shape
    target = tf.reshape(self.output_x, [-1, 5])
    [x1_data, x2_data, eos_data, eoc_data, cont_data] = tf.split(target, 5, 1)
    pen_data = tf.concat([eos_data, eoc_data, cont_data], 1)

    lossfunc = get_lossfunc(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr,
                            o_pen_logits, x1_data, x2_data, pen_data)

    self.r_cost = tf.reduce_mean(lossfunc)

    if self.hps.is_training:
      self.lr = tf.Variable(self.hps.learning_rate, trainable=False)
      optimizer = tf.train.AdamOptimizer(self.lr)

      self.kl_weight = tf.Variable(self.hps.kl_weight_start, trainable=False)
      self.cost = self.r_cost + self.kl_cost * self.kl_weight

      gvs = optimizer.compute_gradients(self.cost)
      g = self.hps.grad_clip
      capped_gvs = [(tf.clip_by_value(grad, -g, g), var) for grad, var in gvs]
      self.train_op = optimizer.apply_gradients(
          capped_gvs, global_step=self.global_step, name='train_step')
示例#14
0
    def build_model(self, hps):
        """Define model architecture."""
        if hps.is_training:
            self.global_step = tf.Variable(0,
                                           name='global_step',
                                           trainable=False)

        if hps.dec_model == 'lstm':
            cell_fn = rnn.LSTMCell
        elif hps.dec_model == 'layer_norm':
            cell_fn = rnn.LayerNormLSTMCell
        elif hps.dec_model == 'hyper':
            cell_fn = rnn.HyperLSTMCell
        else:
            assert False, 'please choose a respectable cell'

        if hps.enc_model == 'lstm':
            enc_cell_fn = rnn.LSTMCell
        elif hps.enc_model == 'layer_norm':
            enc_cell_fn = rnn.LayerNormLSTMCell
        elif hps.enc_model == 'hyper':
            enc_cell_fn = rnn.HyperLSTMCell
        else:
            assert False, 'please choose a respectable cell'

        use_recurrent_dropout = self.hps.use_recurrent_dropout
        use_input_dropout = self.hps.use_input_dropout
        use_output_dropout = self.hps.use_output_dropout

        cell = cell_fn(hps.dec_rnn_size,
                       use_recurrent_dropout=use_recurrent_dropout,
                       dropout_keep_prob=self.hps.recurrent_dropout_prob)

        if hps.conditional:  # vae mode:
            self.enc_cell_fw = enc_cell_fn(
                hps.enc_rnn_size,
                use_recurrent_dropout=use_recurrent_dropout,
                dropout_keep_prob=self.hps.recurrent_dropout_prob)
            self.enc_cell_bw = enc_cell_fn(
                hps.enc_rnn_size,
                use_recurrent_dropout=use_recurrent_dropout,
                dropout_keep_prob=self.hps.recurrent_dropout_prob)

        # dropout:
        tf.logging.info('Input dropout mode = %s.', use_input_dropout)
        tf.logging.info('Output dropout mode = %s.', use_output_dropout)
        tf.logging.info('Recurrent dropout mode = %s.', use_recurrent_dropout)
        if use_input_dropout:
            tf.logging.info('Dropout to input w/ keep_prob = %4.4f.',
                            self.hps.input_dropout_prob)
            cell = tf.contrib.rnn.DropoutWrapper(
                cell, input_keep_prob=self.hps.input_dropout_prob)
        if use_output_dropout:
            tf.logging.info('Dropout to output w/ keep_prob = %4.4f.',
                            self.hps.output_dropout_prob)
            cell = tf.contrib.rnn.DropoutWrapper(
                cell, output_keep_prob=self.hps.output_dropout_prob)
        self.cell = cell

        self.input_data = tf.placeholder(dtype=tf.float32,
                                         shape=[
                                             self.hps.batch_size,
                                             self.hps.max_seq_len,
                                             self.hps.input_dimension_get
                                         ])
        self.input_handle = self.input_data[:, :, :self.hps.
                                            input_dimension_real]
        # The target/expected vectors of strokes
        self.output_x = tf.placeholder(
            dtype=tf.float32,
            shape=[self.hps.batch_size, self.hps.max_seq_len,
                   1])  # always 0 or 1 indicates one's action

        # vectors of strokes to be fed to decoder (same as above, but lagged behind
        # one step to include initial dummy value of (0, 0, 1, 0, 0))
        self.input_x = self.input_handle

        # either do vae-bit and get z, or do unconditional, decoder-only
        if hps.conditional:  # vae mode:
            _ = tf.concat([
                self.input_handle[:, :self.hps.input_seq_len, :],
                self.output_x[:, :self.hps.input_seq_len, :]
            ],
                          axis=2)
            self.mean = self.encoder(_)
            # self.sigma = tf.exp(self.presig / 2.0)  # sigma > 0. div 2.0 -> sqrt.
            # eps = tf.random_normal(
            #     (self.hps.batch_size, self.hps.z_size), 0.0, 1.0, dtype=tf.float32)
            self.batch_z = self.mean

            # KL cost
            self.kl_cost = -0.5 * tf.reduce_mean((1 + -tf.square(self.mean)))
            self.kl_cost = tf.maximum(self.kl_cost, self.hps.kl_tolerance)
            pre_tile_y = tf.reshape(self.batch_z,
                                    [self.hps.batch_size, 1, self.hps.z_size])
            overlay_x = tf.tile(pre_tile_y, [1, self.hps.max_seq_len, 1])
            actual_input_x = tf.concat([self.input_x, overlay_x], 2)
            self.initial_state = tf.nn.tanh(
                rnn.super_linear(self.batch_z,
                                 cell.state_size,
                                 init_w='gaussian',
                                 weight_start=0.001,
                                 input_size=self.hps.z_size))
        else:  # unconditional, decoder-only generation
            self.batch_z = tf.zeros((self.hps.batch_size, self.hps.z_size),
                                    dtype=tf.float32)
            self.kl_cost = tf.zeros([], dtype=tf.float32)
            actual_input_x = self.input_x
            self.initial_state = cell.zero_state(batch_size=hps.batch_size,
                                                 dtype=tf.float32)

        # TODO(deck): Better understand this comment.
        # Number of outputs is 3 (one logit per pen state) plus 6 per mixture
        # component: mean_x, stdev_x, mean_y, stdev_y, correlation_xy, and the
        # mixture weight/probability (Pi_k)
        # n_out = (3 + self.num_mixture * 6)
        n_out = 1
        with tf.variable_scope('RNN'):
            output_w = tf.get_variable('output_w',
                                       [self.hps.dec_rnn_size, n_out])
            output_b = tf.get_variable('output_b', [n_out])

        # decoder module of sketch-rnn is below
        output, last_state = tf.nn.dynamic_rnn(
            cell,
            actual_input_x,
            initial_state=self.initial_state,
            time_major=False,
            swap_memory=True,
            dtype=tf.float32,
            scope='RNN')

        output = tf.reshape(output, [-1, hps.dec_rnn_size])
        output = tf.nn.xw_plus_b(output, output_w, output_b)
        self.final_state = last_state

        label = tf.reshape(self.output_x,
                           [self.hps.batch_size, self.hps.max_seq_len])
        out = tf.reshape(output, [self.hps.batch_size, self.hps.max_seq_len])
        self.sigmoid_out = tf.sigmoid(out)
        # self.r_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=out[:,self.hps.input_seq_len:], labels=label[:,self.hps.input_seq_len:]))
        self.r_cost = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=out, labels=label))
        if self.hps.is_training:
            self.lr = tf.Variable(self.hps.learning_rate, trainable=False)
            optimizer = tf.train.AdamOptimizer(self.lr)

            self.kl_weight = tf.Variable(self.hps.kl_weight_start,
                                         trainable=False)
            self.cost = self.r_cost + self.kl_cost * self.kl_weight

            gvs = optimizer.compute_gradients(self.cost)
            g = self.hps.grad_clip
            capped_gvs = [(tf.clip_by_value(grad, -g, g), var)
                          for grad, var in gvs]
            self.train_op = optimizer.apply_gradients(
                capped_gvs, global_step=self.global_step, name='train_step')
示例#15
0
    def build_model(self, hps):
        """Define model architecture."""
        if hps.is_training:
            self.global_step = tf.Variable(0,
                                           name='global_step',
                                           trainable=False)

        if hps.dec_model == 'lstm':
            cell_fn = rnn.LSTMCell
        elif hps.dec_model == 'layer_norm':
            cell_fn = rnn.LayerNormLSTMCell
        elif hps.dec_model == 'hyper':
            cell_fn = rnn.HyperLSTMCell
        else:
            assert False, 'please choose a respectable cell'

        if hps.enc_model == 'lstm':
            enc_cell_fn = rnn.LSTMCell
        elif hps.enc_model == 'layer_norm':
            enc_cell_fn = rnn.LayerNormLSTMCell
        elif hps.enc_model == 'hyper':
            enc_cell_fn = rnn.HyperLSTMCell
        else:
            assert False, 'please choose a respectable cell'

        use_recurrent_dropout = False
        if self.hps.use_recurrent_dropout == 1:
            use_recurrent_dropout = True

        use_input_dropout = False if self.hps.use_input_dropout == 0 else True
        use_output_dropout = False if self.hps.use_output_dropout == 0 else True

        if hps.dec_model == 'hyper':
            cell = cell_fn(hps.dec_rnn_size,
                           use_recurrent_dropout=use_recurrent_dropout,
                           dropout_keep_prob=self.hps.recurrent_dropout_prob)
        else:
            cell = cell_fn(hps.dec_rnn_size,
                           use_recurrent_dropout=use_recurrent_dropout,
                           dropout_keep_prob=self.hps.recurrent_dropout_prob)

        if hps.conditional:  # vae mode:
            if hps.enc_model == 'hyper':
                self.enc_cell_fw = enc_cell_fn(
                    hps.enc_rnn_size,
                    use_recurrent_dropout=use_recurrent_dropout,
                    dropout_keep_prob=self.hps.recurrent_dropout_prob)
                self.enc_cell_bw = enc_cell_fn(
                    hps.enc_rnn_size,
                    use_recurrent_dropout=use_recurrent_dropout,
                    dropout_keep_prob=self.hps.recurrent_dropout_prob)
            else:
                self.enc_cell_fw = enc_cell_fn(
                    hps.enc_rnn_size,
                    use_recurrent_dropout=use_recurrent_dropout,
                    dropout_keep_prob=self.hps.recurrent_dropout_prob)
                self.enc_cell_bw = enc_cell_fn(
                    hps.enc_rnn_size,
                    use_recurrent_dropout=use_recurrent_dropout,
                    dropout_keep_prob=self.hps.recurrent_dropout_prob)

        # dropout:
        tf.logging.info('Input dropout mode = %s.', use_input_dropout)
        tf.logging.info('Output dropout mode = %s.', use_output_dropout)
        tf.logging.info('Recurrent dropout mode = %s.', use_recurrent_dropout)
        if use_input_dropout:
            tf.logging.info('Dropout to input w/ keep_prob = %4.4f.',
                            self.hps.input_dropout_prob)
            cell = tf.contrib.rnn.DropoutWrapper(
                cell, input_keep_prob=self.hps.input_dropout_prob)
        if use_output_dropout:
            tf.logging.info('Dropout to output w/ keep_prob = %4.4f.',
                            self.hps.output_dropout_prob)
            cell = tf.contrib.rnn.DropoutWrapper(
                cell, output_keep_prob=self.hps.output_dropout_prob)
        self.cell = cell

        self.sequence_lengths = tf.placeholder(dtype=tf.int32,
                                               shape=[self.hps.batch_size])
        self.input_data = tf.placeholder(
            dtype=tf.float32,
            shape=[self.hps.batch_size, self.hps.max_seq_len + 1, 5])

        self.input_x = self.input_data[:, :self.hps.max_seq_len, :]
        self.output_x = self.input_data[:, 1:self.hps.max_seq_len + 1, :]

        self.img_data = tf.placeholder(
            tf.float32,
            [self.hps.batch_size, self.hps.img_size, self.hps.img_size, 1])

        # either do vae-bit and get z, or do unconditional, decoder-only
        if hps.conditional:  # vae mode:
            self.mean, self.presig = self.cnn_encoder(self.img_data)
            self.sigma = tf.exp(self.presig /
                                2.0)  # sigma > 0. div 2.0 -> sqrt.
            eps = tf.random_normal((self.hps.batch_size, self.hps.z_size),
                                   0.0,
                                   1.0,
                                   dtype=tf.float32)
            self.batch_z = self.mean + tf.multiply(self.sigma, eps)
            # KL cost
            self.kl1 = -0.5 * tf.reduce_mean(self.presig)
            self.kl2 = -0.5 * tf.reduce_mean(-tf.square(self.mean))
            self.kl3 = -0.5 * tf.reduce_mean(-tf.exp(self.presig))

            self.kl_cost = -0.5 * tf.reduce_mean(
                (1 + self.presig - tf.square(self.mean) - tf.exp(self.presig)))
            self.kl_cost = tf.maximum(self.kl_cost, self.hps.kl_tolerance)
            pre_tile_y = tf.reshape(self.batch_z,
                                    [self.hps.batch_size, 1, self.hps.z_size])
            overlay_x = tf.tile(pre_tile_y, [1, self.hps.max_seq_len, 1])
            actual_input_x = tf.concat([self.input_x, overlay_x], 2)
            self.initial_state = tf.nn.tanh(
                rnn.super_linear(self.batch_z,
                                 cell.state_size,
                                 init_w='gaussian',
                                 weight_start=0.001,
                                 input_size=self.hps.z_size))
        else:  # unconditional, decoder-only generation
            self.batch_z = tf.zeros((self.hps.batch_size, self.hps.z_size),
                                    dtype=tf.float32)
            self.kl_cost = tf.zeros([], dtype=tf.float32)
            actual_input_x = self.input_x
            self.initial_state = cell.zero_state(batch_size=hps.batch_size,
                                                 dtype=tf.float32)

        self.num_mixture = hps.num_mixture

        # TODO(deck): Better understand this comment.
        # Number of outputs is end_of_stroke + prob + 2*(mu + sig) + corr
        n_out = (3 + self.num_mixture * 6)

        with tf.variable_scope('RNN'):
            output_w = tf.get_variable('output_w',
                                       [self.hps.dec_rnn_size, n_out])
            output_b = tf.get_variable('output_b', [n_out])

        # decoder module of sketch-rnn is below
        output, last_state = tf.nn.dynamic_rnn(
            cell,
            actual_input_x,
            initial_state=self.initial_state,
            time_major=False,
            swap_memory=True,
            dtype=tf.float32,
            scope='RNN')

        output = tf.reshape(output, [-1, hps.dec_rnn_size])
        output = tf.nn.xw_plus_b(output, output_w, output_b)
        self.final_state = last_state

        def tf_2d_normal(x1, x2, mu1, mu2, s1, s2, rho):
            """Returns result of eq # 24 and 25 of http://arxiv.org/abs/1308.0850."""
            norm1 = tf.subtract(x1, mu1)
            norm2 = tf.subtract(x2, mu2)
            s1s2 = tf.multiply(s1, s2)
            z = (tf.square(tf.div(norm1, s1)) + tf.square(tf.div(norm2, s2)) -
                 2 * tf.div(tf.multiply(rho, tf.multiply(norm1, norm2)), s1s2))
            neg_rho = 1 - tf.square(rho)
            result = tf.exp(tf.div(-z, 2 * neg_rho))
            denom = 2 * np.pi * tf.multiply(s1s2, tf.sqrt(neg_rho))
            result = tf.div(result, denom)
            return result

        def get_lossfunc(z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr,
                         z_pen_logits, x1_data, x2_data, pen_data):
            """Returns a loss fn based on eq #26 of http://arxiv.org/abs/1308.0850."""
            result0 = tf_2d_normal(x1_data, x2_data, z_mu1, z_mu2, z_sigma1,
                                   z_sigma2, z_corr)
            epsilon = 1e-6
            result1 = tf.multiply(result0, z_pi)
            result1 = tf.reduce_sum(result1, 1, keep_dims=True)
            result1 = -tf.log(result1 + epsilon)  # avoid log(0)

            fs = 1.0 - pen_data[:, 2]  # use training data for this
            fs = tf.reshape(fs, [-1, 1])
            result1 = tf.multiply(result1, fs)

            result2 = tf.nn.softmax_cross_entropy_with_logits(
                labels=pen_data, logits=z_pen_logits)
            result2 = tf.reshape(result2, [-1, 1])
            if not self.hps.is_training:  # eval mode, mask eos columns
                result2 = tf.multiply(result2, fs)

            result = result1 + result2
            return result

        # below is where we need to do MDN splitting of distribution params
        def get_mixture_coef(output):
            """Returns the tf slices containing mdn dist params."""
            # This uses eqns 18 -> 23 of http://arxiv.org/abs/1308.0850.
            z = output
            z_pen_logits = z[:, 0:3]  # pen states
            z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = tf.split(
                z[:, 3:], 6, 1)

            # process output z's into MDN paramters

            # softmax all the pi's and pen states:
            z_pi = tf.nn.softmax(z_pi)
            z_pen = tf.nn.softmax(z_pen_logits)

            # exponentiate the sigmas and also make corr between -1 and 1.
            z_sigma1 = tf.exp(z_sigma1)
            z_sigma2 = tf.exp(z_sigma2)
            z_corr = tf.tanh(z_corr)

            r = [
                z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen,
                z_pen_logits
            ]
            return r

        out = get_mixture_coef(output)
        [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen,
         o_pen_logits] = out

        self.pi = o_pi
        self.mu1 = o_mu1
        self.mu2 = o_mu2
        self.sigma1 = o_sigma1
        self.sigma2 = o_sigma2
        self.corr = o_corr
        self.pen = o_pen  # state of the pen
        self.pen_logits = o_pen_logits  # state of the pen

        # reshape target data so that it is compatible with prediction shape
        target = tf.reshape(self.output_x, [-1, 5])
        [x1_data, x2_data, eos_data, eoc_data,
         cont_data] = tf.split(target, 5, 1)
        pen_data = tf.concat([eos_data, eoc_data, cont_data], 1)

        lossfunc = get_lossfunc(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr,
                                o_pen_logits, x1_data, x2_data, pen_data)

        self.r_cost = tf.reduce_mean(lossfunc)

        if self.hps.is_training:
            self.cost = self.r_cost + self.kl_cost * self.hps.kl_weight

            self.lr = tf.Variable(self.hps.learning_rate, trainable=False)
            optimizer = tf.train.AdamOptimizer(self.lr)

            if self.hps.kl_weight != 0:
                self.kl_weight = tf.Variable(self.hps.kl_weight_start,
                                             trainable=False)
                self.cost = self.r_cost + self.kl_cost * self.kl_weight
            else:
                self.kl_weight = tf.Variable(self.hps.kl_weight,
                                             trainable=False)
                self.cost = self.r_cost

            gvs = optimizer.compute_gradients(self.cost)
            g = self.hps.grad_clip
            capped_gvs = [(tf.clip_by_value(grad, -g, g), var)
                          for grad, var in gvs]
            self.train_op = optimizer.apply_gradients(
                capped_gvs, global_step=self.global_step, name='train_step')
示例#16
0
    def decoder(self, reuse):
        with tf.variable_scope("decode", reuse=reuse):

            if self.hps.conditional:
                self.batch_z = tf.concat([self.batch_z, self.c_expr], 1)
                pre_tile_y = tf.reshape(
                    self.batch_z,
                    [self.hps.batch_size, 1, self.hps.z_size + self.expr_num])
                overlay_x = tf.tile(pre_tile_y,
                                    [1, self.hps.max_seq_len, 1
                                     ])  #replicating input multiples times
                actual_input_x = tf.concat([self.input_x, overlay_x], 2)

            else:
                self.batch_z = tf.zeros((self.hps.batch_size, self.hps.z_size),
                                        dtype=tf.float32)
                self.batch_z = tf.concat([self.batch_z, self.c_expr], 1)
                actual_input_x = self.input_x

            self.num_mixture = self.hps.num_mixture  # Number of mixtures in Gaussian mixture model.

            n_out = (3 + self.num_mixture * 6)

            output_w = tf.get_variable('output_w',
                                       [self.hps.dec_rnn_size, n_out])
            output_b = tf.get_variable('output_b', [n_out])

            if self.hps.conditional:
                self.initial_state = tf.nn.tanh(
                    rnn.super_linear(self.batch_z,
                                     self.cell.state_size,
                                     init_w='gaussian',
                                     weight_start=0.001,
                                     input_size=self.hps.z_size +
                                     self.expr_num))

            else:
                self.initial_state = self.cell.zero_state(
                    batch_size=self.hps.batch_size, dtype=tf.float32)

            output, last_state = tf.nn.dynamic_rnn(
                self.cell,
                actual_input_x,
                initial_state=self.initial_state,
                time_major=False,
                swap_memory=True,
                dtype=tf.float32,
                scope='RNN')

            self.c_kl_batch_train = tf.zeros([], dtype=tf.float32)
            self.training_logits = output
            output = tf.reshape(self.training_logits,
                                [-1, self.hps.dec_rnn_size])
            output = tf.nn.xw_plus_b(output, output_w, output_b)
            self.final_state = last_state

            out = self.get_mixture_coef(output)

            [
                o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen,
                o_pen_logits
            ] = out

            self.pi = o_pi
            self.mu1 = o_mu1
            self.mu2 = o_mu2
            self.sigma1 = o_sigma1
            self.sigma2 = o_sigma2
            self.corr = o_corr
            self.pen_logits = o_pen_logits
            self.pen = o_pen