Ejemplo n.º 1
0
    def _module_fn():
        """
        Function building the module
        """

        feature_layer = tf.placeholder(
            tf.float32,
            shape=[None, None, None, None, nchannels],
            name='input')
        obs_layer = tf.placeholder(tf.float32,
                                   shape=[None, None, None, None, n_y],
                                   name='observations')

        # Builds the neural network
        net = slim.conv3d(feature_layer,
                          16,
                          5,
                          activation_fn=tf.nn.leaky_relu,
                          padding='valid')
        #net = wide_resnet(feature_layer, 8, activation_fn=tf.nn.leaky_relu, is_training=is_training)
        net = wide_resnet(net,
                          16,
                          activation_fn=tf.nn.leaky_relu,
                          keep_prob=dropout,
                          is_training=is_training)
        net = wide_resnet(net,
                          32,
                          activation_fn=tf.nn.leaky_relu,
                          keep_prob=dropout,
                          is_training=is_training)
        net = wide_resnet(net,
                          32,
                          activation_fn=tf.nn.leaky_relu,
                          keep_prob=dropout,
                          is_training=is_training)
        net = slim.conv3d(net, 32, 3, activation_fn=tf.nn.tanh)

        # Define the probabilistic layer
        #out_rate = slim.conv3d(net, 1, 1, activation_fn=tf.nn.relu)
        #out_rate = tf.math.add(out_rate, 1e-6, name='rate')
        net = slim.conv3d(net, n_mixture * n_y, 1, activation_fn=tf.nn.relu)
        cube_size = tf.shape(obs_layer)[1]
        out_rate = tf.reshape(net, [-1, cube_size, cube_size, cube_size, n_y])
        out_rate = tf.math.add(out_rate, 1e-6, name='rate')
        pdf = tfd.Poisson(rate=out_rate)

        # Define a function for sampling, and a function for estimating the log likelihood
        sample = tf.squeeze(pdf.sample())
        loglik = pdf.log_prob(obs_layer)
        hub.add_signature(inputs={
            'features': feature_layer,
            'labels': obs_layer
        },
                          outputs={
                              'sample': sample,
                              'loglikelihood': loglik
                          })
Ejemplo n.º 2
0
def module_fn():
    '''Define network here'''
    x = tf.placeholder(
        tf.float32,
        shape=[None, cube_sizeft, cube_sizeft, cube_sizeft, nchannels],
        name='input')
    y = tf.placeholder(tf.float32,
                       shape=[None, cube_size, cube_size, cube_size, 1],
                       name='labels')
    keepprob = tf.placeholder(tf.float32, name='keepprob')
    print('Shape of training and testing data is : ',
          x.shape,
          y.shape,
          file=fname)

    #
    wregwt, bregwt = 0.001, 0.001
    if wregwt: wreg = slim.regularizers.l2_regularizer(wregwt)
    else: wreg = None
    if bregwt: breg = slim.regularizers.l2_regularizer(bregwt)
    else: breg = None
    print('Regularizing weights are : ', wregwt, bregwt, file=fname)
    #
    net = slim.conv3d(x,
                      16,
                      5,
                      activation_fn=tf.nn.leaky_relu,
                      padding='valid',
                      weights_regularizer=wreg,
                      biases_regularizer=breg)
    net = wide_resnet(net,
                      32,
                      keep_prob=keepprob,
                      activation_fn=tf.nn.leaky_relu)
    net = wide_resnet(net,
                      64,
                      keep_prob=keepprob,
                      activation_fn=tf.nn.leaky_relu)
    net = wide_resnet(net,
                      32,
                      keep_prob=keepprob,
                      activation_fn=tf.nn.leaky_relu)
    net = wide_resnet(net,
                      16,
                      keep_prob=keepprob,
                      activation_fn=tf.nn.leaky_relu)
    net = slim.conv3d(net, 1, 3, activation_fn=None)
    net = tf.identity(net, name='logits')
    pred = tf.nn.sigmoid(net, name='prediction')
    #
    inputs = dict(input=x, label=y, keepprob=keepprob)
    outputs = dict(default=net, prediction=pred)
    hub.add_signature(inputs=inputs, outputs=outputs)
