def _vlb_in_bits_per_dims(self, model, x_start, x_t, t, clip_denoised=True): """ Calculate variational lower bound in bits/dims. """ B, C, H, W = x_start.shape assert x_start.shape == x_t.shape assert t.shape == (B, ) # true parameters mean, _, log_var_clipped = self.q_posterior(x_start, x_t, t) # pred parameters preds = self.p_mean_var(model, x_t, t, clip_denoised=clip_denoised) # Negative log-likelihood nll = -gaussian_log_likelihood(x_start, mean=preds.mean, logstd=0.5 * preds.log_var) nll_bits = mean_along_except_batch(nll) / np.log(2.0) assert nll.shape == x_start.shape assert nll_bits.shape == (B, ) # kl between true and pred in bits kl = kl_normal(mean, log_var_clipped, preds.mean, preds.log_var) kl_bits = mean_along_except_batch(kl) / np.log(2.0) assert kl.shape == x_start.shape assert kl_bits.shape == (B, ) # Return nll at t = 0, otherwise KL(q(x_{t-1}|x_t,x_0)||p(x_{t-1}|x_t)) return F.where(F.equal_scalar(t, 0), nll_bits, kl_bits)
def bool_scatter_backward(inputs): """ Args: inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function. kwargs (dict of arguments): Dictionary of the corresponding function arguments. Return: list of Variable: Return the gradients wrt inputs of the corresponding function. """ dy = inputs[0] x0 = inputs[1] m0 = inputs[2] o0 = inputs[3] if len(inputs) == 4 else None dx = F.bool_gather(dy, m0) dm = None if o0 is None: return dx, dm else: m1 = F.equal_scalar(m0, 0) m1 = F.reshape(m1, m1.shape + (1, ) * (dy.ndim - m1.ndim)) m1 = F.broadcast(m1, dy.shape) m1 = no_grad(m1) do = dy * m1 return dx, dm, do
def embedding(x, input_dim, output_dim, init=None, mask_zero=False): if init is None: init = I.UniformInitializer((-0.1, 0.1)) initialized = "embed/W" in nn.get_parameters() result = PF.embed(x, input_dim, output_dim) if not initialized: nn.get_parameters()["embed/W"].d = init( nn.get_parameters()["embed/W"].shape) if mask_zero: return result, 1 - F.equal_scalar(x, 0) else: return result
def bool_fill_backward(inputs, value=0): """ Args: inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function. kwargs (dict of arguments): Dictionary of the corresponding function arguments. Return: list of Variable: Return the gradients wrt inputs of the corresponding function. """ dy = inputs[0] x0 = inputs[1] m0 = inputs[2] m1 = F.equal_scalar(m0, 0.0) m1 = F.broadcast(m1, dy.shape) m1 = no_grad(m1) dx = dy * m1 dm = None return dx, dm
def decoder(target_action, target_action_type, target_node_type, target_parent_rule, target_parent_index, query_embed, query_embed_mask, rule_num, token_num, node_type_num, embedding_size, node_type_embedding_size, state_size, hidden_size, previous_action_embed=None, initial_state=None, initial_cell=None, hist=None, dropout=0.0, train=True): """ target_action: (batch_size, max_action_length, 3) target_action_type: (batch_size, max_action_length, 3) target_node_type: (batch_size, max_action_length) target_parent_rule: (batch_size, max_action_length) target_parent_index: (batch_size, max_action_length) """ batch_size, max_action_length, _ = target_action.shape # Node type ebedding with nn.parameter_scope("node_type_embedding"): target_node_type_embed = embedding(target_node_type, node_type_num, node_type_embedding_size, mask_zero=False, init=I.NormalInitializer(0.01)) # Previous action embedding ## (batch_size, max_action_length) target_apply_rule, target_gen_token, target_copy_token = split( target_action, axis=2) with nn.parameter_scope("rule_embedding"): # (batch_size, max_action_length, embedding_size) target_apply_rule_embed = embedding(target_apply_rule, rule_num, embedding_size, mask_zero=False, init=I.NormalInitializer(0.01)) target_apply_rule_embed = F.reshape( target_apply_rule_embed, (batch_size, max_action_length, 1, embedding_size)) with nn.parameter_scope("token_embedding"): # (batch_size, max_action_length, embedding_size) target_gen_token_embed = embedding(target_gen_token, token_num, embedding_size, mask_zero=False, init=I.NormalInitializer(0.01)) target_gen_token_embed = F.reshape( target_gen_token_embed, (batch_size, max_action_length, 1, embedding_size)) target_copy_token = F.reshape(target_copy_token, (batch_size, max_action_length, 1, 1)) target_copy_token = F.broadcast( target_copy_token, (batch_size, max_action_length, 1, embedding_size)) target_copy_token *= 0 # (batch_size, max_action_length, 3, embedding_size) target_action_embed = concatenate(target_apply_rule_embed, target_gen_token_embed, target_copy_token, axis=2) target_action_type2 = F.reshape(target_action_type, (batch_size, max_action_length, 3, 1)) target_action_type2 = F.broadcast( target_action_type2, (batch_size, max_action_length, 3, embedding_size)) # (batch_size, max_action_length, 3, embedding_size) target_action_embed = target_action_embed * target_action_type2 # (batch_size, max_action_length, embedding_size) target_action_embed = F.sum(target_action_embed, axis=2) # Shift action if previous_action_embed is None: previous_action_embed = nn.Variable((batch_size, 1, embedding_size), need_grad=False) previous_action_embed.data.zero() # (batch_size, max_action_length + 1, embedding_size) target_action_embed = concatenate(previous_action_embed, target_action_embed, axis=1) # (batch_size, max_action_length, embedding_size) target_action_embed = F.slice( target_action_embed, start=[0, 0, 0], stop=[batch_size, max_action_length, embedding_size]) # Parent action embedding parent_rule_mask = 1 - F.equal_scalar(target_parent_rule, 0) # (batch_size, max_action_length) parent_rule_mask = F.reshape(parent_rule_mask, (batch_size, max_action_length, 1)) parent_rule_mask = F.broadcast( parent_rule_mask, (batch_size, max_action_length, embedding_size)) with nn.parameter_scope("rule_embedding"): target_parent_rule_embed = embedding(target_parent_rule, rule_num, embedding_size, mask_zero=False) target_parent_rule_embed = parent_rule_mask * target_parent_rule_embed # (batch_size, max_action_length, embedding_size * 2 + node_type_embedding_size) decoder_input = concatenate(target_action_embed, target_node_type_embed, target_parent_rule_embed, axis=2) target_action_mask = 1 - F.equal_scalar(F.sum( target_action_type, axis=2), 0) # (batch_size, max_action_length) with nn.parameter_scope("decoder"): decoder_hidden_states, decoder_cells, ctx_vectors, new_hist = cond_att_lstm( decoder_input, target_parent_index, target_action_mask, query_embed, query_embed_mask, state_size, hidden_size, initial_state=initial_state, initial_cell=initial_cell, hist=hist, dropout=dropout, train=train) return target_action_embed, decoder_hidden_states, decoder_cells, ctx_vectors, target_action_mask, new_hist