def attention(k, q, v, div_dim=True, softmax=True): v_shape = v.shape k = F.identity(k) q = F.identity(q) k = F.reshape(k, (k.shape[0], np.prod(k.shape[1:]))) q = F.reshape(q, (q.shape[0], np.prod(q.shape[1:]))) v = q # F.reshape is inplace cf = F.affine(q, F.transpose(k, (1, 0))) if div_dim: dim = np.prod(v_shape[1:]) cf /= np.sqrt(dim) h = cf if softmax: h = F.softmax(h) h = F.affine(h, v)x h = F.reshape(h, v_shape) return h
def upscale_four(inputs, scope='upscale_four'): """ Mimic the tensorflow bilinear-upscaling for a fix ratio of 4. """ with nn.parameter_scope(scope): b, h, w, c = inputs.shape p_inputs = F.concatenate( inputs, inputs[:, -1:, :, :], axis=1) # pad bottom p_inputs = F.concatenate( p_inputs, p_inputs[:, :, -1:, :], axis=2) # pad right hi_res_bin = [ [ inputs, # top-left p_inputs[:, :-1, 1:, :] # top-right ], [ p_inputs[:, 1:, :-1, :], # bottom-left p_inputs[:, 1:, 1:, :] # bottom-right ] ] hi_res_array = [] for hi in range(4): for wj in range(4): hi_res_array.append( hi_res_bin[0][0] * (1.0 - 0.25 * hi) * (1.0 - 0.25 * wj) + hi_res_bin[0][1] * (1.0 - 0.25 * hi) * (0.25 * wj) + hi_res_bin[1][0] * (0.25 * hi) * (1.0 - 0.25 * wj) + hi_res_bin[1][1] * (0.25 * hi) * (0.25 * wj) ) hi_res = F.stack(*hi_res_array, axis=3) # shape (b,h,w,16,c) hi_res_reshape = F.reshape(hi_res, (b, h, w, 4, 4, c)) hi_res_reshape = F.transpose(hi_res_reshape, (0, 1, 3, 2, 4, 5)) hi_res_reshape = F.reshape(hi_res_reshape, (b, h*4, w*4, c)) return hi_res_reshape
def _rnn(x, h, w, b, nonlinearity, with_bias): """RNN cell. Args: x (:obj:`~nnabla.Variable`): Input data. h (:obj:`~nnabla.Variable`): Hidden state. w (:obj:`~nnabla.Variable`): Weight. b (:obj:`~nnabla.Variable`): Bias. nonlinearity (str): "tanh" or "relu". with_bias (bool): Include the bias or not. """ hidden_size = h.shape[1] xh = F.concatenate(*(x, h), axis=1) b_ = None if with_bias: b_ = b h_t = F.affine(xh, F.transpose(w, (1, 0)), b_) if nonlinearity == 'tanh': h_t = F.tanh(h_t) elif nonlinearity == 'relu': h_t = F.relu(h_t) return h_t
def __call__(self, conv_in, h=None): v_stack_in = conv_in h_stack_in = conv_in features = [] with nn.parameter_scope('ConditionalPixelCNN'): for i in range(self.num_layers): if i == 0: kernel_shape = (7, 7) mask_type = self.mask_type_A residual = False else: kernel_shape = (3, 3) mask_type = self.mask_type_B residual = True v_stack_gated, v_stack_conv = self.gated_conv(v_stack_in, kernel_shape, h, mask_type=mask_type, return_payload=True, scope_name='vertical_stack_gated_'+str(i)) h_stack_gated = self.gated_conv(h_stack_in, (1, kernel_shape[0]), h, mask_type=mask_type, payload=v_stack_conv, scope_name='horizontal_stack_gated_'+str(i)) h_stack_conv = self.gated_conv(h_stack_gated, (1, 1), h, mask_type=mask_type, gated=False, scope_name='horizontal_stack_conv_'+str(i)) if residual: h_stack_conv += h_stack_in v_stack_in = v_stack_gated h_stack_in = h_stack_conv fc_1 = self.gated_conv( h_stack_in, (1, 1), gated=False, scope_name='fc_1') fc_2 = PF.convolution(fc_1, self.out_channels, (1, 1), apply_w=self.mask_type_B, name='fc_2') fc_2 = F.transpose(fc_2, (0, 2, 3, 1)) fc_2 = F.reshape(fc_2, (-1, fc_2.shape[-1]), inplace=True) return fc_2
def define_network(self): if self.use_inst: obj_onehot, bm = encode_inputs(self.ist_mask, self.obj_mask, n_ids=self.conf.n_class) mask = F.concatenate(obj_onehot, bm, axis=1) else: om = self.obj_mask if len(om.shape) == 3: om = F.reshape(om, om.shape + (1, )) obj_onehot = F.one_hot(om, shape=(self.conf.n_class, )) mask = F.transpose(obj_onehot, (0, 3, 1, 2)) generator = SpadeGenerator(self.conf.g_ndf, image_shape=self.conf.image_shape) z = F.randn(shape=(self.conf.batch_size, self.conf.z_dim)) fake = generator(z, mask) # Pixel intensities of fake are [-1, 1]. Rescale it to [0, 1] fake = (fake + 1) / 2 return fake
def primary_capsule(h, factor_capsules=32, out_channels=8, kernel=9, stride=2, fix_parameters=False): ''' Takes Conv1 output and produces PrimaryCapsules. PrimaryCapsules are computed by using a single Convolution layer. Args: h (nnabla.Variable): A shape of [B, C, H, W]. factor_capsules (int): Multiplication factor of output capsules. The output capsules will be ``factor_capsules x out_H x out_H`` where ``out_H`` and ``out_W`` are height and width of the output of the ``(kernel, kernel)`` Convolution with ``stride``. E.g. ``out_H = (H - (kernel - 1)) / stride``. out_channels (int): Number of units in each capsule of the output. kernel (int): Kernel size of the Convolution. stride (int): Stride of the Convolution. fix_parameters (bool): Fix parameters (Set need_grad=False). Returns: nn.Variable: A shape [B, factor_capsules x H' x W', out_channels]. ''' h = PF.convolution(h, out_channels * factor_capsules, (kernel, kernel), stride=(stride, stride), fix_parameters=fix_parameters) num_capsules = factor_capsules * h.shape[2] * h.shape[3] h = F.reshape(h, [h.shape[0], out_channels, num_capsules]) h = F.transpose(h, (0, 2, 1)) return squash(h)
def backward_impl(self, inputs, outputs, prop_down, accum): # inputs: [inputs_fwd_graph] + [inputs_bwd_graph] or # [inputs_fwd_graph] + [outputs_fwd_graph] + [inputs_bwd_graph] # Args axes = self.forward_func.info.args["axes"] # Inputs x0 = inputs[0].data dy = inputs[1].data # Outputs dx0 = outputs[0].data # Grads of inputs g_x0 = inputs[0].grad g_dy = inputs[1].grad # Grads of outputs g_dx0 = outputs[0].grad # Computation if prop_down[1]: g_dy_ = F.transpose(g_dx0, axes) if accum[1]: g_dy += g_dy_ else: g_dy.copy_from(g_dy_)
def call(self, char_inputs, mel_inputs=None): r"""Return mel-spectrograms, spectrograms, and attention weights. Args: char_inputs (nn.Variable): Inputs containing indices of characters. This has a shape of(B, Tx). Returns: nn.Variable: The synthetic mel-spectrograms of shape (B, T_y/r, n_mels*r). nn.Variable: The synthetic mel-spectrograms of shape (B, T_y/r, n_mels*r) after postnet. nn.Variable: The attention matrix of shape (B, Tx, Ty). """ with nn.parameter_scope('encoder'): encoder_outputs = self.encoder(char_inputs) # (Tx, B, C=512) with nn.parameter_scope('decoder'): encoder_outputs = F.transpose(encoder_outputs, (1, 0, 2)) mel_outputs, gate_outputs, alignments = self.decoder( encoder_outputs, mel_inputs) with nn.parameter_scope('post_net'): mel_outputs_postnet = self.postnet(mel_outputs) + mel_outputs return mel_outputs, mel_outputs_postnet, gate_outputs, alignments
def dot(a, b, out=None): ''' A compatible operation with ``numpy.dot``. Note: Any operation between nnabla's Variable/NdArray and numpy array is not supported. If both arguments are 1-D, it is inner product of vectors. If both arguments are 2-D, it is matrix multiplication. If either a or b is 0-D(scalar), it is equivalent to multiply. If b is a 1-D array, it is a sum product over the last axis of a and b. If b is an M-D array (M>=2), it is a sum product over the last axis of a and the second-to-last axis of b. Args: a (Variable, NdArray or scalar): Left input array. b (Variable, NdArray or scalar): Right input array. out: Output argument. This must have the same shape, dtype, and type as the result that would be returned for F.dot(a,b). Returns: ~nnabla.Variable or ~nnabla.NdArray Examples: .. code-block:: python import numpy as np import nnabla as nn import nnabla.functions as F # 2-D matrix * 2-D matrix arr1 = np.arange(5*6).reshape(5, 6) arr2 = np.arange(6*8).reshape(6, 8) nd1 = nn.NdArray.from_numpy_array(arr1) nd2 = nn.NdArray.from_numpy_array(arr2) ans1 = F.dot(nd1, nd2) print(ans1.shape) #(5, 8) var1 = nn.Variable.from_numpy_array(arr1) var2 = nn.Variable.from_numpy_array(arr2) ans2 = F.dot(var1, var2) ans2.forward() print(ans2.shape) #(5, 8) out1 = nn.NdArray((5, 8)) out1.cast(np.float32) F.dot(nd1, nd2, out1) print(out1.shape) #(5, 8) out2 = nn.Variable((5, 8)) out2.data.cast(np.float32) F.dot(var1, var2, out2) out2.forward() print(out2.shape) #(5, 8) # N-D matrix * M-D matrix (M>=2) arr1 = np.arange(5*6*7*8).reshape(5, 6, 7, 8) arr2 = np.arange(2*3*8*6).reshape(2, 3, 8, 6) nd1 = nn.NdArray.from_numpy_array(arr1) nd2 = nn.NdArray.from_numpy_array(arr2) ans1 = F.dot(nd1, nd2) print(ans1.shape) #(5, 6, 7, 2, 3, 6) var1 = nn.Variable.from_numpy_array(arr1) var2 = nn.Variable.from_numpy_array(arr2) ans2 = F.dot(var1, var2) ans2.forward() print(ans2.shape) #(5, 6, 7, 2, 3, 6) out1 = nn.NdArray((5, 6, 7, 2, 3, 6)) out1.cast(np.float32) F.dot(nd1, nd2, out1) print(out1.shape) #(5, 6, 7, 2, 3, 6) out2 = nn.Variable((5, 6, 7, 2, 3, 6)) out2.data.cast(np.float32) F.dot(var1, var2, out2) out2.forward() print(out2.shape) #(5, 6, 7, 2, 3, 6) ''' import nnabla as nn import nnabla.functions as F def _chk(x, mark=0): if isinstance(x, nn.NdArray): return x.data, 1 elif isinstance(x, nn.Variable): return x.d, 1 else: return x, mark m, mark1 = _chk(a) n, mark2 = _chk(b) if mark1 and mark2: if a.ndim == 1 and b.ndim == 1: result = F.sum(a * b) elif a.ndim == 2 and b.ndim == 2: result = F.affine(a, b) elif a.ndim == 0 or b.ndim == 0: if a.ndim == 0: result = F.mul_scalar(b, m) if isinstance(a, nn.NdArray) and isinstance(b, nn.Variable): result.forward() result = result.data else: result = F.mul_scalar(a, n) if isinstance(a, nn.Variable) and isinstance(b, nn.NdArray): result.forward() result = result.data elif b.ndim == 1: h = F.affine(a, F.reshape(b, (-1, 1)), base_axis=a.ndim - 1) result = F.reshape(h, h.shape[:-1]) elif b.ndim >= 2: index = [*range(0, b.ndim)] index.insert(0, index.pop(b.ndim - 2)) b = F.transpose(b, index) h = F.affine(a, b, base_axis=a.ndim - 1) result = h else: result = np.dot(a, b) if out is not None: out_, _ = _chk(out) result_, _ = _chk(result) if type(out) == type( result ) and out_.shape == result_.shape and out_.dtype == result_.dtype: if isinstance(out, nn.NdArray): out.cast(result.data.dtype)[...] = result.data elif isinstance(out, nn.Variable): out.rewire_on(result) else: out = result else: raise ValueError( f"Output argument must have the same shape, type and dtype as the result that would be " f"returned for F.dot(a,b).") else: return result
def train(): rng = np.random.RandomState(803) conf = get_config() comm = init_nnabla(conf) # create data iterator if conf.dataset == "cityscapes": data_list = get_cityscape_datalist(conf.cityscapes, save_file=comm.rank == 0) n_class = conf.cityscapes.n_label_ids use_inst = True data_iter = create_cityscapes_iterator(conf.batch_size, data_list, comm=comm, image_shape=conf.image_shape, rng=rng, flip=conf.use_flip) elif conf.dataset == "ade20k": data_list = get_ade20k_datalist(conf.ade20k, save_file=comm.rank == 0) n_class = conf.ade20k.n_label_ids + 1 # class id + unknown use_inst = False load_shape = tuple( x + 30 for x in conf.image_shape) if conf.use_crop else conf.image_shape data_iter = create_ade20k_iterator(conf.batch_size, data_list, comm=comm, load_shape=load_shape, crop_shape=conf.image_shape, rng=rng, flip=conf.use_flip) else: raise NotImplementedError( "Currently dataset {} is not supported.".format(conf.dataset)) real = nn.Variable(shape=(conf.batch_size, 3) + conf.image_shape) obj_mask = nn.Variable(shape=(conf.batch_size, ) + conf.image_shape) if use_inst: ist_mask = nn.Variable(shape=(conf.batch_size, ) + conf.image_shape) obj_onehot, bm = encode_inputs(ist_mask, obj_mask, n_ids=n_class) mask = F.concatenate(obj_onehot, bm, axis=1) else: om = obj_mask if len(om.shape) == 3: om = F.reshape(om, om.shape + (1, )) obj_onehot = F.one_hot(om, shape=(n_class, )) mask = F.transpose(obj_onehot, (0, 3, 1, 2)) # generator generator = SpadeGenerator(conf.g_ndf, image_shape=conf.image_shape) z = F.randn(shape=(conf.batch_size, conf.z_dim)) fake = generator(z, mask) # unlinking ul_mask, ul_fake = get_unlinked_all(mask, fake) # discriminator discriminator = PatchGAN(n_scales=conf.d_n_scales) d_input_real = F.concatenate(real, ul_mask, axis=1) d_input_fake = F.concatenate(ul_fake, ul_mask, axis=1) d_real_out, d_real_feats = discriminator(d_input_real) d_fake_out, d_fake_feats = discriminator(d_input_fake) g_gan, g_feat, d_real, d_fake = discriminator.get_loss( d_real_out, d_real_feats, d_fake_out, d_fake_feats, use_fm=conf.use_fm, fm_lambda=conf.lambda_fm, gan_loss_type=conf.gan_loss_type) def _rescale(x): return rescale_values(x, input_min=-1, input_max=1, output_min=0, output_max=255) g_vgg = vgg16_perceptual_loss(_rescale(ul_fake), _rescale(real)) * conf.lambda_vgg set_persistent_all(fake, mask, g_gan, g_feat, d_real, d_fake, g_vgg) # loss g_loss = g_gan + g_feat + g_vgg d_loss = (d_real + d_fake) / 2 # load params if conf.load_params is not None: print("load parameters from {}".format(conf.load_params)) nn.load_parameters(conf.load_params) # Setup Solvers g_solver = S.Adam(beta1=0.) g_solver.set_parameters(get_params_startswith("spade_generator")) d_solver = S.Adam(beta1=0.) d_solver.set_parameters(get_params_startswith("discriminator")) # lr scheduler g_lrs = LinearDecayScheduler(start_lr=conf.g_lr, end_lr=0., start_iter=100, end_iter=200) d_lrs = LinearDecayScheduler(start_lr=conf.d_lr, end_lr=0., start_iter=100, end_iter=200) ipe = get_iteration_per_epoch(data_iter._size, conf.batch_size, round="ceil") if not conf.show_interval: conf.show_interval = ipe if not conf.save_interval: conf.save_interval = ipe if not conf.niter: conf.niter = 200 * ipe # Setup Reporter losses = { "g_gan": g_gan, "g_feat": g_feat, "g_vgg": g_vgg, "d_real": d_real, "d_fake": d_fake } reporter = Reporter(comm, losses, conf.save_path, nimage_per_epoch=min(conf.batch_size, 5), show_interval=10) progress_iterator = trange(conf.niter, disable=comm.rank > 0) reporter.start(progress_iterator) colorizer = Colorize(n_class) # output all config and dump to file if comm.rank == 0: conf.dump_to_stdout() write_yaml(os.path.join(conf.save_path, "config.yaml"), conf) epoch = 0 for itr in progress_iterator: if itr % ipe == 0: g_lr = g_lrs(epoch) d_lr = d_lrs(epoch) g_solver.set_learning_rate(g_lr) d_solver.set_learning_rate(d_lr) if comm.rank == 0: print( "\n[epoch {}] update lr to ... g_lr: {}, d_lr: {}".format( epoch, g_lr, d_lr)) epoch += 1 if conf.dataset == "cityscapes": im, ist, obj = data_iter.next() ist_mask.d = ist elif conf.dataset == "ade20k": im, obj = data_iter.next() else: raise NotImplemented() real.d = im obj_mask.d = obj # text embedding and create fake fake.forward() # update discriminator d_solver.zero_grad() d_loss.forward() d_loss.backward(clear_buffer=True) comm.all_reduced_solver_update(d_solver) # update generator ul_fake.grad.zero() g_solver.zero_grad() g_loss.forward() g_loss.backward(clear_buffer=True) # backward generator fake.backward(grad=None, clear_buffer=True) comm.all_reduced_solver_update(g_solver) # report iteration progress reporter() # report epoch progress show_epoch = itr // conf.show_interval if (itr % conf.show_interval) == 0: show_images = { "RealImages": real.data.get_data("r").transpose((0, 2, 3, 1)), "ObjectMask": colorizer(obj).astype(np.uint8), "GeneratedImage": fake.data.get_data("r").transpose( (0, 2, 3, 1)) } reporter.step(show_epoch, show_images) if (itr % conf.save_interval) == 0 and comm.rank == 0: nn.save_parameters( os.path.join(conf.save_path, 'param_{:03d}.h5'.format(show_epoch))) if comm.rank == 0: nn.save_parameters(os.path.join(conf.save_path, 'param_final.h5'))
def _spectral_norm_backward(dw_sn, w, u, dim=0, itr=1, eps=1e-12): # Forward recomputation w_shape = w.shape # Transpose if the output dimension is not the most-left dimension. if dim != 0: dims_transpose = [dim] + [i for i in range(len(w_shape)) if i != dim] w = F.transpose(w, dims_transpose) w_shape = w.shape d0 = w.shape[0] # Out d1 = np.prod(w.shape[1:]) # In w = F.reshape(w, [d0, d1]) u = F.reshape(u, [1, d0]) # Power method for _ in range(itr): # v v = F.affine(u, w) v = v / ((F.sum(v ** 2.0, keepdims=True) + eps) ** 0.5) v = F.reshape(v, [d1, 1]) # u u = F.affine(w, v) u = u / ((F.sum(u ** 2.0, keepdims=True) + eps) ** 0.5) u = F.reshape(u, [1, d0]) # No grad u = no_grad(u) v = no_grad(v) # Spectral normalization wv = F.affine(w, v) sigma = F.affine(u, wv) w_sn = w / sigma # The fowllowing process is not necessary for gradient calculation # w_sn = F.reshape(w_sn, w_shape) # # Transpose again if the output dimension is not the most-left dimension. # if dim != 0: # dims_transpose = [i for i in range(1, dim + 1)] \ # + [0] + [i for i in range(dim + 1, len(w_shape))] # w_sn = F.transpose(w_sn, dims_transpose) # Backward # Backward for post-transpose if dim != 0: dims_transpose = [dim] + [i for i in range(len(w_shape)) if i != dim] dw_sn = F.transpose(dw_sn, dims_transpose) dw_sn = dw_sn.reshape(w.shape) # Backward for spectral norm # Sum for broadcast backward S = sum_for_arithmetics(dw_sn * w_sn, sigma) # Add batch axis S = S.reshape((1,) + S.shape) u = u.reshape((1,) + u.shape) v = v.reshape((1,) + v.shape) m = F.batch_matmul(u, S, transpose_a=True) m = F.batch_matmul(m, v, transpose_b=True) # Remove batch axis m = m.reshape((m.shape[1], m.shape[2])) dw = (dw_sn - m) / sigma # Backward for pre-transpose dw = dw.reshape(w_shape) if dim != 0: dims_transpose = [i for i in range(1, dim + 1)] \ + [0] + [i for i in range(dim + 1, len(w_shape))] dw = F.transpose(dw, dims_transpose) return dw, None
def cbhg(inputs, K, projections, depth, is_training, scope): r"""Returns the 1D Convolution Bank Highwaynet bindirectional GRU (CBHG) module. Args: inputs (nn.Variable): NNabla Variable of shape (B, C, T). K (int): Maximum kernel size. projections (list of int): A list of channels. depth (int): A depth. This should be an even number. is_training (bool): Whether training mode is activated. scope (str): The parameter scope name. Returns: nn.Variable: Output variable. """ with nn.parameter_scope(scope): # Convolution bank: concatenate channels from all 1D convolutions with nn.parameter_scope('conv_bank'): conv = partial(conv1d, inputs, channels=128, activation=F.relu, is_training=is_training) conv_outputs = [ conv(kernel_size=k, scope=f'conv1d_{k}') for k in range(1, K + 1) ] conv_outputs = F.concatenate(*conv_outputs, axis=1) # make sure a valid input to max_pooling x = F.pad(conv_outputs, (0, ) * 5 + (1, ), mode='constant') # Maxpooling: reshape is needed because nnabla does support 1D pooling maxpool_output = F.max_pooling(x.reshape(x.shape + (1, )), kernel=(2, 1), stride=(1, 1)).reshape(conv_outputs.shape) # Two projection layers: proj1_output = conv1d(maxpool_output, kernel_size=3, channels=projections[0], activation=F.relu, is_training=is_training, scope='proj_1') proj2_output = conv1d(proj1_output, kernel_size=3, channels=projections[1], activation=None, is_training=is_training, scope='proj_2') # Residual connection: highway_input = proj2_output + inputs assert depth % 2 == 0 half_depth = depth // 2 with nn.parameter_scope('highwaynet'): # transposing to shape (B, T, C) highway_input = F.transpose(highway_input, (0, 2, 1)) # Handle dimensionality mismatch: if highway_input.shape[2] != half_depth: highway_input = PF.affine(highway_input, half_depth, base_axis=2, name='adjust_dim') # 4-layer HighwayNet: for i in range(4): highway_input = highwaynet(highway_input, half_depth, scope=f'highway_{i+1}') with nn.parameter_scope('rnn_net'): # transpose to shape (T, B, C) rnn_input = F.transpose(highway_input, (1, 0, 2)) outputs, _ = PF.gru(rnn_input, F.constant(shape=(2, 2, rnn_input.shape[1], half_depth)), training=is_training, bidirectional=True) # (T, B, C) return outputs
def __call__(self, x, test=False): fft_real, fft_imag = STFT(x, n_fft=self.n_fft, n_hop=self.n_hop) x_theta = F.atan2(fft_imag, fft_real) x = Spectrogram(fft_real, fft_imag, power=self.power, mono=(self.nb_channels == 1)) nb_frames, nb_samples, nb_channels, nb_bins = x.shape mix_spec = F.identity(x) x = x[..., :self.nb_bins] # clone x_bass = F.identity(x) x_drums = F.identity(x) x_vocals = F.identity(x) x_other = F.identity(x) # shift and scale input to mean=0 std=1 (across all bins) x_bass += F.reshape(self.input_mean_bass, shape=(1, 1, 1, self.nb_bins), inplace=False) x_drums += F.reshape(self.input_mean_drums, shape=(1, 1, 1, self.nb_bins), inplace=False) x_vocals += F.reshape(self.input_mean_vocals, shape=(1, 1, 1, self.nb_bins), inplace=False) x_other += F.reshape(self.input_mean_other, shape=(1, 1, 1, self.nb_bins), inplace=False) x_bass *= F.reshape(self.input_scale_bass, shape=(1, 1, 1, self.nb_bins), inplace=False) x_drums *= F.reshape(self.input_scale_drums, shape=(1, 1, 1, self.nb_bins), inplace=False) x_vocals *= F.reshape(self.input_scale_vocals, shape=(1, 1, 1, self.nb_bins), inplace=False) x_other *= F.reshape(self.input_scale_other, shape=(1, 1, 1, self.nb_bins), inplace=False) # encode and normalize every instance in a batch x_bass = self.fc_bn(x_bass, self.hidden_size, "fc1_bass", test, activation='tanh') x_drums = self.fc_bn(x_drums, self.hidden_size, "fc1_drums", test, activation='tanh') x_vocals = self.fc_bn(x_vocals, self.hidden_size, "fc1_vocals", test, activation='tanh') x_other = self.fc_bn(x_other, self.hidden_size, "fc1_other", test, activation='tanh') # Average the sources cross_1 = (x_bass + x_drums + x_vocals + x_other) / 4.0 # apply 3-layers of stacked LSTM lstm_out_bass = self.lstm(cross_1, nb_samples, "lstm_bass", test) lstm_out_drums = self.lstm(cross_1, nb_samples, "lstm_drums", test) lstm_out_vocals = self.lstm(cross_1, nb_samples, "lstm_vocals", test) lstm_out_other = self.lstm(cross_1, nb_samples, "lstm_other", test) # lstm skip connection x_bass = F.concatenate(x_bass, lstm_out_bass) x_drums = F.concatenate(x_drums, lstm_out_drums) x_vocals = F.concatenate(x_vocals, lstm_out_vocals) x_other = F.concatenate(x_other, lstm_out_other) cross_2 = (x_bass + x_drums + x_vocals + x_other) / 4.0 # first dense stage + batch norm x_bass = self.fc_bn(cross_2, self.hidden_size, "fc2_bass", test, activation='relu') x_drums = self.fc_bn(cross_2, self.hidden_size, "fc2_drums", test, activation='relu') x_vocals = self.fc_bn(cross_2, self.hidden_size, "fc2_vocals", test, activation='relu') x_other = self.fc_bn(cross_2, self.hidden_size, "fc2_other", test, activation='relu') # second dense stage + batch norm x_bass = self.fc_bn(x_bass, nb_channels * nb_bins, "fc3_bass", test) x_drums = self.fc_bn(x_drums, nb_channels * nb_bins, "fc3_drums", test) x_vocals = self.fc_bn(x_vocals, nb_channels * nb_bins, "fc3_vocals", test) x_other = self.fc_bn(x_other, nb_channels * nb_bins, "fc3_other", test) # reshape back to original dim x_bass = F.reshape( x_bass, (nb_frames, nb_samples, nb_channels, self.nb_output_bins)) x_drums = F.reshape( x_drums, (nb_frames, nb_samples, nb_channels, self.nb_output_bins)) x_vocals = F.reshape( x_vocals, (nb_frames, nb_samples, nb_channels, self.nb_output_bins)) x_other = F.reshape( x_other, (nb_frames, nb_samples, nb_channels, self.nb_output_bins)) # apply output scaling x_bass *= F.reshape(self.output_scale_bass, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_drums *= F.reshape(self.output_scale_drums, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_vocals *= F.reshape(self.output_scale_vocals, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_other *= F.reshape(self.output_scale_other, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_bass += F.reshape(self.output_mean_bass, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_drums += F.reshape(self.output_mean_drums, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_vocals += F.reshape(self.output_mean_vocals, shape=(1, 1, 1, self.nb_output_bins), inplace=False) x_other += F.reshape(self.output_mean_other, shape=(1, 1, 1, self.nb_output_bins), inplace=False) # since our output is non-negative, we can apply RELU mask_bass = F.relu(x_bass) mask_drums = F.relu(x_drums) mask_vocals = F.relu(x_vocals) mask_other = F.relu(x_other) # (Frames, Bsize, Channels, Fbins) x_bass = mask_bass * mix_spec x_drums = mask_drums * mix_spec x_vocals = mask_vocals * mix_spec x_other = mask_other * mix_spec if not self.is_predict: tmp = F.stack(*[x_bass, x_drums, x_vocals, x_other], axis=0) # (4(sources), Frames, Bsize(16), 2(channels), Fbins) ==> (4, Bsize, Channels, Fbins, Frames) tmp = F.transpose(tmp, (0, 2, 3, 4, 1)) pred_r, pred_i = [], [] for i in range(tmp.shape[0]): pred_r.append(tmp[i] * F.cos(x_theta)) pred_i.append(tmp[i] * F.sin(x_theta)) pred_r = F.stack(*pred_r, axis=0) pred_i = F.stack(*pred_i, axis=0) pred_r = F.reshape(pred_r, (4 * nb_samples * nb_channels, 2049, nb_frames)) pred_i = F.reshape(pred_i, (4 * nb_samples * nb_channels, 2049, nb_frames)) pred = istft(pred_r, pred_i, self.n_fft, self.n_hop, self.n_fft, window_type='hanning', center=True) pred = F.reshape(pred, (4, nb_samples, nb_channels, -1)) else: pred = None return mix_spec, F.concatenate(mask_bass, mask_drums, mask_vocals, mask_other, axis=2), pred
def infer(self, mels, sigma=0.9): r"""Returns the generated audio. Args: mels (nn.Variable): Inputs containing mel-spectrograms of shape(B, n_mels, Ty). Defaults to None. If None, the mel spectrograms are infferred from data. sigma (float, optional): Sigma used to infer audio. Defaults to 0.9. Returns: nn.Variable: A synthetic audio. """ hp = self.hparams with nn.parameter_scope('', self.parameter_scope): # Upsample spectrogram to size of audio with nn.parameter_scope('upsample'): with nn.parameter_scope('deconv'): mels = PF.deconvolution(mels, hp.n_mels, kernel=(1024, ), stride=(256, )) # cutout conv artifacts mels = mels[..., :-(1024 - 256)] # kernel - stride # transforming to correct shape mels = F.reshape(mels, mels.shape[:2] + (-1, hp.n_samples_per_group)) mels = F.transpose(mels, (0, 2, 1, 3)) mels = F.reshape(mels, mels.shape[:2] + (-1, )) # (B, n_mels * n_groups, L/n_groups) mels = F.transpose(mels, (0, 2, 1)) wave = F.randn(shape=(mels.shape[0], self.n_remaining_channels, mels.shape[2])) * sigma for k in reversed(range(hp.n_flows)): n_half = wave.shape[1] // 2 audio_0 = wave[:, :n_half, :] audio_1 = wave[:, n_half:, :] with nn.parameter_scope(f'wn_{k}'): output = getattr(self, f'WN_{k}')(audio_0, mels) s = output[:, n_half:, :] b = output[:, :n_half, :] audio_1 = (audio_1 - b) / F.exp(s) wave = F.concatenate(audio_0, audio_1, axis=1) wave = invertible_conv(wave, reverse=True, rng=self.rng, scope=f'inv_{k}') if k % hp.n_early_every == 0 and k > 0: z = F.randn(shape=(mels.shape[0], hp.n_early_size, mels.shape[2])) wave = F.concatenate(sigma * z, wave, axis=1) wave = F.transpose(wave, (0, 2, 1)) wave = F.reshape(wave, (wave.shape[0], -1)) return wave
def embed_inverse(embed, n_inputs, n_features, base_axis=1): W = nn.parameter.get_parameter_or_create("embed/W", [n_inputs, n_features]) W = F.transpose(W, axes=[1, 0]) return F.affine(embed, W, base_axis=base_axis)
def global_attention(query: nn.Variable, memory: nn.Variable, mask: Optional[nn.Variable] = None, score: str = 'general', fix_parameters: bool = False) -> nn.Variable: ''' A global attention layer Args: query (nnabla.Variable): A shape of [batch_size, length_query, embedding_size] memory (nnabla.Variable): A shape of [batch_size, length_memory, embedding_size] mask (nnabla.Variable): A shape of [batch_size, length_query, length_memory] score (str): A kind of score functions for calculating attention weights. 'general', 'dot' or 'concat'. see [Effective Approaches to Attention-based Neural Machine Translation] (http://aclweb.org/anthology/D15-1166) fix_parameters (bool): Fix parameters (Set need_grad=False). Returns: nn.Variable: A shape [batch_size, length_query, embedding_size]. ''' batch_size, length_query, embedding_size = query.shape _, length_memory, _ = memory.shape q = query # -> (batch_size, length_query, embedding_size) k = memory # -> (batch_size, length_memory, embedding_size) v = memory # -> (batch_size, length_memory, embedding_size) if score == 'dot': logit = F.batch_matmul(q, k, transpose_b=True) # -> (batch_size, length_query, length_memory) elif score == 'general': with nn.parameter_scope('Wa'): wa = time_distributed(PF.affine)(q, embedding_size, with_bias=False) # -> (batch_size, length_query, embeding_size) logit = F.batch_matmul(wa, k, transpose_b=True) # -> (batch_size, length_query, length_memory) elif score == 'concat': a_list = [] for _q in F.split(q, axis=1): _q = F.reshape(_q, shape=(batch_size, 1, embedding_size)) _q = F.broadcast(_q, shape=(batch_size, length_memory, embedding_size)) concat = F.concatenate(_q, k, axis=2) # -> (batch_size, length_memory, embedding_size * 2) with nn.parameter_scope('Wa'): a = time_distributed(PF.affine)(concat, 1, with_bias=False) # -> (batch_size, length_memory, 1) a_list.append(a) logit = F.concatenate(*a_list, axis=2) # -> (batch_size, length_memory, length_query) logit = F.transpose(logit, axes=(0, 2, 1)) # -> (batch_size, length_query, length_memory) # get_attention_logit_mask -> (batch_size, length_query, length_memory)である if mask is not None: logit += get_attention_logit_mask(mask) attention_weights = F.softmax(logit, axis=2) # -> (batch_size, length_query, length_memory) attention_output = F.batch_matmul(attention_weights, v) # -> (batch_size, length_query, embedding_size) return attention_output
def call(self, memory, inputs=None): r"""Return mel-spectrogram and attention matrix. Args: memory(nn.Variable): A 3D tensor of shape (T, B, C). inputs(nn.Variable, optional): A 3D tensor with shape of [B, T/r, n_mels(*r)]. Shifted log melspectrogram of sound files. Defaults to None. Returns: nn.Variable: The synthetic mel-spectrograms of shape (B, Ty/r, r*n_mels). nn.Variable: The attention matrix of shape (B, Tx, Ty). References: - https://github.com/Kyubyong/tacotron/ """ hp = self._hparams bz, mel_shape = hp.batch_size, hp.n_mels * hp.r encoder_dim = hp.encoder_embedding_dim # initialize input tensor input = F.constant(shape=(bz, 1, mel_shape)) # initialize hidden states context = F.constant(shape=(bz, 1, hp.attention_dim)) hidden = F.constant(shape=(1, 1, bz, encoder_dim)) h_gru = [ F.constant(shape=(1, 1, bz, encoder_dim)), F.constant(shape=(1, 1, bz, encoder_dim)) ] outputs, attends = [], [] for i in range(hp.n_frames): if i > 0: input = (outputs[-1] if inputs is None else inputs[:, i - 1:i, :]) # feed a prenet to the input input = prenet(input, layer_sizes=hp.prenet_channels, is_training=self.training, scope='prenet_decoder') # (bz, 1, C) # concat the input and context vector input = F.concatenate(input, context) # (bz, 1, 384) with nn.parameter_scope('rnn_attention'): # calculate the output output, hidden = PF.gru( input.reshape((1, bz, -1)), hidden, training=self.training, bidirectional=False) # (1, bz, 256), (1, 1, bz, 256) # compute the context and attention vectors context, attend = Bahdanau_attention( F.transpose(hidden[0], (1, 0, 2)), memory, out_features=hp.attention_dim, scope='Bahdanau_attention') # (bz, 1, 256), (bz, 1, T) with nn.parameter_scope('rnn_decoder'): # concat RNN output and attention context vector with nn.parameter_scope('project_to_decoder'): output = F.concatenate(output, F.transpose(context, (1, 0, 2)), axis=2) output = PF.affine(output, encoder_dim, base_axis=2) # (1, bz, 256) # decoder RNN with residual connection for j in range(2): with nn.parameter_scope(f'gru_resisidual_{j}'): out, h_gru[j] = PF.gru(output, h_gru[j], training=self.training, bidirectional=False) output += out # (1, bz, 256) # projector to mels with nn.parameter_scope('project_to_mel'): output = F.transpose(output, (1, 0, 2)) # (bz, 1, n_mels*r) output = PF.affine(output, mel_shape, base_axis=2) outputs.append(output) attends.append(attend) outputs = F.concatenate(*outputs, axis=1) # (B, T2, C2) attends = F.concatenate(*attends, axis=1) # (B, T2, T1) return outputs, attends
def train(): args = get_args() # Set context. from nnabla.ext_utils import get_extension_context logger.info("Running in {}:{}".format(args.context, args.type_config)) ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) data_iterator = data_iterator_librispeech(args.batch_size, args.data_dir) _data_source = data_iterator._data_source # dirty hack... # model x = nn.Variable( shape=(args.batch_size, data_config.duration, 1)) # (B, T, 1) onehot = F.one_hot(x, shape=(data_config.q_bit_len, )) # (B, T, C) wavenet_input = F.transpose(onehot, (0, 2, 1)) # (B, C, T) # speaker embedding if args.use_speaker_id: s_id = nn.Variable(shape=(args.batch_size, 1)) with nn.parameter_scope("speaker_embedding"): s_emb = PF.embed(s_id, n_inputs=_data_source.n_speaker, n_features=WavenetConfig.speaker_dims) s_emb = F.transpose(s_emb, (0, 2, 1)) else: s_emb = None net = WaveNet() wavenet_output = net(wavenet_input, s_emb) pred = F.transpose(wavenet_output, (0, 2, 1)) # (B, T, 1) t = nn.Variable(shape=(args.batch_size, data_config.duration, 1)) loss = F.mean(F.softmax_cross_entropy(pred, t)) # for generation prob = F.softmax(pred) # Create Solver. solver = S.Adam(args.learning_rate) solver.set_parameters(nn.get_parameters()) # Create monitor. monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=10) # setup save env. audio_save_path = os.path.join(os.path.abspath( args.model_save_path), "audio_results") if audio_save_path and not os.path.exists(audio_save_path): os.makedirs(audio_save_path) # Training loop. for i in range(args.max_iter): # todo: validation x.d, _speaker, t.d = data_iterator.next() if args.use_speaker_id: s_id.d = _speaker.reshape(-1, 1) solver.zero_grad() loss.forward(clear_no_need_grad=True) loss.backward(clear_buffer=True) solver.update() loss.data.cast(np.float32, ctx) monitor_loss.add(i, loss.d.copy()) if i % args.model_save_interval == 0: prob.forward() audios = mu_law_decode( np.argmax(prob.d, axis=-1), quantize=data_config.q_bit_len) # (B, T) save_audio(audios, i, audio_save_path)
def nonlocal_net(B_lab_map, relu_layers, temperature=0.001 * 5, detach_flag=False, WTA_scale_weight=1, feature_noise=0): batch_size = B_lab_map.shape[0] channel = B_lab_map.shape[1] image_height = B_lab_map.shape[2] image_width = B_lab_map.shape[3] feature_height = int(image_height / 4) feature_width = int(image_width / 4) feature_channel = 64 in_channels = feature_channel * 4 inter_channels = 256 # layer2_1 A_feature2_1 = layer2_1(relu_layers[0]) B_feature2_1 = layer2_1(relu_layers[4]) # layer3_1 A_feature3_1 = layer3_1(relu_layers[1]) B_feature3_1 = layer3_1(relu_layers[5]) # layer4_1 A_feature4_1 = layer4_1(relu_layers[2]) B_feature4_1 = layer4_1(relu_layers[6]) # layer5_1 A_feature5_1 = layer5_1(relu_layers[3]) B_feature5_1 = layer5_1(relu_layers[7]) if A_feature5_1.shape[2] != A_feature2_1.shape[2] or A_feature5_1.shape[3] != A_feature2_1.shape[3]: A_feature5_1 = pad_replicate(A_feature5_1) B_feature5_1 = pad_replicate(B_feature5_1) A_features = layer( F.concatenate( A_feature2_1, A_feature3_1, A_feature4_1, A_feature5_1, axis=1), feature_channel * 4) B_features = layer( F.concatenate( B_feature2_1, B_feature3_1, B_feature4_1, B_feature5_1, axis=1), feature_channel * 4) # pairwise cosine similarity theta = PF.convolution( A_features, inter_channels, kernel=( 1, 1), stride=( 1, 1), name='theta') theta_re = F.reshape(theta, (batch_size, inter_channels, -1)) theta_re = theta_re - F.mean(theta_re, axis=2, keepdims=True) # center the feature theta_norm = F.norm( theta_re, p=2, axis=1, keepdims=True) + sys.float_info.epsilon theta_re = F.div2(theta_re, theta_norm) # 2*(feature_height*feature_width)*256 theta_permute = F.transpose(theta_re, (0, 2, 1)) phi = PF.convolution( B_features, inter_channels, kernel=( 1, 1), stride=( 1, 1), name='phi') phi_re = F.reshape(phi, (batch_size, inter_channels, -1)) # center the feature phi_re = phi_re - F.mean(phi_re, axis=2, keepdims=True) phi_norm = F.norm(phi_re, p=2, axis=1, keepdims=True) + \ sys.float_info.epsilon phi_re = F.div2(phi_re, phi_norm) # 2*(feature_height*feature_width)*(feature_height*feature_width) f = F.batch_matmul(theta_permute, phi_re) f_shape = f.shape f = F.reshape(f, (1,) + f_shape) f_similarity = F.reshape(f, (1,) + f_shape) similarity_map = F.max(f_similarity, axis=3, keepdims=True) similarity_map = F.reshape( similarity_map, (batch_size, 1, feature_height, feature_width)) # f can be negative # if WTA_scale_weight == 1: f_WTA = f f_WTA = f_WTA / temperature f_WTA_sp = f_WTA.shape f_WTA = F.reshape(f_WTA, (f_WTA_sp[1], f_WTA_sp[2], f_WTA_sp[3])) # 2*1936*1936; softmax along the horizontal line (dim=-1) f_div_C = F.softmax(f_WTA, axis=2) # downsample the reference color B_lab = F.average_pooling(B_lab_map, (4, 4)) B_lab = F.reshape(B_lab, (batch_size, channel, -1)) B_lab = F.transpose(B_lab, (0, 2, 1)) # 2*1936*channel # multiply the corr map with color y = F.batch_matmul(f_div_C, B_lab) # 2*1936*channel y = F.transpose(y, (0, 2, 1)) y = F.reshape( y, (batch_size, channel, feature_height, feature_width)) # 2*3*44*44 y = F.interpolate(y, scale=(4, 4), mode='nearest', align_corners=False) similarity_map = F.interpolate( similarity_map, scale=( 4, 4), mode='nearest', align_corners=False) return y, similarity_map
def train_nerf(config, comm, model, dataset='blender'): use_transient = False use_embedding = False if model == 'wild': use_transient = True use_embedding = True elif model == 'uncertainty': use_transient = True elif model == 'appearance': use_embedding = True save_results_dir = config.log.save_results_dir os.makedirs(save_results_dir, exist_ok=True) train_loss_dict = { 'train_coarse_loss': 0.0, 'train_fine_loss': 0.0, 'train_total_loss': 0.0, } test_metric_dict = {'test_loss': 0.0, 'test_psnr': 0.0} monitor_manager = MonitorManager(train_loss_dict, test_metric_dict, save_results_dir) if dataset != 'phototourism': images, poses, _, hwf, i_test, i_train, near_plane, far_plane = get_data( config) height, width, focal_length = hwf else: di = get_photo_tourism_dataiterator(config, 'train', comm) val_di = get_photo_tourism_dataiterator(config, 'val', comm) if model != 'vanilla': if dataset != 'phototourism': config.train.n_vocab = max(np.max(i_train), np.max(i_test)) + 1 print( f'Setting Vocabulary size of embedding as {config.train.n_vocab}') if dataset != 'phototourism': if model in ['vanilla']: if comm is not None: # uncomment the following line to test on fewer images i_test = i_test[3 * comm.rank:3 * (comm.rank + 1)] pass else: # uncomment the following line to test on fewer images i_test = i_test[:3] pass else: # i_test = i_train[0:5] i_test = [i * (comm.rank + 1) for i in range(5)] else: i_test = [1] encode_position_function = get_encoding_function( config.train.num_encodings_position, True, True) if config.train.use_view_directions: encode_direction_function = get_encoding_function( config.train.num_encodings_direction, True, True) else: encode_direction_function = None lr = config.solver.lr num_decay_steps = config.solver.lr_decay_step * 1000 lr_decay_factor = config.solver.lr_decay_factor solver = S.Adam(alpha=lr) load_solver_state = False if config.checkpoint.param_path is not None: nn.load_parameters(config.checkpoint.param_path) load_solver_state = True if comm is not None: num_decay_steps /= comm.n_procs comm_size = comm.n_procs else: comm_size = 1 pbar = trange(config.train.num_iterations // comm_size, disable=(comm is not None and comm.rank > 0)) for i in pbar: if dataset != 'phototourism': idx = np.random.choice(i_train) image = nn.Variable.from_numpy_array(images[idx][None, :, :, :3]) pose = nn.Variable.from_numpy_array(poses[idx]) ray_directions, ray_origins = get_ray_bundle( height, width, focal_length, pose) grid = get_direction_grid(width, height, focal_length, return_ij_2d_grid=True) grid = F.reshape(grid, (-1, 2)) select_inds = np.random.choice(grid.shape[0], size=[config.train.num_rand_points], replace=False) select_inds = F.gather_nd(grid, select_inds[None, :]) select_inds = F.transpose(select_inds, (1, 0)) embed_inp = nn.Variable.from_numpy_array( np.full((config.train.chunksize_fine, ), idx, dtype=int)) ray_origins = F.gather_nd(ray_origins, select_inds) ray_directions = F.gather_nd(ray_directions, select_inds) image = F.gather_nd(image[0], select_inds) else: rays, embed_inp, image = di.next() ray_origins = nn.Variable.from_numpy_array(rays[:, :3]) ray_directions = nn.Variable.from_numpy_array(rays[:, 3:6]) near_plane = nn.Variable.from_numpy_array(rays[:, 6]) far_plane = nn.Variable.from_numpy_array(rays[:, 7]) embed_inp = nn.Variable.from_numpy_array(embed_inp) image = nn.Variable.from_numpy_array(image) hwf = None app_emb, trans_emb = None, None if use_embedding: with nn.parameter_scope('embedding_a'): app_emb = PF.embed(embed_inp, config.train.n_vocab, config.train.n_app) if use_transient: with nn.parameter_scope('embedding_t'): trans_emb = PF.embed(embed_inp, config.train.n_vocab, config.train.n_trans) if use_transient: rgb_map_course, rgb_map_fine, static_rgb_map_fine, transient_rgb_map_fine, beta, static_sigma, transient_sigma = forward_pass( ray_directions, ray_origins, near_plane, far_plane, app_emb, trans_emb, encode_position_function, encode_direction_function, config, use_transient, hwf=hwf, image=image) course_loss = 0.5 * F.mean(F.squared_error(rgb_map_course, image)) fine_loss = 0.5 * F.mean( F.squared_error(rgb_map_fine, image) / F.reshape(F.pow_scalar(beta, 2), beta.shape + (1, ))) beta_reg_loss = 3 + F.mean(F.log(beta)) sigma_trans_reg_loss = 0.01 * F.mean(transient_sigma) loss = course_loss + fine_loss + beta_reg_loss + sigma_trans_reg_loss else: rgb_map_course, _, _, _, rgb_map_fine, _, _, _ = forward_pass( ray_directions, ray_origins, near_plane, far_plane, app_emb, trans_emb, encode_position_function, encode_direction_function, config, use_transient, hwf=hwf) course_loss = F.mean(F.squared_error(rgb_map_course, image)) fine_loss = F.mean(F.squared_error(rgb_map_fine, image)) loss = course_loss + fine_loss pbar.set_description( f'Total: {np.around(loss.d, 4)}, Course: {np.around(course_loss.d, 4)}, Fine: {np.around(fine_loss.d, 4)}' ) solver.set_parameters(nn.get_parameters(), reset=False, retain_state=True) if load_solver_state: solver.load_states(config['checkpoint']['solver_path']) load_solver_state = False solver.zero_grad() loss.backward(clear_buffer=True) # Exponential LR decay if dataset != 'phototourism': lr_factor = (lr_decay_factor**((i) / num_decay_steps)) solver.set_learning_rate(lr * lr_factor) else: if i % num_decay_steps == 0 and i != 0: solver.set_learning_rate(lr * lr_decay_factor) if comm is not None: params = [x.grad for x in nn.get_parameters().values()] comm.all_reduce(params, division=False, inplace=True) solver.update() if ((i % config.train.save_interval == 0 or i == config.train.num_iterations - 1) and i != 0) and (comm is not None and comm.rank == 0): nn.save_parameters(os.path.join(save_results_dir, f'iter_{i}.h5')) solver.save_states( os.path.join(save_results_dir, f'solver_iter_{i}.h5')) if (i % config.train.test_interval == 0 or i == config.train.num_iterations - 1) and i != 0: avg_psnr, avg_mse = 0.0, 0.0 for i_t in trange(len(i_test)): if dataset != 'phototourism': idx_test = i_test[i_t] image = nn.NdArray.from_numpy_array( images[idx_test][None, :, :, :3]) pose = nn.NdArray.from_numpy_array(poses[idx_test]) ray_directions, ray_origins = get_ray_bundle( height, width, focal_length, pose) ray_directions = F.reshape(ray_directions, (-1, 3)) ray_origins = F.reshape(ray_origins, (-1, 3)) embed_inp = nn.NdArray.from_numpy_array( np.full((config.train.chunksize_fine, ), idx_test, dtype=int)) else: rays, embed_inp, image = val_di.next() ray_origins = nn.NdArray.from_numpy_array(rays[0, :, :3]) ray_directions = nn.NdArray.from_numpy_array(rays[0, :, 3:6]) near_plane_ = nn.NdArray.from_numpy_array(rays[0, :, 6]) far_plane_ = nn.NdArray.from_numpy_array(rays[0, :, 7]) embed_inp = nn.NdArray.from_numpy_array( embed_inp[0, :config.train.chunksize_fine]) image = nn.NdArray.from_numpy_array(image[0].transpose( 1, 2, 0)) image = F.reshape(image, (1, ) + image.shape) idx_test = 1 app_emb, trans_emb = None, None if use_embedding: with nn.parameter_scope('embedding_a'): app_emb = PF.embed(embed_inp, config.train.n_vocab, config.train.n_app) if use_transient: with nn.parameter_scope('embedding_t'): trans_emb = PF.embed(embed_inp, config.train.n_vocab, config.train.n_trans) num_ray_batches = ray_directions.shape[ 0] // config.train.ray_batch_size + 1 if use_transient: rgb_map_fine_list, static_rgb_map_fine_list, transient_rgb_map_fine_list = [], [], [] else: rgb_map_fine_list, depth_map_fine_list = [], [] for r_idx in trange(num_ray_batches): if r_idx != num_ray_batches - 1: ray_d, ray_o = ray_directions[ r_idx * config.train.ray_batch_size:(r_idx + 1) * config.train.ray_batch_size], ray_origins[ r_idx * config.train.ray_batch_size:(r_idx + 1) * config.train.ray_batch_size] if dataset == 'phototourism': near_plane = near_plane_[ r_idx * config.train.ray_batch_size:(r_idx + 1) * config.train.ray_batch_size] far_plane = far_plane_[r_idx * config.train.ray_batch_size: (r_idx + 1) * config.train.ray_batch_size] else: if ray_directions.shape[0] - ( num_ray_batches - 1) * config.train.ray_batch_size == 0: break ray_d, ray_o = ray_directions[ r_idx * config.train.ray_batch_size:, :], ray_origins[ r_idx * config.train.ray_batch_size:, :] if dataset == 'phototourism': near_plane = near_plane_[r_idx * config.train. ray_batch_size:] far_plane = far_plane_[r_idx * config.train. ray_batch_size:] if use_transient: rgb_map_course, rgb_map_fine, static_rgb_map_fine, transient_rgb_map_fine, beta, static_sigma, transient_sigma = forward_pass( ray_d, ray_o, near_plane, far_plane, app_emb, trans_emb, encode_position_function, encode_direction_function, config, use_transient, hwf=hwf) rgb_map_fine_list.append(rgb_map_fine) static_rgb_map_fine_list.append(static_rgb_map_fine) transient_rgb_map_fine_list.append( transient_rgb_map_fine) else: _, _, _, _, rgb_map_fine, depth_map_fine, _, _ = \ forward_pass(ray_d, ray_o, near_plane, far_plane, app_emb, trans_emb, encode_position_function, encode_direction_function, config, use_transient, hwf=hwf) rgb_map_fine_list.append(rgb_map_fine) depth_map_fine_list.append(depth_map_fine) if use_transient: rgb_map_fine = F.concatenate(*rgb_map_fine_list, axis=0) static_rgb_map_fine = F.concatenate( *static_rgb_map_fine_list, axis=0) transient_rgb_map_fine = F.concatenate( *transient_rgb_map_fine_list, axis=0) rgb_map_fine = F.reshape(rgb_map_fine, image[0].shape) static_rgb_map_fine = F.reshape(static_rgb_map_fine, image[0].shape) transient_rgb_map_fine = F.reshape(transient_rgb_map_fine, image[0].shape) static_trans_img_to_save = np.concatenate( (static_rgb_map_fine.data, np.ones((image[0].shape[0], 5, 3)), transient_rgb_map_fine.data), axis=1) img_to_save = np.concatenate( (image[0].data, np.ones( (image[0].shape[0], 5, 3)), rgb_map_fine.data), axis=1) else: rgb_map_fine = F.concatenate(*rgb_map_fine_list, axis=0) depth_map_fine = F.concatenate(*depth_map_fine_list, axis=0) rgb_map_fine = F.reshape(rgb_map_fine, image[0].shape) depth_map_fine = F.reshape(depth_map_fine, image[0].shape[:-1]) img_to_save = np.concatenate( (image[0].data, np.ones( (image[0].shape[0], 5, 3)), rgb_map_fine.data), axis=1) filename = os.path.join(save_results_dir, f'{i}_{idx_test}.png') try: imsave(filename, np.clip(img_to_save, 0, 1), channel_first=False) print(f'Saved generation at {filename}') if use_transient: filename_static_trans = os.path.join( save_results_dir, f's_t_{i}_{idx_test}.png') imsave(filename_static_trans, np.clip(static_trans_img_to_save, 0, 1), channel_first=False) else: filename_dm = os.path.join(save_results_dir, f'dm_{i}_{idx_test}.png') depth_map_fine = (depth_map_fine.data - depth_map_fine.data.min()) / ( depth_map_fine.data.max() - depth_map_fine.data.min()) imsave(filename_dm, depth_map_fine[:, :, None], channel_first=False) plt.imshow(depth_map_fine.data) plt.savefig(filename_dm) plt.close() except: pass avg_mse += F.mean(F.squared_error(rgb_map_fine, image[0])).data avg_psnr += (-10. * np.log10( F.mean(F.squared_error(rgb_map_fine, image[0])).data)) test_metric_dict['test_loss'] = avg_mse / len(i_test) test_metric_dict['test_psnr'] = avg_psnr / len(i_test) monitor_manager.add(i, test_metric_dict) print( f'Saved generations after {i} training iterations! Average PSNR: {avg_psnr/len(i_test)}. Average MSE: {avg_mse/len(i_test)}' )
def loss_function(u, v, negative_samples): return F.sum(-F.log( F.exp(-distance(u, v)) / sum([ F.exp(-distance(u, x)) for x in F.split(negative_samples, axis=2) ]))) u = nn.Variable((batch_size, )) v = nn.Variable((batch_size, )) negative_samples = nn.Variable((batch_size, negative_sample_size)) _u = PF.embed(u, vocab_size, embedding_size) _v = PF.embed(v, vocab_size, embedding_size) _neg = PF.embed(negative_samples, vocab_size, embedding_size) _neg = F.transpose(_neg, axes=(0, 2, 1)) loss = loss_function(_u, _v, _neg) nn.get_parameters()["embed/W"].d = I.UniformInitializer( [-0.01, 0.01])(shape=(vocab_size, embedding_size)) solver = RiemannianSgd(lr=0.1) solver.set_parameters(nn.get_parameters()) trainer = Trainer(inputs=[u, v, negative_samples], loss=loss, solver=solver) trainer.run(train_data_iter, None, epochs=max_epoch) line_points = [['mustang.n.01', 'odd-toed_ungulate.n.01'], ['elk.n.01', 'even-toed_ungulate.n.01'], ['even-toed_ungulate.n.01', 'ungulate.n.01'],
def call(self, memory, decoder_inputs=None): r"""Return mel-spectrograms, gate outputs and an attention matrix. Args: memory (nn.Variable): A 3D tensor of shape (B, T, C). decoder_inputs (nn.Variable, optional): A 3D tensor with shape of (B, T/r, r*n_mels). Shifted log melspectrogram of sound files. Defaults to None. Returns: nn.Variable: The synthetic mel-spectrograms of shape (B, Ty/r, r*n_mels). nn.Variable: The gate outputs of shape (B, Ty). nn.Variable: The attention matrix of shape (B, Tx, Ty). """ hp = self._hparams mel_shape = hp.n_mels * hp.r # initialize decoder states decoder_input = F.constant(shape=(hp.batch_size, 1, mel_shape)) decoder_hidden = F.constant(shape=(1, 1, hp.batch_size, hp.decoder_rnn_dim)) decoder_cell = F.constant(shape=(1, 1, hp.batch_size, hp.decoder_rnn_dim)) # initialize attention states attention_weights = F.constant(shape=(hp.batch_size, 1, hp.text_len)) attention_weights_cum = F.constant(shape=(hp.batch_size, 1, hp.text_len)) attention_context = F.constant(shape=(hp.batch_size, 1, hp.encoder_embedding_dim)) attention_hidden = F.constant(shape=(1, 1, hp.batch_size, hp.attention_rnn_dim)) attention_cell = F.constant(shape=(1, 1, hp.batch_size, hp.attention_rnn_dim)) # store outputs mel_outputs, gate_outputs, alignments = [], [], [] for i in range(hp.mel_len): if i > 0: decoder_input = (mel_outputs[-1] if decoder_inputs is None else decoder_inputs[:, i - 1:i, :]) if decoder_inputs is None: decoder_input = decoder_input[None, ...] # decoder of shape (B, 1, prenet_channels=256) decoder_input = prenet(decoder_input, hp.prenet_channels, is_training=self.training, scope='prenet') with nn.parameter_scope('attention_rnn'): # cell_input of shape (B, 1, prenet_channels[-1] + C=768) cell_input = F.concatenate(decoder_input, attention_context, axis=2) _, attention_hidden, attention_cell = PF.lstm( F.transpose(cell_input, (1, 0, 2)), attention_hidden, attention_cell, training=self.training, name='lstm_attention' ) # (1, 1, B, attention_hidden), (1, 1, B, attention_hidden) if self.training: attention_hidden = F.dropout(attention_hidden, hp.p_attention_dropout) with nn.parameter_scope('location_attention'): attention_weights_cat = F.concatenate(attention_weights, attention_weights_cum, axis=1) attention_context, attention_weights = location_sensitive_attention( F.transpose(attention_hidden[0], (1, 0, 2)), memory, attention_weights_cat, attention_location_kernel_size=hp. attention_location_kernel_size, attention_n_filters=hp.attention_location_n_filters, attention_dim=hp.attention_dim, is_training=self.training, scope='ls_attention') attention_weights_cum += attention_weights alignments.append(attention_weights) with nn.parameter_scope('decoder_rnn'): # (1, B, attention_rnn_dim + encoder_embedding_dim) inp_decoder = F.concatenate(attention_hidden[0], F.transpose( attention_context, (1, 0, 2)), axis=2) _, decoder_hidden, decoder_cell = PF.lstm( inp_decoder, decoder_hidden, decoder_cell, training=self.training, name='lstm_decoder') if self.training: decoder_hidden = F.dropout(decoder_hidden, hp.p_decoder_dropout) with nn.parameter_scope('projection'): proj_input = F.concatenate( decoder_hidden[0, 0], F.reshape(attention_context, (hp.batch_size, -1), inplace=False), axis=1) # (B, decoder_rnn_dim + encoder_embedding_dim) decoder_output = affine_norm(proj_input, mel_shape, base_axis=1, with_bias=True, w_init_gain='affine', scope='affine') mel_outputs.append(decoder_output) with nn.parameter_scope('gate_prediction'): gate_prediction = affine_norm(proj_input, 1, base_axis=1, with_bias=True, w_init_gain='sigmoid', scope='affine') gate_outputs.append(gate_prediction) # (B, T2, n_mels*r) mel_outputs = F.stack(*mel_outputs, axis=1) gate_outputs = F.concatenate(*gate_outputs, axis=1) # (B, T2) alignments = F.concatenate(*alignments, axis=1) # (B, T1, T2) return mel_outputs, gate_outputs, alignments
valid_data_iter = data_iterator_simple(load_valid_func, len(x_valid), batch_size, shuffle=True, with_file_cache=False) char_embedding_dim = 16 lstm_size = 650 filters = [50, 150, 200, 200] filster_sizes = [1, 3, 5, 7] # filters = [50, 100, 150, 200, 200, 200, 200] # filster_sizes = [1, 2, 3, 4, 5, 6, 7] x = nn.Variable((batch_size, sentence_length, word_length)) h = PF.embed(x, char_vocab_size, char_embedding_dim) h = F.transpose(h, (0, 3, 1, 2)) output = [] for f, f_size in zip(filters, filster_sizes): _h = PF.convolution(h, f, kernel=(1, f_size), pad=(0, f_size // 2), name='conv_{}'.format(f_size)) _h = F.max_pooling(_h, kernel=(1, word_length)) output.append(_h) h = F.concatenate(*output, axis=1) h = F.transpose(h, (0, 2, 1, 3)) h = F.reshape(h, (batch_size, sentence_length, sum(filters))) # h = PF.batch_normalization(h, axes=[2]) h = TimeDistributed(Highway)(h, name='highway1') h = TimeDistributed(Highway)(h, name='highway2')
def call(self, wave): hp = self.hparams batch_size = hp.batch_size # compute mel-spectrogram from waveform with nn.parameter_scope('stft'): mels = self.compute_mel(wave) # Upsample spectrogram to the size of audio with nn.parameter_scope('upsample'): with nn.parameter_scope('deconv'): mels = PF.deconvolution(mels, hp.n_mels, kernel=(1024, ), stride=(256, )) # make sure mels having the same length as wave if mels.shape[2] > wave.shape[1]: mels = mels[..., :wave.shape[1]] # (B, L, n_mels) # transforming to correct shape mels = F.reshape(mels, mels.shape[:2] + (-1, hp.n_samples_per_group)) mels = F.transpose(mels, (0, 2, 1, 3)) mels = F.reshape(mels, mels.shape[:2] + (-1, )) # (B, n_mels * n_groups, L/n_groups) mels = F.transpose(mels, (0, 2, 1)) # reshape audio wave = F.reshape(wave, (batch_size, -1, hp.n_samples_per_group)) wave = F.transpose(wave, (0, 2, 1)) # (B, n_groups, L/n_groups) output_audio, log_s_list, log_det_W_list = [], [], [] for k in range(hp.n_flows): if k % hp.n_early_every == 0 and k > 0: output_audio.append(wave[:, :hp.n_early_size, :]) wave = wave[:, hp.n_early_size:, :] # apply invertible convolution wave, log_det_W = invertible_conv(wave, reverse=False, rng=self.rng, scope=f'inv_{k}') log_det_W_list.append(log_det_W) n_half = wave.shape[1] // 2 audio_0 = wave[:, :n_half, :] audio_1 = wave[:, n_half:, :] with nn.parameter_scope(f'wn_{k}'): output = getattr(self, f'WN_{k}')(audio_0, mels) log_s = output[:, n_half:, :] # (B, n_half, L/n_groups) b = output[:, :n_half, :] # (B, n_half, L/n_groups) audio_1 = F.add2(F.exp(log_s) * audio_1, b, inplace=True) log_s_list.append(log_s) # (B, n_half*2, L/n_groups) wave = F.concatenate(audio_0, audio_1, axis=1) output_audio.append(wave) return F.concatenate(*output_audio, axis=1), log_s_list, log_det_W_list
def sample_from_controller(args): """ 2-layer RNN(LSTM) based controller which outputs an architecture of CNN, represented as a sequence of integers and its list. Given the number of layers, for each layer, it executes 2 types of computation, one for sampling the operation at that layer, another for sampling the skip connection patterns. """ entropys = nn.Variable([1, 1], need_grad=True) log_probs = nn.Variable([1, 1], need_grad=True) entropys.d = log_probs.d = 0.0 # initialize them all num_cells = args.num_cells num_nodes = args.num_nodes lstm_size = args.lstm_size state_size = args.state_size lstm_num_layers = args.lstm_layers temperature = args.temperature tanh_constant = args.tanh_constant op_tanh_reduce = args.op_tanh_reduce num_branch = args.num_ops both_archs = [list(), list()] initializer = I.UniformInitializer((-0.1, 0.1)) prev_h = [ nn.Variable([1, lstm_size], need_grad=True) for _ in range(lstm_num_layers) ] prev_c = [ nn.Variable([1, lstm_size], need_grad=True) for _ in range(lstm_num_layers) ] for i in range(len(prev_h)): prev_h[i].d = 0 # initialize. prev_c[i].d = 0 inputs = nn.Variable([1, lstm_size]) inputs.d = np.random.normal(0, 0.5, [1, lstm_size]) g_emb = nn.Variable([1, lstm_size]) g_emb.d = np.random.normal(0, 0.5, [1, lstm_size]) for ind in range(2): # first create conv cell and then reduc cell. idx_seq = list() ops_seq = list() for node_id in range(num_nodes): if node_id == 0: anchors = nn.parameter.get_parameter_or_create("anchors", [2, lstm_size], initializer, need_grad=False) anchors_w_1 = nn.parameter.get_parameter_or_create( "anchors_w_1", [2, lstm_size], initializer, need_grad=False) else: assert anchors.shape[0] == node_id + \ 2, "Something wrong with anchors." assert anchors_w_1.shape[0] == node_id + \ 2, "Something wrong with anchors_w_1." # for each node, get the index used as inputs for i in range(2): # One-step stacked LSTM. with nn.parameter_scope("controller_lstm"): next_h, next_c = stack_lstm(inputs, prev_h, prev_c, state_size) prev_h, prev_c = next_h, next_c # shape:(1, lstm_size) query = anchors_w_1 with nn.parameter_scope("skip_affine_1"): query = F.tanh( F.add2( query, PF.affine(next_h[-1], lstm_size, w_init=initializer, with_bias=False))) # (node_id + 2, lstm_size) + (1, lstm_size) # broadcast occurs here. resulting shape is; (node_id + 2, lstm_size) with nn.parameter_scope("skip_affine_2"): # (node_id + 2, 1) logit = PF.affine(query, 1, w_init=initializer, with_bias=False) if temperature is not None: logit = F.mul_scalar(logit, (1 / temperature)) if tanh_constant is not None: logit = F.mul_scalar(F.tanh(logit), tanh_constant) index = F.exp(logit) index = F.mul_scalar(index, (1 / index.d.sum())) # Sampling input indices from multinomial distribution. index = np.random.multinomial( 1, np.reshape(index.d, (1, index.d.size))[0], 1) idx_seq.append(index.nonzero()[1]) label = nn.Variable.from_numpy_array( index.transpose()) # (node_id + 2, 1) log_prob = F.softmax_cross_entropy(logit, label) log_probs = F.add2(log_probs, F.sum(log_prob, keepdims=True)) curr_ent = F.softmax_cross_entropy(logit, F.softmax(logit)) entropy = F.sum(curr_ent, keepdims=True) entropys = F.add2(entropys, entropy) taking_ind = int(index.nonzero()[1][0]) # (1, lstm_size) inputs = F.reshape(anchors[taking_ind], (1, anchors.shape[1])) # ops for j in range(2): with nn.parameter_scope("controller_lstm"): next_h, next_c = stack_lstm(inputs, prev_h, prev_c, state_size) prev_h, prev_c = next_h, next_c # shape:(1, lstm_size) # Compute for operation. with nn.parameter_scope("ops"): logit = PF.affine(next_h[-1], num_branch, w_init=initializer, with_bias=False) # shape of logit : (1, num_branch) if temperature is not None: logit = F.mul_scalar(logit, (1 / temperature)) if tanh_constant is not None: op_tanh = tanh_constant / op_tanh_reduce logit = F.mul_scalar(F.tanh(logit), op_tanh) # normalizing logits. normed_logit = np.e**logit.d normed_logit = normed_logit / np.sum(normed_logit) # Sampling operation id from multinomial distribution. branch_id = np.random.multinomial(1, normed_logit[0], 1).nonzero()[1] branch_id = nn.Variable.from_numpy_array(branch_id) ops_seq.append(branch_id.d) # log policy for operation. log_prob = F.softmax_cross_entropy( logit, F.reshape(branch_id, shape=(1, 1))) # accumulate log policy as log probs log_probs = F.add2(log_probs, log_prob) logit = F.transpose(logit, axes=(1, 0)) curr_ent = F.softmax_cross_entropy(logit, F.softmax(logit)) entropy = F.sum(curr_ent, keepdims=True) entropys = F.add2(entropys, entropy) w_emb = nn.parameter.get_parameter_or_create( "w_emb", [num_branch, lstm_size], initializer, need_grad=False) # (1, lstm_size) inputs = F.reshape(w_emb[int(branch_id.d)], (1, w_emb.shape[1])) with nn.parameter_scope("controller_lstm"): next_h, next_c = stack_lstm(inputs, prev_h, prev_c, lstm_size) prev_h, prev_c = next_h, next_c with nn.parameter_scope("skip_affine_3"): adding_w_1 = PF.affine(next_h[-1], lstm_size, w_init=initializer, with_bias=False) # (node_id + 2 + 1, lstm_size) anchors = F.concatenate(anchors, next_h[-1], axis=0) # (node_id + 2 + 1, lstm_size) anchors_w_1 = F.concatenate(anchors_w_1, adding_w_1, axis=0) for idx, ops in zip(idx_seq, ops_seq): both_archs[ind].extend([int(idx), int(ops)]) return both_archs, log_probs, entropys
def multi_head_attention(query, key, value, d_model, num_heads, need_weights=False, attn_mask=None): in_proj_weight = nn.parameter.get_parameter_or_create( name="attn/in_proj_W", shape=(d_model * 3, d_model)) in_proj_bias = nn.parameter.get_parameter_or_create(name="attn/in_proj_b", shape=(d_model * 3, )) out_proj_weight = nn.parameter.get_parameter_or_create( name="attn/out_proj/W", shape=(d_model, d_model)) out_proj_bias = nn.parameter.get_parameter_or_create( name="attn/out_proj/b", shape=(d_model, )) tgt_len, batch_size, embed_dim = query.shape src_len, _, _ = key.shape head_dim = d_model // num_heads assert head_dim * num_heads == embed_dim, 'embed_dim must be divisible by num_heads' if attn_mask is not None: if attn_mask.ndim == 2: correct_2d_size = (tgt_len, src_len) if attn_mask.shape != correct_2d_size: raise RuntimeError( f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." ) attn_mask = attn_mask.reshape((1, tgt_len, src_len)) elif attn_mask.dim() == 3: correct_3d_size = (batch_size * num_heads, tgt_len, src_len) if attn_mask.shape != correct_3d_size: raise RuntimeError( f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." ) else: raise RuntimeError( f"attn_mask's dimension {attn_mask.ndim} is not supported") q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) q = F.transpose(F.reshape(q, (tgt_len, batch_size * num_heads, head_dim)), (1, 0, 2)) # q:(B*H, L_T, head_dim) k = F.transpose(F.reshape(k, (-1, batch_size * num_heads, head_dim)), (1, 0, 2)) # k:(B*H, L_S, head_dim) v = F.transpose(F.reshape(v, (-1, batch_size * num_heads, head_dim)), (1, 0, 2)) # v:(B*H, L_S, head_vdim) dropout_p = 0.0 attn_output, attn_output_weights = _scaled_dot_product_attention( q, k, v, attn_mask, dropout_p) attn_output = F.reshape( F.transpose(attn_output, (1, 0, 2)), (tgt_len, batch_size, embed_dim)) # attn_output: (L_T, B, E_v) out_proj_weight = F.transpose(out_proj_weight, (1, 0)) attn_output = F.affine(attn_output, out_proj_weight, out_proj_bias, base_axis=2) return attn_output
def predict_dense_motion(source_image, kp_driving, kp_source, block_expansion, num_blocks, max_features, num_kp, num_channels, estimate_occlusion_map=False, scale_factor=1, kp_variance=0.01, test=False, comm=None): if scale_factor != 1: source_image = anti_alias_interpolate(source_image, num_channels, scale_factor) bs, _, h, w = source_image.shape out_dict = dict() heatmap_representation = create_heatmap_representations( source_image, kp_driving, kp_source, kp_variance) sparse_motion = create_sparse_motions(source_image, kp_driving, kp_source, num_kp) deformed_source = create_deformed_source_image(source_image, sparse_motion, num_kp) out_dict['sparse_deformed'] = deformed_source input = F.concatenate(heatmap_representation, deformed_source, axis=2) input = F.reshape(input, (bs, -1, h, w)) with nn.parameter_scope("hourglass"): prediction = hourglass(input, block_expansion=block_expansion, num_blocks=num_blocks, max_features=max_features, test=test, comm=comm) with nn.parameter_scope("mask"): inmaps, outmaps = prediction.shape[1], num_kp + 1 k_w = I.calc_normal_std_he_forward(inmaps, outmaps, kernel=(7, 7)) / np.sqrt(2.) k_b = I.calc_normal_std_he_forward(inmaps, outmaps) / np.sqrt(2.) w_init = I.UniformInitializer((-k_w, k_w)) b_init = I.UniformInitializer((-k_b, k_b)) mask = PF.convolution(prediction, outmaps=num_kp + 1, kernel=(7, 7), pad=(3, 3), w_init=w_init, b_init=b_init) mask = F.softmax(mask, axis=1) out_dict['mask'] = mask reshaped_mask = F.reshape(mask, mask.shape[:2] + (1, ) + mask.shape[2:], inplace=False) sparse_motion = F.transpose(sparse_motion, (0, 1, 4, 2, 3)) deformation = F.sum(sparse_motion * reshaped_mask, axis=1) deformation = F.transpose(deformation, (0, 2, 3, 1)) out_dict['deformation'] = deformation if estimate_occlusion_map: with nn.parameter_scope("occlusion_map"): occlusion_map = F.sigmoid( PF.convolution(prediction, outmaps=1, kernel=(7, 7), pad=(3, 3), w_init=w_init, b_init=b_init)) out_dict['occlusion_map'] = occlusion_map else: occlusion_map = None return out_dict
def location_sensitive_attention(query, values, attention_weights_cat, attention_location_kernel_size, attention_n_filters, attention_dim, is_training, scope): r"""Returns the location-sensitive attention mechanism. Args: query (nn.Variable): A query of size (B, 1, C1). values (nn.Variable): Values of size (B, T, C2). attention_weights_cat (nn.Variable): A variable of shape (B, 2, T). attention_dim (int): The projected dimensionality. scope (str): Parameter scope. Returns: nn.Variable: The context vector. nn.Variable: The attention weight vector. References: J. K. Chorowski, et al., "Attention-based models for speech recognition" in Advances in Neural Information Processing Systems, 2015, pp. 577-585. """ with nn.parameter_scope(scope): x = affine_norm(query, attention_dim, base_axis=2, with_bias=False, w_init_gain='tanh', scope='query') y = affine_norm(values, attention_dim, base_axis=2, with_bias=False, w_init_gain='tanh', scope='memory') # apply a 1D-convolutional filter z = conv_norm(attention_weights_cat, attention_n_filters, kernel_size=attention_location_kernel_size, stride=1, padding=(attention_location_kernel_size - 1) // 2, dilation=1, bias=False, w_init_gain='affine', scope='conv_norm_lsa') z = F.transpose(z, (0, 2, 1)) # location of shape (B, T, attention_dim) location = affine_norm(z, attention_dim, base_axis=2, with_bias=False, w_init_gain='tanh', scope='location') # scores of shape (B, T, 1) scores = affine_norm(F.tanh(x + y + location), 1, base_axis=2, with_bias=False, w_init_gain='affine', scope='scores') # attention_weights of shape (B, 1, T) attention_weights = F.softmax(scores, axis=1).reshape( (query.shape[0], 1, -1)) # context_vector shape after sum == (B, 1, C) context_vector = F.batch_matmul(attention_weights, values) return context_vector, attention_weights
def feature_transform_net( feature: nn.Variable, train: bool, K: int = 64) -> Tuple[nn.Variable, Dict[str, nn.Variable]]: """T net, create transformation matrix Args: feature (nn.Variable): feature, shape(batch, number of points, 1, K) train (bool): training flag K (int): transformation matrix size, default is 64. Returns: Tuple[nn.Variable, Dict[str, nn.Variable]]: transformation matrix and internal variables """ batch_size, num_points, *_ = feature.shape # B*H(=num_points)*W(=dim)*C(=K) to B*C(=K)*H(=num_points)*W(=dim) feature = F.transpose(feature, (0, 3, 1, 2)) with nn.parameter_scope("conv1"): conv_h1 = PF.convolution(feature, 64, (1, 1), stride=(1, 1), with_bias=False) conv_h1 = PF.batch_normalization(conv_h1, batch_stat=train) conv_h1 = F.relu(conv_h1) with nn.parameter_scope("conv2"): conv_h2 = PF.convolution(conv_h1, 128, (1, 1), stride=(1, 1), with_bias=False) conv_h2 = PF.batch_normalization(conv_h2, batch_stat=train) conv_h2 = F.relu(conv_h2) with nn.parameter_scope("conv3"): conv_h3 = PF.convolution(conv_h2, 1024, (1, 1), stride=(1, 1), with_bias=False) conv_h3 = PF.batch_normalization(conv_h3, batch_stat=train) conv_h3 = F.relu(conv_h3) pool_h = F.max_pooling(conv_h3, (num_points, 1)) pool_h = F.reshape(pool_h, (batch_size, -1)) with nn.parameter_scope("affine1"): affine_h1 = PF.affine(pool_h, 512, with_bias=False) affine_h1 = PF.batch_normalization(affine_h1, batch_stat=train) affine_h1 = F.relu(affine_h1) with nn.parameter_scope("affine2"): affine_h2 = PF.affine(affine_h1, 256, with_bias=False) affine_h2 = PF.batch_normalization(affine_h2, batch_stat=train) affine_h2 = F.relu(affine_h2) with nn.parameter_scope("affine3"): transform_h = PF.affine(affine_h2, K * K) eye_mat = nn.Variable.from_numpy_array( np.eye(K, dtype=np.float32).flatten()) eye_mat = F.reshape(eye_mat, (1, K * K)) transform_h = transform_h + eye_mat transform_h = F.reshape(transform_h, (batch_size, K, K)) return transform_h, { "conv_h1": conv_h1, "conv_h2": conv_h2, "conv_h3": conv_h3, "pool_h": pool_h, "affine_h1": affine_h1, "affine_h2": affine_h2, "transform_h": transform_h, }
def __call__(self, inp): ''' Define D3Net ''' valid_signal_idx = self.hparams['valid_signal_idx'] band_split_idxs = self.hparams['band_split_idxs'] + \ [self.hparams['valid_signal_idx']] inp = F.transpose(inp, (0, 2, 1, 3)) scaled_inp = (inp - self.in_offset) / self.in_scale max_final_k = 0 for k in self.hparams['dens_k']: if max_final_k < k[-1]: max_final_k = k[-1] i = 0 band_idx_start = 0 band_out = [] band_dense_out = [] # Low ~ middle bands for num_init_features, dens_k, num_layer_block, b_n_block, comp_rates in zip( self.hparams['num_init_features'], self.hparams['dens_k'], self.hparams['num_layer_blocks'], self.hparams['b_n_blocks'], self.hparams['comp_rates']): x_band = scaled_inp[:, :, :, band_idx_start:band_split_idxs[i]] x_band = self.conv2d(x_band, num_init_features, kernel_size=3, stride=1, name='features_init/%s' % i, pad=1) dense_band = self.md3_block_ds(x_band, num_init_features, dens_k, num_layer_block, b_n_block, comp_rates, name='dense_band/%s' % i) band_dense_out.append(dense_band[::-1]) if max_final_k > self.hparams['dens_k'][i][-1]: h = self.batch_norm(band_dense_out[-1][0], name='match_fm_conv/%s/norm' % i) out = self.conv2d(h, max_final_k, kernel_size=1, stride=1, name='match_fm_conv/%s/conv' % i) band_out.append(out) else: band_out.append(band_dense_out[-1][0]) band_idx_start = band_split_idxs[i] i += 1 # full bands full = self.conv2d(scaled_inp[:, :, :, :valid_signal_idx], self.hparams['f_num_init_features'], kernel_size=3, stride=1, name='features_init_full', pad=1) full = self.md3_block_ds(full, self.hparams['f_num_init_features'], self.hparams['f_dens_k'], self.hparams['f_num_layer_block'], self.hparams['f_n_blocks'], self.hparams['f_comp_rates'], name='dense_full') # concat low~middle bands and then with full bands concat_bands = F.concatenate(*band_out, axis=3) concat_full = F.concatenate(*[concat_bands, full[-1]], axis=1) # Final dense block final = self.dilated_dense_block_2(concat_full, self.hparams['ttl_dens_k'], self.hparams['ttl_num_layer_block'], scope_name='final_dense') # Define BNC_Gate : Batch-Normalization, Convolution and Sigmoid Gate with nn.parameter_scope('out_gate'): bn_out = self.batch_norm(final, name='bn') gate = F.sigmoid( self.conv2d(bn_out, self.hparams['n_channels'], kernel_size=1, stride=1, name='conv_gate/conv')) filt = self.conv2d(bn_out, self.hparams['n_channels'], kernel_size=1, stride=1, name='conv_filt/conv') out = gate * filt out = out * self.decode_scale + self.decode_bias out = F.relu(out) out = F.concatenate(*[out, inp[:, :, :, valid_signal_idx:]], axis=3) out = F.transpose(out, (0, 2, 1, 3)) return out
def pointnet_feature_extraction( point_cloud: nn.Variable, train: bool) -> Tuple[nn.Variable, Dict[str, nn.Variable]]: """pointnet feature extraction proposed by Charles R. Qi et. al. See: https://arxiv.org/pdf/1612.00593.pdf Args: point_cloud (nn.Variable): point cloud, shape(batch, number of points, 3) train (bool): training flag Returns: Tuple[nn.Variable, Dict[str, nn.Variable]]: pointnet feature and internal variables """ batch_size, num_points, _ = point_cloud.shape with nn.parameter_scope("tnet1"): point_cloud_transformation_mat, _ = point_cloud_transform_net( point_cloud, train) transformed_point_cloud = F.batch_matmul(point_cloud, point_cloud_transformation_mat) # expand dim to B*C(=K)*H(=num_points)*W(=dim) input_point_cloud = F.reshape(transformed_point_cloud, (batch_size, 1, num_points, 3)) with nn.parameter_scope("conv1"): conv_h1 = PF.convolution(input_point_cloud, 64, (1, 3), stride=(1, 1), with_bias=False) conv_h1 = PF.batch_normalization(conv_h1, batch_stat=train) conv_h1 = F.relu(conv_h1) conv_h1 = F.transpose(conv_h1, (0, 2, 3, 1)) with nn.parameter_scope("tnet2"): feature_transformation_mat, _ = feature_transform_net(conv_h1, train, K=64) transformed_feature = F.batch_matmul(conv_h1[:, :, 0, :], feature_transformation_mat) # expand dim to B*H(=num_points)*W(=dim)*C(=K) input_feature = F.reshape(transformed_feature, (batch_size, num_points, 1, 64)) # B*H(=num_points)*W(=dim)*C(=K) to B*C(=K)*H(=num_points)*W(=dim) input_feature = F.transpose(input_feature, (0, 3, 1, 2)) with nn.parameter_scope("conv2"): conv_h2 = PF.convolution(input_feature, 128, (1, 1), stride=(1, 1), with_bias=False) conv_h2 = PF.batch_normalization(conv_h2, batch_stat=train) conv_h2 = F.relu(conv_h2) with nn.parameter_scope("conv3"): conv_h3 = PF.convolution(conv_h2, 1024, (1, 1), stride=(1, 1), with_bias=False) conv_h3 = PF.batch_normalization(conv_h3, batch_stat=train) conv_h3 = F.relu(conv_h3) pool_h = F.max_pooling(conv_h3, (num_points, 1)) pool_h = F.reshape(pool_h, (batch_size, -1)) return pool_h, { "transformed_point_cloud": transformed_point_cloud, "point_cloud_transformation_mat": point_cloud_transformation_mat, "conv_h1": conv_h1, "conv_h2": conv_h2, "feature_transformation_mat": feature_transformation_mat, "transformed_feature": transformed_feature, "conv_h3": conv_h3, "pool_h": pool_h, }