Example #1
0
 def __init__(self,
              do_convolution,
              dim_latent,
              num_examples,
              dim_hidden,
              dropout_rate=0.1,
              beta=1.0,
              negative_sampling=True,
              *args,
              **kwargs):
     super(Encoder, self).__init__(*args, **kwargs)
     self.conv = (convolution.ConvHead(
         base_filters=32,
         num_examples=sum(num_examples),
         dropout_rate=dropout_rate,
     ) if do_convolution else dense.identity())
     self.hidden_1 = dense.Dense(
         units=dim_hidden,
         num_examples=sum(num_examples),
         dropout_rate=dropout_rate,
         activation="elu",
         name="encoder_hidden_1",
     )
     self.hidden_2 = dense.Dense(
         units=dim_hidden,
         num_examples=sum(num_examples),
         dropout_rate=dropout_rate,
         activation="elu",
         name="encoder_hidden_2",
     )
     self.hidden_3 = dense.Dense(
         units=dim_hidden,
         num_examples=num_examples,
         dropout_rate=dropout_rate,
         num_branches=2,
         activation="elu",
         name="encoder_hidden_3",
     )
     self.hidden_4 = dense.Dense(
         units=dim_hidden,
         num_examples=num_examples,
         dropout_rate=dropout_rate,
         num_branches=2,
         activation="elu",
         name="encoder_hidden_4",
     )
     self.sampler = samplers.NormalSampler(
         dim_output=dim_latent,
         num_branches=2,
         num_examples=num_examples,
         beta=beta / 2 if negative_sampling else beta,
     )
     self.negative_sampling = negative_sampling
Example #2
0
 def __init__(self, dim_output, num_branches, num_examples, beta, *args,
              **kwargs):
     super(NormalSampler, self).__init__(*args, **kwargs)
     self.mu = dense.Dense(units=dim_output,
                           num_examples=num_examples,
                           dropout_rate=0.0,
                           num_branches=num_branches,
                           activation='linear',
                           name='normal_mu')
     self.sigma = dense.Dense(
         units=dim_output,
         num_examples=num_examples,
         dropout_rate=0.0,
         num_branches=num_branches,
         activation='softplus',
         kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
         name='normal_sigma')
     self.beta = beta
Example #3
0
 def __init__(self, dim_output, num_branches, num_examples, beta, *args,
              **kwargs):
     super(CategoricalSampler, self).__init__(*args, **kwargs)
     self.logits = dense.Dense(units=dim_output,
                               num_examples=num_examples,
                               dropout_rate=0.0,
                               num_branches=num_branches,
                               activation='linear',
                               name='categorical_logits')
     self.beta = beta
Example #4
0
 def __init__(self, dim_output, num_branches, num_examples, beta, *args,
              **kwargs):
     super(BernoulliSampler, self).__init__(*args, **kwargs)
     self.logits = dense.Dense(units=dim_output,
                               num_examples=num_examples,
                               dropout_rate=0.0,
                               num_branches=num_branches,
                               activation='linear',
                               name='encoder_mu')
     self.beta = beta