Ejemplo n.º 3
0
    def _module_fn():
        """
        Function building the module
        """

        feature_layer = tf.placeholder(
            tf.float32,
            shape=[None, None, None, None, nchannels],
            name='input')
        obs_layer = tf.placeholder(tf.float32,
                                   shape=[None, None, None, None, n_y],
                                   name='observations')

        # Builds the neural network
        net = slim.conv3d(feature_layer,
                          16,
                          5,
                          activation_fn=tf.nn.leaky_relu,
                          padding='valid')
        #net = wide_resnet(feature_layer, 8, activation_fn=tf.nn.leaky_relu, is_training=is_training)
        net = wide_resnet(net,
                          16,
                          activation_fn=tf.nn.leaky_relu,
                          keep_prob=dropout,
                          is_training=is_training)
        net = wide_resnet(net,
                          32,
                          activation_fn=tf.nn.leaky_relu,
                          keep_prob=dropout,
                          is_training=is_training)
        net = wide_resnet(net,
                          32,
                          activation_fn=tf.nn.leaky_relu,
                          keep_prob=dropout,
                          is_training=is_training)
        net = slim.conv3d(net, 32, 3, activation_fn=tf.nn.tanh)

        # Define the probabilistic layer
        net = slim.conv3d(net, n_mixture * 3 * n_y, 1, activation_fn=None)
        cube_size = tf.shape(obs_layer)[1]
        net = tf.reshape(
            net, [-1, cube_size, cube_size, cube_size, n_y, n_mixture * 3])
        #         net = tf.reshape(net, [None, None, None, None, n_y, n_mixture*3])
        loc, unconstrained_scale, logits = tf.split(net,
                                                    num_or_size_splits=3,
                                                    axis=-1)
        scale = tf.nn.softplus(unconstrained_scale)

        # Form mixture of discretized logistic distributions. Note we shift the
        # logistic distribution by -0.5. This lets the quantization capture "rounding"
        # intervals, `(x-0.5, x+0.5]`, and not "ceiling" intervals, `(x-1, x]`.
        discretized_logistic_dist = tfd.QuantizedDistribution(
            distribution=tfd.TransformedDistribution(
                distribution=tfd.Logistic(loc=loc, scale=scale),
                bijector=tfb.AffineScalar(shift=-0.5)),
            low=0.,
            high=2.**4 - 1)

        mixture_dist = tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(logits=logits),
            components_distribution=discretized_logistic_dist)

        # Define a function for sampling, and a function for estimating the log likelihood
        sample = tf.squeeze(mixture_dist.sample())
        loglik = mixture_dist.log_prob(obs_layer)
        hub.add_signature(inputs={
            'features': feature_layer,
            'labels': obs_layer
        },
                          outputs={
                              'sample': sample,
                              'loglikelihood': loglik
                          })
