示例#1
0
    def check(self, h, w, cs, rs, pa, rtp, dim):
        a = jt.random([h, w])
        a.data

        with jt.log_capture_scope(
                log_v=0,
                log_vprefix="tuner_manager=100",
                # this value is used for force compile
                compile_options={"test_reduce_tuner": 1}) as logs:
            amean = jt.mean(a, dims=[dim], keepdims=1)
            a2mean = jt.mean(a * a, dims=[dim], keepdims=1)
            norm_aa = (a - amean.broadcast_var(a)) / (
                jt.sqrt(a2mean - amean * amean).broadcast_var(a))
            norm_aa.data
        logs = find_log_with_re(
            logs,
            "Run tuner reduce: confidence\\((20)\\) candidates\\((.*)\\)$")
        assert len(logs) == 1, logs
        assert logs[0][0] == "20", "confidence of reorder should be 20"
        candidates = simple_parser(logs[0][1])
        assert candidates == {
            "order0": [
                0,
            ],
            "order1": [
                1,
            ],
            "order2": [
                0,
            ],
            "split1": [
                2048,
            ],
        }
示例#2
0
    def execute(self, x):
        dims = [0] + list(range(2, x.ndim))
        ####### centering calibration begin ####### $
        x += self.center_weight * self.stas(x)
        ####### centering calibration end ####### $
        if self.is_train:
            xmean = jt.mean(x, dims=dims)
            x2mean = jt.mean(x * x, dims=dims)
            if self.sync and jt.in_mpi:
                xmean = xmean.mpi_all_reduce("mean")
                x2mean = x2mean.mpi_all_reduce("mean")

            xvar = (x2mean - xmean * xmean).maximum(0.0)
            w = 1.0 / jt.sqrt(xvar + self.eps)
            b = -xmean * w
            norm_x = x * w.broadcast(x, dims) + b.broadcast(x, dims)

            self.running_mean.update(self.running_mean + (xmean.reshape(
                (-1, )) - self.running_mean) * self.momentum)
            self.running_var.update(self.running_var +
                                    (xvar.reshape((-1, )) - self.running_var) *
                                    self.momentum)

        else:
            w = 1.0 / jt.sqrt(self.running_var + self.eps)
            b = -self.running_mean * w
            norm_x = x * w.broadcast(x, dims) + b.broadcast(x, dims)

        ####### scaling calibration begin ####### $
        scale_factor = jt.sigmoid(self.scale_weight * self.stas(norm_x) +
                                  self.scale_bias)
        ####### scaling calibration end ####### $
        return self.weight * scale_factor * norm_x + self.bias
示例#3
0
    def execute(self, x):
        if len(x.shape) == 3:
            dims = [0, 2]
        else:
            dims = [0]
        if self.is_train:
            xmean = jt.mean(x, dims=dims, keepdims=1)
            x2mean = jt.mean(x * x, dims=dims, keepdims=1)

            if self.sync and jt.in_mpi:
                xmean = xmean.mpi_all_reduce("mean")
                x2mean = x2mean.mpi_all_reduce("mean")

            xvar = x2mean - xmean * xmean
            norm_x = (x - xmean) / jt.sqrt(xvar + self.eps)
            self.running_mean.update(self.running_mean +
                                     (xmean.sum(dims) - self.running_mean) *
                                     self.momentum)
            self.running_var.update(self.running_var +
                                    (xvar.sum(dims) - self.running_var) *
                                    self.momentum)
        else:
            running_mean = self.running_mean.broadcast(x, dims)
            running_var = self.running_var.broadcast(x, dims)
            norm_x = (x - running_mean) / jt.sqrt(running_var + self.eps)
        if not self.affine:
            return norm_x
        w = self.weight.broadcast(x, dims)
        b = self.bias.broadcast(x, dims)
        return norm_x * w + b
示例#4
0
    def dis_loss(self, real_samps, fake_samps, height, alpha):
        r_preds = self.dis(real_samps, height, alpha)
        f_preds = self.dis(fake_samps, height, alpha)

        # loss = (torch.mean(nn.ReLU()(1 - r_preds)) +
        #         torch.mean(nn.ReLU()(1 + f_preds)))
        loss = (jt.mean(nn.ReLU()(1 - r_preds)) +
                jt.mean(nn.ReLU()(1 + f_preds)))
        return loss