Example #5
0
 def __init__(self,
              do_convolution,
              num_examples,
              dim_hidden,
              regression,
              dropout_rate=0.1,
              beta=1.0,
              mode="tarnet",
              *args,
              **kwargs):
     super(TARNet, self).__init__(*args, **kwargs)
     self.conv = (convolution.ConvHead(
         base_filters=32,
         num_examples=sum(num_examples),
         dropout_rate=0.1,
     ) if do_convolution else dense.identity())
     self.hidden_1 = dense.Dense(
         units=dim_hidden,
         num_examples=sum(num_examples),
         dropout_rate=dropout_rate,
         activation="elu",
         name="tarnet_hidden_1",
     )
     self.hidden_2 = dense.Dense(
         units=dim_hidden,
         num_examples=sum(num_examples),
         dropout_rate=dropout_rate,
         activation="elu",
         name="tarnet_hidden_2",
     )
     self.hidden_3 = dense.Dense(
         units=dim_hidden,
         num_examples=sum(num_examples),
         dropout_rate=dropout_rate,
         activation="elu",
         name="tarnet_hidden_3",
     )
     self.hidden_4 = dense.Dense(
         units=dim_hidden,
         num_examples=num_examples,
         dropout_rate=dropout_rate,
         num_branches=2,
         activation="elu",
         name="encoder_hidden_4",
     )
     self.hidden_5 = dense.Dense(
         units=dim_hidden,
         num_examples=num_examples,
         dropout_rate=0.5,
         num_branches=2,
         activation="elu",
         name="encoder_hidden_5",
     )
     y_sampler = samplers.NormalSampler if regression else samplers.BernoulliSampler
     self.y_sampler = y_sampler(dim_output=1,
                                num_branches=2,
                                num_examples=num_examples,
                                beta=0.0)
     self.regression = regression
     self.beta = beta
     if mode == "dragon":
         self.t_sampler = samplers.CategoricalSampler(
             2, num_branches=1, num_examples=sum(num_examples), beta=0.0)
     self.mode = mode
     self.y_loss = (tf.keras.losses.MeanSquaredError()
                    if regression else tf.keras.losses.BinaryCrossentropy())
Example #6
0
 def __init__(self,
              dim_x,
              dim_t,
              dim_y,
              regression,
              num_examples,
              dim_hidden,
              dropout_rate=0.1,
              *args,
              **kwargs):
     super(Decoder, self).__init__(*args, **kwargs)
     do_convolution = isinstance(dim_x, (tuple, list))
     self.x_hidden_1 = dense.Dense(
         units=dim_hidden,
         num_examples=sum(num_examples),
         dropout_rate=dropout_rate,
         num_branches=1,
         activation="elu",
         name="decoder_x_hidden_1",
     )
     self.x_hidden_2 = dense.Dense(
         units=dim_hidden,
         num_examples=sum(num_examples),
         dropout_rate=dropout_rate if do_convolution else 0.1,
         num_branches=1,
         activation="elu",
         name="decoder_x_hidden_2",
     )
     self.x_conv = (convolution.ConvTail(
         base_filters=32,
         num_examples=sum(num_examples),
         dropout_rate=dropout_rate,
     ) if do_convolution else dense.identity())
     self.t_hidden_1 = dense.Dense(
         units=dim_hidden,
         num_examples=sum(num_examples),
         dropout_rate=dropout_rate,
         num_branches=1,
         activation="elu",
         name="decoder_t_hidden_1",
     )
     self.t_hidden_2 = dense.Dense(
         units=dim_hidden,
         num_examples=sum(num_examples),
         dropout_rate=0.5,
         num_branches=1,
         activation="elu",
         name="decoder_t_hidden_2",
     )
     self.t_sampler = samplers.CategoricalSampler(
         dim_output=dim_t,
         num_branches=1,
         num_examples=sum(num_examples),
         beta=0.0)
     self.y_hidden_1 = dense.Dense(
         units=dim_hidden,
         num_examples=sum(num_examples),
         dropout_rate=dropout_rate,
         activation="elu",
         name="decoder_y_hidden_1",
     )
     self.y_hidden_2 = dense.Dense(
         units=dim_hidden,
         num_examples=sum(num_examples),
         dropout_rate=dropout_rate,
         activation="elu",
         name="decoder_y_hidden_2",
     )
     self.y_hidden_3 = dense.Dense(
         units=dim_hidden,
         num_examples=num_examples,
         dropout_rate=dropout_rate,
         num_branches=2,
         activation="elu",
         name="decoder_y_hidden_3",
     )
     self.y_hidden_4 = dense.Dense(
         units=dim_hidden,
         num_examples=num_examples,
         dropout_rate=0.5,
         num_branches=2,
         activation="elu",
         name="decoder_y_hidden_4",
     )
     y_sampler = samplers.NormalSampler if regression else samplers.BernoulliSampler
     self.y_sampler = y_sampler(dim_output=dim_y,
                                num_branches=2,
                                num_examples=num_examples,
                                beta=0.0)