def __init__(self, hparams, comm=None, test=False, recompute=False, init_method=None, input_mean=None, input_scale=None): super(D3NetMSS, self).__init__(comm=comm, test=test, recompute=recompute, init_method=init_method) self.hparams = hparams if input_mean is None or input_scale is None: input_mean = np.zeros((1, 1, 1, self.hparams['fft_size'] // 2 + 1)) input_scale = np.ones((1, 1, 1, self.hparams['fft_size'] // 2 + 1)) else: input_mean = input_mean.reshape( (1, 1, 1, self.hparams['fft_size'] // 2 + 1)) input_scale = input_scale.reshape( (1, 1, 1, self.hparams['fft_size'] // 2 + 1)) self.in_offset = get_parameter_or_create('in_offset', shape=input_mean.shape, initializer=input_mean) self.in_scale = get_parameter_or_create('in_scale', shape=input_scale.shape, initializer=input_scale) self.decode_scale = get_parameter_or_create( 'decode_scale', (1, 1, 1, self.hparams['valid_signal_idx']), initializer=I.ConstantInitializer(value=1)) self.decode_bias = get_parameter_or_create( 'decode_bias', (1, 1, 1, self.hparams['valid_signal_idx']), initializer=I.ConstantInitializer(value=1))
def layer_normalization(x: nn.Variable, eps: float = 1e-6) -> nn.Variable: batch_size, sequence_length, dim = x.shape scale = nn.parameter.get_parameter_or_create( 'scale', shape=(1, 1, dim), initializer=I.ConstantInitializer(1.0)) bias = nn.parameter.get_parameter_or_create( 'bias', shape=(1, 1, dim), initializer=I.ConstantInitializer(0.0)) mean = F.mean(x, axis=2, keepdims=True) std = F.mean((x - mean)**2, axis=2, keepdims=True)**0.5 return scale * (x - mean) / (std + eps) + bias
def test_node_representation(self): h = nn.Variable((1, 1)) h.data.data[0] = 1 x = nn.Variable((1, 1)) x.data.data[0] = 2 with nn.parameter_scope("test_node_representation"): r = L.node_representation(h, x, 1, w_init=I.ConstantInitializer(1), b_init=I.ConstantInitializer(0)) self.assertEqual((1, 1), r.shape) r.forward() self.assertEqual(3, r.data.data[0, 0])
def test_compute_simple_hessian(ctx): nn.clear_parameters() # Network state = nn.Variable((1, 2)) output = PF.affine(state, 1, w_init=I.ConstantInitializer(value=1.), b_init=I.ConstantInitializer(value=1.)) loss = F.sum(output**2) # Input state_array = np.array([[1.0, 0.5]]) state.d = state_array # Grad of network params = nn.get_parameters().values() for param in params: param.grad.zero() grads = nn.grad([loss], params) flat_grads = F.concatenate(*[F.reshape(grad, (-1,)) for grad in grads]) if len(grads) > 1 \ else F.reshape(grads[0], (-1,)) # Compute hessian hessian = np.zeros((flat_grads.shape[0], flat_grads.shape[0]), dtype=np.float32) for i in range(flat_grads.shape[0]): flat_grads_i = flat_grads[i] flat_grads_i.forward() for param in params: param.grad.zero() flat_grads_i.backward() num_index = 0 for param in params: grad = param.g.flatten() # grad of grad so this is hessian hessian[i, num_index:num_index + len(grad)] = grad num_index += len(grad) actual = hessian expected = np.array([[ 2 * state_array[0, 0]**2, 2 * state_array[0, 0] * state_array[0, 1], 2 * state_array[0, 0] ], [ 2 * state_array[0, 0] * state_array[0, 1], 2 * state_array[0, 1]**2, 2 * state_array[0, 1] ], [2 * state_array[0, 0], 2 * state_array[0, 1], 2.]]) assert_allclose(actual, expected)
def nin(x, c, name, zeroing_w=False): lim = np.sqrt(x.shape[1])**-1 w_init = I.UniformInitializer(lim=(-lim, lim)) # same as pytorch's default b_init = I.UniformInitializer(lim=(-lim, lim)) # same as pytorch's default if zeroing_w: w_init = I.ConstantInitializer(0) b_init = I.ConstantInitializer(0) return PF.convolution(x, c, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name=name, w_init=w_init, b_init=b_init)
def detect_keypoint(x, block_expansion, num_kp, num_channels, max_features, num_blocks, temperature, estimate_jacobian=False, scale_factor=1, single_jacobian_map=False, pad=0, test=False, comm=None): if scale_factor != 1: x = anti_alias_interpolate(x, num_channels, scale_factor) with nn.parameter_scope("hourglass"): feature_map = hourglass(x, block_expansion, num_blocks=num_blocks, max_features=max_features, test=test, comm=comm) with nn.parameter_scope("keypoint_detector"): inmaps, outmaps = feature_map.shape[1], num_kp k_w = I.calc_normal_std_he_forward( inmaps, outmaps, kernel=(7, 7)) / np.sqrt(2.) k_b = I.calc_normal_std_he_forward(inmaps, outmaps) / np.sqrt(2.) w_init = I.UniformInitializer((-k_w, k_w)) b_init = I.UniformInitializer((-k_b, k_b)) prediction = PF.convolution(feature_map, outmaps=num_kp, kernel=(7, 7), pad=(pad, pad), w_init=w_init, b_init=b_init) final_shape = prediction.shape heatmap = F.reshape(prediction, (final_shape[0], final_shape[1], -1)) heatmap = F.softmax(heatmap / temperature, axis=2) heatmap = F.reshape(heatmap, final_shape, inplace=False) out = gaussian2kp(heatmap) # {"value": value}, keypoint positions. if estimate_jacobian: if single_jacobian_map: num_jacobian_maps = 1 else: num_jacobian_maps = num_kp with nn.parameter_scope("jacobian_estimator"): jacobian_map = PF.convolution(feature_map, outmaps=4*num_jacobian_maps, kernel=(7, 7), pad=(pad, pad), w_init=I.ConstantInitializer(0), b_init=np.array([1, 0, 0, 1]*num_jacobian_maps)) jacobian_map = F.reshape( jacobian_map, (final_shape[0], num_jacobian_maps, 4, final_shape[2], final_shape[3])) heatmap = F.reshape( heatmap, heatmap.shape[:2] + (1,) + heatmap.shape[2:], inplace=False) jacobian = heatmap * jacobian_map jacobian = F.sum(jacobian, axis=(3, 4)) jacobian = F.reshape( jacobian, (jacobian.shape[0], jacobian.shape[1], 2, 2), inplace=False) out['jacobian'] = jacobian # jacobian near each keypoint. # out is a dictionary containing {"value": value, "jacobian": jacobian} return out
def test_graph_representation(self): h = nn.Variable((2, 1)) h.data.data[0] = 1 h.data.data[1] = 2 x = nn.Variable((2, 1)) x.data.data[0] = 2 x.data.data[1] = 4 with nn.parameter_scope("test_graph_representation"): r = L.graph_representation(h, x, 1, w_init=I.ConstantInitializer(1), b_init=I.ConstantInitializer(0)) self.assertEqual((1, 1), r.shape) r.forward() actual = 0 actual += 1 / (1 + math.exp(-3)) * math.tanh(3) actual += 1 / (1 + math.exp(-6)) * math.tanh(6) self.assertTrue(np.allclose(actual, r.data.data[0, 0]))
def node_annotation(self): h = nn.Variable((1, 2)) h.data.data[0, 0] = 1 h.data.data[0, 1] = 0 x = nn.Variable((1, 1)) x.data.data[0] = 2 with nn.parameter_scope("test_node_annotation"): h, x = L.node_annotation(h, x, 1, w_init=I.ConstantInitializer(1), b_init=I.ConstantInitializer(0)) self.assertEqual((1, 2), h.shape) self.assertEqual((1, 1), x.shape) F.sink(h, x).forward() actual = 1 / (1 + math.exp(-3)) self.assertTrue(np.allclose(actual, h.data.data[0, 0])) self.assertEqual(0, h.data.data[0, 1]) self.assertTrue(np.allclose(actual, x.data.data[0, 0]))
def conv_2d(x, o_ch, kernel, name=None): """ Convolution for JSInet """ b = I.ConstantInitializer(0.) h = PF.convolution(x, o_ch, kernel=kernel, stride=(1, 1), pad=(1, 1), channel_last=True, b_init=b, name=name) return h
def bilstm(x, mask, state_size, w_init=None, inner_w_init=None, forget_bias_init=I.ConstantInitializer(1), b_init=I.ConstantInitializer(0), initial_state=None, dropout=0, train=True, rng=np.random): rx = F.flip(x, axes=[1]) # reverse rmask = F.flip(mask, axes=[1]) # reverse with nn.parameter_scope("forward"): hs, _ = lstm(x, mask, state_size, w_init, inner_w_init, forget_bias_init, b_init, initial_state, dropout, train, rng) with nn.parameter_scope("backward"): rhs, _ = lstm(rx, rmask, state_size, w_init, inner_w_init, forget_bias_init, b_init, initial_state, dropout, train, rng) hs2 = F.flip(rhs, axes=[1]) # reverse return concatenate(hs, hs2, axis=2) # (batch_size, length, 2 * state_size)
def conv(x, c, name, kernel=(3, 3), pad=(1, 1), stride=(1, 1), zeroing_w=False): # init weight and bias with uniform, which is the same as pytorch lim = I.calc_normal_std_he_forward(x.shape[1] * 2, c, tuple(kernel)) w_init = I.UniformInitializer(lim=(-lim, lim), rng=None) b_init = I.UniformInitializer(lim=(-lim, lim), rng=None) if zeroing_w: w_init = I.ConstantInitializer(0) b_init = I.ConstantInitializer(0) return PF.convolution(x, c, kernel, pad=pad, stride=stride, name=name, w_init=w_init, b_init=b_init)
def dense(x, output_dim, base_axis=1, w_init=None, b_init=I.ConstantInitializer(0), activation=F.tanh): if w_init is None: w_init = I.UniformInitializer( I.calc_uniform_lim_glorot(np.prod(x.shape[1:]), output_dim)) return activation( PF.affine(x, output_dim, base_axis=base_axis, w_init=w_init, b_init=b_init))
def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, scope='conv_0'): """ Convolution for discriminator """ w_n_shape = (channels, kernel, kernel, x.shape[-1]) w_init = truncated_normal(w_n_shape, mean=0.0, std=0.02) b_init = I.ConstantInitializer(0.) with nn.parameter_scope(scope): if pad > 0: h = x.shape[1] if h % stride == 0: pad = pad * 2 else: pad = max(kernel - (h % stride), 0) pad_top = pad // 2 pad_bottom = pad - pad_top pad_left = pad // 2 pad_right = pad - pad_left if pad_type == 'zero': x = F.pad( x, (0, 0, pad_top, pad_bottom, pad_left, pad_right, 0, 0)) if pad_type == 'reflect': x = F.pad( x, (0, 0, pad_top, pad_bottom, pad_left, pad_right, 0, 0), mode='reflect') def apply_w(w): return PF.spectral_norm(w, dim=0) x = PF.convolution(x, channels, kernel=(kernel, kernel), stride=(stride, stride), apply_w=apply_w, w_init=w_init, b_init=b_init, with_bias=use_bias, channel_last=True) return x
def conv_initializer(f_in, n_out, base_axis, kernel, mode): ''' Conv initializer function This function returns various types of initialization for weights and bias parameters in convolution layer. Args: f_in (~nnabla.Variable): input variable. n_out (int) : number of output neurons per data. base_axis (int): dimensions up to base_axis are treated as the sample dimensions. kernel (tuple of int) : convolution kernel size. mode (str) : type of initialization to use. Returns: w (~nnabla.initializer.BaseInitializer): weight parameters b (~nnabla.initializer.BaseInitializer): bias parameters ''' if mode == 'nnabla': # https://github.com/sony/nnabla/blob/master/python/src/nnabla/parametric_functions.py, line415, 417 # https://github.com/sony/nnabla/blob/master/python/src/nnabla/initializer.py, line224. 121 # uniform_lim_glorot = uniform(sqrt(6/(fin+fout))) n_input_plane = f_in.shape[base_axis] s = np.sqrt(6.0 / (n_input_plane * np.prod(kernel) + n_out)) w = I.UniformInitializer([-s, s]) b = I.ConstantInitializer(0) return w, b
y.forward() y.backward() def check_none_arg(arg, val, none_case): if val is None: assert arg == none_case return assert arg == val @pytest.mark.parametrize("inshape", [(8, 2, 2, 2), (16, 1, 8)]) @pytest.mark.parametrize("n_outmaps", [16, 32]) @pytest.mark.parametrize("base_axis", [1, 2]) @pytest.mark.parametrize("w_init", [None, I.NormalInitializer(), True]) @pytest.mark.parametrize("b_init", [None, I.ConstantInitializer(), True]) @pytest.mark.parametrize("with_bias", [False, True]) @pytest.mark.parametrize("fix_parameters", [False, True]) @pytest.mark.parametrize("rng", [None, True]) def test_pf_affine_execution(g_rng, inshape, n_outmaps, base_axis, w_init, b_init, with_bias, fix_parameters, rng): w_shape = (int(np.prod(inshape[base_axis:])), n_outmaps) b_shape = (n_outmaps, ) w_init = process_param_init(w_init, w_shape, g_rng) b_init = process_param_init(b_init, b_shape, g_rng) rng = process_rng(rng) kw = {} insert_if_not_none(kw, 'w_init', w_init) insert_if_not_none(kw, 'b_init', b_init)
def Generator(Noisy, z): """ Building generator network [Arguments] Noisy : Noisy speech waveform (Batch, 1, 16384) Output : (Batch, 1, 16384) """ ## Sub-functions ## --------------------------------- # Convolution def conv(x, output_ch, karnel=(32,), pad=(15,), stride=(2,), name=None, w_init=None, b_init=None): return PF.convolution(x, output_ch, karnel, pad=pad, stride=stride, name=name, w_init=w_init, b_init=b_init) # deconvolution def deconv(x, output_ch, karnel=(32,), pad=(15,), stride=(2,), name=None): return PF.deconvolution(x, output_ch, karnel, pad=pad, stride=stride, name=name) # Activation Function def af(x, name=None): return PF.prelu(x, name=name) def af2(x, name=None): return F.tanh(x) # Concantate input and skip-input def concat(x, h, axis=1): return F.concatenate(x, h, axis=axis) ## Main Processing ## --------------------------------- with nn.parameter_scope("gen"): # Genc : Encoder in Generator enc1 = af(conv(Noisy, 16, name="enc1")) # Input:(16384, 1) --> (16, 8192) *convolution reshapes output to (No. of Filter, Output Size) automatically enc2 = af(conv(enc1, 32, name="enc2")) # (16, 8192) --> (32, 4096) enc3 = af(conv(enc2, 32, name="enc3")) # (32, 4096) --> (32, 2048) enc4 = af(conv(enc3, 64, name="enc4")) # (32, 2048) --> (64, 1024) enc5 = af(conv(enc4, 64, name="enc5")) # (64, 1024) --> (64, 512) enc6 = af(conv(enc5, 128, name="enc6")) # (64, 512) --> (128, 256) enc7 = af(conv(enc6, 128, name="enc7")) # (128, 256) --> (128, 128) enc8 = af(conv(enc7, 256, name="enc8")) # (128, 128) --> (256, 64) enc9 = af(conv(enc8, 256, name="enc9")) # (256, 64) --> (256, 32) enc10 = af(conv(enc9, 512, name="enc10")) # (256, 32) --> (512, 16) enc11 = af2(conv(enc10, 1024, name="enc11", w_init=I.ConstantInitializer(), b_init=I.ConstantInitializer()))# (512, 16) --> (1024, 8) # Latent Variable (concat random sequence) with nn.parameter_scope("latent"): C = F.concatenate(enc11, z, axis=1) # (1024, 8) --> (2048, 8) # Gdec : Decoder in Generator # Concatenate skip input for each layer dec1 = concat(af(deconv(C, 512, name="dec1")), enc10) # (2048, 8) --> (512, 16) >> [concat](1024, 16) dec2 = concat(af(deconv(dec1, 256, name="dec2")), enc9) # (1024, 16) --> (256, 32) dec3 = concat(af(deconv(dec2, 256, name="dec3")), enc8) # (512, 32) --> (256, 64) dec4 = concat(af(deconv(dec3, 128, name="dec4")), enc7) # (512, 128) --> (128, 256) dec5 = concat(af(deconv(dec4, 128, name="dec5")), enc6) # (512, 128) --> (128, 256) dec6 = concat(af(deconv(dec5, 64, name="dec6")), enc5) # (512, 256) --> (64, 512) dec7 = concat(af(deconv(dec6, 64, name="dec7")), enc4) # (128, 512) --> (64, 1024) dec8 = concat(af(deconv(dec7, 32, name="dec8")), enc3) # (128, 1024) --> (32, 2048) dec9 = concat(af(deconv(dec8, 32, name="dec9")), enc2) # (64, 2048) --> (32, 4096) dec10 = concat(af(deconv(dec9, 16, name="dec10")), enc1) # (32, 4096) --> (16, 8192) dec11 = F.tanh(deconv(dec10, 1, name="dec11")) # (32, 8192) --> (1, 16384) return dec11
def implicit_network(x, D=512, feature_size=256, L=9, skip_in=[4], N=6, act="softplus", including_input=True, initial_sphere_radius=0.75): """Implicit Network. Args: x: Position on a ray. D: Dimension of a network. feature_size: Feature dimension of the final output. L: Number of layers skip_in: Where the skip connection appears. N: Number of frequency of the positional encoding. act: Activation function. inclugin_input: Include input to the positional encoding (PE). initial_sphere_radius: the radius of the initial network sphere. Network architecture looks like: x --> [PE(x)] --> affine --> relu --> ... --> concate([h, x]) --> affine --> relu --> ... --> affine(h) --> [sdf, feature] """ act_map = dict(relu=F.relu, softplus=partial(F.softplus, beta=100)) Dx = x.shape[-1] act = act_map[act] h = positional_encoding(x, N, including_input) for l in range(L): # First if l == 0: Dh = h.shape[-1] Dx = x.shape[-1] w_init = GeometricInitializer(Dh, D, 2 / D, Dx) h = affine(h, D, w_init=w_init, name=f"affine-{l:02d}") h = act(h) # Skip elif l in skip_in: w_init = GeometricInitializer(D, D, 2 / (D - Dx), -Dx) h = affine(h, D, w_init=w_init, name=f"affine-{l:02d}") h = act(h) # Last (scalar + feature_size) elif l == L - 1: Do = 1 + feature_size w_init = np.sqrt(np.pi / D) * np.ones([D, Do]) h = affine(h, Do, w_init=w_init, b_init=I.ConstantInitializer(-initial_sphere_radius), name=f"affine-last") # Intermediate else: Do = D - Dx if l + 1 in skip_in else D w_init = GeometricInitializer(D, Do, 2 / Do) h = affine(h, Do, w_init=w_init, name=f"affine-{l:02d}") h = act(h) h = F.concatenate(*[h, x]) if l + 1 in skip_in else h # h = F.concatenate(*[h, x]) / np.sqrt(2) if l + 1 in skip_in else h # (the paper used this scale) return h
def lstm(x, mask, state_size, w_init=None, inner_w_init=None, forget_bias_init=I.ConstantInitializer(1), b_init=I.ConstantInitializer(0), initial_state=None, dropout=0, train=True, rng=np.random): """ x: (batch_size, length, input_size) mask: (batch_size, length) """ batch_size, length, input_size = x.shape if w_init is None: w_init = I.UniformInitializer( I.calc_uniform_lim_glorot(input_size, state_size)) if inner_w_init is None: inner_w_init = orthogonal retain_prob = 1.0 - dropout z_w = nn.Variable((batch_size, 4, input_size), need_grad=False) z_w.d = 1 z_u = nn.Variable((batch_size, 4, state_size), need_grad=False) z_u.d = 1 if dropout > 0: if train: z_w = F.dropout(z_w, p=retain_prob) z_u = F.dropout(z_u, p=retain_prob) z_w *= retain_prob z_u *= retain_prob z_w = F.reshape(z_w, (batch_size, 4, 1, input_size)) z_w = F.broadcast(z_w, (batch_size, 4, length, input_size)) z_w = F.split(z_w, axis=1) z_u = F.split(z_u, axis=1) xi = z_w[0] * x xf = z_w[1] * x xc = z_w[2] * x xo = z_w[3] * x with nn.parameter_scope("lstm"): # (batch_size, length, state_size) xi = PF.affine(xi, state_size, base_axis=2, w_init=w_init, b_init=b_init, name="Wi") xf = PF.affine(xf, state_size, base_axis=2, w_init=w_init, b_init=forget_bias_init, name="Wf") xc = PF.affine(xc, state_size, base_axis=2, w_init=w_init, b_init=b_init, name="Wc") xo = PF.affine(xo, state_size, base_axis=2, w_init=w_init, b_init=b_init, name="Wo") if initial_state is None: h = nn.Variable((batch_size, state_size), need_grad=False) h.data.zero() else: h = initial_state c = nn.Variable((batch_size, state_size), need_grad=False) c.data.zero() # (batch_size, state_size) xi = split(xi, axis=1) xf = split(xf, axis=1) xc = split(xc, axis=1) xo = split(xo, axis=1) mask = F.reshape(mask, [batch_size, length, 1]) # (batch_size, length, 1) mask = F.broadcast(mask, [batch_size, length, state_size]) # (batch_size, state_size) mask = split(mask, axis=1) hs = [] cs = [] with nn.parameter_scope("lstm"): for i, f, c2, o, m in zip(xi, xf, xc, xo, mask): i_t = PF.affine(z_u[0] * h, state_size, w_init=inner_w_init(state_size, state_size), with_bias=False, name="Ui") i_t = F.sigmoid(i + i_t) f_t = PF.affine(z_u[1] * h, state_size, w_init=inner_w_init(state_size, state_size), with_bias=False, name="Uf") f_t = F.sigmoid(f + f_t) c_t = PF.affine(z_u[2] * h, state_size, w_init=inner_w_init(state_size, state_size), with_bias=False, name="Uc") c_t = f_t * c + i_t * F.tanh(c2 + c_t) o_t = PF.affine(z_u[3] * h, state_size, w_init=inner_w_init(state_size, state_size), with_bias=False, name="Uo") o_t = F.sigmoid(o + o_t) h_t = o_t * F.tanh(c_t) h_t = (1 - m) * h + m * h_t c_t = (1 - m) * c + m * c_t h = h_t c = c_t h_t = F.reshape(h_t, (batch_size, 1, state_size), inplace=False) c_t = F.reshape(c_t, (batch_size, 1, state_size), inplace=False) hs.append(h_t) cs.append(c_t) return concatenate(*hs, axis=1), concatenate(*cs, axis=1)
def Wave_U_Net(Noisy): ds_outputs = list() num_initial_filters = 24 num_layers = 12 filter_size = 15 merge_filter_size = 5 b = I.ConstantInitializer() w = I.NormalInitializer(sigma=0.02) ## Sub-functions ## --------------------------------- # Convolution def conv(x, output_ch, karnel=(15,), pad=(7,), stride=(1,), name=None): return PF.convolution(x, output_ch, karnel, pad=pad, stride=stride, w_init=w, b_init=b, name=name) # Activation Function def af(x, alpha=0.2): return F.leaky_relu(x, alpha) # def crop_and_concat(x1, x2): def crop(tensor, target_times): shape = tensor.shape[2] diff = shape - target_times if diff == 0: return tensor crop_start = diff // 2 crop_end = diff - crop_start return F.slice(tensor, start=(0, 0, crop_start), stop=(tensor.shape[0], tensor.shape[1], shape - crop_end), step=(1, 1, 1)) x1 = crop(x1, x2.shape[2]) return F.concatenate(x1, x2, axis=1) def downsampling_block(x, i): with nn.parameter_scope(('ds_block-%2d' % i)): ds = af(conv(x, (num_initial_filters + num_initial_filters * i), (filter_size,), (7,), name='conv')) ds_slice = F.slice(ds, start=(0, 0, 0), stop=ds.shape, step=(1, 1, 2)) # Decimate by factor of 2 # ds_slice = F.average_pooling(ds, kernel=(1, 1,), stride=(1, 2,), pad=(0, 0,)) return ds, ds_slice def upsampling_block(x, i): with nn.parameter_scope(('us_block-%2d' % i)): up = F.unpooling(af(x), (2,)) cac_x = crop_and_concat(ds_outputs[-i - 1], up) us = af(conv(cac_x, num_initial_filters + num_initial_filters * (num_layers - i - 1), (merge_filter_size,), (2,), name='conv')) return us with nn.parameter_scope('Wave-U-Net'): current_layer = Noisy ## downsampling block for i in range(num_layers): ds, current_layer = downsampling_block(current_layer, i) ds_outputs.append(ds) ## latent variable with nn.parameter_scope('latent_variable'): current_layer = af(conv(current_layer, num_initial_filters + num_initial_filters * num_layers)) ## upsampling block for i in range(num_layers): current_layer = upsampling_block(current_layer, i) current_layer = crop_and_concat(Noisy, current_layer) ## output layer target_1 = F.tanh(conv(current_layer, 1, (1,), (0,), name='target_1')) target_2 = F.tanh(conv(current_layer, 1, (1,), (0,), name='target_2')) return target_1, target_2
def main(): args = get_args() state_size = args.state_size batch_size = args.batch_size num_steps = args.num_steps num_layers = args.num_layers max_epoch = args.max_epoch max_norm = args.gradient_clipping_max_norm num_words = 10000 lr = args.learning_rate train_data, val_data, test_data = get_data() # Get context. from nnabla.ext_utils import get_extension_context logger.info("Running in %s" % args.context) ctx = get_extension_context( args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) from nnabla.monitor import Monitor, MonitorSeries monitor = Monitor(args.work_dir) monitor_perplexity = MonitorSeries( "Training perplexity", monitor, interval=10) monitor_vperplexity = MonitorSeries("Validation perplexity", monitor, interval=( len(val_data)//(num_steps*batch_size))) monitor_tperplexity = MonitorSeries( "Test perplexity", monitor, interval=(len(test_data)//(num_steps*1))) l1 = LSTMWrapper(batch_size, state_size) l2 = LSTMWrapper(batch_size, state_size) # train graph x = nn.Variable((batch_size, num_steps)) t = nn.Variable((batch_size, num_steps)) w = I.UniformInitializer((-0.1, 0.1)) b = I.ConstantInitializer(1) loss = get_loss(l1, l2, x, t, w, b, num_words, batch_size, state_size, True) l1.share_data() l2.share_data() # validation graph vx = nn.Variable((batch_size, num_steps)) vt = nn.Variable((batch_size, num_steps)) vloss = get_loss(l1, l2, vx, vt, w, b, num_words, batch_size, state_size) solver = S.Sgd(lr) solver.set_parameters(nn.get_parameters()) if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) best_val = 10000 for epoch in range(max_epoch): l1.reset_state() l2.reset_state() for i in range(len(train_data)//(num_steps*batch_size)): x.d, t.d = get_batch(train_data, i*num_steps, batch_size, num_steps) solver.zero_grad() loss.forward() loss.backward(clear_buffer=True) solver.weight_decay(1e-5) gradient_clipping(nn.get_parameters().values(), max_norm) solver.update() perp = perplexity(loss.d.copy()) monitor_perplexity.add( (len(train_data)//(num_steps*batch_size))*(epoch)+i, perp) l1.reset_state() l2.reset_state() vloss_avg = 0 for i in range(len(val_data)//(num_steps * batch_size)): vx.d, vt.d = get_batch(val_data, i*num_steps, batch_size, num_steps) vloss.forward() vloss_avg += vloss.d.copy() vloss_avg /= float((len(val_data)//(num_steps*batch_size))) vper = perplexity(vloss_avg) if vper < best_val: best_val = vper if vper < 200: save_name = "params_epoch_{:02d}.h5".format(epoch) nn.save_parameters(os.path.join(args.save_dir, save_name)) else: solver.set_learning_rate(solver.learning_rate()*0.25) logger.info("Decreased learning rate to {:05f}".format( solver.learning_rate())) monitor_vperplexity.add( (len(val_data)//(num_steps*batch_size))*(epoch)+i, vper) # for final test split t_batch_size = 1 tl1 = LSTMWrapper(t_batch_size, state_size) tl2 = LSTMWrapper(t_batch_size, state_size) tloss_avg = 0 tx = nn.Variable((t_batch_size, num_steps)) tt = nn.Variable((t_batch_size, num_steps)) tloss = get_loss(tl1, tl2, tx, tt, w, b, num_words, 1, state_size) tl1.share_data() tl2.share_data() for i in range(len(test_data)//(num_steps * t_batch_size)): tx.d, tt.d = get_batch(test_data, i*num_steps, 1, num_steps) tloss.forward() tloss_avg += tloss.d.copy() tloss_avg /= float((len(test_data)//(num_steps*t_batch_size))) tper = perplexity(tloss_avg) monitor_tperplexity.add( (len(test_data)//(num_steps*t_batch_size))*(epoch)+i, tper)
def cond_att_lstm(x, parent_index, mask, context, context_mask, state_size, att_hidden_size, initial_state=None, initial_cell=None, hist=None, dropout=0, train=True, w_init=None, inner_w_init=None, b_init=I.ConstantInitializer(0), forget_bias_init=I.ConstantInitializer(1)): """ x: (batch_size, length, input_size) parent_index: (batch_size, length) mask: (batch_size, length) context: (batch_size, context_length, context_size) context_mask: (batch_size, context_length) hist: (batch_size, l, state_size) """ batch_size, length, input_size = x.shape _, context_length, context_size = context.shape if w_init is None: w_init = I.UniformInitializer( I.calc_uniform_lim_glorot(input_size, state_size)) if inner_w_init is None: inner_w_init = orthogonal retain_prob = 1.0 - dropout z_w = nn.Variable((batch_size, 4, input_size), need_grad=False) z_w.d = 1 z_u = nn.Variable((batch_size, 4, state_size), need_grad=False) z_u.d = 1 if dropout > 0: if train: z_w = F.dropout(z_w, p=retain_prob) z_u = F.dropout(z_u, p=retain_prob) z_w *= retain_prob z_u *= retain_prob z_w = F.reshape(z_w, (batch_size, 4, 1, input_size)) z_w = F.broadcast(z_w, (batch_size, 4, length, input_size)) z_w = F.split(z_w, axis=1) z_u = F.split(z_u, axis=1) xi = z_w[0] * x xf = z_w[1] * x xc = z_w[2] * x xo = z_w[3] * x with nn.parameter_scope("cond_att_lstm"): # (batch_size, length, state_size) with nn.parameter_scope("lstm"): xi = PF.affine( xi, state_size, base_axis=2, w_init=w_init, b_init=b_init, name="Wi") xf = PF.affine( xf, state_size, base_axis=2, w_init=w_init, b_init=forget_bias_init, name="Wf") xc = PF.affine( xc, state_size, base_axis=2, w_init=w_init, b_init=b_init, name="Wc") xo = PF.affine( xo, state_size, base_axis=2, w_init=w_init, b_init=b_init, name="Wo") with nn.parameter_scope("context"): # context_att_trans: (batch_size, context_size, att_hidden_size) context_att_trans = PF.affine( context, att_hidden_size, base_axis=2, w_init=w_init, b_init=b_init, name="layer1_c") if initial_state is None: h = nn.Variable((batch_size, state_size), need_grad=False) h.data.zero() else: h = initial_state if initial_cell is None: c = nn.Variable((batch_size, state_size), need_grad=False) c.data.zero() else: c = initial_cell if hist is None: hist = nn.Variable((batch_size, 1, state_size), need_grad=False) hist.data.zero() # (batch_size, state_size) xi = split(xi, axis=1) xf = split(xf, axis=1) xc = split(xc, axis=1) xo = split(xo, axis=1) mask = F.reshape(mask, [batch_size, length, 1]) # (batch_size, length, 1) mask = F.broadcast(mask, [batch_size, length, state_size]) # (batch_size, state_size) mask = split(mask, axis=1) # (batch_size, max_action_length) parent_index = parent_index + 1 # index == 0 means that parent is root # (batch_size) parent_index = split(parent_index, axis=1) hs = [] cs = [] ctx = [] for i, f, c2, o, m, p in zip(xi, xf, xc, xo, mask, parent_index): h_num = hist.shape[1] with nn.parameter_scope("context"): h_att_trans = PF.affine( h, att_hidden_size, with_bias=False, w_init=w_init, name="layer1_h") # (batch_size, att_hidden_size) h_att_trans = F.reshape(h_att_trans, (batch_size, 1, att_hidden_size)) h_att_trans = F.broadcast( h_att_trans, (batch_size, context_length, att_hidden_size)) att_hidden = F.tanh(context_att_trans + h_att_trans) att_raw = PF.affine( att_hidden, 1, base_axis=2, w_init=w_init, b_init=b_init) # (batch_size, context_length, 1) att_raw = F.reshape(att_raw, (batch_size, context_length)) ctx_att = F.exp(att_raw - F.max(att_raw, axis=1, keepdims=True)) ctx_att = ctx_att * context_mask ctx_att = ctx_att / F.sum(ctx_att, axis=1, keepdims=True) ctx_att = F.reshape(ctx_att, (batch_size, context_length, 1)) ctx_att = F.broadcast(ctx_att, (batch_size, context_length, context_size)) ctx_vec = F.sum( context * ctx_att, axis=1) # (batch_size, context_size) # parent_history p = F.reshape(p, (batch_size, 1)) p = F.one_hot(p, (h_num, )) p = F.reshape(p, (batch_size, 1, h_num)) par_h = F.batch_matmul(p, hist) # [batch_size, 1, state_size] par_h = F.reshape(par_h, (batch_size, state_size)) with nn.parameter_scope("lstm"): i_t = PF.affine( z_u[0] * h, state_size, w_init=inner_w_init(state_size, state_size), with_bias=False, name="Ui") i_t += PF.affine( ctx_vec, state_size, w_init=inner_w_init(context_size, state_size), with_bias=False, name="Ci") i_t += PF.affine( par_h, state_size, w_init=inner_w_init(state_size, state_size), with_bias=False, name="Pi") i_t = F.sigmoid(i + i_t) f_t = PF.affine( z_u[1] * h, state_size, w_init=inner_w_init(state_size, state_size), with_bias=False, name="Uf") f_t += PF.affine( ctx_vec, state_size, w_init=inner_w_init(context_size, state_size), with_bias=False, name="Cf") f_t += PF.affine( par_h, state_size, w_init=inner_w_init(state_size, state_size), with_bias=False, name="Pf") f_t = F.sigmoid(f + f_t) c_t = PF.affine( z_u[2] * h, state_size, w_init=inner_w_init(state_size, state_size), with_bias=False, name="Uc") c_t += PF.affine( ctx_vec, state_size, w_init=inner_w_init(context_size, state_size), with_bias=False, name="Cc") c_t += PF.affine( par_h, state_size, w_init=inner_w_init(state_size, state_size), with_bias=False, name="Pc") c_t = f_t * c + i_t * F.tanh(c2 + c_t) o_t = PF.affine( z_u[3] * h, state_size, w_init=inner_w_init(state_size, state_size), with_bias=False, name="Uo") o_t += PF.affine( ctx_vec, state_size, w_init=inner_w_init(context_size, state_size), with_bias=False, name="Co") o_t += PF.affine( par_h, state_size, w_init=inner_w_init(state_size, state_size), with_bias=False, name="Po") o_t = F.sigmoid(o + o_t) h_t = o_t * F.tanh(c_t) h_t = (1 - m) * h + m * h_t c_t = (1 - m) * c + m * c_t h = h_t c = c_t h_t = F.reshape(h_t, (batch_size, 1, state_size), inplace=False) c_t = F.reshape(c_t, (batch_size, 1, state_size), inplace=False) ctx_vec = F.reshape( ctx_vec, (batch_size, 1, context_size), inplace=False) hs.append(h_t) cs.append(c_t) ctx.append(ctx_vec) hist = F.concatenate( hist, h_t, axis=1) # (batch_size, h_num + 1, state_size) return concatenate( *hs, axis=1), concatenate( *cs, axis=1), concatenate( *ctx, axis=1), hist