Ejemplo n.º 4
0
    def _module_fn():
        """                                                                                                                     
        Function building the module                                                                                            
        """

        feature_layer = tf.placeholder(
            tf.float32,
            shape=[None, None, None, None, nchannels],
            name='input')
        obs_layer = tf.placeholder(tf.float32,
                                   shape=[None, None, None, None, n_y],
                                   name='observations')

        conditional_im = wide_resnet(feature_layer,
                                     16,
                                     activation_fn=tf.nn.leaky_relu,
                                     keep_prob=dropout,
                                     is_training=is_training)
        conditional_im = wide_resnet(conditional_im,
                                     16,
                                     activation_fn=tf.nn.leaky_relu,
                                     keep_prob=dropout,
                                     is_training=is_training)
        conditional_im = wide_resnet(conditional_im,
                                     1,
                                     activation_fn=tf.nn.leaky_relu,
                                     keep_prob=dropout,
                                     is_training=is_training)
        conditional_im = tf.concat((feature_layer, conditional_im), -1)

        # Builds the neural network
        ul = [[obs_layer]]
        for i in range(10):
            ul.append(
                PixelCNN3Dlayer(i,
                                ul[i],
                                f_map=f_map,
                                full_horizontal=True,
                                h=None,
                                conditional_im=conditional_im,
                                cfilter_size=cfilter_size,
                                gatedact='sigmoid'))

        h_stack_in = ul[-1][-1]

        with tf.variable_scope("fc_1"):
            fc1 = GatedCNN([1, 1, 1, 1],
                           h_stack_in,
                           orientation=None,
                           gated=False,
                           mask='b').output()

        with tf.variable_scope("fc_2"):
            fc2 = GatedCNN([1, 1, 1, n_mixture * 3 * n_y],
                           fc1,
                           orientation=None,
                           gated=False,
                           mask='b',
                           activation=False).output()

        cube_size = tf.shape(obs_layer)[1]
        net = tf.reshape(
            fc2, [-1, cube_size, cube_size, cube_size, n_y, n_mixture * 3])

        loc, unconstrained_scale, logits = tf.split(net,
                                                    num_or_size_splits=3,
                                                    axis=-1)
        scale = tf.nn.softplus(unconstrained_scale) + 1e-3

        # Form mixture of discretized logistic distributions. Note we shift the
        # logistic distribution by -0.5. This lets the quantization capture "rounding"
        # intervals, `(x-0.5, x+0.5]`, and not "ceiling" intervals, `(x-1, x]`.
        #         discretized_logistic_dist = tfd.QuantizedDistribution(
        #             distribution=tfd.TransformedDistribution(
        #                 distribution=tfd.Logistic(loc=loc, scale=scale),
        #                 bijector=tfb.AffineScalar(shift=-0.5)),
        #             low=0.,
        #             high=2.**3-1)

        mixture_dist = tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(logits=logits),
            components_distribution=tfd.Normal(loc, scale))

        # Define a function for sampling, and a function for estimating the log likelihood
        #sample = tf.squeeze(mixture_dist.sample())
        sample = mixture_dist.sample()
        loglik = mixture_dist.log_prob(obs_layer)
        hub.add_signature(inputs={
            'features': feature_layer,
            'labels': obs_layer
        },
                          outputs={
                              'sample': sample,
                              'loglikelihood': loglik,
                              'loc': loc,
                              'scale': scale,
                              'logits': logits
                          })
Ejemplo n.º 5
0
                  activation_fn=tf.nn.leaky_relu,
                  padding='valid',
                  weights_regularizer=wreg,
                  biases_regularizer=breg)
#net = tf.nn.dropout(net, keep_prob=keepprob)
#net = slim.conv3d(net, 32, 5, activation_fn=tf.nn.leaky_relu, padding='valid', weights_regularizer=wreg, biases_regularizer=breg)
#net = tf.nn.dropout(net, keep_prob=keepprob)
#net = slim.conv3d(net, 128, 5, activation_fn=tf.nn.leaky_relu, weights_regularizer=wreg, biases_regularizer=breg)
#net = tf.nn.netout(net, keep_prob=keepprob)
#net = slim.conv3d(net, 64, 3, activation_fn=tf.nn.leaky_relu, weights_regularizer=wreg, biases_regularizer=breg)
#net = tf.nn.netout(net, keep_prob=keepprob)
#net = slim.conv3d(net, 16, 3, activation_fn=tf.nn.leaky_relu, weights_regularizer=wreg, biases_regularizer=breg)
#net = tf.nn.netout(net, keep_prob=keepprob)
#net = slim.conv3d(net, 1, 3, activation_fn=None, weights_regularizer=wreg, biases_regularizer=breg)
#pred = tf.nn.sigmoid(net, name='prediction')
net = wide_resnet(net, 32, keep_prob=keepprob, activation_fn=tf.nn.leaky_relu)
net = wide_resnet(net, 64, keep_prob=keepprob, activation_fn=tf.nn.leaky_relu)
net = wide_resnet(net, 32, keep_prob=keepprob, activation_fn=tf.nn.leaky_relu)
net = wide_resnet(net, 16, keep_prob=keepprob, activation_fn=tf.nn.leaky_relu)
net = slim.conv3d(net, 1, 3, activation_fn=None)
pred = tf.nn.sigmoid(net, name='prediction')
#####
loss = tf.losses.sigmoid_cross_entropy(y, net)
optimizer = tf.train.AdamOptimizer(learning_rate=lr, name='optimizer')

opt_op = optimizer.minimize(loss, name='minimize')

#############################
###Train

