def forward(self, x, forward=0, stochastic=False):

        h = x  # (n, t)
        h = self.dropout(self.embedding(h))  # (n, t, c)

        states = []
        for rnn in self.rnns:
            h, state = rnn(h)
            states.append(state)
        h = self.projection(h)
        if stochastic:
            gumbel = utils.to_variable(
                utils.sample_gumbel(shape=h.size(), out=h.data.new()))
            h += gumbel
        logits = h
        if forward > 0:
            outputs = []
            h = torch.max(logits[:, -1:, :], dim=2)[1] + 1
            for i in range(forward):
                h = self.embedding(h)
                for j, rnn in enumerate(self.rnns):
                    h, state = rnn(h, states[j])
                    states[j] = state
                h = self.projection(h)
                if stochastic:
                    gumbel = utils.to_variable(
                        utils.sample_gumbel(shape=h.size(), out=h.data.new()))
                    h += gumbel
                outputs.append(h)
                h = torch.max(h, dim=2)[1] + 1
            logits = torch.cat([logits] + outputs, dim=1)
        return logits
    def forward(self, x, hidden, forward=0, stochastic=False):

        emb = self.dropout(self.embedding(x))  # (n, t, c)
        output, hidden = self.lstm(emb, hidden)
        output = self.dropout(output)
        output = self.projection(output)

        if stochastic:
            gumbel = utils.to_variable(
                utils.sample_gumbel(shape=output.size(),
                                    out=output.data.new()))
            output += gumbel
        logits = output
        if forward > 0:
            outputs = []
            output = torch.max(logits[:, -1:, :], dim=2)[1] + 1
            for i in range(forward):
                emb = self.dropout(self.embedding(x))  # (n, t, c)
                output, hidden = self.lstm(emb, hidden)
                output = self.dropout(output)
                output = self.projection(output)

                if stochastic:
                    gumbel = utils.to_variable(
                        utils.sample_gumbel(shape=output.size(),
                                            out=output.data.new()))
                    output += gumbel
                outputs.append(output)
                output = torch.max(output, dim=2)[1] + 1
            logits = torch.cat([logits] + outputs, dim=1)
        return logits, hidden
Exemple #3
0
    def generate(self, x, hidden, forward):

        emb = self.dropout(self.embedding(x))  # (n, t, c)
        output, hidden = self.rnn(emb, hidden)
        output = self.dropout(output)
        h = self.projection(output)
        gumbel = utils.to_variable(
            utils.sample_gumbel(shape=h.size(), out=h.data.new()))

        h += gumbel
        logits = h
        outputs = []
        h = torch.max(logits[:, -1:, :], dim=2)[1] + 1
        hidden = self.init_hidden(1)
        for i in range(forward):
            emb = self.dropout(self.embedding(h))
            h, hidden = self.rnn(emb, hidden)
            h = self.dropout(h)
            h = self.projection(h)
            gumbel = utils.to_variable(
                utils.sample_gumbel(shape=h.size(), out=h.data.new()))
            h += gumbel
            outputs.append(h)
            h = torch.max(h, dim=2)[1] + 1
        logits = torch.cat([logits] + outputs, dim=1)

        return logits