示例#5
0
def batch_norm(x):
    xmean = jt.mean(x, dims=[0, 2, 3], keepdims=1)
    x2mean = jt.mean(x * x, dims=[0, 2, 3], keepdims=1)
    norm_x = (x - xmean.broadcast_var(x)) / (
        jt.sqrt(x2mean - xmean * xmean + jt.float32(1e-5)).broadcast_var(x))
    w = jt.make_var([x.shape[1]], init=get_init_var)
    b = jt.make_var([x.shape[1]], init=get_init_var)
    w = w.broadcast([1, w.shape[0], 1, 1], [0, 2, 3])
    b = b.broadcast([1, b.shape[0], 1, 1], [0, 2, 3])
    return norm_x * w + b
示例#6
0
 def execute(self, x):
     dims = [-i for i in range(len(self.normalized_shape), 0, -1)]
     mean = jt.mean(x, dims=dims, keepdims=1)
     numerator = x - mean
     variance = jt.mean(numerator.sqr(), dims=dims, keepdims=1)
     denominator = jt.sqrt(variance + self.eps)
     norm_x = numerator / denominator
     if self.elementwise_affine:
         norm_x = norm_x * self.weight + self.bias
     return norm_x
示例#7
0
文件: nn.py 项目: shcig/jittor
    def execute(self, x):
        xmean = jt.mean(x, dims=[2, 3], keepdims=1)
        x2mean = jt.mean(x * x, dims=[2, 3], keepdims=1)
        if self.sync and jt.in_mpi:
            xmean = xmean.mpi_all_reduce("mean")
            x2mean = x2mean.mpi_all_reduce("mean")

        xvar = jt.maximum(x2mean - xmean * xmean, 0)
        norm_x = (x - xmean) / jt.sqrt(xvar + self.eps)
        w = self.weight.broadcast(x, [0, 2, 3])
        b = self.bias.broadcast(x, [0, 2, 3])
        return norm_x * w + b