sess = tf.Session()
Ejemplo n.º 6
0
tf.reset_default_graph()

x = tf.placeholder(tf.float32, shape=[None, cube_sizeft, cube_sizeft, cube_sizeft, nchannels], name='input')
ycen = tf.placeholder(tf.float32, shape=[None, cube_size, cube_size, cube_size, 1], name='centrals')
ysat = tf.placeholder(tf.float32, shape=[None, cube_size, cube_size, cube_size, 1], name='satellites')
ywt = tf.placeholder(tf.float32, shape=[None, cube_size, cube_size, cube_size], name='weights')
##x = tf.placeholder(tf.float32, shape=[None, cube_sizeft, cube_sizeft, cube_sizeft, nchannels], name='input')
##y = tf.placeholder(tf.float32, shape=[None, cube_size, cube_size, cube_size, 1], name='labels')
##m = tf.placeholder(tf.float32, shape=[None, cube_size, cube_size, cube_size], name='mask')

keepprob = tf.placeholder(tf.float32, name='keepprob')
lr = tf.placeholder(tf.float32, name='learningrate')

net = slim.conv3d(x, 16, 5, activation_fn=tf.nn.leaky_relu, padding='valid')
# net = slim.conv3d(net, 32, 5, activation_fn=None, padding='valid')
net = wide_resnet(net, 32, activation_fn=tf.nn.leaky_relu, keep_prob=keepprob)
net = wide_resnet(net, 32, activation_fn=tf.nn.leaky_relu, keep_prob=keepprob)
net = wide_resnet(net, 32, activation_fn=tf.nn.leaky_relu, keep_prob=keepprob)
net = slim.conv3d(net, 64, 1, activation_fn=tf.nn.relu6)
net = tf.nn.dropout(net, keep_prob=keepprob)

# Create mixture components from network output
#out_rate = slim.conv3d(net, 1, 3, activation_fn=tf.nn.relu)
out_rate = slim.conv3d(net, 1, 1, activation_fn=tf.nn.relu, 
                       #weights_initializer=tf.initializers.random_normal(mean=1, stddev=0.25))
                       weights_initializer=tf.initializers.random_uniform(minval=0.01, maxval=1))
out_rate = tf.math.add(out_rate, 1e-8, name='rate')

# Predicted mask
out_mask = slim.conv3d(net, 1, 1, activation_fn=None)
pred_mask = tf.nn.sigmoid(out_mask, name='prediction')
Ejemplo n.º 7
0
    def _module_fn():
        """
        Function building the module
        """

        feature_layer = tf.placeholder(
            tf.float32,
            shape=[None, None, None, None, nchannels],
            name='input')
        obs_layer = tf.placeholder(tf.float32,
                                   shape=[None, None, None, None, n_y],
                                   name='observations')

        # Builds the neural network

        if pad == 0:
            d00 = slim.conv3d(feature_layer,
                              fsize,
                              5,
                              activation_fn=tf.nn.leaky_relu,
                              padding='same')
        elif pad == 2:
            d00 = slim.conv3d(feature_layer,
                              fsize,
                              5,
                              activation_fn=tf.nn.leaky_relu,
                              padding='valid')
        if pad == 4:
            d00 = slim.conv3d(feature_layer,
                              fsize,
                              5,
                              activation_fn=tf.nn.leaky_relu,
                              padding='valid')
            d00 = slim.conv3d(d00,
                              fsize * 2,
                              5,
                              activation_fn=tf.nn.leaky_relu,
                              padding='valid')