Exemple #4
0
    def __init__(self, input_, cont_dim=2, discrete_dim=0,
                 filters=[32, 64], hidden_dim=1024, model_name="ConcreteVae"):
        """
        Constructs a Variational Autoencoder that supports continuous and
        discrete dimensions. Currently only one discrete dimension is
        supported.

        Args:
        input_        the input tensor
        cont_dim      the number of continuous latent dimensions
        discrete_dim  the number of categories in the discrete latent dimension
        filters       the number of filters for each convolution
        hidden_dim    the dimension of the fully-connected hidden layer between
                          the convolutions and the latent variable
        model_name    the name of the model
        """
        self.input_ = input_
        input_shape = input_.get_shape().as_list()
        print('Input shape {}'.format(input_shape))

        self.model_name = model_name

        # Build the encoder
        # According to karpathy, generative models work better when
        # they discard pooling layers in favor of larger strides
        # (https://cs231n.github.io/convolutional-networks/#pool)
        net = slim.conv2d(self.input_, filters[0], kernel_size=5, stride=2,
                          padding='SAME')
        net = slim.conv2d(net, filters[1], kernel_size=5, stride=2,
                          padding='SAME')
        # Use dropout to reduce overfitting
        # net = slim.dropout(net, 0.9)
        net = slim.flatten(net)

        # Sample from the latent distribution
        q_z_mean = slim.fully_connected(net, cont_dim, activation_fn=None)
        q_z_log_var = slim.fully_connected(net, cont_dim, activation_fn=None)
        # TODO: support multiple categorical variables
        q_category_logits = slim.fully_connected(net, discrete_dim,
                                                 activation_fn=None)
        q_category = tf.nn.softmax(q_category_logits)
        self.q_z_mean = q_z_mean
        self.q_z_log_var = q_z_log_var
        self.q_category = q_category

        self.continuous_z = sample_normal(q_z_mean, q_z_log_var)
        self.tau = tf.Variable(5.0, name="temperature")
        self.category = sample_gumbel(q_category_logits, self.tau)
        self.z = tf.concat([self.continuous_z, self.category], axis=1)

        # Build the decoder
        net = tf.reshape(self.z, [-1, 1, 1, cont_dim + discrete_dim])
        net = slim.conv2d_transpose(net, filters[1], kernel_size=5,
                                    stride=2, padding='SAME')
        net = slim.conv2d_transpose(net, filters[0], kernel_size=5,
                                    stride=2, padding='SAME')
        net = slim.conv2d_transpose(net, input_shape[3], kernel_size=5,
                                    padding='VALID')
        net = slim.flatten(net)
        # TODO: figure out the whole logits and Bernoulli dist vs MSE thing
        # Do not include the batch size in creating the final layer
        self.logits = slim.fully_connected(net, np.product(input_shape[1:]),
                                           activation_fn=None)
        print('Output shape {}'.format(self.logits.get_shape()))
        p_x = Bernoulli(logits=self.logits)
        self.p_x = p_x

        self.loss = self._vae_loss()
        self.learning_rate = tf.Variable(1e-3, name="learning_rate")
        self.optimizer = tf.train.AdamOptimizer(
            learning_rate=self.learning_rate) \
            .minimize(self.loss, var_list=slim.get_model_variables())