示例#8
0
文件: nn.py 项目: shcig/jittor
 def execute(self, x):
     N, C, H, W = x.shape
     assert C == self.num_channels
     assert C % self.num_groups == 0
     x = x.reshape((N, self.num_groups, int(C / self.num_groups), H * W))
     xmean = jt.mean(x, dims=[2, 3], keepdims=1)
     x2mean = jt.mean(x * x, dims=[2, 3], keepdims=1)
     xvar = jt.maximum(x2mean - xmean * xmean, 0)
     norm_x = (x - xmean) / jt.sqrt(xvar + self.eps)
     w = self.weight.reshape((1, self.num_groups, C // self.num_groups, 1))
     b = self.bias.reshape((1, self.num_groups, C // self.num_groups, 1))
     return (norm_x * w + b).reshape((N, C, H, W))
示例#9
0
文件: nn.py 项目: zhangp14/jittor
 def execute(self, x):
     if self.is_train:
         xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
         x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
         xvar = x2mean-xmean*xmean
         norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
         self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum
         self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum
     else:
         running_mean = self.running_mean.broadcast(x, [0,2,3])
         running_var = self.running_var.broadcast(x, [0,2,3])
         norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
     w = self.weight.broadcast(x, [0,2,3])
     b = self.bias.broadcast(x, [0,2,3])
     return norm_x * w + b
示例#10
0
    def dis_loss(self, real_samps, fake_samps, height, alpha, r1_gamma=10.0):
        # Obtain predictions
        r_preds = self.dis(real_samps, height, alpha)
        f_preds = self.dis(fake_samps, height, alpha)

        # loss = torch.mean(nn.Softplus()(f_preds)) + torch.mean(nn.Softplus()(-r_preds))
        loss = jt.mean(nn.Softplus()(f_preds)) + jt.mean(
            nn.Softplus()(-r_preds))

        if r1_gamma != 0.0:
            r1_penalty = self.R1Penalty(real_samps.detach(), height,
                                        alpha) * (r1_gamma * 0.5)
            loss += r1_penalty

        return loss
示例#11
0
def chamfer_loss(pc1, pc2, reduction='mean', sqrt=True):
    '''
    return the chamfer loss from pc1 to pc2.

    Parameters:
    ===========
        pc1:  [B, N, xyz]
        pc2:  [B, N, xyz]
        reduction: 'mean', 'sum', or None
    '''
    batch_size_1, n_samples_pc1, _ = pc1.shape
    batch_size_2, n_samples_pc2, _ = pc2.shape

    assert batch_size_1 == batch_size_2
    batch_size = batch_size_1

    idx = jt.code([batch_size, n_samples_pc1],
                  'int32', [pc1, pc2],
                  cpu_src=cpu_src,
                  cuda_src=cuda_src)

    nearest_pts = select_vertices(pc2, idx)
    if sqrt:
        chamfer_distance = (((pc1 - nearest_pts)**2).sum(dim=-1)).sqrt()
    else:
        chamfer_distance = (((pc1 - nearest_pts)**2).sum(dim=-1))

    if reduction is None:
        return chamfer_distance
    elif reduction == 'sum':
        return jt.sum(chamfer_distance)
    elif reduction == 'mean':
        return jt.mean(chamfer_distance)
示例#12
0
 def execute(self, x):
     avg_out = jt.mean(x, dim=1, keepdims=1)  # 压缩通道
     max_out = jt.max(x, dim=1, keepdims=1)  # 压缩通道
     x = jt.contrib.concat([avg_out, max_out], dim=1)  # [b, 1, h, w]
     x = self.conv1(x)
     y = self.sigmoid(x)
     return y
示例#13
0
 def execute(self, x, cls_label):
     batch_size, _, N = x.size()
     x = self.relu(self.bn1(self.conv1(x)))  # B, D, N
     x = self.relu(self.bn2(self.conv2(x)))
     x1 = self.sa1(x)
     x2 = self.sa2(x1)
     x3 = self.sa3(x2)
     x4 = self.sa4(x3)
     x = concat((x1, x2, x3, x4), dim=1)
     x = self.conv_fuse(x)
     x_max = jt.max(x, 2)
     x_avg = jt.mean(x, 2)
     x_max_feature = x_max.view(batch_size,
                                -1).unsqueeze(-1).repeat(1, 1, N)
     x_avg_feature = x_avg.view(batch_size,
                                -1).unsqueeze(-1).repeat(1, 1, N)
     cls_label_one_hot = cls_label.view(batch_size, 16, 1)
     cls_label_feature = self.label_conv(cls_label_one_hot).repeat(1, 1, N)
     x_global_feature = concat(
         (x_max_feature, x_avg_feature, cls_label_feature), 1)  # 1024 + 64
     x = concat((x, x_global_feature), 1)  # 1024 * 3 + 64
     x = self.relu(self.bns1(self.convs1(x)))
     x = self.dp1(x)
     x = self.relu(self.bns2(self.convs2(x)))
     x = self.convs3(x)
     return x
示例#14
0
    def check(self, h, w, cs, rs, pa, rtp, dim):
        a = jt.random([h, w])
        a.sync()

        with jt.log_capture_scope(
                log_v=0,
                log_vprefix="tuner_manager=100",
                # this value is used for force compile
                compile_options={"test_new_fused_op": 1}) as logs:
            amean = jt.mean(a, dims=[dim], keepdims=1)
            a2mean = jt.mean(a * a, dims=[dim], keepdims=1)
            norm_aa = (a - amean.broadcast_var(a)) / (
                jt.sqrt(a2mean - amean * amean).broadcast_var(a))
            norm_aa.sync()
        logs = find_log_with_re(
            logs,
            "Run tuner reduce: confidence\\((.*)\\) candidates\\((.*)\\)$")
        assert len(logs) == 3, logs
示例#15
0
 def execute(self, x):
     N = x.shape[0]
     C = self.num_channels
     output_shape = (N, -1)
     # TODO: 3d group norm
     if x.ndim == 4:
         output_shape = x.shape
     assert C % self.num_groups == 0
     x = x.reshape((N, self.num_groups, int(C / self.num_groups), -1))
     xmean = jt.mean(x, dims=[2, 3], keepdims=1)
     x2mean = jt.mean(x * x, dims=[2, 3], keepdims=1)
     xvar = jt.maximum(x2mean - xmean * xmean, 0)
     norm_x = (x - xmean) / jt.sqrt(xvar + self.eps)
     if not self.affine:
         return norm_x.reshape(output_shape)
     w = self.weight.reshape((1, self.num_groups, C // self.num_groups, 1))
     b = self.bias.reshape((1, self.num_groups, C // self.num_groups, 1))
     return (norm_x * w + b).reshape(output_shape)
示例#16
0
    def gen_loss(self, real_samps, fake_samps, height, alpha):
        # Obtain predictions
        r_preds = self.dis(real_samps, height, alpha)
        f_preds = self.dis(fake_samps, height, alpha)

        # difference between real and fake:
        # r_f_diff = r_preds - torch.mean(f_preds)
        r_f_diff = r_preds - jt.mean(f_preds)

        # difference between fake and real samples
        # f_r_diff = f_preds - torch.mean(r_preds)
        f_r_diff = f_preds - jt.mean(r_preds)

        # return the loss
        # return (torch.mean(nn.ReLU()(1 + r_f_diff))
        #         + torch.mean(nn.ReLU()(1 - f_r_diff)))
        return (jt.mean(nn.ReLU()(1 + r_f_diff)) +
                jt.mean(nn.ReLU()(1 - f_r_diff)))
示例#17
0
def chamfer_loss(pc1, pc2, reduction='mean', dims='BNC', bidirectional=False):
    ''' return the chamfer loss from pc1 to pc2.

    :param pc1:  input point cloud
    :type pc1: jittor array

    :param pc2:  input point cloud
    :type pc2: jittor array

    :param reduction: reduction method in batches, can be 'mean', 'sum', or None. Default: 'mean'.
    :type reduction: str, optional
            
    :param dims: a string that represents each dimension, can be
            '[BNC]' ([batch, number of points, xyz]), or
            '[BCN]' ([batch, xyz, number of points]). Default: 'BNC'.
    :type dims: str, optional

    Example:

    >>> import jittor as jt
    >>> from jittor.loss3d import chamfer_loss
    >>> jt.flags.use_cuda = True
    >>> pc1 = jt.rand([10, 100, 3], dtype=jt.float32)
    >>> pc2 = jt.rand([10, 100, 3], dtype=jt.float32)
    >>> cf = chamfer_loss(pc1, pc2, dims='BNC', bidirectional=True)
    >>> print('chamfer loss =', cf.item())
    '''
    if bidirectional:
        return chamfer_loss(pc1, pc2, reduction, dims) + chamfer_loss(
            pc2, pc1, reduction, dims)

    assert dims in ['BNC', 'BCN']
    if dims == 'BCN':
        pc1, pc2 = pc1.permute(0, 2, 1), pc2.permute(0, 2, 1)

    batch_size_1, N, _ = pc1.shape
    batch_size_2, M, _ = pc2.shape
    assert batch_size_1 == batch_size_2
    batch_size = batch_size_1

    idx = jt.code([batch_size, N],
                  'int32', [pc1, pc2],
                  cpu_src=cpu_src,
                  cuda_src=cuda_src)

    nearest_pts = pc2.reindex([batch_size, idx.shape[1], 3],
                              ['i0', '@e0(i0, i1)', 'i2'],
                              extras=[idx])

    chamfer_distance = (((pc1 - nearest_pts)**2).sum(dim=-1)).sqrt()
    if reduction is None:
        return chamfer_distance
    elif reduction == 'sum':
        return jt.sum(chamfer_distance)
    elif reduction == 'mean':
        return jt.mean(chamfer_distance)
示例#18
0
    def execute(self, x):
        try:
            avg_out = jt.mean(x, dim=1, keepdims=True)
            max_out = jt.max(x, dim=1, keepdims=True)
            scale = jt.contrib.concat([avg_out, max_out], dim=1)
            scale = self.conv(scale)
            out = x * self.sigmoid(scale)
        except Exception as e:
            print(e)
            out = x

        return out
示例#19
0
文件: nn.py 项目: waTeim/jittor
    def execute(self, x):
        if self.is_train:
            xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
            x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
            if self.sync and jt.in_mpi:
                xmean = xmean.mpi_all_reduce("mean")
                x2mean = x2mean.mpi_all_reduce("mean")

            xvar = x2mean-xmean*xmean
            norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
            self.running_mean.update(self.running_mean +
                (xmean.reshape((-1,)) - self.running_mean) * self.momentum)
            self.running_var.update(self.running_var +
                (xvar.reshape((-1,))-self.running_var)*self.momentum)
        else:
            running_mean = self.running_mean.broadcast(x, [0,2,3])
            running_var = self.running_var.broadcast(x, [0,2,3])
            norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
        w = self.weight.broadcast(x, [0,2,3])
        b = self.bias.broadcast(x, [0,2,3])
        return norm_x * w + b
示例#20
0
文件: nn.py 项目: zhangp14/jittor
def batch_norm(x, is_train, eps=1e-5, momentum=0.1):
    w = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 1.0))
    b = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 0.0))
    running_mean = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 0.0))
    running_var = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 1.0))

    w = w.broadcast(x, [0,2,3])
    b = b.broadcast(x, [0,2,3])
    if is_train:
        xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
        x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
        xvar = x2mean-xmean*xmean
        norm_x = (x-xmean)/jt.sqrt(xvar+eps)

        running_mean += (xmean.sum([0,2,3])-running_mean)*momentum
        running_var += (xvar.sum([0,2,3])-running_var)*momentum
    else:
        running_mean = running_mean.broadcast(x, [0,2,3])
        running_var = running_var.broadcast(x, [0,2,3])
        norm_x = (x-running_mean)/jt.sqrt(running_var+eps)

    return norm_x * w + b
