def rgb_to_lab(image_srgb, ctx=None): if ctx is None: raise ValueError("ctx can not be None") if image_srgb is None: raise ValueError("image_srgb can not be None") with mx.Context(ctx): srgb = __check_image(image_srgb) if nd.max(srgb).asscalar() > 1: srgb = __normalize_rgb_image(srgb) srgb_pixels = nd.reshape(srgb, [-1, 3]) linear_mask = nd.cast(srgb_pixels <= 0.04045, dtype='float32') exponential_mask = nd.cast(srgb_pixels > 0.04045, dtype='float32') rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask rgb_to_xyz = nd.array([ # X Y Z [0.412453, 0.212671, 0.019334], # R [0.357580, 0.715160, 0.119193], # G [0.180423, 0.072169, 0.950227], # B ]) xyz_pixels = nd.linalg_gemm2(rgb_pixels, rgb_to_xyz) # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions # convert to fx = f(X/Xn), fy = f(Y/Yn), fz = f(Z/Zn) # normalize for D65 white point xyz_normalized_pixels = nd.multiply(xyz_pixels, nd.array([1 / 0.950456, 1.0, 1 / 1.088754])) epsilon = 6 / 29 linear_mask = nd.cast(xyz_normalized_pixels <= (epsilon ** 3), dtype='float32') exponential_mask = nd.cast(xyz_normalized_pixels > (epsilon ** 3), dtype='float32') fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon ** 2) + 4 / 29) * linear_mask + ( xyz_normalized_pixels ** ( 1 / 3)) * exponential_mask # convert to lab fxfyfz_to_lab = nd.array([ # l a b [0.0, 500.0, 0.0], # fx [116.0, -500.0, 200.0], # fy [0.0, 0.0, -200.0], # fz ]) lab_pixels = nd.linalg_gemm2(fxfyfz_pixels, fxfyfz_to_lab) + nd.array([-16.0, 0.0, 0.0]) return nd.reshape(lab_pixels, srgb.shape)
def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size): hidden_dim = heads.shape[1] emb_r, emb_i, emb_j, emb_k = nd.split(heads, num_outputs=4, axis=-1) rel_r, rel_i, rel_j, rel_k = nd.split(relations, num_outputs=4, axis=-1) rel_norm = nd.stack(rel_r, rel_i, rel_j, rel_k, axis=0).norm(ord=2, axis=0) x_r = (emb_r * rel_r - emb_i * rel_i - emb_j * rel_j - emb_k * rel_k) / (rel_norm + 1e-15) x_i = (emb_r * rel_i + emb_i * rel_r + emb_j * rel_k - emb_k * rel_j) / (rel_norm + 1e-15) x_j = (emb_r * rel_j - emb_i * rel_k + emb_j * rel_r + emb_k * rel_i) / (rel_norm + 1e-15) x_k = (emb_r * rel_k + emb_i * rel_j - emb_j * rel_i + emb_k * rel_r) / (rel_norm + 1e-15) emb_quaternion = nd.concat(x_r, x_i, x_j, x_k, dim=-1) tmp = emb_quaternion.reshape(num_chunks, chunk_size, hidden_dim) tails = tails.reshape(num_chunks, neg_sample_size, hidden_dim) tails = nd.transpose(tails, axes=(0, 2, 1)) return nd.linalg_gemm2(tmp, heads)
def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size): hidden_dim = heads.shape[1] tails = tails.reshape(num_chunks, neg_sample_size, hidden_dim) tails = nd.transpose(tails, axes=(0, 2, 1)) tmp = (heads * relations).reshape(num_chunks, chunk_size, hidden_dim) return nd.linalg_gemm2(tmp, tails)
def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size): hidden_dim = heads.shape[1] tails = tails.reshape(num_chunks, neg_sample_size, hidden_dim) tails = mx.nd.transpose(tails, axes=(0,2,1)) heads = heads.expand_dims(2) relations = relations.reshape(-1, self.relation_dim, self.entity_dim) tmp = mx.nd.batch_dot(relations, heads).squeeze() tmp = tmp.reshape(num_chunks, chunk_size, hidden_dim) return nd.linalg_gemm2(tmp, tails)
def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size): hidden_dim = heads.shape[1] emb_real, emb_img = nd.split(tails, num_outputs=2, axis=-1) rel_real, rel_img = nd.split(relations, num_outputs=2, axis=-1) real = emb_real * rel_real + emb_img * rel_img img = -emb_real * rel_img + emb_img * rel_real emb_complex = nd.concat(real, img, dim=-1) tmp = emb_complex.reshape(num_chunks, chunk_size, hidden_dim) heads = heads.reshape(num_chunks, neg_sample_size, hidden_dim) heads = nd.transpose(heads, axes=(0, 2, 1)) return nd.linalg_gemm2(tmp, heads)
with autograd.record(): target_features = vgg(target) context_features = vgg(context) style_features = vgg(style) style_loss = nd.zeros((1, ), ctx=ctx) context_loss = nd.zeros((1, ), ctx=ctx) for f1, f2, f3 in zip(target_features, context_features, style_features): # Compute content loss (target and content image) context_loss = context_loss + nd.mean((f1 - f2)**2) # Reshape conv features _, c, h, w = f1.shape f1 = f1.reshape((c, h * w)) f3 = f3.reshape((c, h * w)) # Compute gram matrix f1 = nd.linalg_gemm2(f1, f1, transpose_b=1) f3 = nd.linalg_gemm2(f3, f3, transpose_b=1) # Compute style loss (target and style image) style_loss = style_loss + nd.mean((f1 - f3)**2) / (c * h * w) # Compute total loss loss = context_loss + style_weight * style_loss # backprop and optimize loss.backward() optimizer.update(step, target, target.grad, optimizer.create_state(step, target)) if (step + 1) % log_step == 0: print('Step [%d/%d], Content Loss: %.4f, Style Loss: %.4f' % (step + 1, total_step, context_loss.sum().asscalar(), style_loss.sum().asscalar()))