Exemple #5
0
    def init_challenge(self, mu=1e-4, sigma=1e-4, pi=1e-4, eps=1e-4):
        """Initialize the perturbation vector r.
        We want to learn r, which perturbs the fixed random parameters
        \lambda_0. To do this, we will initialize r as \lambda_0, and
        use a generative loss to ensure r is minimally different than
        \lambda_0.

        Initialization is lambda_0 + alpha * distribution_noise.
        """
        for idx, row in enumerate(self.dists):
            trainable = row['trainable']
            # We have lambda_0 (for center/maybe scale)
            if not torch.is_tensor(row['lambda_0']):
                # If this is not a torch tensor, convert it
                row['lambda_0'] = torch.tensor(row['lambda_0'])

            # Send lambda_0 center to device
            row['lambda_0'] = row['lambda_0'].to(self.device)
            if (row['family'] == 'gaussian' or row['family'] == 'normal'
                    or row['family'] == 'cnormal'
                    or row['family'] == 'abs_normal'):
                if not torch.is_tensor(row['lambda_0_scale']):
                    # If this is not a torch tensor, convert it
                    row['lambda_0_scale'] = torch.is_tensor(
                        row['lambda_0_scale'])  # noqa
                # Send lambda_0 scale to device
                self.dists[idx]['lambda_0_scale'] = row['lambda_0_scale'].to(
                    self.device)  # noqa

                # Initilize challenge center/scale
                lambda_r = row['lambda_0'] + torch.randn_like(
                    row['lambda_0']) * mu  # noqa
                lambda_r_scale = row['lambda_0_scale'] + torch.randn_like(
                    row['lambda_0_scale']) * sigma  # noqa

                # Also add the r parameters to a list
                if self.wn:
                    raise RuntimeError('Weightnorm not working for psvrt.')
                    lambda_r_g, lambda_r_v = self.init_wns(lambda_r)
                    lambda_r_scale_g, lambda_r_scale_v = self.init_wns(
                        lambda_r_scale)  # noqa
                    # Also add the r parameters to a list
                    attr_name = '{}_{}'.format(row['name'], 'center_g')
                    setattr(self, attr_name,
                            nn.Parameter(lambda_r_g,
                                         requires_grad=trainable))  # noqa
                    attr_name = '{}_{}'.format(row['name'], 'center_v')
                    setattr(self, attr_name,
                            nn.Parameter(lambda_r_v,
                                         requires_grad=trainable))  # noqa
                    attr_name = '{}_{}'.format(row['name'], 'scale_g')
                    setattr(self, attr_name,
                            nn.Parameter(lambda_r_scale_g,
                                         requires_grad=trainable))  # noqa
                    attr_name = '{}_{}'.format(row['name'], 'scale_v')
                    setattr(self, attr_name,
                            nn.Parameter(lambda_r_scale_v,
                                         requires_grad=trainable))  # noqa
                else:
                    # Also add the r parameters to a list
                    attr_name = '{}_{}'.format(row['name'], 'center')
                    setattr(self, attr_name,
                            nn.Parameter(lambda_r,
                                         requires_grad=trainable))  # noqa
                    attr_name = '{}_{}'.format(row['name'], 'scale')
                    setattr(self, attr_name,
                            nn.Parameter(lambda_r_scale,
                                         requires_grad=trainable))  # noqa
            elif row['family'] == 'half_normal':
                if not torch.is_tensor(row['lambda_0_scale']):
                    # If this is not a torch tensor, convert it
                    row['lambda_0_scale'] = torch.is_tensor(
                        row['lambda_0_scale'])  # noqa
                # Send lambda_0 scale to device
                self.dists[idx]['lambda_0_scale'] = row['lambda_0_scale'].to(
                    self.device)  # noqa

                # Initilize challenge center/scale
                lambda_r_scale = row['lambda_0_scale'] + torch.abs(
                    torch.randn_like(row['lambda_0_scale'])) * sigma  # noqa

                # Also add the r parameters to a list
                if self.wn:
                    raise RuntimeError('Weightnorm not working for psvrt.')
                    lambda_r_scale_g, lambda_r_scale_v = self.init_wns(
                        lambda_r_scale)
                    # Also add the r parameters to a list
                    attr_name = '{}_{}'.format(row['name'], 'scale_g')
                    setattr(self, attr_name,
                            nn.Parameter(lambda_r_scale_g,
                                         requires_grad=trainable))  # noqa
                    attr_name = '{}_{}'.format(row['name'], 'scale_v')
                    setattr(self, attr_name,
                            nn.Parameter(lambda_r_scale_v,
                                         requires_grad=trainable))  # noqa
                else:
                    # Also add the r parameters to a list
                    attr_name = '{}_{}'.format(row['name'], 'scale')
                    setattr(self, attr_name,
                            nn.Parameter(lambda_r_scale,
                                         requires_grad=trainable))  # noqa
            elif row['family'] == 'relaxed_bernoulli':
                # Handle pi of categorical dist (var/temp is hardcoded)
                soft_log_probs = torch.log(row['lambda_0'] + eps)  # Don't save
                lambda_r = soft_log_probs  #  + torch.rand_like(soft_log_probs) * pi  # noqa
                lambda_r = lambda_r.to(self.device)  # noqa

                # Also add the r parameters to a list
                if self.wn:
                    raise RuntimeError('Weightnorm not working for psvrt.')
                    lambda_r_g, lambda_r_v = self.init_wns(lambda_r)
                    lambda_r_scale_g, lambda_r_scale_v = self.init_wns(
                        lambda_r_scale)
                    # Also add the r parameters to a list
                    attr_name = '{}_{}'.format(row['name'], 'center_g')
                    setattr(self, attr_name,
                            nn.Parameter(lambda_r_g,
                                         requires_grad=trainable))  # noqa
                    attr_name = '{}_{}'.format(row['name'], 'center_v')
                    setattr(self, attr_name,
                            nn.Parameter(lambda_r_v,
                                         requires_grad=trainable))  # noqa
                else:
                    # Also add the r parameters to a list
                    attr_name = '{}_{}'.format(row['name'], 'center')
                    setattr(self, attr_name,
                            nn.Parameter(lambda_r,
                                         requires_grad=trainable))  # noqa
            elif row['family'] == 'categorical':
                # Handle pi of categorical dist (var/temp is hardcoded)
                soft_log_probs = torch.log(row['lambda_0'] + eps)  # Don't save
                lambda_r = soft_log_probs + sample_gumbel(
                    soft_log_probs) * pi  # noqa
                lambda_r = lambda_r.to(self.device)  # noqa

                # Also add the r parameters to a list
                if self.wn:
                    raise RuntimeError('Weightnorm not working for psvrt.')
                    lambda_r_g, lambda_r_v = self.init_wns(lambda_r)
                    lambda_r_scale_g, lambda_r_scale_v = self.init_wns(
                        lambda_r_scale)
                    # Also add the r parameters to a list
                    attr_name = '{}_{}'.format(row['name'], 'center_g')
                    setattr(self, attr_name,
                            nn.Parameter(lambda_r_g,
                                         requires_grad=trainable))  # noqa
                    attr_name = '{}_{}'.format(row['name'], 'center_v')
                    setattr(self, attr_name,
                            nn.Parameter(lambda_r_v,
                                         requires_grad=trainable))  # noqa
                else:
                    # Also add the r parameters to a list
                    attr_name = '{}_{}'.format(row['name'], 'center')
                    setattr(self, attr_name,
                            nn.Parameter(lambda_r,
                                         requires_grad=trainable))  # noqa
            else:
                raise NotImplementedError