示例#21
0
 def execute(self, x, normal=None):
     if normal is None:
         x = (x, x)
     else :
         x = (x, normal)
     
     x = self.pcnn1(x)
     x = self.pcnn2(x)[1] # grab features 
     
     x = x.permute(0, 2, 1) # b, dim, n
     logits = self.fcn(x) 
     logits = jt.mean(logits, dim=2)
     return logits
示例#22
0
 def execute(self, trans_points, cp, voxel, gridSize, weight=1):
     if len(trans_points.shape) == 4:
         trans_points = trans_points.squeeze(dim=-1)
     nb = pointClosestCellIndex(trans_points)
     idx = jt.matmul(
         nb, jt.transform.to_tensor(jt.array([(gridSize**2), gridSize, 1])))
     mask = (1 - voxel.view((-1), (gridSize**3)).gather(1, idx))
     idx = idx.unsqueeze(2)
     idx = idx.repeat(1, 1, 3)
     mask = mask.unsqueeze(2).repeat(1, 1, 3)
     closest_points = cp.gather(1, idx)
     self.constant = weight
     distance = (trans_points - closest_points)
     distance = (distance * mask)
     # self.save_for_backward(distance)
     self.saved_tensors = distance
     return (jt.mean(jt.sum(jt.sum(jt.pow(distance, 2), 2), 1)) * weight)