##        #downsample
##        dd = [[d00]]
##        cfsize = fsize
##        for i in range(nsub):
##            d0 = dd[-1][-1]
##            d1 = wide_resnet(d0, cfsize, activation_fn=tf.nn.leaky_relu)
##            d2 = wide_resnet(d1, cfsize, activation_fn=tf.nn.leaky_relu)
##            dsub = slim.max_pool3d(d2, kernel_size=3, stride=2, padding='SAME')
##            dd.append([d1, d2, dsub])
##            cfsize  *= 2
##
##        #lower layer
##        d0 = dd[-1][-1]
##        d1 = wide_resnet(d0, cfsize, activation_fn=tf.nn.leaky_relu)
##        d2 = wide_resnet(d1, cfsize, activation_fn=tf.nn.leaky_relu)
##
##        up = [[d1, d2]]
##        #upsample
##        for i in range(nsub):
##            cfsize = cfsize // 2
##            usub = up[-1][-1]
##            dup = dd.pop()
##            u0 = dynamic_deconv3d('up%d'%i, usub, shape=[3,3,3,cfsize], activation=tf.nn.leaky_relu)
##            #u0 = slim.conv3d_transpose(usub, fsize, kernel_size=3, stride=2)
##            uc = tf.concat([u0, dup[1]], axis=-1)
##            u1 = wide_resnet(uc, cfsize, activation_fn=tf.nn.leaky_relu)
##            u2 = wide_resnet(u1, cfsize, activation_fn=tf.nn.leaky_relu)
##            up.append([u0, u1, u1c, u2])
##
##        u0 = up[-1][-1]
##        net = slim.conv3d(u0, 1, 3, activation_fn=tf.nn.tanh)
##
#downsample #restructure code while doubling filter size
        cfsize = fsize
        d1 = wide_resnet(d00, cfsize, activation_fn=tf.nn.leaky_relu)
        d2 = wide_resnet(d1, cfsize, activation_fn=tf.nn.leaky_relu)
        dd = [d2]
        for i in range(nsub):
            cfsize *= 2
            print(i, cfsize)
            dsub = slim.max_pool3d(dd[-1],
                                   kernel_size=3,
                                   stride=2,
                                   padding='SAME')
            d1 = wide_resnet(dsub, cfsize, activation_fn=tf.nn.leaky_relu)
            d2 = wide_resnet(d1, cfsize, activation_fn=tf.nn.leaky_relu)
            dd.append(d2)

        print(len(dd))
        #upsample
        usub = dd.pop()
        for i in range(nsub):
            u0 = dynamic_deconv3d('up%d' % i,
                                  usub,
                                  shape=[3, 3, 3, cfsize],
                                  activation=tf.identity)
            cfsize = cfsize // 2
            print(i, cfsize)
            u0 = slim.conv3d(u0,
                             cfsize,
                             1,
                             activation_fn=tf.identity,
                             padding='same')
            #u0 = slim.conv3d_transpose(usub, fsize, kernel_size=3, stride=2)
            uc = tf.concat([u0, dd.pop()], axis=-1)
            u1 = wide_resnet(uc, cfsize, activation_fn=tf.nn.leaky_relu)
            u2 = wide_resnet(u1, cfsize, activation_fn=tf.nn.leaky_relu)
            usub = u2

        print(len(dd))
        net = slim.conv3d(usub, 1, 3, activation_fn=tf.nn.tanh)

        # Define the probabilistic layer
        net = slim.conv3d(net, n_mixture * 3 * n_y, 1, activation_fn=None)
        cube_size = tf.shape(obs_layer)[1]
        net = tf.reshape(
            net, [-1, cube_size, cube_size, cube_size, n_y, n_mixture * 3])
        #         net = tf.reshape(net, [None, None, None, None, n_y, n_mixture*3])
        loc, unconstrained_scale, logits = tf.split(net,
                                                    num_or_size_splits=3,
                                                    axis=-1)
        scale = tf.nn.softplus(unconstrained_scale) + 1e-3

        # Form mixture of discretized logistic distributions. Note we shift the
        # logistic distribution by -0.5. This lets the quantization capture "rounding"
        # intervals, `(x-0.5, x+0.5]`, and not "ceiling" intervals, `(x-1, x]`.
        if distribution == 'logistic':
            discretized_logistic_dist = tfd.QuantizedDistribution(
                distribution=tfd.TransformedDistribution(
                    distribution=tfd.Logistic(loc=loc, scale=scale),
                    bijector=tfb.AffineScalar(shift=-0.5)),
                low=0.,
                high=2.**3 - 1)

            mixture_dist = tfd.MixtureSameFamily(
                mixture_distribution=tfd.Categorical(logits=logits),
                components_distribution=discretized_logistic_dist)

        elif distribution == 'normal':

            mixture_dist = tfd.MixtureSameFamily(
                mixture_distribution=tfd.Categorical(logits=logits),
                components_distribution=tfd.Normal(loc=loc, scale=scale))

        # Define a function for sampling, and a function for estimating the log likelihood
        #sample = tf.squeeze(mixture_dist.sample())
        sample = mixture_dist.sample()
        loglik = mixture_dist.log_prob(obs_layer)
        hub.add_signature(inputs={
            'features': feature_layer,
            'labels': obs_layer
        },
                          outputs={
                              'sample': sample,
                              'loglikelihood': loglik,
                              'loc': loc,
                              'scale': scale,
                              'logits': logits
                          })
