def test_Abs(tmpdir): shape = (4, 5) data = np.random.rand(*shape).astype(np.float32) model = C.abs(data) verify_no_input(model, tmpdir, 'Abs_0') x = C.input_variable(shape) model = C.abs(x) verify_one_input(model, data, tmpdir, 'Abs_1')
def test_Abs(tmpdir, dtype): with C.default_options(dtype = dtype): shape = (4, 5) data = np.random.rand(*shape).astype(dtype) model = C.abs(data) verify_no_input(model, tmpdir, 'Abs_0') x = C.input_variable(shape) model = C.abs(x) verify_one_input(model, data, tmpdir, 'Abs_1')
def multiFunc(self, arg1): # load or create the inputs we need multiIn = C.input(shape=arg1.shape, dynamic_axes = arg1.dynamic_axes) bit_map = C.constant(self.bit_map) max_bits = self.bit_map.max() shape = multiIn.shape reformed = C.reshape(multiIn, (-1,)) # lets compute the means we need # carry over represents the remaining value that needs to binarized. For a single bit, this is just the input. For more bits, # it is the difference between the previous bits approximation and the true value. carry_over = multiIn approx = C.element_times(multiIn, 0) # iterate through the maximum number of bits specified by the bit maps, basically compute each level of binarization for i in range(max_bits): # determine which values of the input should be binarized to i bits or more hot_vals = C.greater(bit_map, i) # select only the values which we need to binarize valid_vals = C.element_select(hot_vals, carry_over, 0) # compute mean on a per kernel basis, reshaping is done to allow for sum reduction along only axis 0 (the kernels) mean = C.element_divide(C.reduce_sum(C.reshape(C.abs(valid_vals), (valid_vals.shape[0], -1)), axis=1), C.reduce_sum(C.reshape(hot_vals, (hot_vals.shape[0], -1)), axis=1)) # reshape the mean to match the dimensionality of the input mean = C.reshape(mean, (mean.shape[0], mean.shape[1], 1, 1)) # binarize the carry over bits = C.greater(carry_over, 0) bits = C.element_select(bits, bits, -1) bits = C.element_select(hot_vals, bits, 0) # add in the equivalent binary representation to the approximation approx = C.plus(approx, C.element_times(mean, bits)) # compute the new carry over carry_over = C.plus(C.element_times(C.element_times(-1, bits), mean), carry_over) return approx, multiIn
def gradFunc(self, arg): # create an input variable corresponding the inputs of the forward prop function gradIn = C.input(shape=arg.shape, dynamic_axes=arg.dynamic_axes) # create an input variable for the gradient passed from the next stage gradRoot = C.input(shape=arg.shape, dynamic_axes=arg.dynamic_axes) # first step is to take absolute value of input arg signGrad = C.abs(gradIn) # then compare its magnitude to 1 signGrad = C.less_equal(signGrad, 1) # finish by multiplying this result with the input gradient return C.element_times(gradRoot, signGrad), gradIn, gradRoot
def test_outputs(): fwd_state = C.placeholder("placeholder") prev_state = C.sequence.past_value(fwd_state, name="prev_state") z = C.abs(prev_state, "abs") output = z.output z = z.replace_placeholders({fwd_state: z.output}) fwd_state = None prev_state = None z = None for arg in output.owner.arguments: print("Argument name: {}, argument owner name {}".format(arg.name, arg.owner.name))
def gradFunc(self, arg): # create an input variable corresponding the inputs of the forward prop function gradIn = C.input(shape=arg.shape, dynamic_axes=arg.dynamic_axes) # create an input variable for the gradient passed from the next stage gradRoot = C.input(shape=arg.shape, dynamic_axes=arg.dynamic_axes) signGrad = C.abs(gradIn) # new idea, bound of clipping should be a function of the bit map since higher bits can represent higher numbers bit_map = C.constant(self.bit_map) signGrad = C.less_equal(signGrad, bit_map) outGrad = signGrad outGrad = element_times(gradRoot, outGrad) return outGrad, gradIn, gradRoot
def gradFunc(self, arg): # create an input variable corresponding the inputs of the forward prop function gradIn = C.input(shape=arg.shape, dynamic_axes=arg.dynamic_axes) # create an input variable for the gradient passed from the next stage gradRoot = C.input(shape=arg.shape, dynamic_axes=arg.dynamic_axes) #gradOut = C.input(shape=arg.shape, dynamic_axes=arg.dynamic_axes) signGrad = C.abs(gradIn) # new idea, bound of clipping should be a function of the bit map since higher bits can represent higher numbers bit_map = C.constant(self.bit_map) signGrad = C.less_equal(signGrad, bit_map) outGrad = signGrad outGrad = element_times(gradRoot, outGrad) return outGrad, gradIn, gradRoot
def mae(self, z, l): ''' Small helpfunction implementing mae. Used as an error metric during optimization. (So far only used within all subclasses based on neural networks.) Parameters ---------- z: vector<float> prediction l: vector<float> label Returns ------- errors: mape ''' return C.reduce_mean(C.abs(z - l))
def SmoothL1Loss(sigma, bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_weights): """ From https://github.com/smallcorgi/Faster-RCNN_TF/blob/master/lib/fast_rcnn/train.py ResultLoss = outside_weights * SmoothL1(inside_weights * (bbox_pred - bbox_targets)) SmoothL1(x) = 0.5 * (sigma * x)^2, if |x| < 1 / sigma^2 |x| - 0.5 / sigma^2, otherwise """ sigma2 = sigma * sigma inside_mul_abs = C.abs(C.element_times(bbox_inside_weights, C.minus(bbox_pred, bbox_targets))) smooth_l1_sign = C.less(inside_mul_abs, 1.0 / sigma2) smooth_l1_option1 = C.element_times(C.element_times(inside_mul_abs, inside_mul_abs), 0.5 * sigma2) smooth_l1_option2 = C.minus(inside_mul_abs, 0.5 / sigma2) smooth_l1_result = C.plus(C.element_times(smooth_l1_option1, smooth_l1_sign), C.element_times(smooth_l1_option2, C.minus(1.0, smooth_l1_sign))) return C.element_times(bbox_outside_weights, smooth_l1_result)
def abs(x, name=''): ''' Computes the element-wise absolute of `x`: :math:`abs(x) = |x|` Example: >>> C.eval(C.abs([-1, 1, -2, 3])) [array([[ 1., 1., 2., 3.]])] Args: x: numpy array or any :class:`cntk.Function` that outputs a tensor name (str): the name of the node in the network Returns: :class:`cntk.Function` ''' from cntk import abs x = sanitize_input(x) return abs(x, name).output()
def multiFunc(self, arg1): multiIn = C.input(shape=arg1.shape, dynamic_axes = arg1.dynamic_axes) bit_map = C.constant(self.bit_map) max_bits = self.bit_map.max() shape = multiIn.shape reformed = C.reshape(multiIn, (-1,)) carry_over = multiIn approx = C.element_times(multiIn, 0) for i in range(max_bits): hot_vals = C.greater(bit_map, i) valid_vals = C.element_select(hot_vals, carry_over, 0) mean = C.element_divide(C.reduce_sum(C.abs(valid_vals)), C.reduce_sum(hot_vals)) bits = C.greater(carry_over, 0) bits = C.element_select(bits, bits, -1) bits = C.element_select(hot_vals, bits, 0) approx = C.plus(approx, C.element_times(mean, bits)) carry_over = C.plus(C.element_times(C.element_times(-1, bits), mean), carry_over) return approx, multiIn
def multiFunc(self, arg1): multiIn = C.input(shape=arg1.shape, dynamic_axes=arg1.dynamic_axes) bit_map = C.constant(self.bit_map) max_bits = self.bit_map.max() carry_over = multiIn approx = C.element_times(multiIn, 0) for i in range(max_bits): hot_vals = C.greater(bit_map, i) valid_vals = C.element_select(hot_vals, carry_over, 0) mean = C.element_divide(C.reduce_sum(C.abs(valid_vals)), C.reduce_sum(hot_vals)) bits = C.greater(carry_over, 0) bits = C.element_select(bits, bits, -1) bits = C.element_select(hot_vals, bits, 0) approx = C.plus(approx, C.element_times(mean, bits)) carry_over = C.plus( C.element_times(C.element_times(-1, bits), mean), carry_over) return approx, multiIn
def build(self): input_kernel = C.Parameter(shape=(self._input_size, self._hidden_dim), init=self._input_initializer) recur_kernel = C.Parameter(shape=(self._hidden_dim, ), init=self._recurrent_initializer) bias = C.Parameter(shape=(self._hidden_dim), init=0) if self._recurrent_min_abs > 0: abs_kernel = C.abs(recur_kernel) min_abs_kernel = C.element_max(abs_kernel, self._recurrent_min_abs) recur_kernel = min_abs_kernel * C.element_select( C.greater_equal(recur_kernel, C.constant(0)), C.constant(1), C.constant(-1)) if self._recurrent_max_abs: recur_kernel = C.clip(recur_kernel, -self._recurrent_max_abs, self._recurrent_max_abs) @C.Function def runit(h, x): h_t = C.times(x, input_kernel) + bias + recur_kernel * h return h_t return runit
def multiFunc(self, arg1): # load or create the inputs we need multiIn = C.input(shape=arg1.shape, dynamic_axes=arg1.dynamic_axes) bit_map = C.constant(self.bit_map) max_bits = self.bit_map.max() shape = multiIn.shape reformed = C.reshape(multiIn, (-1, )) # lets compute the means we need # carry over represents the remaining value that needs to binarized. For a single bit, this is just the input. For more bits, # it is the difference between the previous bits approximation and the true value. carry_over = multiIn approx = C.element_times(multiIn, 0) # iterate through the maximum number of bits specified by the bit maps, basically compute each level of binarization for i in range(max_bits): # determine which values of the input should be binarized to i bits or more hot_vals = C.greater(bit_map, i) # select only the values which we need to binarize valid_vals = C.element_select(hot_vals, carry_over, 0) # compute mean on a per kernel basis, reshaping is done to allow for sum reduction along only axis 0 (the kernels) mean = C.element_divide( C.reduce_sum(C.reshape(C.abs(valid_vals), (valid_vals.shape[0], -1)), axis=1), C.reduce_sum(C.reshape(hot_vals, (hot_vals.shape[0], -1)), axis=1)) # reshape the mean to match the dimensionality of the input mean = C.reshape(mean, (mean.shape[0], mean.shape[1], 1, 1)) # binarize the carry over bits = C.greater(carry_over, 0) bits = C.element_select(bits, bits, -1) bits = C.element_select(hot_vals, bits, 0) # add in the equivalent binary representation to the approximation approx = C.plus(approx, C.element_times(mean, bits)) # compute the new carry over carry_over = C.plus( C.element_times(C.element_times(-1, bits), mean), carry_over) return approx, multiIn
def test_abs(): assert_cntk_ngraph_array_equal(C.abs([-1, 1, -2, 3])) assert_cntk_ngraph_array_equal(C.abs([[1, -2], [3, -4]])) assert_cntk_ngraph_array_equal( C.abs([[[1, 2], [-3, 4]], [[1, -2], [3, 4]]]))
def ddist(prediction, c_interval_center, c_interval_radius): ''' Distance of the predictions from the edges of the intervals ''' return cntk.relu( cntk.abs(prediction - c_interval_center) - c_interval_radius)
h = lambda x: C.tanh(x) h_prime = lambda x: 1 - C.square(C.tanh(x)) base_dist = MultivariateNormalDiag(loc=[0., 0.], scale_diag=[1., 1.]) z_0 = C.input_variable(base_dist.size(), name='sampled') z_prev = z_0 sum_log_det_jacob = 0. initializer = C.initializer.uniform(1) for i in range(K): u = C.parameter((2), name='u', init=initializer) w = C.parameter((2), name='w', init=initializer) b = C.parameter((1), name='b', init=initializer) psi = h_prime(C.dot(w, z_prev)+b) * w det_jacob = C.abs(1 + C.dot(u, psi)) sum_log_det_jacob += C.log(EPS + det_jacob) z_prev = z_prev + u * h(C.dot(w, z_prev)+b) z_k = z_prev log_q_k = C.log(base_dist.pdf(z_0)) - sum_log_det_jacob log_p = C.log(EPS + true_density(z_k)) kl = C.reduce_mean(log_q_k - log_p) #%% lr = 1 lr_schedule = C.learning_parameter_schedule(lr) learner = C.adam(kl.parameters, lr_schedule, 0.9) trainer = C.Trainer(kl, (kl, None), learner)
def flow_forward(input_dim: int, act_func_pair: tuple = (None, None), batch_norm: bool = False): chunk = {} log_det_J = 0 chunk['input_dim'] = input_dim _ph = C.placeholder(input_dim, name='place_holder') _out = _ph if batch_norm: # _bn = C.layers.BatchNormalization(name='batch_norm')(_ph) # chunk['scale'] = _bn.parameters[0] # chunk['bias'] = _bn.parameters[1] chunk['mu'] = C.Constant(np.zeros(shape=input_dim)) chunk['var'] = C.Constant(np.ones(shape=input_dim)) _eps = C.Constant(1e-7) _mu = C.reduce_mean(_ph, axis=C.Axis.default_batch_axis()) _var = C.reduce_mean(C.square(_ph-_mu), axis=C.Axis.default_batch_axis()) chunk['muB'] = _mu chunk['varB'] = _var # _bn = (_ph-chunk['mu'])/C.sqrt(chunk['var']+_eps) _bn = C.sqrt(chunk['var']+_eps)*_ph + chunk['mu'] _ph = _bn log_det_J += -0.5*C.reduce_sum(C.log((_var+_eps))) # log_det_J += C.reduce_sum(C.log()) chunk['W_rot_mat'] = _W = C.parameter((input_dim, input_dim)) _W.value = random_rotation_matrix = special_ortho_group.rvs(input_dim) # _W.value = np.roll(np.eye(input_dim),input_dim//2,axis=0) _out = _ph@_W log_det_J += C.log(C.abs(C.det(_W))) # or # log_det_J += C.slogdet(_W)[1] _half_dim = input_dim//2 _x1 = _out[:_half_dim] _x2 = _out[_half_dim:] _log_s_func, _t_func = act_func_pair if _log_s_func is None: # basic network _log_s_func = C.layers.Sequential([ C.layers.Dense(256, C.leaky_relu), C.layers.Dense(256, C.leaky_relu), C.layers.Dense(_half_dim, C.tanh), ])#(C.placeholder(input_dim, name='place_holder')) if _t_func is None: # basic network _t_func = C.layers.Sequential([ C.layers.Dense(256, C.leaky_relu), C.layers.Dense(256, C.leaky_relu), C.layers.Dense(_half_dim), ])#(C.placeholder(input_dim, name='place_holder')) chunk['log_s_func'] = _log_s_func chunk['t_func'] = _t_func _log_s, _t = _log_s_func(_x2), _t_func(_x2) _s = C.exp(_log_s) _y1 = _s*_x1 + _t _y2 = _x2 _Y = C.splice(_y1, _y2) chunk['output'] = _Y log_det_J += C.reduce_sum(_log_s) return _Y, log_det_J, chunk
def inner(a): p = position(a) integers = p / s # every s sequence item will be an integer valid = C.less_equal(C.abs(C.sin(integers * pi)), tol) # sin of integer multiple of pi will return close to zero result = C.sequence.gather(a, valid) return result
def _build_model(self): hidden_size = self.hidden_size output_size = self.output_size num_layers = self.num_layers keep_prob = self.keep_prob inputs = cntk.sequence.input_variable((output_size), name='inputs') target = cntk.input_variable((output_size), name='target') def lstm_cell(): _cell_creator = cntk.layers.Recurrence(cntk.layers.LSTM( hidden_size, use_peepholes=self.params.use_peephole), name='basic_lstm') if self.params.use_dropout: print(" ** using dropout for LSTM ** ") _cell_creator = cntk.layers.Dropout( keep_prob=keep_prob)(_cell_creator) return _cell_creator def gru_cell(): _cell_creator = cntk.layers.Recurrence( cntk.layers.GRU(hidden_size), name='gru') if self.params.use_dropout: print(" ** using dropout for LSTM ** ") _cell_creator = cntk.layers.Dropout( keep_prob=keep_prob)(_cell_creator) return _cell_creator def cifg_cell(): _cell_creator = cntk.layers.Recurrence(CIFG_LSTM( hidden_size, use_peepholes=self.params.use_peephole), name='cifg_lstm') if self.params.use_dropout: print(" ** using dropout for LSTM ** ") _cell_creator = cntk.layers.Dropout( keep_prob=keep_prob)(_cell_creator) return _cell_creator if self.config.cell == 'gru': _cell_creator = gru_cell elif self.config.cell == 'lstm': _cell_creator = lstm_cell elif self.config.cell == 'cifg_lstm': _cell_creator = cifg_cell else: raise ValueError( "Unsupported cell type, choose from {'lstm', 'gru', 'cifg_lstm'}." ) if self.params.use_residual: print(" ** using residual ** ") _output = inputs for _ in range(num_layers): _output = self.params.resWeight * _cell_creator()( _output) + _output # _output = _cell_creator()(_output) + _output else: cell = cntk.layers.For(range(num_layers), lambda: _cell_creator()) _output = cell(inputs) _output = cntk.sequence.last(_output) output = cntk.layers.Dense(output_size)(_output) self.output = output self.loss = cntk.squared_error(output, target) cost_mape = cntk.reduce_mean(cntk.abs(output - target) / target, axis=cntk.Axis.all_axes(), name='mape') cost_mae = cntk.reduce_mean(cntk.abs(output - target), axis=cntk.Axis.all_axes(), name='mae') cost_rmse = cntk.reduce_l2((output - target), axis=cntk.Axis.all_axes(), name='rmse') self.cost = cntk.combine([cost_mape, cost_mae, cost_rmse]) self.criterion = cntk.combine([loss, cost_mape])
# generator, and discriminator # x = C.input_variable(shape=(img_channel, img_height, img_width), dtype="float32", needs_gradient=True) y = C.input_variable(shape=(img_channel, img_height, img_width), dtype="float32", needs_gradient=True) x_real = (x - 127.5) / 127.5 y_real = (y - 127.5) / 127.5 G_fake = pix2pix_generator(x) D_real = pix2pix_discriminator(y_real, x_real) D_fake = D_real.clone(method="share", substitutions={y_real.output: G_fake.output, x_real.output: x_real.output}) # # loss function # G_loss = C.reduce_mean(C.square(D_fake - 1.0)) / 2 + lambda_1 * C.reduce_mean(C.abs(y_real - G_fake)) D_loss = C.reduce_mean(C.square(D_real - 1.0)) / 2 + C.reduce_mean(C.square(D_fake)) / 2 # # optimizer and cyclical learning rate # G_learner = C.adam(G_fake.parameters, lr=1e-4, momentum=0.5, gradient_clipping_threshold_per_sample=minibatch_size, gradient_clipping_with_truncation=True) D_learner = C.adam(D_real.parameters, lr=1e-4, momentum=0.5, gradient_clipping_threshold_per_sample=minibatch_size, gradient_clipping_with_truncation=True) G_progress_printer = C.logging.ProgressPrinter(tag="Generator") D_progress_printer = C.logging.ProgressPrinter(tag="Discriminator") if not os.path.exists("./pix2pix_image"): os.mkdir("./pix2pix_image")
def main(): show_image = False if show_image: bs = 1 ci = 3 co = 3 cg = co * (ci + 1) gd = 8 gh = 64 gw = 64 h = 256 w = 256 else: bs = 1 ci = 3 co = 3 cg = co * (ci + 1) gd = 8 gh = 64 gw = 64 h = 1024 w = 1024 im = C.input_variable([bs, ci, h, w], needs_gradient=True, dynamic_axes=[]) guide = C.input_variable([bs, h, w], needs_gradient=True, dynamic_axes=[]) guide_no_grad = C.input_variable([bs, h, w], needs_gradient=False, dynamic_axes=[]) grid = C.input_variable([bs, cg, gd, gh, gw], needs_gradient=True, dynamic_axes=[]) # Create indices xx = np.arange(0, w).reshape(1, -1).repeat(h, 0).astype(np.float32) yy = np.arange(0, h).reshape(-1, 1).repeat(w, 1).astype(np.float32) xx = C.Constant(xx, xx.shape) yy = C.Constant(yy, yy.shape) gx = ((xx + 0.5) / w) * gw gy = ((yy + 0.5) / h) * gh gz = C.clip(guide, 0.0, 1.0) * gd gz_no_grad = C.clip(guide_no_grad, 0.0, 1.0) * gd fx = C.element_max(C.floor(gx - 0.5), 0.0) fy = C.element_max(C.floor(gy - 0.5), 0.0) fz = C.element_max(C.floor(gz - 0.5), 0.0) fz_no_grad = C.element_max(C.floor(gz_no_grad - 0.5), 0.0) wx = gx - 0.5 - fx wy = gy - 0.5 - fy wx = C.expand_dims(C.expand_dims(wx, -1 - len(wx.shape)), -1 - len(wx.shape)) wy = C.expand_dims(C.expand_dims(wy, -1 - len(wy.shape)), -1 - len(wy.shape)) wz = C.abs(gz - 0.5 - fz) wz = C.expand_dims(wz, 0) fx = C.expand_dims(C.expand_dims(fx, -1 - len(fx.shape)), -1 - len(fx.shape)) fy = C.expand_dims(C.expand_dims(fy, -1 - len(fy.shape)), -1 - len(fy.shape)) cx = C.element_min(fx + 1, gw - 1) cy = C.element_min(fy + 1, gh - 1) cz = C.element_min(fz_no_grad + 1, gd - 1) batch_idx = np.arange(bs).reshape(bs, 1, 1, 1).astype(np.float32) batch_idx = C.Constant(batch_idx, batch_idx.shape) out = [] flat_grid = C.reshape(grid, [-1]) for c_ in range(co): c_idx = np.arange((ci + 1) * c_, (ci + 1) * (c_ + 1)).reshape(1, ci + 1, 1, 1).astype(np.float32) c_idx = C.Constant(c_idx, c_idx.shape) def flatten_and_gather(x, y, z): linear_idx = x + gw * y + gw * gh * z + c_idx * gw * gh * gd + batch_idx * gw * gh * gd * cg flat_linear_idx = C.reshape(linear_idx, [-1]) return C.reshape(C.gather(flat_grid, flat_linear_idx), linear_idx.shape) gather_fff = flatten_and_gather(fx, fy, fz_no_grad) gather_ffc = flatten_and_gather(fx, fy, cz) gather_fcf = flatten_and_gather(fx, cy, fz_no_grad) gather_fcc = flatten_and_gather(fx, cy, cz) gather_cff = flatten_and_gather(cx, fy, fz_no_grad) gather_cfc = flatten_and_gather(cx, fy, cz) gather_ccf = flatten_and_gather(cx, cy, fz_no_grad) gather_ccc = flatten_and_gather(cx, cy, cz) a = gather_fff*(1-wx)*(1-wy)*(1-wz) + \ gather_ffc*(1-wx)*(1-wy)*( wz) + \ gather_fcf*(1-wx)*( wy)*(1-wz) + \ gather_fcc*(1-wx)*( wy)*( wz) + \ gather_cff*( wx)*(1-wy)*(1-wz) + \ gather_cfc*( wx)*(1-wy)*( wz) + \ gather_ccf*( wx)*( wy)*(1-wz) + \ gather_ccc*( wx)*( wy)*( wz) o = C.reduce_sum(a[:, :-1, ...] * im, 1) + a[:, -1, ...] print(o.shape) out.append(C.expand_dims(o, 0)) out = C.splice(*out, axis=1) loss = C.reduce_l2(out) grid_val = np.random.rand(bs, cg, gd, gh, gw).astype(np.float32) if show_image: guide_val = skio.imread("/data/rgb.png").mean(2)[:h, :w].astype( np.float32) guide_val = np.expand_dims(guide_val / 255.0, 0) im_val = np.tile(np.expand_dims(guide_val, 1), [1, 3, 1, 1]) out_val = out.eval({ im: im_val, guide: guide_val, guide_no_grad: guide_val, grid: grid_val }) out_val = np.clip(np.transpose(np.squeeze(out_val), [1, 2, 0]), 0, 1) skio.imsave("/output/imout.png", out_val) else: im_val = np.random.randn(bs, ci, h, w) guide_val = np.random.rand(bs, h, w).astype(np.float32) # burning iteration for it in range(5): print('burning (', it, ')') g = loss.grad({ im: im_val, guide: guide_val, guide_no_grad: guide_val, grid: grid_val }) # actual iterations start = time.time() for it in range(50): print('profiling (', it, ')') g = loss.grad({ im: im_val, guide: guide_val, guide_no_grad: guide_val, grid: grid_val }) end = time.time() runtime = (end - start) * 1000.0 / 50.0 print('Runtime:', runtime)
def main(): print("version", C.__version__) bs = 1 n_chans = 1 sigma_s = 16 sigma_r = 12 # 4x4x1024x1024 # 4x12x64x64 sz = 256 # sz = 1024 small_sz = sz // sigma_s yy, xx = np.meshgrid(np.arange(0, sz), np.arange(0, sz)) cc, bb = np.meshgrid(np.arange(0, n_chans), np.arange(0, bs)) xx = np.expand_dims(xx, 0) xx = np.expand_dims(xx, 0) yy = np.expand_dims(yy, 0) yy = np.expand_dims(yy, 0) bb = np.expand_dims(bb, 2) bb = np.expand_dims(bb, 3) cc = np.expand_dims(cc, 2) cc = np.expand_dims(cc, 3) # Compute graph grid = C.Parameter([bs, n_chans, sigma_r, small_sz, small_sz], ) # grid = C.input_variable( # [bs, n_chans, sigma_r, small_sz, small_sz], # dynamic_axes=[], needs_gradient=True) guide = C.input_variable([bs, sz, sz], dynamic_axes=[], needs_gradient=True) guide_non_diff = C.input_variable([bs, sz, sz], dynamic_axes=[]) # Coordinates xx = C.Constant(xx, xx.shape) yy = C.Constant(yy, yy.shape) cc = C.Constant(cc, cc.shape) bb = C.Constant(bb, bb.shape) gx_d, gy_d, gz_d, fx_d, fy_d, fz_d, _, _, _ = grid_coord( guide, xx, yy, sz, small_sz, sigma_r, bs) # Trilerp weights wx = (gx_d - 0.5 - fx_d) wy = (gy_d - 0.5 - fy_d) wz = C.abs(gz_d - 0.5 - fz_d) # Enclosing cell gx, gy, gz, fx, fy, fz, cx, cy, cz = grid_coord(guide_non_diff, xx, yy, sz, small_sz, sigma_r, bs) output_components = [] for ix, x in enumerate([fx, cx]): wx_ = (1 - wx) if ix == 0 else wx for iy, y in enumerate([fy, cy]): wy_ = (1 - wy) if iy == 0 else wy for iz, z in enumerate([fz, cz]): wz_ = (1 - wz) if iz == 0 else wz linear_idx = x + small_sz * (y + small_sz * (z + sigma_r * (cc + n_chans * bb))) # Flatten data for gather op flat_grid = C.reshape( grid, [bs * small_sz * small_sz * sigma_r * n_chans]) flat_linear_idx = C.reshape(linear_idx, [bs * n_chans * sz * sz]) # Slice interp = C.gather(flat_grid, flat_linear_idx) interp_fsz = C.reshape(interp, [bs, n_chans, sz, sz]) output_components.append(interp_fsz * wz_ * wx_ * wy_) out = sum(output_components) loss = C.squared_error(out, guide) # svg = C.logging.graph.plot(out, "/output/graph.svg") grid_data = np.random.uniform(size=(bs, n_chans, sigma_r, small_sz, small_sz)).astype(np.float32) # guide_data = np.random.uniform( # size=(bs, sz, sz)).astype(np.float32) guide_data = skio.imread("/data/rgb.png").mean(2)[:sz, :sz].astype( np.float32) guide_data = np.expand_dims(guide_data, 0) / 255.0 inputs = {guide: guide_data, guide_non_diff: guide_data}
x_hat = F_fake.clone(method="share", substitutions={y_real.output: G_fake.output}) # F(G(X)) -> X' y_hat = G_fake.clone(method="share", substitutions={x_real.output: F_fake.output}) # G(F(Y)) -> Y' # # discriminator # Dx_real = cyclegan_discriminator(x_real) Dx_fake = Dx_real.clone(method="share", substitutions={x_real.output: F_fake.output}) Dy_real = cyclegan_discriminator(y_real) Dy_fake = Dy_real.clone(method="share", substitutions={y_real.output: G_fake.output}) # # loss function # cycle_consistency_loss = lambda_x * C.reduce_mean(C.abs(x_hat - x_real)) + \ lambda_y * C.reduce_mean(C.abs(y_hat - y_real)) F_loss = C.reduce_mean(C.square(Dx_fake - 1.0)) / 2 + cycle_consistency_loss G_loss = C.reduce_mean(C.square(Dy_fake - 1.0)) / 2 + cycle_consistency_loss Dx_loss = C.reduce_mean(C.square(Dx_real - 1.0)) / 2 + C.reduce_mean(C.square(Dx_fake)) / 2 Dy_loss = C.reduce_mean(C.square(Dy_real - 1.0)) / 2 + C.reduce_mean(C.square(Dy_fake)) / 2 # # optimizer # F_learner = C.adam(F_fake.parameters, lr=2e-4, momentum=0.5, gradient_clipping_threshold_per_sample=minibatch_size, gradient_clipping_with_truncation=True) G_learner = C.adam(G_fake.parameters, lr=2e-4, momentum=0.5, gradient_clipping_threshold_per_sample=minibatch_size, gradient_clipping_with_truncation=True) Dx_learner = C.adam(Dx_real.parameters, lr=1e-4, momentum=0.5,