Ejemplo n.º 1
0
class ConvDRAW(object):
    def __init__(self, d, lr, lambda_z_wu, read_attn, write_attn, do_classify,
                 do_reconst):

        self.do_classify = do_classify
        """ flags for each regularizor """
        self.do_reconst = do_reconst
        self.read_attn = read_attn
        self.write_attn = write_attn
        """ dataset information """
        self.set_datainfo(d)
        """ external toolkits """
        self.ls = Layers()
        self.lf = LossFunctions(self.ls, self.d, self.encoder)
        self.ii = ImageInterface(_is_3d, self.read_attn, self.write_attn,
                                 GLIMPSE_SIZE_READ, GLIMPSE_SIZE_WRITE, _h, _w,
                                 _c)
        # for refference from get_loss_kl_draw()
        self.T = T
        self.L = L
        self.Z_SIZES = Z_SIZES
        """ placeholders defined outside"""
        self.lr = lr
        self.lambda_z_wu = lambda_z_wu
        """sequence of canvases """
        self.cs = [0] * T
        """ initialization """
        self.init_lstms()
        self.init_time_zero()
        """ workaround for variable_scope(reuse=True) """
        self.DO_SHARE = None

    def set_datainfo(self, d):
        self.d = d  # dataset manager
        global _b, _h, _w, _c, _img_size, _is_3d
        _b = d.batch_size
        _h = d.h
        _w = d.w
        _c = d.c
        _img_size = d.img_size
        _is_3d = d.is_3d

    def init_time_zero(self):

        self.cs[0] = tf.zeros((_b, _h, _w, _c)) if _is_3d else tf.zeros(
            (_b, _img_size))
        self.h_dec[0][0] = tf.zeros((_b, RNN_SIZES[0]))

    def init_lstms(self):

        h_enc, e_mus, e_logsigmas = [[0] * L] * (T + 1), [[0] * L] * (
            T + 1), [[0] * L] * (T + 1)  # q(z_i+1 | z_i), bottom-up inference
        h_dec, d_mus, d_logsigmas = [[0] * L] * (T + 1), [[0] * L] * (
            T + 1), [[0] * L] * (T + 1)  # q(z_i | .), bidirectional inference
        p_mus, p_logsigmas = [[0] * L] * (T + 1), [[0] * L] * (
            T + 1)  # p(z_i | z_i+1), top-down prior
        """ set-up LSTM cells """
        e_cells, e_states = [None] * L, [None] * L
        d_cells, d_states = [None] * L, [None] * L

        for l in range(L):
            e_cells[l] = tf.contrib.rnn.core_rnn_cell.LSTMCell(RNN_SIZES[l])
            d_cells[l] = tf.contrib.rnn.core_rnn_cell.LSTMCell(RNN_SIZES[l])
            e_states[l] = e_cells[l].zero_state(_b, tf.float32)
            d_states[l] = d_cells[l].zero_state(_b, tf.float32)
            """ set as standard Gaussian, N(0,I). """
            d_mus[0][l], d_logsigmas[0][l] = tf.zeros(
                (_b, Z_SIZES[l])), tf.zeros((_b, Z_SIZES[l]))
            p_mus[0][l], p_logsigmas[0][l] = tf.zeros(
                (_b, Z_SIZES[l])), tf.zeros((_b, Z_SIZES[l]))

        self.h_enc, self.e_mus, self.e_logsigmas = h_enc, e_mus, e_logsigmas
        self.h_dec, self.d_mus, self.d_logsigmas = h_dec, d_mus, d_logsigmas
        self.p_mus, self.p_logsigmas = p_mus, p_logsigmas
        self.e_cells, self.e_states = e_cells, e_states
        self.d_cells, self.d_states = d_cells, d_states
        self.z = [[0] * L] * (T + 1)

    ###########################################
    """            LSTM cells               """

    ###########################################
    def lstm_encode(self, state, x, l, is_train):

        scope = 'lstm_encode_' + str(l)
        x = tf.reshape(x, (_b, -1))
        if x.get_shape()[1] != RNN_SIZES[l]:
            print(scope, ':', x.get_shape()[1:], '=>', RNN_SIZES[l])
            x = self.ls.dense(scope, x, RNN_SIZES[l])

        return self.e_cells[l](x, state)

    def lstm_decode(self, state, x, l, is_train):

        scope = 'lstm_decode_' + str(l)
        x = tf.reshape(x, (_b, -1))
        if x.get_shape()[1] != RNN_SIZES[l]:
            print(scope, ':', x.get_shape()[1:], '=>', RNN_SIZES[l])
            x = self.ls.dense(scope, x, RNN_SIZES[l])

        return self.d_cells[l](x, state)

    ###########################################
    """             Encoder                 """

    ###########################################
    def encoder(self, x, t, is_train=True, do_update_bn=True):

        for l in range(L):
            scope = 'Encode_L' + str(l)
            with tf.variable_scope(scope, reuse=self.DO_SHARE):

                if l == 0:
                    x_hat = x - self.canvase_previous(t)

                    h_dec_lowest_prev = self.h_dec[
                        t - 1][0] if t == 0 else tf.zeros((_b, RNN_SIZES[0]))

                    input = self.ii.read(x, x_hat, h_dec_lowest_prev)
                else:
                    input = self.h_enc[t][l - 1]

                self.h_enc[t][l], self.e_states[l] = self.lstm_encode(
                    self.e_states[l], input, l, is_train)

                input = self.ls.dense(scope, self.h_enc[t][l], Z_SIZES[l] * 2)
                self.z[t][l], self.e_mus[t][l], self.e_logsigmas[t][
                    l] = self.ls.vae_sampler_w_feature_slice(
                        input, Z_SIZES[l])
        """ classifier """
        logit = self.ls.dense('top',
                              self.h_enc[t][-1],
                              self.d.l,
                              activation=tf.nn.elu)
        return logit

    ###########################################
    """             Decoder                 """

    ###########################################
    def decoder(self, t, is_train=True, do_update_bn=True):

        for l in range(L - 1, -1, -1):
            scope = 'Decoder_L' + str(l)
            with tf.variable_scope(scope, reuse=self.DO_SHARE):

                if l == L - 1:
                    input = self.z[t][l]
                else:
                    input = self.concat(self.z[t][l], self.h_dec[t][l + 1], l)

                self.h_dec[t][l], self.d_states[l] = self.lstm_decode(
                    self.d_states[l], input, l, is_train)
                """ go out to the input space """
                if l == 0:
                    # [ToDo] replace bellow reconstructor with conv-lstm
                    if _is_3d:

                        o = self.canvase_previous(t) + self.ii.write(
                            self.h_dec[t][l])
                        #if t == T-1: # for MNIST
                        o = tf.nn.sigmoid(o)
                        self.cs[t] = o
                    else:
                        self.cs[t] = tf.nn.sigmoid(
                            self.canvase_previous(t) +
                            self.ii.write(self.h_dec[t][l]))
        return self.cs[t]

    """ set prior after building the decoder """

    def prior(self, t):

        for l in range(L - 1, -1, -1):
            scope = 'Piror_L' + str(l)
            """ preparation for p_* for t+1 and d_* for t with the output from lstm-decoder"""
            if l != 0:
                input = self.ls.dense(scope, self.h_dec[t][l],
                                      Z_SIZES[l] * 2 + Z_SIZES[l - 1] * 2)
                self.p_mus[t + 1][l], self.p_logsigmas[t + 1][l], self.d_mus[
                    t][l], self.d_logsigmas[t][l] = self.ls.split(
                        input, 1, [Z_SIZES[l]] * 2 + [Z_SIZES[l - 1]] * 2)
            else:
                """ no one uses d_* """
                input = self.ls.dense(scope, self.h_dec[t][l], Z_SIZES[l] * 2)
                self.p_mus[t +
                           1][l], self.p_logsigmas[t + 1][l] = self.ls.split(
                               input, 1, [Z_SIZES[l]] * 2)
            """ setting p_mus[0][l] and p_logsigmas[0][l] """
            if t == 0:
                if l == L - 1:
                    """ has already been performed at init() """
                    pass
                else:
                    """ by using only decoder's top-down path as prior since p(z) of t-1 does not exist """
                    self.p_mus[t][l], self.p_logsigmas[t][l] = self.d_mus[t][
                        l + 1], tf.exp(self.d_logsigmas[t][l +
                                                           1])  # Eq.19 at t=0
            else:
                if l == L - 1:
                    """ has already been performed at t-1 """
                    pass
                else:
                    """ update p(z) of current t """
                    _, self.p_mus[t][l], self.p_logsigmas[t][
                        l] = self.ls.precision_weighted_sampler(
                            scope,
                            (self.p_mus[t][l], tf.exp(self.p_logsigmas[t][l])),
                            (self.d_mus[t][l + 1],
                             tf.exp(self.d_logsigmas[t][l + 1])))  # Eq.19

    ###########################################
    """           Build Graph               """

    ###########################################
    def build_graph_train(self, x_l, y_l, x, is_supervised=True):

        o = dict()  # output
        loss = 0
        logit_ls = []
        """ Build DRAW """
        for t in range(T):
            logit_ls.append(self.encoder(x, t))
            x_reconst = self.decoder(t)
            self.prior(t)

            if t == 0:
                self.DO_SHARE = DO_SHARE = True
                self.ii.set_do_share(DO_SHARE)
                self.ls.set_do_share(DO_SHARE)
        """ p(x|z) Reconstruction Loss """
        o['x'] = x
        o['cs'] = self.cs
        o['Lr'] = self.lf.get_loss_pxz(x_reconst, x, 'DiscretizedLogistic')
        loss += o['Lr']
        """ VAE KL-Divergence Loss """
        o['KL1'], o['KL2'], o['Lz'] = self.lf.get_loss_kl_draw(self)
        loss += self.lambda_z_wu * o['Lz']
        """ set losses """
        o['loss'] = loss
        self.o_train = o
        """ set optimizer """
        optimizer = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5)
        grads = optimizer.compute_gradients(loss)
        for i, (g, v) in enumerate(grads):
            if g is not None:
                #g = tf.Print(g, [g], "g %s = "%(v))
                grads[i] = (tf.clip_by_norm(g, 5), v)  # clip gradients
            else:
                print('g is None:', v)
                v = tf.Print(v, [v], "v = ", summarize=10000)
        self.op = optimizer.apply_gradients(grads)  # return train_op

    def build_graph_test(self, x_l, y_l, is_supervised=False):

        o = dict()  # output
        loss = 0
        logit_ls = []
        """ Build DRAW """
        for t in range(T):
            logit_ls.append(
                self.encoder(x_l, t, is_train=False, do_update_bn=False))
            x_reconst = self.decoder(t)
            self.prior(t)

            if t == 0:
                self.DO_SHARE = DO_SHARE = True
                self.ii.set_do_share(DO_SHARE)
                self.ls.set_do_share(DO_SHARE)
        """ classification loss """
        if is_supervised:
            o['Ly'], o['accur'] = self.lf.get_loss_pyx(logit_ls[-1], y_l)
            loss += o['Ly']
        """ for visualizationc """
        o['z'], o['y'] = logit_ls[-1], y_l
        """ set losses """
        o['loss'] = loss
        self.o_test = o

    ###########################################
    """             Utilities               """

    ###########################################
    def canvase_previous(self, t):
        if _is_3d:
            c_prev = tf.zeros((_b, _h, _w, _c)) if t == 0 else self.cs[t - 1]
        else:
            c_prev = tf.zeros((_b, _img_size)) if t == 0 else self.cs[t - 1]
        return c_prev

    def concat(self, x1, x2, l):
        if False:  # [ToDo]
            x1 = tf.reshape(x1, (_b, IMAGE_SIZES[l][0], IMAGE_SIZES[l][1], -1))
            x2 = tf.reshape(x2, (_b, IMAGE_SIZES[l][0], IMAGE_SIZES[l][1], -1))
            return tf.concat([x1, x2], 3)
        else:
            x1 = tf.reshape(x1, (_b, -1))
            x2 = tf.reshape(x2, (_b, -1))
            return tf.concat([x1, x2], 1)