Ejemplo n.º 8
0
    def _module_fn():
        """
        Function building the module
        """

        feature_layer = tf.placeholder(
            tf.float32,
            shape=[None, None, None, None, nchannels],
            name='input')
        obs_layer = tf.placeholder(tf.float32,
                                   shape=[None, None, None, None, n_y],
                                   name='observations')

        # Builds the neural network
        net = slim.conv3d(feature_layer,
                          16,
                          5,
                          activation_fn=tf.nn.leaky_relu,
                          padding='same')
        #net = wide_resnet(feature_layer, 8, activation_fn=tf.nn.leaky_relu, is_training=is_training)
        net = wide_resnet(net,
                          16,
                          activation_fn=tf.nn.leaky_relu,
                          keep_prob=dropout,
                          is_training=is_training)
        net = wide_resnet(net,
                          32,
                          activation_fn=tf.nn.leaky_relu,
                          keep_prob=dropout,
                          is_training=is_training)
        net = wide_resnet(net,
                          32,
                          activation_fn=tf.nn.leaky_relu,
                          keep_prob=dropout,
                          is_training=is_training)
        net = slim.conv3d(net, 32, 3, activation_fn=tf.nn.leaky_relu)

        # Define the probabilistic layer
        net = slim.conv3d(net,
                          3 * n_mixture * nchannels,
                          1,
                          activation_fn=None)
        cube_size = tf.shape(obs_layer)[1]
        net = tf.reshape(
            net,
            [-1, cube_size, cube_size, cube_size, nchannels, n_mixture * 3])

        logits, loc, unconstrained_scale = tf.split(net,
                                                    num_or_size_splits=3,
                                                    axis=-1)
        print('\nloc :\n', loc)
        scale = tf.nn.softplus(unconstrained_scale[...]) + 1e-3

        distribution = tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(logits=logits[...]),
            #components_distribution=tfd.MultivariateNormalDiag(loc=loc[...,0], scale_diag=scale))
            components_distribution=tfd.Normal(loc=loc[...], scale=scale))
        print('\ngmm\n', distribution)

        # Define a function for sampling, and a function for estimating the log likelihood
        if log:
            print('Logged it')
            sample = tf.exp(distribution.sample()) - logoffset
            print('\ninf dist sample :\n', distribution.sample())
            logfeature = tf.log(tf.add(logoffset, obs_layer), 'logfeature')
            print('\nlogfeature :\n', logfeature)
            prob = distribution.prob(logfeature[...])
            loglik = distribution.log_prob(logfeature[...])
        else:
            print('UnLogged it')
            sample = distribution.sample()
            print('\ninf dist sample :\n', distribution.sample())
            loglik = distribution.log_prob(obs_layer[...])

        hub.add_signature(inputs={
            'features': feature_layer,
            'labels': obs_layer
        },
                          outputs={
                              'sample': sample,
                              'loglikelihood': loglik,
                              'sigma': scale,
                              'mean': loc,
                              'logits': logits
                          })
