def forward(self, x, forward=0, stochastic=False): h = x # (n, t) h = self.dropout(self.embedding(h)) # (n, t, c) states = [] for rnn in self.rnns: h, state = rnn(h) states.append(state) h = self.projection(h) if stochastic: gumbel = utils.to_variable( utils.sample_gumbel(shape=h.size(), out=h.data.new())) h += gumbel logits = h if forward > 0: outputs = [] h = torch.max(logits[:, -1:, :], dim=2)[1] + 1 for i in range(forward): h = self.embedding(h) for j, rnn in enumerate(self.rnns): h, state = rnn(h, states[j]) states[j] = state h = self.projection(h) if stochastic: gumbel = utils.to_variable( utils.sample_gumbel(shape=h.size(), out=h.data.new())) h += gumbel outputs.append(h) h = torch.max(h, dim=2)[1] + 1 logits = torch.cat([logits] + outputs, dim=1) return logits
def forward(self, x, hidden, forward=0, stochastic=False): emb = self.dropout(self.embedding(x)) # (n, t, c) output, hidden = self.lstm(emb, hidden) output = self.dropout(output) output = self.projection(output) if stochastic: gumbel = utils.to_variable( utils.sample_gumbel(shape=output.size(), out=output.data.new())) output += gumbel logits = output if forward > 0: outputs = [] output = torch.max(logits[:, -1:, :], dim=2)[1] + 1 for i in range(forward): emb = self.dropout(self.embedding(x)) # (n, t, c) output, hidden = self.lstm(emb, hidden) output = self.dropout(output) output = self.projection(output) if stochastic: gumbel = utils.to_variable( utils.sample_gumbel(shape=output.size(), out=output.data.new())) output += gumbel outputs.append(output) output = torch.max(output, dim=2)[1] + 1 logits = torch.cat([logits] + outputs, dim=1) return logits, hidden
def generate(self, x, hidden, forward): emb = self.dropout(self.embedding(x)) # (n, t, c) output, hidden = self.rnn(emb, hidden) output = self.dropout(output) h = self.projection(output) gumbel = utils.to_variable( utils.sample_gumbel(shape=h.size(), out=h.data.new())) h += gumbel logits = h outputs = [] h = torch.max(logits[:, -1:, :], dim=2)[1] + 1 hidden = self.init_hidden(1) for i in range(forward): emb = self.dropout(self.embedding(h)) h, hidden = self.rnn(emb, hidden) h = self.dropout(h) h = self.projection(h) gumbel = utils.to_variable( utils.sample_gumbel(shape=h.size(), out=h.data.new())) h += gumbel outputs.append(h) h = torch.max(h, dim=2)[1] + 1 logits = torch.cat([logits] + outputs, dim=1) return logits
def __init__(self, input_, cont_dim=2, discrete_dim=0, filters=[32, 64], hidden_dim=1024, model_name="ConcreteVae"): """ Constructs a Variational Autoencoder that supports continuous and discrete dimensions. Currently only one discrete dimension is supported. Args: input_ the input tensor cont_dim the number of continuous latent dimensions discrete_dim the number of categories in the discrete latent dimension filters the number of filters for each convolution hidden_dim the dimension of the fully-connected hidden layer between the convolutions and the latent variable model_name the name of the model """ self.input_ = input_ input_shape = input_.get_shape().as_list() print('Input shape {}'.format(input_shape)) self.model_name = model_name # Build the encoder # According to karpathy, generative models work better when # they discard pooling layers in favor of larger strides # (https://cs231n.github.io/convolutional-networks/#pool) net = slim.conv2d(self.input_, filters[0], kernel_size=5, stride=2, padding='SAME') net = slim.conv2d(net, filters[1], kernel_size=5, stride=2, padding='SAME') # Use dropout to reduce overfitting # net = slim.dropout(net, 0.9) net = slim.flatten(net) # Sample from the latent distribution q_z_mean = slim.fully_connected(net, cont_dim, activation_fn=None) q_z_log_var = slim.fully_connected(net, cont_dim, activation_fn=None) # TODO: support multiple categorical variables q_category_logits = slim.fully_connected(net, discrete_dim, activation_fn=None) q_category = tf.nn.softmax(q_category_logits) self.q_z_mean = q_z_mean self.q_z_log_var = q_z_log_var self.q_category = q_category self.continuous_z = sample_normal(q_z_mean, q_z_log_var) self.tau = tf.Variable(5.0, name="temperature") self.category = sample_gumbel(q_category_logits, self.tau) self.z = tf.concat([self.continuous_z, self.category], axis=1) # Build the decoder net = tf.reshape(self.z, [-1, 1, 1, cont_dim + discrete_dim]) net = slim.conv2d_transpose(net, filters[1], kernel_size=5, stride=2, padding='SAME') net = slim.conv2d_transpose(net, filters[0], kernel_size=5, stride=2, padding='SAME') net = slim.conv2d_transpose(net, input_shape[3], kernel_size=5, padding='VALID') net = slim.flatten(net) # TODO: figure out the whole logits and Bernoulli dist vs MSE thing # Do not include the batch size in creating the final layer self.logits = slim.fully_connected(net, np.product(input_shape[1:]), activation_fn=None) print('Output shape {}'.format(self.logits.get_shape())) p_x = Bernoulli(logits=self.logits) self.p_x = p_x self.loss = self._vae_loss() self.learning_rate = tf.Variable(1e-3, name="learning_rate") self.optimizer = tf.train.AdamOptimizer( learning_rate=self.learning_rate) \ .minimize(self.loss, var_list=slim.get_model_variables())
def init_challenge(self, mu=1e-4, sigma=1e-4, pi=1e-4, eps=1e-4): """Initialize the perturbation vector r. We want to learn r, which perturbs the fixed random parameters \lambda_0. To do this, we will initialize r as \lambda_0, and use a generative loss to ensure r is minimally different than \lambda_0. Initialization is lambda_0 + alpha * distribution_noise. """ for idx, row in enumerate(self.dists): trainable = row['trainable'] # We have lambda_0 (for center/maybe scale) if not torch.is_tensor(row['lambda_0']): # If this is not a torch tensor, convert it row['lambda_0'] = torch.tensor(row['lambda_0']) # Send lambda_0 center to device row['lambda_0'] = row['lambda_0'].to(self.device) if (row['family'] == 'gaussian' or row['family'] == 'normal' or row['family'] == 'cnormal' or row['family'] == 'abs_normal'): if not torch.is_tensor(row['lambda_0_scale']): # If this is not a torch tensor, convert it row['lambda_0_scale'] = torch.is_tensor( row['lambda_0_scale']) # noqa # Send lambda_0 scale to device self.dists[idx]['lambda_0_scale'] = row['lambda_0_scale'].to( self.device) # noqa # Initilize challenge center/scale lambda_r = row['lambda_0'] + torch.randn_like( row['lambda_0']) * mu # noqa lambda_r_scale = row['lambda_0_scale'] + torch.randn_like( row['lambda_0_scale']) * sigma # noqa # Also add the r parameters to a list if self.wn: raise RuntimeError('Weightnorm not working for psvrt.') lambda_r_g, lambda_r_v = self.init_wns(lambda_r) lambda_r_scale_g, lambda_r_scale_v = self.init_wns( lambda_r_scale) # noqa # Also add the r parameters to a list attr_name = '{}_{}'.format(row['name'], 'center_g') setattr(self, attr_name, nn.Parameter(lambda_r_g, requires_grad=trainable)) # noqa attr_name = '{}_{}'.format(row['name'], 'center_v') setattr(self, attr_name, nn.Parameter(lambda_r_v, requires_grad=trainable)) # noqa attr_name = '{}_{}'.format(row['name'], 'scale_g') setattr(self, attr_name, nn.Parameter(lambda_r_scale_g, requires_grad=trainable)) # noqa attr_name = '{}_{}'.format(row['name'], 'scale_v') setattr(self, attr_name, nn.Parameter(lambda_r_scale_v, requires_grad=trainable)) # noqa else: # Also add the r parameters to a list attr_name = '{}_{}'.format(row['name'], 'center') setattr(self, attr_name, nn.Parameter(lambda_r, requires_grad=trainable)) # noqa attr_name = '{}_{}'.format(row['name'], 'scale') setattr(self, attr_name, nn.Parameter(lambda_r_scale, requires_grad=trainable)) # noqa elif row['family'] == 'half_normal': if not torch.is_tensor(row['lambda_0_scale']): # If this is not a torch tensor, convert it row['lambda_0_scale'] = torch.is_tensor( row['lambda_0_scale']) # noqa # Send lambda_0 scale to device self.dists[idx]['lambda_0_scale'] = row['lambda_0_scale'].to( self.device) # noqa # Initilize challenge center/scale lambda_r_scale = row['lambda_0_scale'] + torch.abs( torch.randn_like(row['lambda_0_scale'])) * sigma # noqa # Also add the r parameters to a list if self.wn: raise RuntimeError('Weightnorm not working for psvrt.') lambda_r_scale_g, lambda_r_scale_v = self.init_wns( lambda_r_scale) # Also add the r parameters to a list attr_name = '{}_{}'.format(row['name'], 'scale_g') setattr(self, attr_name, nn.Parameter(lambda_r_scale_g, requires_grad=trainable)) # noqa attr_name = '{}_{}'.format(row['name'], 'scale_v') setattr(self, attr_name, nn.Parameter(lambda_r_scale_v, requires_grad=trainable)) # noqa else: # Also add the r parameters to a list attr_name = '{}_{}'.format(row['name'], 'scale') setattr(self, attr_name, nn.Parameter(lambda_r_scale, requires_grad=trainable)) # noqa elif row['family'] == 'relaxed_bernoulli': # Handle pi of categorical dist (var/temp is hardcoded) soft_log_probs = torch.log(row['lambda_0'] + eps) # Don't save lambda_r = soft_log_probs # + torch.rand_like(soft_log_probs) * pi # noqa lambda_r = lambda_r.to(self.device) # noqa # Also add the r parameters to a list if self.wn: raise RuntimeError('Weightnorm not working for psvrt.') lambda_r_g, lambda_r_v = self.init_wns(lambda_r) lambda_r_scale_g, lambda_r_scale_v = self.init_wns( lambda_r_scale) # Also add the r parameters to a list attr_name = '{}_{}'.format(row['name'], 'center_g') setattr(self, attr_name, nn.Parameter(lambda_r_g, requires_grad=trainable)) # noqa attr_name = '{}_{}'.format(row['name'], 'center_v') setattr(self, attr_name, nn.Parameter(lambda_r_v, requires_grad=trainable)) # noqa else: # Also add the r parameters to a list attr_name = '{}_{}'.format(row['name'], 'center') setattr(self, attr_name, nn.Parameter(lambda_r, requires_grad=trainable)) # noqa elif row['family'] == 'categorical': # Handle pi of categorical dist (var/temp is hardcoded) soft_log_probs = torch.log(row['lambda_0'] + eps) # Don't save lambda_r = soft_log_probs + sample_gumbel( soft_log_probs) * pi # noqa lambda_r = lambda_r.to(self.device) # noqa # Also add the r parameters to a list if self.wn: raise RuntimeError('Weightnorm not working for psvrt.') lambda_r_g, lambda_r_v = self.init_wns(lambda_r) lambda_r_scale_g, lambda_r_scale_v = self.init_wns( lambda_r_scale) # Also add the r parameters to a list attr_name = '{}_{}'.format(row['name'], 'center_g') setattr(self, attr_name, nn.Parameter(lambda_r_g, requires_grad=trainable)) # noqa attr_name = '{}_{}'.format(row['name'], 'center_v') setattr(self, attr_name, nn.Parameter(lambda_r_v, requires_grad=trainable)) # noqa else: # Also add the r parameters to a list attr_name = '{}_{}'.format(row['name'], 'center') setattr(self, attr_name, nn.Parameter(lambda_r, requires_grad=trainable)) # noqa else: raise NotImplementedError