示例#23
0
def adjust_contrast_tensor(img, contrast_factor):
    """Adjust contrast of an RGB image.
    Args:
        img (Tensor): Image to be adjusted.
        contrast_factor (float): How much to adjust the contrast. Can be any
            non negative number. 0 gives a solid gray image, 1 gives the
            original image while 2 increases the contrast by a factor of 2.
    Returns:
        Tensor: Contrast adjusted image.
    """
    if contrast_factor < 0:
        raise ValueError('contrast_factor ({}) is not non-negative.'.format(contrast_factor))

    if not _is_tensor_a_jittor_image(img):
        raise TypeError('tensor is not a jittor image.')

    gray  = rgb_to_grayscale(img)
    gray.dtype = 'float'
    mean = jt.mean(gray)

    return _blend(img, mean, contrast_factor)
示例#24
0
    def gen_loss(self, _, fake_samps, height, alpha):
        f_preds = self.dis(fake_samps, height, alpha)
        # print(f_preds.is_stop_grad())

        # return torch.mean(nn.Softplus()(-f_preds))
        return jt.mean(nn.Softplus()(-f_preds))
示例#25
0
 def execute(self, input):
     return input / jt.sqrt(jt.mean(input ** 2, dim=1, keepdims=True) + 1e-8)
示例#26
0
 def gen_loss(self, _, fake_samps, height, alpha):
     # return -torch.mean(self.dis(fake_samps, height, alpha))
     return -jt.mean(self.dis(fake_samps, height, alpha))
