def rgb_to_lab(image_srgb, ctx=None):

    if ctx is None:
        raise ValueError("ctx can not be None")

    if image_srgb is None:
        raise ValueError("image_srgb can not be None")

    with mx.Context(ctx):

        srgb = __check_image(image_srgb)

        if nd.max(srgb).asscalar() > 1:
            srgb = __normalize_rgb_image(srgb)

        srgb_pixels = nd.reshape(srgb, [-1, 3])

        linear_mask = nd.cast(srgb_pixels <= 0.04045, dtype='float32')
        exponential_mask = nd.cast(srgb_pixels > 0.04045, dtype='float32')
        rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
        rgb_to_xyz = nd.array([
            #    X        Y          Z
            [0.412453, 0.212671, 0.019334],  # R
            [0.357580, 0.715160, 0.119193],  # G
            [0.180423, 0.072169, 0.950227],  # B
        ])
        xyz_pixels = nd.linalg_gemm2(rgb_pixels, rgb_to_xyz)

        # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
        # convert to fx = f(X/Xn), fy = f(Y/Yn), fz = f(Z/Zn)
        # normalize for D65 white point
        xyz_normalized_pixels = nd.multiply(xyz_pixels, nd.array([1 / 0.950456, 1.0, 1 / 1.088754]))

        epsilon = 6 / 29
        linear_mask = nd.cast(xyz_normalized_pixels <= (epsilon ** 3), dtype='float32')
        exponential_mask = nd.cast(xyz_normalized_pixels > (epsilon ** 3), dtype='float32')
        fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon ** 2) + 4 / 29) * linear_mask + (
                                                                                                  xyz_normalized_pixels ** (
                                                                                                  1 / 3)) * exponential_mask
            # convert to lab
        fxfyfz_to_lab = nd.array([
                #  l       a       b
                [0.0, 500.0, 0.0],  # fx
                [116.0, -500.0, 200.0],  # fy
                [0.0, 0.0, -200.0],  # fz
            ])
        lab_pixels = nd.linalg_gemm2(fxfyfz_pixels, fxfyfz_to_lab) + nd.array([-16.0, 0.0, 0.0])

        return nd.reshape(lab_pixels, srgb.shape)
Ejemplo n.º 2
0
            def fn(heads, relations, tails, num_chunks, chunk_size,
                   neg_sample_size):
                hidden_dim = heads.shape[1]
                emb_r, emb_i, emb_j, emb_k = nd.split(heads,
                                                      num_outputs=4,
                                                      axis=-1)
                rel_r, rel_i, rel_j, rel_k = nd.split(relations,
                                                      num_outputs=4,
                                                      axis=-1)

                rel_norm = nd.stack(rel_r, rel_i, rel_j, rel_k,
                                    axis=0).norm(ord=2, axis=0)

                x_r = (emb_r * rel_r - emb_i * rel_i - emb_j * rel_j -
                       emb_k * rel_k) / (rel_norm + 1e-15)
                x_i = (emb_r * rel_i + emb_i * rel_r + emb_j * rel_k -
                       emb_k * rel_j) / (rel_norm + 1e-15)
                x_j = (emb_r * rel_j - emb_i * rel_k + emb_j * rel_r +
                       emb_k * rel_i) / (rel_norm + 1e-15)
                x_k = (emb_r * rel_k + emb_i * rel_j - emb_j * rel_i +
                       emb_k * rel_r) / (rel_norm + 1e-15)

                emb_quaternion = nd.concat(x_r, x_i, x_j, x_k, dim=-1)
                tmp = emb_quaternion.reshape(num_chunks, chunk_size,
                                             hidden_dim)
                tails = tails.reshape(num_chunks, neg_sample_size, hidden_dim)
                tails = nd.transpose(tails, axes=(0, 2, 1))
                return nd.linalg_gemm2(tmp, heads)
Ejemplo n.º 3
0
 def fn(heads, relations, tails, num_chunks, chunk_size,
        neg_sample_size):
     hidden_dim = heads.shape[1]
     tails = tails.reshape(num_chunks, neg_sample_size, hidden_dim)
     tails = nd.transpose(tails, axes=(0, 2, 1))
     tmp = (heads * relations).reshape(num_chunks, chunk_size,
                                       hidden_dim)
     return nd.linalg_gemm2(tmp, tails)
Ejemplo n.º 4
0
 def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
     hidden_dim = heads.shape[1]
     tails = tails.reshape(num_chunks, neg_sample_size, hidden_dim)
     tails = mx.nd.transpose(tails, axes=(0,2,1))
     heads = heads.expand_dims(2)
     relations = relations.reshape(-1, self.relation_dim, self.entity_dim)
     tmp = mx.nd.batch_dot(relations, heads).squeeze()
     tmp = tmp.reshape(num_chunks, chunk_size, hidden_dim)
     return nd.linalg_gemm2(tmp, tails)
Ejemplo n.º 5
0
 def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
     hidden_dim = heads.shape[1]
     emb_real, emb_img = nd.split(tails, num_outputs=2, axis=-1)
     rel_real, rel_img = nd.split(relations, num_outputs=2, axis=-1)
     real = emb_real * rel_real + emb_img * rel_img
     img = -emb_real * rel_img + emb_img * rel_real
     emb_complex = nd.concat(real, img, dim=-1)
     tmp = emb_complex.reshape(num_chunks, chunk_size, hidden_dim)
     heads = heads.reshape(num_chunks, neg_sample_size, hidden_dim)
     heads = nd.transpose(heads, axes=(0, 2, 1))
     return nd.linalg_gemm2(tmp, heads)
Ejemplo n.º 6
0
    with autograd.record():
        target_features = vgg(target)
        context_features = vgg(context)
        style_features = vgg(style)
        style_loss = nd.zeros((1, ), ctx=ctx)
        context_loss = nd.zeros((1, ), ctx=ctx)
        for f1, f2, f3 in zip(target_features, context_features,
                              style_features):
            # Compute content loss (target and content image)
            context_loss = context_loss + nd.mean((f1 - f2)**2)
            # Reshape conv features
            _, c, h, w = f1.shape
            f1 = f1.reshape((c, h * w))
            f3 = f3.reshape((c, h * w))
            # Compute gram matrix
            f1 = nd.linalg_gemm2(f1, f1, transpose_b=1)
            f3 = nd.linalg_gemm2(f3, f3, transpose_b=1)
            # Compute style loss (target and style image)
            style_loss = style_loss + nd.mean((f1 - f3)**2) / (c * h * w)
        # Compute total loss
        loss = context_loss + style_weight * style_loss
    # backprop and optimize
    loss.backward()
    optimizer.update(step, target, target.grad,
                     optimizer.create_state(step, target))

    if (step + 1) % log_step == 0:
        print('Step [%d/%d], Content Loss: %.4f, Style Loss: %.4f' %
              (step + 1, total_step, context_loss.sum().asscalar(),
               style_loss.sum().asscalar()))