Ejemplo n.º 2
0
class ImageInterface(object):

    def __init__(self, is_3d, is_read_attention, is_write_attention, read_n, write_n, h, w, c):
    
        """ to manage do_share flag inside Layers object, ImageInterface has Layers as its own property """
        self.do_share = False
        self.ls       = Layers()
        self.is_3d    = is_3d
        self.read_n   = read_n
        self.write_n  = write_n
        self.h = h
        self.w = w
        self.c = c

        if is_read_attention:
            self.read = self._read_attention
        else:
            self.read = self._read_no_attention
    
        if is_write_attention:
            self.write = self._write_attention
        else:
            self.write = self._write_no_attention
     
    def set_do_share(self, flag):
        self.do_share    = flag
        self.ls.set_do_share(flag)
        
    ###########################
    """       READER        """
    ###########################
    def _read_no_attention(self, x,x_hat, h_dec):
        _h,_w,_c = self.h, self.w, self.c
        if self.is_3d:
            # x is a raw image and x_hat is an error one, and eash is handled as a different channel,
            # so the shape of r and return are [-1, _h,_w,_c*2]

            USE_CONV_READ = False # 170720
            if USE_CONV_READ:
                scope = 'read_1'
                x = self.ls.conv2d(scope+'_1', x, 64, activation=tf.nn.elu)
                x = self.ls.max_pool(x)
                x = self.ls.conv2d(scope+'_2', x, 64, activation=tf.nn.elu)
                x = self.ls.max_pool(x)
                x = self.ls.conv2d(scope+'_3', x, 64, activation=tf.nn.elu)

                scope = 'read_hat_1'
                x_hat = self.ls.conv2d(scope+'_1', x_hat, 16, activation=tf.nn.elu)
                x_hat = self.ls.max_pool(x_hat)
                x_hat = self.ls.conv2d(scope+'_2', x_hat, 16, activation=tf.nn.elu)
                x_hat = self.ls.max_pool(x_hat)
                x_hat = self.ls.conv2d(scope+'_3', x_hat, 16, activation=tf.nn.elu)

                r = tf.concat([x,x_hat], 3)
                h_dec = tf.reshape( self.ls.dense(scope, h_dec, _h*_w*_c), [-1, int(_h/4), int(_w/4),_c*4*4])
                return tf.concat([r,h_dec], 3)
            elif False:

                scope = 'read_1'
                x = self.ls.conv2d(scope+'_1', x, 128, activation=tf.nn.elu)
                x = self.ls.conv2d(scope+'_2', x, 128, activation=tf.nn.elu)
                x = self.ls.conv2d(scope+'_3', x, 128, activation=tf.nn.elu)
                x = self.ls.max_pool(x)
                scope = 'read_2'
                x = self.ls.conv2d(scope+'_1', x, 256, activation=tf.nn.elu)
                x = self.ls.conv2d(scope+'_2', x, 256, activation=tf.nn.elu)
                x = self.ls.conv2d(scope+'_3', x, 256, activation=tf.nn.elu)
                x = self.ls.max_pool(x)
                scope = 'read_3'
                x = self.ls.conv2d(scope+'_1', x, 512, activation=tf.nn.elu)
                x = self.ls.conv2d(scope+'_2', x, 256, activation=tf.nn.elu, filter_size=(1,1))
                x = self.ls.conv2d(scope+'_3', x, 128, activation=tf.nn.elu, filter_size=(1,1))
                x = self.ls.conv2d(scope+'_4', x, 64, activation=tf.nn.elu, filter_size=(1,1))

                scope = 'read_hat_1'
                x_hat = self.ls.conv2d(scope+'_1', x_hat, 128, activation=tf.nn.elu)
                x_hat = self.ls.max_pool(x_hat)
                scope = 'read_hat_2'
                x_hat = self.ls.conv2d(scope+'_1', x_hat, 256, activation=tf.nn.elu)
                x_hat = self.ls.max_pool(x_hat)
                scope = 'read_hat_3'
                x_hat = self.ls.conv2d(scope+'_4', x_hat, 16, activation=tf.nn.elu, filter_size=(1,1))
                r = tf.concat([x,x_hat], 3)
                h_dec = tf.reshape( self.ls.dense(scope, h_dec, _h*_w*_c), [-1, int(_h/4), int(_w/4),_c*4*4])
                return tf.concat([r,h_dec], 3)
            else:
                r = tf.concat([x,x_hat], 3)
            USE_DEC_LOWEST_PREV = True
            if USE_DEC_LOWEST_PREV:
                # use decoder feedback as element-wise adding   
                # Eq.(21) in [Gregor, 2016]
                scope = 'read'
                USE_CONV = True
                if USE_CONV:
                    h_dec = tf.reshape( self.ls.dense(scope, h_dec, _h*_w*_c), [-1, _h,_w,_c])
                    h_dec = self.ls.conv2d("conv", h_dec, _c*2, activation=tf.nn.elu)
                    return r + h_dec
                else:
                    h_dec = tf.reshape( self.ls.dense(scope, h_dec, _h*_w*_c*2), [-1, _h,_w,_c*2])
                    return r + h_dec
            else:
                return r
        else:
            return tf.concat([x,x_hat], 1)
    
    def _read_attention( self, x, x_hat, h_dec ):
        _h,_w,_c = self.h, self.w, self.c
        N = self.read_n
        if self.is_3d:
            Fx,Fy,gamma = self._set_window("read", h_dec,N)
            # Fx is (?, 5, 32, 3)
            # gamma is (?, 3)
            def filter_img(img,Fx,Fy,gamma, N):
                # Fx and Fy are (?, 5, 32, 3)
                Fxt = tf.transpose(Fx,perm=[0,3,2,1])
                Fy  = tf.transpose(Fy,perm=[0,3,2,1])
                
                # img.get_shape() has already been (?, 32, 32, 3)
                img  = tf.transpose(img, perm=[0,3,2,1])
                # tf.matmul(img,Fxt) is (?, 3, 32, 5)
                img_Fxt = tf.matmul(img,Fxt)
                img_Fxt = tf.transpose(img_Fxt, perm=[0,1,3,2])
                # Fy: (?, 3, 32, 5)
                Fy  = tf.transpose(Fy,perm=[0,1,3,2])
                glimpse = tf.matmul(Fy, img_Fxt, transpose_b=True)
                # glimpse.get_shape() is (?, 3, 32, 32)
                glimpse = tf.transpose(glimpse, perm=[0,2,3,1])
                glimpse = tf.reshape(glimpse,[-1,N*N, _c])
    
                glimpse = tf.transpose(glimpse, perm=[0,2,1])
                gamma   = tf.reshape(gamma,[-1,1, _c])
                gamma   = tf.transpose(gamma,   perm=[0,2,1])
                o = glimpse*gamma
                o = tf.transpose(o, perm=[0,2,1])
                return o
            x = filter_img( x, Fx, Fy, gamma, N) # batch x (read_n*read_n)
            x_hat = filter_img( x_hat, Fx, Fy, gamma, N)
            x = tf.reshape(x, [-1, N,N,_c])
            x_hat = tf.reshape(x_hat, [-1, N,N,_c])
            return tf.concat([x,x_hat], 3)
        else:
            Fx,Fy,gamma = self._set_window("read", h_dec,N)
            # Fx: (?, 5, 32), gamma: (?, 1)
            def filter_img(img,Fx,Fy,gamma,N):
                #print('filter_img in is_image == False')
                Fxt = tf.transpose(Fx,perm=[0,2,1])
                img = tf.reshape(img,[-1,_w,_h])
                # Fxt : (?, 32, 5)
                # img : (?, 32, 32)
                glimpse = tf.matmul(Fy,tf.matmul(img,Fxt))
                glimpse = tf.reshape(glimpse,[-1,N*N])
                return glimpse*tf.reshape(gamma,[-1,1])
            x = filter_img( x, Fx, Fy, gamma, N) # batch x (read_n*read_n)
            x_hat = filter_img( x_hat, Fx, Fy, gamma, N)
            return tf.concat([x,x_hat], 1) # concat along feature axis
    
    
    ###########################
    """       WRITER        """
    ###########################
    def _write_no_attention(self, h):
        scope = "write"
        _h,_w,_c = self.h, self.w, self.c
        if self.is_3d:
            IS_SIMPLE_WRITE = True
            if IS_SIMPLE_WRITE :
                print('IS_SIMPLE_WRITE:', IS_SIMPLE_WRITE)  
                return tf.reshape( self.ls.dense(scope, h, _h*_w*_c, tf.nn.elu), [-1, _h, _w, _c])
            else:
                IS_CONV_LSTM = True
                if IS_CONV_LSTM :
                    raise NotImplementedError
                else:
                    activation = tf.nn.elu
                    print('h in write:', h) # h.shape is (_b, RNN_SIZES[0])
                    L = 1
                    h = tf.reshape( h, (-1, 2,2,64*3)) # should match to RNN_SIZES[0]
                    h = self.ls.deconv2d(scope+'_1', h, 64*2) # 4
                    h = activation(h)
                    L = 2
                    h = self.ls.deconv2d(scope+'_2', h, 16*3) # 8
                    h = activation(h)
                    h = PS(h, 4, color=True)
                    print('h in write:', h)
                return tf.reshape( h, [-1, _h, _w, _c])
        else:
            return self.ls.dense( scope,h, _h*_w*_c )
    
    def _write_attention(self, h_dec):
        scope = "writeW"
        N          = self.write_n
        write_size = N*N
        _h,_w,_c = self.h, self.w, self.c
        Fx, Fy, gamma = self._set_window("write", h_dec, N)
        if self.is_3d:
            # Fx and Fy are (?, 5, 32, 3), gamma is (?, 3)
            w = self.ls.dense( scope, h_dec, write_size*_c) # batch x (write_n*write_n) [ToDo] replace self.ls.dense with deconv
            w = tf.reshape(w,[tf.shape(h_dec)[0],N,N,_c])
            w = tf.transpose(w, perm=[0,3,1,2])
            Fyt = tf.transpose(Fx,perm=[0,3,2,1])
            Fx  = tf.transpose(Fx, perm=[0,3,1,2])
    
            w_Fx = tf.matmul(w, Fx)
            # w_Fx.get_shape() is (?, 3, 5, 32)
            w_Fx = tf.transpose(w_Fx, perm=[0,1,3,2])
    
            wr = tf.matmul(Fyt, w_Fx, transpose_b=True)
            wr = tf.reshape(wr,[tf.shape(h_dec)[0],_w*_h, _c])
            wr = tf.transpose(wr, perm=[0,2,1])
            inv_gamma   = tf.reshape(1.0/gamma,[-1,1, _c])
            inv_gamma   = tf.transpose(inv_gamma, perm=[0,2,1])
            o = wr*inv_gamma
            o = tf.transpose(o, perm=[0,2,1])
            o = tf.reshape(o, [tf.shape(h_dec)[0], _w, _h, _c])
            return o
        else:
            w = self.ls.dense( scope, h_dec,write_size) # batch x (write_n*write_n)
            w = tf.reshape(w,[tf.shape(h_dec)[0],N,N])
            Fyt = tf.transpose(Fy,perm=[0,2,1])
            wr = tf.matmul(Fyt,tf.matmul(w,Fx))
            wr = tf.reshape(wr,[tf.shape(h_dec)[0],_w*_h])
            return wr*tf.reshape(1.0/gamma,[-1,1])

    ###########################
    """  Filter Functions   """
    ###########################
    def _filterbank(self, gx, gy, sigma2,delta, N):
        if self.is_3d:

            _h,_w,_c = self.h, self.w, self.c

            # gx and delta are (?,3)
            grid_i = tf.reshape(tf.cast(tf.range(N*_c), tf.float32), [1, -1, _c])
            mu_x = gx + (grid_i - N / 2 - 0.5) * delta # eq 19
            mu_y = gy + (grid_i - N / 2 - 0.5) * delta # eq 20
            # shape : [1, N, _c]
            w = tf.reshape( tf.cast( tf.range(_w*_c), tf.float32), [1, 1, -1, _c])
            h = tf.reshape( tf.cast( tf.range(_h*_c), tf.float32), [1, 1, -1, _c])
            mu_x = tf.reshape(mu_x, [-1, N, 1, _c])
            mu_y = tf.reshape(mu_y, [-1, N, 1, _c])
            sigma2 = tf.reshape(sigma2, [-1, 1, 1, _c])
            Fx = tf.exp(-tf.square((w - mu_x) / (2*sigma2))) # 2*sigma2?
            Fy = tf.exp(-tf.square((h - mu_y) / (2*sigma2))) # batch x N x B
            # normalize, sum over A and B dims
            Fx=Fx/tf.maximum(tf.reduce_sum(Fx,2,keep_dims=True),eps)
            Fy=Fy/tf.maximum(tf.reduce_sum(Fy,2,keep_dims=True),eps)
            return Fx,Fy
    
        else:
            grid_i = tf.reshape(tf.cast(tf.range(N), tf.float32), [1, -1])
            # gx, delta and mu_x are (?, 1), and grid_i is (1, 5))
            mu_x = gx + (grid_i - N / 2 - 0.5) * delta # eq 19
            mu_y = gy + (grid_i - N / 2 - 0.5) * delta # eq 20
            h = tf.reshape(tf.cast(tf.range(_h), tf.float32), [1, 1, -1])
            w = tf.reshape(tf.cast(tf.range(_w), tf.float32), [1, 1, -1])
            mu_x = tf.reshape(mu_x, [-1, N, 1])
            mu_y = tf.reshape(mu_y, [-1, N, 1])
            sigma2 = tf.reshape(sigma2, [-1, 1, 1])
            Fx = tf.exp(-tf.square((w - mu_x) / (2*sigma2))) # 2*sigma2?
            Fy = tf.exp(-tf.square((h - mu_y) / (2*sigma2))) # batch x N x B
            # normalize, sum over A and B dims
            Fx=Fx/tf.maximum(tf.reduce_sum(Fx,2,keep_dims=True),eps)
            Fy=Fy/tf.maximum(tf.reduce_sum(Fy,2,keep_dims=True),eps)
            return Fx,Fy
    
    def _set_window(self, scope, h_dec,N):
        if self.is_3d:
            _h,_w,_c = self.h, self.w, self.c
            # get five (BATCH_SIZE, _c) matrixes
            gx_, gy_, log_sigma2, log_delta, log_gamma = self.ls.split( self.ls.dense(scope, h_dec, _c*5), 1, [_c]*5)
            gx_ = tf.reshape(gx_, [-1,1,_c])
            gy_ = tf.reshape(gy_, [-1,1,_c])
            log_sigma2 = tf.reshape(log_sigma2, [-1,1,_c])
            log_delta = tf.reshape(log_delta, [-1,1,_c])
            log_gamma = tf.reshape(log_gamma, [-1,1,_c])
            gx = (_w + 1)/2*(gx_+1)
            gy = (_h + 1)/2*(gy_+1)
            sigma2 = tf.exp(log_sigma2)
            delta = ( max(_h, _w) -1 ) / ( N -1 ) * tf.exp( log_delta ) # batch x N
            return self._filterbank( gx, gy, sigma2, delta, N) + ( tf.exp(log_gamma),)
        else:
            params = self.ls.dense(scope, h_dec,5)
            gx_,gy_,log_sigma2,log_delta,log_gamma=tf.split(value=params, num_or_size_splits=5, axis=1)
            gx=(_w + 1)/2*(gx_+1)
            gy=(_h + 1)/2*(gy_+1)
            sigma2=tf.exp(log_sigma2)
            delta=(max(_h, _w)-1)/(N-1)*tf.exp(log_delta) # batch x N
            return self._filterbank(gx,gy,sigma2,delta,N)+(tf.exp(log_gamma),)