示例#27
0
                                    latent_dim=latent_dim,
                                    n_c=n_c)
        gen_imgs = generator(zn, zc)
        D_gen = discriminator(gen_imgs)
        D_real = discriminator(real_imgs)

        # -----------------
        #  Train Generator
        # -----------------

        if ((i % n_skip_iter) == 0):
            (enc_gen_zn, enc_gen_zc, enc_gen_zc_logits) = encoder(gen_imgs)
            zn_loss = mse_loss(enc_gen_zn, zn)
            zc_loss = xe_loss(enc_gen_zc_logits, zc_idx)
            if wass_metric:
                ge_loss = ((jt.mean(D_gen) + (betan * zn_loss)) +
                           (betac * zc_loss))
            else:
                valid = jt.ones([gen_imgs.shape[0], 1]).stop_grad()
                v_loss = bce_loss(D_gen, valid)
                ge_loss = ((v_loss + (betan * zn_loss)) + (betac * zc_loss))
            optimizer_GE.step(ge_loss)

        # ---------------------
        #  Train Discriminator
        # ---------------------

        if wass_metric:
            grad_penalty = calc_gradient_penalty(discriminator, real_imgs,
                                                 gen_imgs)
            d_loss = ((jt.mean(D_real) - jt.mean(D_gen)) + grad_penalty)
示例#28
0
total_time = 0.
cnt = 0

# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for (i, (real_imgs, _)) in enumerate(dataloader):
        # -----------------
        #  Train Discriminator
        # -----------------

        z = jt.array(np.random.normal(0, 1, (real_imgs.shape[0], opt.latent_dim)).astype(np.float32))
        fake_imgs = generator(z).detach()
        loss_D = ((- jt.mean(discriminator(real_imgs))) + jt.mean(discriminator(fake_imgs)))
        optimizer_D.step(loss_D)
        for p in discriminator.parameters():
            clamp_(p, - opt.clip_value, opt.clip_value)
        
        # ---------------------
        #  Train Generator
        # ---------------------

        if ((i % opt.n_critic) == 0):
            gen_imgs = generator(z)
            loss_G = (- jt.mean(discriminator(gen_imgs)))
            optimizer_G.step(loss_G)
            if warmup_times==-1:
                print(('[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]' % (epoch, opt.n_epochs, (batches_done % len(dataloader)), len(dataloader), loss_D.numpy()[0], loss_G.numpy()[0])))
                  (fake_B_gray_line2 >= 0)].sum()) / bs * opt.lambda_chamfer2

        # Local loss
        real_B_locals = [
            real_B_eyel, real_B_eyer, real_B_nose, real_B_mouth, real_B_hair,
            real_B_bg
        ]
        loss_G_local = 0
        for j in range(6):
            loss_G_local += criterion_pixelwise(
                fake_B_locals[j], real_B_locals[j]) * opt.lambda_local

        # Line continuity loss
        fake_B_patches, conti_weights = get_patches(fake_B, maskface)
        outputs = regressor(fake_B_patches)
        loss_G_continuity = jt.mean(
            (1.0 - outputs) * conti_weights) * opt.lambda_continuity

        # Total loss
        loss_G = loss_GAN + loss_GAN_local + loss_pixel + (
            loss_G_chamfer +
            loss_G_chamfer2) + loss_G_local + loss_G_continuity
        #pdb.set_trace()
        optimizer_G.step(loss_G)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        # Real loss
        pred_real = D_global(real_B, real_A)
        loss_real = criterion_GAN(pred_real, valid)
        loss_real_local = 0
示例#30
0
 def entropy(self):
     return -jt.sum(jt.mean(self.probs) * jt.log(self.probs))