def backward_impl(self, inputs, outputs, prop_down, accum): # inputs: [inputs_fwd_graph] + [inputs_bwd_graph] or # [inputs_fwd_graph] + [outputs_fwd_graph] + [inputs_bwd_graph] #raise NotImplementedError("The backward method of BinaryCrossEntropyBackward class is not implemented.") # Inputs x0 = inputs[0].data # probabilities t0 = inputs[1].data # labels dz = inputs[2].data # grad_input # Outputs dx0 = outputs[0].data dt0 = outputs[1].data # Grads of inputs g_x0 = inputs[0].grad g_t0 = inputs[1].grad g_dz = inputs[2].grad # Grads of outputs g_dx0 = outputs[0].grad g_dt0 = outputs[1].grad # Computation ## w.r.t. x0 if prop_down[0]: u0 = g_dx0 * (t0 / x0 ** 2.0 + (1.0 - t0) / (1 - x0) ** 2.0) u1 = g_dt0 / (x0 * (1.0 - x0)) g_x0_ = dz * (u0 - u1) if accum[0]: g_x0 += g_x0_ else: g_x0.copy_from(g_x0_) ## w.r.t. t0 if prop_down[1]: #g_t0_ = g_dx0 * dz * (1.0 / x0 + 1.0 / (1.0 - x0)) g_t0_ = g_dx0 * dz / (x0 * (1.0 - x0)) if accum[1]: g_t0 -= g_t0_ else: g_t0.copy_from(-g_t0_) ## w.r.t. dz if prop_down[2]: #u0 = g_dx0 * ((1.0 - t0) / (1.0 - x0) - t0 / x0) u0 = g_dx0 * (x0 - t0) / (x0 * (1.0 - x0)) u1 = g_dt0 * (F.log(1.0 - x0) - F.log(x0)) g_dz_ = u0 + u1 if accum[2]: g_dz += g_dz_ else: g_dz.copy_from(g_dz_)
def kl_divergence(ctx, pred, label, log_var): with nn.context_scope(ctx): s = F.pow_scalar(F.exp(log_var), 0.5) elms = softmax_with_temperature(ctx, label, s) \ * F.log(F.softmax(pred, axis=1)) loss = -F.mean(F.sum(elms, axis=1)) return loss
def backward_impl(self, inputs, outputs, prop_down, accum): # inputs: [inputs_fwd_graph] + [inputs_bwd_graph] or # [inputs_fwd_graph] + [outputs_fwd_graph] + [inputs_bwd_graph] # Inputs x0 = inputs[0].data dy = inputs[1].data # Outputs dx0 = outputs[0].data # Grads of inputs g_x0 = inputs[0].grad g_dy = inputs[1].grad # Grads of outputs g_dx0 = outputs[0].grad # Compute val = self.forward_func.info.args["val"] if prop_down[0] or prop_down[1]: cv = F.constant(val, x0.shape) if not nn.get_auto_forward(): cv.forward() log_v = F.log(cv.data) if prop_down[0]: if accum[0]: g_x0 += g_dx0 * dy * F.r_pow_scalar(x0, val) * log_v**2.0 else: g_x0.copy_from(g_dx0 * dy * F.r_pow_scalar(x0, val) * log_v**2.0) if prop_down[1]: if accum[1]: g_dy += g_dx0 * F.r_pow_scalar(x0, val) * log_v else: g_dy.copy_from(g_dx0 * F.r_pow_scalar(x0, val) * log_v)
def warp_coordinates(self, coordinates): theta = self.theta theta = F.reshape( theta, theta.shape[:1] + (1,) + theta.shape[1:], inplace=False) if coordinates.shape[0] == self.bs: transformed = F.batch_matmul( F.tile(theta[:, :, :, :2], (1, coordinates.shape[1], 1, 1)), F.reshape(coordinates, coordinates.shape + (1,), inplace=False)) + theta[:, :, :, 2:] else: transformed = F.batch_matmul( F.tile(theta[:, :, :, :2], (1, coordinates.shape[1], 1, 1)), F.tile(F.reshape(coordinates, coordinates.shape + (1,), inplace=False), (self.bs / coordinates.shape[0], 1, 1, 1))) + theta[:, :, :, 2:] transformed = F.reshape( transformed, transformed.shape[:-1], inplace=False) if self.tps: control_points = self.control_points control_params = self.control_params distances = F.reshape( coordinates, (coordinates.shape[0], -1, 1, 2), inplace=False) - F.reshape(control_points, (1, 1, -1, 2)) distances = F.sum(F.abs(distances), axis=distances.ndim - 1) result = distances ** 2 result = result * F.log(distances + 1e-6) result = result * control_params result = F.sum(result, axis=2) result = F.reshape( result, (self.bs, coordinates.shape[1], 1), inplace=False) transformed = transformed + result return transformed
def gaussian_log_likelihood(x, mean, logstd, orig_max_val=255): """ Compute the log-likelihood of a Gaussian distribution for given data `x`. Args: x (nn.Variable): Target data. It is assumed that the values are ranged [-1, 1], which are originally [0, orig_max_val]. means (nn.Variable): Gaussian mean. Must be the same shape as x. logstd (nn.Variable): Gaussian log standard deviation. Must be the same shape as x. orig_max_val (int): The maximum value that x originally has before being rescaled. Return: A log probabilies of x in nats. """ assert x.shape == mean.shape == logstd.shape centered_x = x - mean inv_std = F.exp(-logstd) half_bin = 1.0 / orig_max_val def clamp(val): # Here we don't need to clip max return F.clip_by_value(val, min=1e-12, max=1e8) # x + 0.5 (in original scale) plus_in = inv_std * (centered_x + half_bin) cdf_plus = approx_standard_normal_cdf(plus_in) log_cdf_plus = F.log(clamp(cdf_plus)) # x - 0.5 (in original scale) minus_in = inv_std * (centered_x - half_bin) cdf_minus = approx_standard_normal_cdf(minus_in) log_one_minus_cdf_minus = F.log(clamp(1.0 - cdf_minus)) log_cdf_delta = F.log(clamp(cdf_plus - cdf_minus)) log_probs = F.where( F.less_scalar(x, -0.999), log_cdf_plus, # Edge case for 0. It uses cdf for -inf as cdf_minus. F.where(F.greater_scalar(x, 0.999), # Edge case for orig_max_val. It uses cdf for +inf as cdf_plus. log_one_minus_cdf_minus, log_cdf_delta # otherwise ) ) assert log_probs.shape == x.shape return log_probs
def er_loss(ctx, pred): with nn.context_scope(ctx): bs = pred.shape[0] d = np.prod(pred.shape[1:]) denominator = bs * d pred_normalized = F.softmax(pred) pred_log_normalized = F.log(F.softmax(pred)) loss_er = - F.sum(pred_normalized * pred_log_normalized) / denominator return loss_er
def er_loss(ctx, pred): with nn.context_scope(ctx): bs = pred.shape[0] d = np.prod(pred.shape[1:]) denominator = bs * d pred_normalized = F.softmax(pred) pred_log_normalized = F.log(F.softmax(pred)) loss_er = -F.sum(pred_normalized * pred_log_normalized) / denominator return loss_er
def net(n_class, xs, xq, init_type='nnabla', embedding='conv4', net_type='prototypical', distance='euclid', test=False): ''' Similarity net function This function implements the network with settings as specified. Args: n_class (int): number of classes. Typical setting is 5 or 20. xs (~nnabla.Variable): support images. xq (~nnabla.Variable): query images. init_type (str, optional): initialization type for weights and bias parameters. See conv_initializer function. embedding(str, optional): embedding network. distance (str, optional): similarity metric to use. See similarity function. test (bool, optional): switch flag for training dataset and test dataset Returns: h (~nnabla.Variable): output variable indicating similarity between support and query. ''' # feature embedding for supports and queries n_shot = xs.shape[0] / n_class n_query = xq.shape[0] / n_class if embedding == 'conv4': fs = conv4(xs, test, init_type) # tensor of (n_support, fdim) fq = conv4(xq, test, init_type) # tensor of (n_query, fdim) if net_type == 'matching': # This example does not include the full-context-embedding of matching networks. fs = F.reshape(fs, (1, ) + fs.shape) # (1, n_way, fdim) # (n_way*n_query, 1, fdim) fq = F.reshape(fq, (fq.shape[0], 1) + fq.shape[1:]) h = similarity(fq, fs, distance) h = h - F.mean(h, axis=1, keepdims=True) if 1 < n_shot: h = F.minimum_scalar(F.maximum_scalar(h, -35), 35) h = F.softmax(h) h = F.reshape(h, (h.shape[0], n_class, n_shot)) h = F.mean(h, axis=2) # Reverse to logit to use same softmax cross entropy h = F.log(h) elif net_type == 'prototypical': if 1 < n_shot: fs = F.reshape(fs, (n_class, n_shot) + fs.shape[1:]) fs = F.mean(fs, axis=1) fs = F.reshape(fs, (1, ) + fs.shape) # (1, n_way, fdim) # (n_way*n_query, 1, fdim) fq = F.reshape(fq, (fq.shape[0], 1) + fq.shape[1:]) h = similarity(fq, fs, distance) h = h - F.mean(h, axis=1, keepdims=True) return h
def sr_loss_with_uncertainty(ctx, pred0, pred1, log_var0, log_var1): var0 = F.exp(log_var0) var1 = F.exp(log_var1) s0 = F.pow_scalar(var0, 0.5) s1 = F.pow_scalar(var0, 0.5) squared_error = F.squared_error(pred0, pred1) with nn.context_scope(ctx): loss = F.log(s1/s0) + (var0/var1 + squared_error/var1) * 0.5 loss_sr = F.mean(loss) return loss_sr
def forward(self, x): N, C, H, W = x.shape log_abs = F.log(F.abs(self.scale)) logdet = H*W*F.sum(log_abs) if self.logdet: return self.scale * (x + self.loc), logdet else: return self.scale * (x + self.loc)
def log_spectrogram(wave, window_size): r"""Return log spectrogram. Args: wave (nn.Variable): Input waveform of shape (B, 1, L). window_size (int): Window size. Returns: nn.Variable: Log spectrogram. """ linear = spectrogram(wave, window_size) return F.log(linear * 1e4 + 1.0)
def compute_mel(self, wave): hp = self.hparams reals, imags = F.stft(wave, window_size=hp.win_length, stride=hp.hop_length, fft_size=hp.n_fft) linear = F.pow_scalar( F.add2(F.pow_scalar(reals, 2), F.pow_scalar(imags, 2)), 0.5) mels = F.batch_matmul(self.basis, linear) mels = F.log(F.clip_by_value(mels, 1e-5, np.inf)).apply(need_grad=False) return mels
def pow2_backward(inputs, inplace=False): # Inplacing is obsoleted. """ Args: inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function. kwargs (dict of arguments): Dictionary of the corresponding function arguments. Return: list of Variable: Return the gradients wrt inputs of the corresponding function. """ dy = inputs[0] x0 = inputs[1] x1 = inputs[2] dx0 = dy * x1 * x0**(x1 - 1) dx1 = dy * (x0**x1) * F.log(x0) return dx0, dx1
def loss(target_action, target_action_type, target_action_mask, rule_prob, terminal_gen_action_prob, token_prob, copy_prob): batch_size, max_action_length, _ = target_action.shape _, _, rule_num = rule_prob.shape _, _, token_num = token_prob.shape _, _, max_query_length = copy_prob.shape # (batch_size, max_action_length) target_rule, target_token, target_copy = F.split(target_action, axis=2) target_rule = F.reshape(target_rule, (batch_size, max_action_length, 1)) target_rule = F.one_hot( target_rule, (rule_num, )) # (batch_size, max_action_length, rule_num) rule_tgt_prob = rule_prob * target_rule # (batch_size, max_action_length, rule_num) rule_tgt_prob = F.sum(rule_tgt_prob, axis=2) # (batch_size, max_action_length) target_token = F.reshape(target_token, (batch_size, max_action_length, 1)) target_token = F.one_hot( target_token, (token_num, )) # (batch_size, max_action_length, token_num) token_tgt_prob = token_prob * target_token # (batch_size, max_action_length, token_num) token_tgt_prob = F.sum(token_tgt_prob, axis=2) # (batch_size, max_action_length) target_copy = F.reshape(target_copy, (batch_size, max_action_length, 1)) target_copy = F.one_hot( target_copy, (max_query_length, )) # (batch_size, max_action_length, max_query_lenght) copy_tgt_prob = copy_prob * target_copy # (batch_size, max_action_length, max_query_length) copy_tgt_prob = F.sum(copy_tgt_prob, axis=2) # (batch_size, max_action_length) # (batch_size, max_action_length) gen_token_prob, copy_token_prob = F.split(terminal_gen_action_prob, axis=2) # (batch_size, max_action_length) rule_mask, token_mask, copy_mask = F.split(target_action_type, axis=2) # (batch_size, max_action_length) target_prob = rule_mask * rule_tgt_prob + \ token_mask * gen_token_prob * token_tgt_prob + \ copy_mask * copy_token_prob * copy_tgt_prob # (batch_size, max_action_length) likelihood = F.log(target_prob + 1e-7) loss = -likelihood * target_action_mask # (batch_size) loss = F.sum(loss, axis=1) return F.mean(loss)
def __call__(self, _out_var=None): # input # _out_var : type=nn.Variable(), The discriminator output # --- self --- # self.coef_dict : type=OrderedDict(), The coefficient dict of the synthesis network (This needs to be on the graph.) # self.data_iterator : type=nnabla data iterator # output # loss : type=nn.Variable() # --- Calculation of the Fisher Information --- if _out_var is not None: temp_need_grad = self.y.need_grad self.y.need_grad = True if len(self.FisherInformation_val_dict) == 0: log_likelihood_var = F.log(F.sigmoid(_out_var)) for i in range(self.iter_num): log_likelihood_var.forward(clear_no_need_grad=True) self._zero_grad_all() log_likelihood_var.backward(clear_buffer=True) self._accumulate_grads() sys.stdout.write( '\rFisher Information Accumulating ... {}/{}'.format( i + 1, self.iter_num)) sys.stdout.flush() print('') for key in self.FisherInformation_val_dict: self.FisherInformation_val_dict[key] /= self.iter_num self.y.need_grad = temp_need_grad # --- make loss graph --- loss = 0 for key in self.FisherInformation_val_dict: key_source = key.replace(self.FI_scope + '/', '') FI_var = nn.Variable.from_numpy_array( self.FisherInformation_val_dict[key].copy()) FI_var.name = key coef_source_var = nn.Variable.from_numpy_array( self.coef_dict_for_FI[key_source].d.copy()) coef_source_var.name = key.replace(self.FI_scope + '/', 'weight_source/') loss += F.mean( FI_var * (self.coef_dict_for_FI[key_source] - coef_source_var)**2) # --- save Fisher Information --- if self.FI_save_switch: self._save_FisherInformation() print('[ElasticWeightConsolidation] Success!') return loss
def kl_multinomial_backward(inputs, base_axis=1): """ Args: inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function. kwargs (dict of arguments): Dictionary of the corresponding function arguments. Return: list of Variable: Return the gradients wrt inputs of the corresponding function. """ dy = inputs[0] p = inputs[1] q = inputs[2] reshape = list(dy.shape[:base_axis]) + \ [1 for _ in range(p.ndim - base_axis)] dy = F.reshape(dy, reshape, inplace=False) dp = dy * (F.log(p / q) + 1) dq = -dy * p / q return dp, dq
def __call__(self, x, return_encoding_indices=False): x = F.transpose(x, (0, 2, 3, 1)) x_flat = x.reshape((-1, self.embedding_dim)) x_flat_squared = F.broadcast(F.sum(x_flat**2, axis=1, keepdims=True), (x_flat.shape[0], self.num_embedding)) emb_wt_squared = F.transpose( F.sum(self.embedding_weight**2, axis=1, keepdims=True), (1, 0)) distances = x_flat_squared + emb_wt_squared - 2 * \ F.affine(x_flat, F.transpose(self.embedding_weight, (1, 0))) encoding_indices = F.min(distances, only_index=True, axis=1, keepdims=True) encoding_indices.need_grad = False quantized = F.embed( encoding_indices.reshape(encoding_indices.shape[:-1]), self.embedding_weight).reshape(x.shape) if return_encoding_indices: return encoding_indices, F.transpose(quantized, (0, 3, 1, 2)) encodings = F.one_hot(encoding_indices, (self.num_embedding, )) e_latent_loss = F.mean( F.squared_error(quantized.get_unlinked_variable(need_grad=False), x)) q_latent_loss = F.mean( F.squared_error(quantized, x.get_unlinked_variable(need_grad=False))) loss = q_latent_loss + self.commitment_cost * e_latent_loss quantized = x + (quantized - x).get_unlinked_variable(need_grad=False) avg_probs = F.mean(encodings, axis=0) perplexity = F.exp(-F.sum(avg_probs * F.log(avg_probs + 1.0e-10))) return loss, F.transpose(quantized, (0, 3, 1, 2)), perplexity, encodings
def compute_mel(wave, basis, hp): r"""Compute the mel-spectrogram from the waveform. Args: wave (nn.Variable): Wavefrom variable of shape (B, 1, L). basis (nn.Variable): Basis for mel-spectrogram computation. hp (HParams): Hyper-parameters. Returns: nn.Variable: Output variable. """ reals, imags = stft(wave, window_size=hp.win_length, stride=hp.hop_length, fft_size=hp.n_fft) linear = (reals**2 + imags**2)**0.5 mels = F.batch_matmul(basis, linear) mels = F.log(F.clip_by_value(mels, 1e-5, np.inf)) return mels
def __init__(self, num_actions, num_envs, batch_size, v_coeff, ent_coeff, lr_scheduler): # inference graph self.infer_obs_t = nn.Variable((num_envs, 4, 84, 84)) self.infer_pi_t,\ self.infer_value_t = cnn_network(self.infer_obs_t, num_actions, 'network') self.infer_t = F.sink(self.infer_pi_t, self.infer_value_t) # evaluation graph self.eval_obs_t = nn.Variable((1, 4, 84, 84)) self.eval_pi_t, _ = cnn_network(self.eval_obs_t, num_actions, 'network') # training graph self.obss_t = nn.Variable((batch_size, 4, 84, 84)) self.acts_t = nn.Variable((batch_size, 1)) self.rets_t = nn.Variable((batch_size, 1)) self.advs_t = nn.Variable((batch_size, 1)) pi_t, value_t = cnn_network(self.obss_t, num_actions, 'network') # value loss l2loss = F.squared_error(value_t, self.rets_t) self.value_loss = v_coeff * F.mean(l2loss) # policy loss log_pi_t = F.log(pi_t + 1e-20) a_one_hot = F.one_hot(self.acts_t, (num_actions, )) log_probs_t = F.sum(log_pi_t * a_one_hot, axis=1, keepdims=True) self.pi_loss = F.mean(log_probs_t * self.advs_t) # KL loss entropy = -ent_coeff * F.mean(F.sum(pi_t * log_pi_t, axis=1)) self.loss = self.value_loss - self.pi_loss - entropy self.params = nn.get_parameters() self.solver = S.RMSprop(lr_scheduler(0.0), 0.99, 1e-5) self.solver.set_parameters(self.params) self.lr_scheduler = lr_scheduler
def invertible_conv(x, reverse, rng, scope): r"""Invertible 1x1 Convolution Layer. Args: x (nn.Variable): Input variable. reverse (bool): Whether it's a reverse direction. rng (numpy.random.RandomState): A random generator. scope (str): The scope. Returns: nn.Variable: The output variable. """ batch_size, c, n_groups = x.shape with nn.parameter_scope(scope): # initialize w by an orthonormal matrix w_init = np.linalg.qr(rng.randn(c, c))[0][None, ...] W_var = get_parameter_or_create("W", (1, c, c), w_init, True, True) W = F.batch_inv(W_var) if reverse else W_var x = F.convolution(x, F.reshape(W, (c, c, 1)), None, stride=(1, )) if reverse: return x log_det = batch_size * n_groups * F.log(F.abs(F.batch_det(W))) return x, log_det
def log_mel_spectrogram(wave, sr, window_size, n_mels=80): """Return log mel-spectrogram. Args: wave (nn.Variable): Input waveform of shape (B, 1, L). sr (int): Sampling rate. window_size (int): Window size. n_mels (int): Number of mel banks. jitter (bool): Whether to apply random crop. Defaults to False. max_jitter_steps (int): Maximum number of jitter steps if jitter is set to `True`. Returns: nn.Variable: Log mel-spectrogram. """ linear = spectrogram(wave, window_size) mel_basis = librosa_mel_fn(sr, window_size, n_mels=n_mels, fmin=80.0, fmax=7600.0) basis = nn.Variable.from_numpy_array(mel_basis[None, ...]) mels = F.batch_matmul(basis, linear) return F.log(mels * 1e4 + 1.0)
def train_nerf(config, comm, model, dataset='blender'): use_transient = False use_embedding = False if model == 'wild': use_transient = True use_embedding = True elif model == 'uncertainty': use_transient = True elif model == 'appearance': use_embedding = True save_results_dir = config.log.save_results_dir os.makedirs(save_results_dir, exist_ok=True) train_loss_dict = { 'train_coarse_loss': 0.0, 'train_fine_loss': 0.0, 'train_total_loss': 0.0, } test_metric_dict = {'test_loss': 0.0, 'test_psnr': 0.0} monitor_manager = MonitorManager(train_loss_dict, test_metric_dict, save_results_dir) if dataset != 'phototourism': images, poses, _, hwf, i_test, i_train, near_plane, far_plane = get_data( config) height, width, focal_length = hwf else: di = get_photo_tourism_dataiterator(config, 'train', comm) val_di = get_photo_tourism_dataiterator(config, 'val', comm) if model != 'vanilla': if dataset != 'phototourism': config.train.n_vocab = max(np.max(i_train), np.max(i_test)) + 1 print( f'Setting Vocabulary size of embedding as {config.train.n_vocab}') if dataset != 'phototourism': if model in ['vanilla']: if comm is not None: # uncomment the following line to test on fewer images i_test = i_test[3 * comm.rank:3 * (comm.rank + 1)] pass else: # uncomment the following line to test on fewer images i_test = i_test[:3] pass else: # i_test = i_train[0:5] i_test = [i * (comm.rank + 1) for i in range(5)] else: i_test = [1] encode_position_function = get_encoding_function( config.train.num_encodings_position, True, True) if config.train.use_view_directions: encode_direction_function = get_encoding_function( config.train.num_encodings_direction, True, True) else: encode_direction_function = None lr = config.solver.lr num_decay_steps = config.solver.lr_decay_step * 1000 lr_decay_factor = config.solver.lr_decay_factor solver = S.Adam(alpha=lr) load_solver_state = False if config.checkpoint.param_path is not None: nn.load_parameters(config.checkpoint.param_path) load_solver_state = True if comm is not None: num_decay_steps /= comm.n_procs comm_size = comm.n_procs else: comm_size = 1 pbar = trange(config.train.num_iterations // comm_size, disable=(comm is not None and comm.rank > 0)) for i in pbar: if dataset != 'phototourism': idx = np.random.choice(i_train) image = nn.Variable.from_numpy_array(images[idx][None, :, :, :3]) pose = nn.Variable.from_numpy_array(poses[idx]) ray_directions, ray_origins = get_ray_bundle( height, width, focal_length, pose) grid = get_direction_grid(width, height, focal_length, return_ij_2d_grid=True) grid = F.reshape(grid, (-1, 2)) select_inds = np.random.choice(grid.shape[0], size=[config.train.num_rand_points], replace=False) select_inds = F.gather_nd(grid, select_inds[None, :]) select_inds = F.transpose(select_inds, (1, 0)) embed_inp = nn.Variable.from_numpy_array( np.full((config.train.chunksize_fine, ), idx, dtype=int)) ray_origins = F.gather_nd(ray_origins, select_inds) ray_directions = F.gather_nd(ray_directions, select_inds) image = F.gather_nd(image[0], select_inds) else: rays, embed_inp, image = di.next() ray_origins = nn.Variable.from_numpy_array(rays[:, :3]) ray_directions = nn.Variable.from_numpy_array(rays[:, 3:6]) near_plane = nn.Variable.from_numpy_array(rays[:, 6]) far_plane = nn.Variable.from_numpy_array(rays[:, 7]) embed_inp = nn.Variable.from_numpy_array(embed_inp) image = nn.Variable.from_numpy_array(image) hwf = None app_emb, trans_emb = None, None if use_embedding: with nn.parameter_scope('embedding_a'): app_emb = PF.embed(embed_inp, config.train.n_vocab, config.train.n_app) if use_transient: with nn.parameter_scope('embedding_t'): trans_emb = PF.embed(embed_inp, config.train.n_vocab, config.train.n_trans) if use_transient: rgb_map_course, rgb_map_fine, static_rgb_map_fine, transient_rgb_map_fine, beta, static_sigma, transient_sigma = forward_pass( ray_directions, ray_origins, near_plane, far_plane, app_emb, trans_emb, encode_position_function, encode_direction_function, config, use_transient, hwf=hwf, image=image) course_loss = 0.5 * F.mean(F.squared_error(rgb_map_course, image)) fine_loss = 0.5 * F.mean( F.squared_error(rgb_map_fine, image) / F.reshape(F.pow_scalar(beta, 2), beta.shape + (1, ))) beta_reg_loss = 3 + F.mean(F.log(beta)) sigma_trans_reg_loss = 0.01 * F.mean(transient_sigma) loss = course_loss + fine_loss + beta_reg_loss + sigma_trans_reg_loss else: rgb_map_course, _, _, _, rgb_map_fine, _, _, _ = forward_pass( ray_directions, ray_origins, near_plane, far_plane, app_emb, trans_emb, encode_position_function, encode_direction_function, config, use_transient, hwf=hwf) course_loss = F.mean(F.squared_error(rgb_map_course, image)) fine_loss = F.mean(F.squared_error(rgb_map_fine, image)) loss = course_loss + fine_loss pbar.set_description( f'Total: {np.around(loss.d, 4)}, Course: {np.around(course_loss.d, 4)}, Fine: {np.around(fine_loss.d, 4)}' ) solver.set_parameters(nn.get_parameters(), reset=False, retain_state=True) if load_solver_state: solver.load_states(config['checkpoint']['solver_path']) load_solver_state = False solver.zero_grad() loss.backward(clear_buffer=True) # Exponential LR decay if dataset != 'phototourism': lr_factor = (lr_decay_factor**((i) / num_decay_steps)) solver.set_learning_rate(lr * lr_factor) else: if i % num_decay_steps == 0 and i != 0: solver.set_learning_rate(lr * lr_decay_factor) if comm is not None: params = [x.grad for x in nn.get_parameters().values()] comm.all_reduce(params, division=False, inplace=True) solver.update() if ((i % config.train.save_interval == 0 or i == config.train.num_iterations - 1) and i != 0) and (comm is not None and comm.rank == 0): nn.save_parameters(os.path.join(save_results_dir, f'iter_{i}.h5')) solver.save_states( os.path.join(save_results_dir, f'solver_iter_{i}.h5')) if (i % config.train.test_interval == 0 or i == config.train.num_iterations - 1) and i != 0: avg_psnr, avg_mse = 0.0, 0.0 for i_t in trange(len(i_test)): if dataset != 'phototourism': idx_test = i_test[i_t] image = nn.NdArray.from_numpy_array( images[idx_test][None, :, :, :3]) pose = nn.NdArray.from_numpy_array(poses[idx_test]) ray_directions, ray_origins = get_ray_bundle( height, width, focal_length, pose) ray_directions = F.reshape(ray_directions, (-1, 3)) ray_origins = F.reshape(ray_origins, (-1, 3)) embed_inp = nn.NdArray.from_numpy_array( np.full((config.train.chunksize_fine, ), idx_test, dtype=int)) else: rays, embed_inp, image = val_di.next() ray_origins = nn.NdArray.from_numpy_array(rays[0, :, :3]) ray_directions = nn.NdArray.from_numpy_array(rays[0, :, 3:6]) near_plane_ = nn.NdArray.from_numpy_array(rays[0, :, 6]) far_plane_ = nn.NdArray.from_numpy_array(rays[0, :, 7]) embed_inp = nn.NdArray.from_numpy_array( embed_inp[0, :config.train.chunksize_fine]) image = nn.NdArray.from_numpy_array(image[0].transpose( 1, 2, 0)) image = F.reshape(image, (1, ) + image.shape) idx_test = 1 app_emb, trans_emb = None, None if use_embedding: with nn.parameter_scope('embedding_a'): app_emb = PF.embed(embed_inp, config.train.n_vocab, config.train.n_app) if use_transient: with nn.parameter_scope('embedding_t'): trans_emb = PF.embed(embed_inp, config.train.n_vocab, config.train.n_trans) num_ray_batches = ray_directions.shape[ 0] // config.train.ray_batch_size + 1 if use_transient: rgb_map_fine_list, static_rgb_map_fine_list, transient_rgb_map_fine_list = [], [], [] else: rgb_map_fine_list, depth_map_fine_list = [], [] for r_idx in trange(num_ray_batches): if r_idx != num_ray_batches - 1: ray_d, ray_o = ray_directions[ r_idx * config.train.ray_batch_size:(r_idx + 1) * config.train.ray_batch_size], ray_origins[ r_idx * config.train.ray_batch_size:(r_idx + 1) * config.train.ray_batch_size] if dataset == 'phototourism': near_plane = near_plane_[ r_idx * config.train.ray_batch_size:(r_idx + 1) * config.train.ray_batch_size] far_plane = far_plane_[r_idx * config.train.ray_batch_size: (r_idx + 1) * config.train.ray_batch_size] else: if ray_directions.shape[0] - ( num_ray_batches - 1) * config.train.ray_batch_size == 0: break ray_d, ray_o = ray_directions[ r_idx * config.train.ray_batch_size:, :], ray_origins[ r_idx * config.train.ray_batch_size:, :] if dataset == 'phototourism': near_plane = near_plane_[r_idx * config.train. ray_batch_size:] far_plane = far_plane_[r_idx * config.train. ray_batch_size:] if use_transient: rgb_map_course, rgb_map_fine, static_rgb_map_fine, transient_rgb_map_fine, beta, static_sigma, transient_sigma = forward_pass( ray_d, ray_o, near_plane, far_plane, app_emb, trans_emb, encode_position_function, encode_direction_function, config, use_transient, hwf=hwf) rgb_map_fine_list.append(rgb_map_fine) static_rgb_map_fine_list.append(static_rgb_map_fine) transient_rgb_map_fine_list.append( transient_rgb_map_fine) else: _, _, _, _, rgb_map_fine, depth_map_fine, _, _ = \ forward_pass(ray_d, ray_o, near_plane, far_plane, app_emb, trans_emb, encode_position_function, encode_direction_function, config, use_transient, hwf=hwf) rgb_map_fine_list.append(rgb_map_fine) depth_map_fine_list.append(depth_map_fine) if use_transient: rgb_map_fine = F.concatenate(*rgb_map_fine_list, axis=0) static_rgb_map_fine = F.concatenate( *static_rgb_map_fine_list, axis=0) transient_rgb_map_fine = F.concatenate( *transient_rgb_map_fine_list, axis=0) rgb_map_fine = F.reshape(rgb_map_fine, image[0].shape) static_rgb_map_fine = F.reshape(static_rgb_map_fine, image[0].shape) transient_rgb_map_fine = F.reshape(transient_rgb_map_fine, image[0].shape) static_trans_img_to_save = np.concatenate( (static_rgb_map_fine.data, np.ones((image[0].shape[0], 5, 3)), transient_rgb_map_fine.data), axis=1) img_to_save = np.concatenate( (image[0].data, np.ones( (image[0].shape[0], 5, 3)), rgb_map_fine.data), axis=1) else: rgb_map_fine = F.concatenate(*rgb_map_fine_list, axis=0) depth_map_fine = F.concatenate(*depth_map_fine_list, axis=0) rgb_map_fine = F.reshape(rgb_map_fine, image[0].shape) depth_map_fine = F.reshape(depth_map_fine, image[0].shape[:-1]) img_to_save = np.concatenate( (image[0].data, np.ones( (image[0].shape[0], 5, 3)), rgb_map_fine.data), axis=1) filename = os.path.join(save_results_dir, f'{i}_{idx_test}.png') try: imsave(filename, np.clip(img_to_save, 0, 1), channel_first=False) print(f'Saved generation at {filename}') if use_transient: filename_static_trans = os.path.join( save_results_dir, f's_t_{i}_{idx_test}.png') imsave(filename_static_trans, np.clip(static_trans_img_to_save, 0, 1), channel_first=False) else: filename_dm = os.path.join(save_results_dir, f'dm_{i}_{idx_test}.png') depth_map_fine = (depth_map_fine.data - depth_map_fine.data.min()) / ( depth_map_fine.data.max() - depth_map_fine.data.min()) imsave(filename_dm, depth_map_fine[:, :, None], channel_first=False) plt.imshow(depth_map_fine.data) plt.savefig(filename_dm) plt.close() except: pass avg_mse += F.mean(F.squared_error(rgb_map_fine, image[0])).data avg_psnr += (-10. * np.log10( F.mean(F.squared_error(rgb_map_fine, image[0])).data)) test_metric_dict['test_loss'] = avg_mse / len(i_test) test_metric_dict['test_psnr'] = avg_psnr / len(i_test) monitor_manager.add(i, test_metric_dict) print( f'Saved generations after {i} training iterations! Average PSNR: {avg_psnr/len(i_test)}. Average MSE: {avg_mse/len(i_test)}' )
def kl_divergence(ctx, pred, label): with nn.context_scope(ctx): elms = F.softmax(label, axis=1) * F.log(F.softmax(pred, axis=1)) loss = -F.mean(F.sum(elms, axis=1)) return loss
def log2(x): return F.log(x) / np.log(2.)
def quantize_pow2(v): return 2**F.round(F.log(v) / np.log(2.))
central_bias = PF.embed(x_central, vocab_size, 1) with nn.parameter_scope('context_bias'): context_bias = PF.embed(x_context, vocab_size, 1) dot_product = F.reshape(F.batch_matmul( F.reshape(central_embedding, shape=(batch_size, 1, embedding_size)), F.reshape(context_embedding, shape=(batch_size, embedding_size, 1))), shape=(batch_size, 1)) prediction = dot_product + central_bias + context_bias t = nn.Variable((batch_size, 1)) zero = F.constant(0, shape=(batch_size, 1)) one = F.constant(1, shape=(batch_size, 1)) weight = F.clip_by_value(t / 100, zero, one)**0.75 loss = F.sum(weight * ((prediction - F.log(t))**2)) # Create solver. solver = S.Adam() solver.set_parameters(nn.get_parameters()) # Create monitor monitor = M.Monitor('./log') monitor_loss = M.MonitorSeries("Training loss", monitor, interval=1000) monitor_valid_loss = M.MonitorSeries("Validation loss", monitor, interval=1) monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=1000) # Create updater def train_data_feeder(): x_central.d, x_context.d, t.d = train_data_iter.next()
def ce_loss_soft(ctx, pred, target): with nn.context_scope(ctx): #todo: devide or not loss = - F.mean(F.sum(F.softmax(target) * F.log(F.softmax(pred)), axis=1)) return loss
def p_mean_var(self, model, x_t, t, clip_denoised=True): """ Compute mean and var of p(x_{t-1}|x_t) from model. Args: model (Callable): A callbale that takes x_t and t and predict noise (and more). x_t (nn.Variable): The (B, C, ...) tensor at timestep t (x_t). t (nn.Variable): A 1-D tensor of timesteps. The first axis represents batchsize. clip_denoised (bool): If True, clip the denoised signal into [-1, 1]. Returns: An AttrDict containing the following items: "mean": the mean predicted by model. "var": the variance predicted by model (or pre-defined variance). "log_var": the log of "var". "xstart": the x_0 predicted from x_t and t by model. """ B, C, H, W = x_t.shape assert t.shape == (B, ) pred = model(x_t, t) if self.model_var_type == ModelVarType.LEARNED_RANGE: assert pred.shape == (B, 2 * C, H, W) pred_noise, pred_var_coeff = chunk(pred, num_chunk=2, axis=1) min_log = self._extract( self.posterior_log_var_clipped, t, x_t.shape) max_log = F.log(self._extract(self.betas, t, x_t.shape)) # pred_var_coeff should be [0, 1] v = F.sigmoid(pred_var_coeff) model_log_var = v * max_log + (1 - v) * min_log model_var = F.exp(model_log_var) else: # Model only predicts noise pred_noise = pred model_log_var, model_var = { ModelVarType.FIXED_LARGE: lambda: ( self._extract(self.log_betas_clipped, t, x_t.shape), self._extract(self.betas_clipped, t, x_t.shape) ), ModelVarType.FIXED_SMALL: lambda: ( self._extract( self.posterior_log_var_clipped, t, x_t.shape), self._extract(self.posterior_var, t, x_t.shape) ) }[self.model_var_type]() x_recon = self.predict_xstart_from_noise( x_t=x_t, t=t, noise=pred_noise) if clip_denoised: x_recon = F.clip_by_value(x_recon, -1, 1) model_mean, _, _ = self.q_posterior(x_start=x_recon, x_t=x_t, t=t) assert model_mean.shape == x_recon.shape == x_t.shape assert model_mean.shape == model_var.shape == model_log_var.shape or \ (model_mean.shape[0] == model_var.shape[0] == model_log_var.shape[0] and model_var.shape[1:] == ( 1, 1, 1) and model_log_var.shape[1:] == (1, 1, 1)) # returns ret = AttrDict() ret.mean = model_mean ret.var = model_var ret.log_var = model_log_var ret.xstart = x_recon return ret
def ce_soft(pred, label): elms = - F.softmax(label, axis=1) * F.log(F.softmax(pred, axis=1)) loss = F.mean(F.sum(elms, axis=1)) return loss
def loss_function(u, v, negative_samples): return F.sum(-F.log( F.exp(-distance(u, v)) / sum([ F.exp(-distance(u, x)) for x in F.split(negative_samples, axis=2) ])))
def _build(self): # inference graph self.infer_obs_t = nn.Variable((1, ) + self.obs_shape) with nn.parameter_scope('trainable'): infer_dist = policy_network(self.infer_obs_t, self.action_size, 'actor') self.infer_act_t, _ = _squash_action(infer_dist) self.deterministic_act_t = infer_dist.mean() # training graph self.obss_t = nn.Variable((self.batch_size, ) + self.obs_shape) self.acts_t = nn.Variable((self.batch_size, self.action_size)) self.rews_tp1 = nn.Variable((self.batch_size, 1)) self.obss_tp1 = nn.Variable((self.batch_size, ) + self.obs_shape) self.ters_tp1 = nn.Variable((self.batch_size, 1)) with nn.parameter_scope('trainable'): dist = policy_network(self.obss_t, self.action_size, 'actor') squashed_act_t, log_prob_t = _squash_action(dist) v_t = v_network(self.obss_t, 'value') q_t1 = q_network(self.obss_t, self.acts_t, 'critic/1') q_t2 = q_network(self.obss_t, self.acts_t, 'critic/2') q_t1_with_actor = q_network(self.obss_t, squashed_act_t, 'critic/1') q_t2_with_actor = q_network(self.obss_t, squashed_act_t, 'critic/2') with nn.parameter_scope('target'): v_tp1 = v_network(self.obss_tp1, 'value') # value loss q_t = F.minimum2(q_t1_with_actor, q_t2_with_actor) v_target = q_t - log_prob_t v_target.need_grad = False self.value_loss = 0.5 * F.mean(F.squared_error(v_t, v_target)) # q function loss scaled_rews_tp1 = self.rews_tp1 * self.reward_scale q_target = scaled_rews_tp1 + self.gamma * v_tp1 * (1.0 - self.ters_tp1) q_target.need_grad = False q1_loss = 0.5 * F.mean(F.squared_error(q_t1, q_target)) q2_loss = 0.5 * F.mean(F.squared_error(q_t2, q_target)) self.critic_loss = q1_loss + q2_loss # policy function loss mean_loss = 0.5 * F.mean(dist.mean()**2) logstd_loss = 0.5 * F.mean(F.log(dist.stddev())**2) policy_reg_loss = self.policy_reg * (mean_loss + logstd_loss) self.objective_loss = F.mean(log_prob_t - q_t) self.actor_loss = self.objective_loss + policy_reg_loss # trainable parameters with nn.parameter_scope('trainable'): with nn.parameter_scope('value'): value_params = nn.get_parameters() with nn.parameter_scope('critic'): critic_params = nn.get_parameters() with nn.parameter_scope('actor'): actor_params = nn.get_parameters() # target parameters with nn.parameter_scope('target/value'): target_params = nn.get_parameters() # target update update_targets = [] sync_targets = [] for key, src in value_params.items(): dst = target_params[key] updated_dst = (1.0 - self.tau) * dst + self.tau * src update_targets.append(F.assign(dst, updated_dst)) sync_targets.append(F.assign(dst, src)) self.update_target_expr = F.sink(*update_targets) self.sync_target_expr = F.sink(*sync_targets) # setup solvers self.value_solver = S.Adam(self.value_lr) self.value_solver.set_parameters(value_params) self.critic_solver = S.Adam(self.critic_lr) self.critic_solver.set_parameters(critic_params) self.actor_solver = S.Adam(self.actor_lr) self.actor_solver.set_parameters(actor_params)
def get_tecogan_model(conf, r_inputs, r_targets, scope_name, tecogan=True): """ Create computation graph and variables for TecoGAN. """ # r_inputs, r_targets : shape (batch, conf.train.rnn_n, h, w, c) rnn_length = conf.train.rnn_n if tecogan: r_inputs, r_targets = get_tecogan_inputs(r_inputs, r_targets) rnn_length = rnn_length * 2 - 1 # get the consecutive frame sequences from the input sequence frame_t_pre, frame_t = r_inputs[:, 0:-1, :, :, :], r_inputs[:, 1:, :, :, :] # Get flow estimations fnet_output = get_fnet_output(conf, rnn_length, frame_t_pre, frame_t, scope_name) # Get the generated HR output frames gen_outputs = get_generator_output(conf, rnn_length, r_inputs, fnet_output.flow_hr, scope_name) s_gen_output = F.reshape( gen_outputs, (conf.train.batch_size * rnn_length, conf.train.crop_size * 4, conf.train.crop_size * 4, 3), inplace=False) s_targets = F.reshape( r_targets, (conf.train.batch_size * rnn_length, conf.train.crop_size * 4, conf.train.crop_size * 4, 3), inplace=False) # Content loss (l2 loss) content_loss = F.mean( F.sum(F.squared_error(s_gen_output, s_targets), axis=[3])) # Warp loss (l2 loss) warp_loss = get_warp_loss(conf, rnn_length, frame_t, frame_t_pre, fnet_output.flow_lr) if tecogan: d_data = get_d_data(conf, fnet_output.flow_hr, gen_outputs, r_targets, rnn_length) # Build the tempo discriminator for the real part and fake part t_d = get_t_d(conf, r_inputs, d_data) # Discriminator layer loss: d_layer_loss = get_d_layer(t_d.real_layers, t_d.fake_layers) # vgg loss (cosine similarity) loss_vgg = get_vgg_loss(s_gen_output, s_targets) # ping pong loss (an l1 loss) gen_out_first = gen_outputs[:, 0:conf.train.rnn_n - 1, :, :, :] gen_out_last_rev = gen_outputs[:, -1:-conf.train.rnn_n:-1, :, :, :] pp_loss = F.mean(F.abs(gen_out_first - gen_out_last_rev)) # adversarial loss t_adversarial_loss = F.mean(-F.log(t_d.tdiscrim_fake_output + conf.train.eps)) # Overall generator loss gen_loss = content_loss + pp_loss * conf.gan.pp_scaling + conf.gan.ratio * \ t_adversarial_loss + conf.gan.vgg_scaling * loss_vgg + \ conf.gan.dt_ratio_0 * d_layer_loss # Discriminator loss t_discrim_fake_loss = F.log(1 - t_d.tdiscrim_fake_output + conf.train.eps) t_discrim_real_loss = F.log(t_d.tdiscrim_real_output + conf.train.eps) t_discrim_loss = F.mean(-(t_discrim_fake_loss + t_discrim_real_loss)) fnet_loss = gen_loss + warp_loss set_persistent_all(r_targets, r_inputs, loss_vgg, gen_out_first, gen_out_last_rev, pp_loss, d_layer_loss, content_loss, warp_loss, gen_loss, t_adversarial_loss, t_discrim_loss, t_discrim_real_loss, d_data.t_vel, d_data.t_gen_output, s_gen_output, s_targets) Network = collections.namedtuple( 'Network', 'content_loss, warp_loss, fnet_loss, vgg_loss,' 'gen_loss, pp_loss, sum_layer_loss,t_adversarial_loss,' 't_discrim_loss,t_gen_output,t_discrim_real_loss') return Network(content_loss=content_loss, warp_loss=warp_loss, fnet_loss=fnet_loss, vgg_loss=loss_vgg, gen_loss=gen_loss, pp_loss=pp_loss, sum_layer_loss=d_layer_loss, t_adversarial_loss=t_adversarial_loss, t_discrim_loss=t_discrim_loss, t_gen_output=d_data.t_gen_output, t_discrim_real_loss=t_discrim_real_loss) gen_loss = content_loss fnet_loss = gen_loss + warp_loss set_persistent_all(content_loss, s_gen_output, warp_loss, gen_loss, fnet_loss) Network = collections.namedtuple( 'Network', 'content_loss, warp_loss, fnet_loss, gen_loss') return Network( content_loss=content_loss, warp_loss=warp_loss, fnet_loss=fnet_loss, gen_loss=gen_loss, )
def sample_from_controller(args): """ 2-layer RNN(LSTM) based controller which outputs an architecture of CNN, represented as a sequence of integers and its list. Given the number of layers, for each layer, it executes 2 types of computation, one for sampling the operation at that layer, another for sampling the skip connection patterns. """ entropys = nn.Variable([1, 1], need_grad=True) log_probs = nn.Variable([1, 1], need_grad=True) skip_penaltys = nn.Variable([1, 1], need_grad=True) entropys.d = log_probs.d = skip_penaltys.d = 0.0 # initialize them all num_layers = args.num_layers lstm_size = args.lstm_size state_size = args.state_size lstm_num_layers = args.lstm_layers skip_target = args.skip_prob temperature = args.temperature tanh_constant = args.tanh_constant num_branch = args.num_ops arc_seq = [] initializer = I.UniformInitializer((-0.1, 0.1)) prev_h = [ nn.Variable([1, lstm_size], need_grad=True) for _ in range(lstm_num_layers) ] prev_c = [ nn.Variable([1, lstm_size], need_grad=True) for _ in range(lstm_num_layers) ] for i in range(len(prev_h)): prev_h[i].d = 0 # initialize variables in lstm layers. prev_c[i].d = 0 inputs = nn.Variable([1, lstm_size]) inputs.d = np.random.normal(0, 0.5, [1, lstm_size]) g_emb = nn.Variable([1, lstm_size]) g_emb.d = np.random.normal(0, 0.5, [1, lstm_size]) skip_targets = nn.Variable([1, 2]) skip_targets.d = np.array([[1.0 - skip_target, skip_target]]) for layer_id in range(num_layers): # One-step stacked LSTM. with nn.parameter_scope("controller_lstm"): next_h, next_c = stack_lstm(inputs, prev_h, prev_c, state_size) prev_h, prev_c = next_h, next_c # shape:(1, lstm_size) # Compute for operation. with nn.parameter_scope("ops"): logit = PF.affine(next_h[-1], num_branch, w_init=initializer, with_bias=False) if temperature is not None: logit = F.mul_scalar(logit, (1 / temperature)) if tanh_constant is not None: logit = F.mul_scalar(F.tanh(logit), tanh_constant) # (1, num_branch) # normalizing logits. normed_logit = np.e**logit.d normed_logit = normed_logit / np.sum(normed_logit) # Sampling operation id from multinomial distribution. ops_id = np.random.multinomial(1, normed_logit[0], 1).nonzero()[1] ops_id = nn.Variable.from_numpy_array(ops_id) # (1, ) arc_seq.append(ops_id.d) # log policy for operation. log_prob = F.softmax_cross_entropy(logit, F.reshape(ops_id, shape=(1, 1))) # (1, ) # accumulate log policy as log probs log_probs = F.add2(log_probs, log_prob) entropy = log_prob * F.exp(-log_prob) entropys = F.add2(entropys, entropy) # accumulate entropy as entropys. w_emb = nn.parameter.get_parameter_or_create("w_emb", [num_branch, lstm_size], initializer, need_grad=False) inputs = F.reshape(w_emb[int(ops_id.d)], (1, w_emb.shape[1])) # (1, lstm_size) with nn.parameter_scope("controller_lstm"): next_h, next_c = stack_lstm(inputs, prev_h, prev_c, lstm_size) prev_h, prev_c = next_h, next_c # (1, lstm_size) with nn.parameter_scope("skip_affine_3"): adding_w_1 = PF.affine(next_h[-1], lstm_size, w_init=initializer, with_bias=False) # (1, lstm_size) if layer_id == 0: inputs = g_emb # (1, lstm_size) anchors = next_h[-1] # (1, lstm_size) anchors_w_1 = adding_w_1 # then goes back to the entry point of the loop else: # (layer_id, lstm_size) this shape during the process query = anchors_w_1 with nn.parameter_scope("skip_affine_1"): query = F.tanh( F.add2( query, PF.affine(next_h[-1], lstm_size, w_init=initializer, with_bias=False))) # (layer_id, lstm_size) + (1, lstm_size) # broadcast occurs here. resulting shape is; (layer_id, lstm_size) with nn.parameter_scope("skip_affine_2"): query = PF.affine(query, 1, w_init=initializer, with_bias=False) # (layer_id, 1) # note that each weight for skip_affine_X is shared across all steps of LSTM. # re-define logits, now its shape is;(layer_id, 2) logit = F.concatenate(-query, query, axis=1) if temperature is not None: logit = F.mul_scalar(logit, (1 / temperature)) if tanh_constant is not None: logit = F.mul_scalar(F.tanh(logit), tanh_constant) skip_prob_unnormalized = F.exp(logit) # (layer_id, 2) # normalizing skip_prob_unnormalized. summed = F.sum(skip_prob_unnormalized, axis=1, keepdims=True).apply(need_grad=False) summed = F.concatenate(summed, summed, axis=1) skip_prob_normalized = F.div2(skip_prob_unnormalized, summed) # (layer_id, 2) # Sampling skip_pattern from multinomial distribution. skip_pattern = np.random.multinomial( 1, skip_prob_normalized.d[0], layer_id).nonzero()[1] # (layer_id, 1) arc_seq.append(skip_pattern) skip = nn.Variable.from_numpy_array(skip_pattern) # compute skip penalty. # (layer_id, 2) broadcast occurs here too kl = F.mul2(skip_prob_normalized, F.log(F.div2(skip_prob_normalized, skip_targets))) kl = F.sum(kl, keepdims=True) # get the mean value here in advance. kl = kl * (1.0 / (num_layers - 1)) # accumulate kl divergence as skip penalty. skip_penaltys = F.add2(skip_penaltys, kl) # log policy for connection. log_prob = F.softmax_cross_entropy( logit, F.reshape(skip, shape=(skip.shape[0], 1))) log_probs = F.add2(log_probs, F.sum(log_prob, keepdims=True)) entropy = F.sum(log_prob * F.exp(-log_prob), keepdims=True) # accumulate entropy as entropys. entropys = F.add2(entropys, entropy) skip = F.reshape(skip, (1, layer_id)) inputs = F.affine(skip, anchors).apply(need_grad=False) # (1, lstm_size) inputs = F.mul_scalar(inputs, (1.0 / (1.0 + (np.sum(skip.d))))) # add new row for the next computation # (layer_id + 1, lstm_size) anchors = F.concatenate(anchors, next_h[-1], axis=0) # (layer_id + 1, lstm_size) anchors_w_1 = F.concatenate(anchors_w_1, adding_w_1, axis=0) return arc_seq, log_probs, entropys, skip_penaltys