Ejemplo n.º 9
0
    def _module_fn():
        """
        Function building the module
        """

        feature_layer = tf.placeholder(
            tf.float32,
            shape=[None, None, None, None, nchannels],
            name='input')
        obs_layer = tf.placeholder(tf.float32,
                                   shape=[None, None, None, None, n_y],
                                   name='observations')
        mask_layer = tf.clip_by_value(obs_layer, 0, 0.001) * 1000
        #
        # Builds the neural network
        if pad == 0:
            net = slim.conv3d(feature_layer,
                              16,
                              5,
                              activation_fn=tf.nn.leaky_relu,
                              padding='same')
        elif pad == 2:
            net = slim.conv3d(feature_layer,
                              16,
                              5,
                              activation_fn=tf.nn.leaky_relu,
                              padding='valid')
        #net = wide_resnet(feature_layer, 8, activation_fn=tf.nn.leaky_relu, is_training=is_training)
        net = wide_resnet(net,
                          16,
                          activation_fn=tf.nn.leaky_relu,
                          keep_prob=dropout,
                          is_training=is_training)
        net = wide_resnet(net,
                          32,
                          activation_fn=tf.nn.leaky_relu,
                          keep_prob=dropout,
                          is_training=is_training)
        net = wide_resnet(net,
                          32,
                          activation_fn=tf.nn.leaky_relu,
                          keep_prob=dropout,
                          is_training=is_training)
        if distribution == 'logistic':
            net = slim.conv3d(net, 32, 3, activation_fn=tf.nn.tanh)
        else:
            net = slim.conv3d(net, 32, 3, activation_fn=tf.nn.leaky_relu)

        #Predicted mask
        masknet = slim.conv3d(net, 8, 1, activation_fn=tf.nn.leaky_relu)
        out_mask = slim.conv3d(masknet, 1, 1, activation_fn=None)
        pred_mask = tf.nn.sigmoid(out_mask)

        # Define the probabilistic layer
        likenet = slim.conv3d(net, 64, 1, activation_fn=tf.nn.leaky_relu)
        net = slim.conv3d(likenet, n_mixture * 3 * n_y, 1, activation_fn=None)
        cube_size = tf.shape(obs_layer)[1]
        net = tf.reshape(
            net, [-1, cube_size, cube_size, cube_size, n_y, n_mixture * 3])
        #         net = tf.reshape(net, [None, None, None, None, n_y, n_mixture*3])
        loc, unconstrained_scale, logits = tf.split(net,
                                                    num_or_size_splits=3,
                                                    axis=-1)
        scale = tf.nn.softplus(unconstrained_scale) + 1e-3

        # Form mixture of discretized logistic distributions. Note we shift the
        # logistic distribution by -0.5. This lets the quantization capture "rounding"
        # intervals, `(x-0.5, x+0.5]`, and not "ceiling" intervals, `(x-1, x]`.
        if distribution == 'logistic':
            discretized_logistic_dist = tfd.QuantizedDistribution(
                distribution=tfd.TransformedDistribution(
                    distribution=tfd.Logistic(loc=loc, scale=scale),
                    bijector=tfb.AffineScalar(shift=-0.5)),
                low=0.,
                high=2.**3 - 1)

            mixture_dist = tfd.MixtureSameFamily(
                mixture_distribution=tfd.Categorical(logits=logits),
                components_distribution=discretized_logistic_dist)

        elif distribution == 'normal':

            mixture_dist = tfd.MixtureSameFamily(
                mixture_distribution=tfd.Categorical(logits=logits),
                components_distribution=tfd.Normal(loc=loc, scale=scale))

        # Define a function for sampling, and a function for estimating the log likelihood
        #sample = tf.squeeze(mixture_dist.sample())
        rawsample = mixture_dist.sample()
        sample = rawsample * pred_mask
        rawloglik = mixture_dist.log_prob(obs_layer)
        print(rawloglik)
        print(out_mask)
        print(mask_layer)

        #loss1 = - rawloglik* out_mask #This can be constant mask as well if we use mask_layer instead
        if masktype == 'constant': loss1 = -rawloglik * mask_layer
        elif masktype == 'vary': loss1 = -rawloglik * pred_mask
        loss2 = tf.nn.sigmoid_cross_entropy_with_logits(logits=out_mask,
                                                        labels=mask_layer)
        loglik = -(loss1 + loss2)

        hub.add_signature(inputs={
            'features': feature_layer,
            'labels': obs_layer
        },
                          outputs={
                              'sample': sample,
                              'loglikelihood': loglik,
                              'loc': loc,
                              'scale': scale,
                              'logits': logits,
                              'rawsample': rawsample,
                              'pred_mask': pred_mask,
                              'out_mask': out_mask,
                              'rawloglik': rawloglik,
                              'loss1': loss1,
                              'loss2': loss2
                          })