Ejemplo n.º 1
0
def feature_matching_loss(dis_maps_real, dis_maps_fake, model_params, weights):
    variable_not_exist = True
    for j, scale in enumerate(model_params['discriminator_params']['scales']):
        key = f"feature_maps_{scale}".replace('.', '-')
        for i, (a, b) in enumerate(zip(dis_maps_real[key], dis_maps_fake[key])):
            if weights[i] == 0:
                continue
            if variable_not_exist:
                loss = F.mean(F.absolute_error(a, b)) * weights[i]
                variable_not_exist = False
            else:
                loss += F.mean(F.absolute_error(a, b)) * weights[i]
    return loss
Ejemplo n.º 2
0
def get_esrgan_gen(conf, train_gt, train_lq, fake_h):
    """
    Create computation graph and variables for ESRGAN Generator.
    """
    var_ref = nn.Variable(
        (conf.train.batch_size, 3, conf.train.gt_size, conf.train.gt_size))
    # Feature Loss (L1 Loss)
    load_vgg19 = PretrainedVgg19()
    real_fea = load_vgg19(train_gt)
    # need_grad set to False, to avoid BP to vgg19 network
    real_fea.need_grad = False
    fake_fea = load_vgg19(fake_h)
    feature_loss = F.mean(F.absolute_error(fake_fea, real_fea))
    feature_loss.persistent = True

    # Gan Loss Generator
    with nn.parameter_scope("dis"):
        pred_g_fake = discriminator(fake_h)
        pred_d_real = discriminator(var_ref)
    pred_d_real.persistent = True
    pred_g_fake.persistent = True
    unlinked_pred_d_real = pred_d_real.get_unlinked_variable()
    gan_loss = RelativisticAverageGanLoss(GanLoss())
    gan_loss_gen_out = gan_loss(unlinked_pred_d_real, pred_g_fake)
    loss_gan_gen = gan_loss_gen_out.generator_loss
    loss_gan_gen.persistent = True
    Model_gen = namedtuple('Model_gen', [
        'train_gt', 'train_lq', 'var_ref', 'feature_loss', 'loss_gan_gen',
        'pred_d_real', 'pred_g_fake'
    ])
    return Model_gen(train_gt, train_lq, var_ref, feature_loss, loss_gan_gen,
                     pred_d_real, pred_g_fake)
Ejemplo n.º 3
0
def Loss_gen(wave_fake, wave_true, dval_fake, lmd=100):

    def SquaredError_Scalor(x, val=1):
        return F.squared_error(x, F.constant(val, x.shape))

    E_fake = F.mean( SquaredError_Scalor(dval_fake, val=1) )	# fake
    E_wave = F.mean( F.absolute_error(wave_fake, wave_true) )  	# Reconstruction Performance
    return E_fake / 2 + lmd * E_wave
Ejemplo n.º 4
0
 def perceptual_loss(self, x, target):
     r"""Returns perceptual loss."""
     loss = []
     out_x, out_t = self(x, None), self(target, None)
     for (a, t) in zip(out_x, out_t):
         for la, lt in zip(a[:-1], t[:-1]):
             lt.need_grad = False  # avoid grads flowing though targets
             loss.append(F.mean(F.absolute_error(la, lt)))
     return sum(loss) / self.hp.num_D
Ejemplo n.º 5
0
def test(args):
    """
    Training
    """

    ##  ~~~~~~~~~~~~~~~~~~~
    ##   Initial settings
    ##  ~~~~~~~~~~~~~~~~~~~

    #   Input Variable
    nn.clear_parameters()  # Clear
    Input = nn.Variable([1, 3, 64, 64])  # Input
    Trues = nn.Variable([1, 1])  # True Value

    #   Network Definition
    Name = "CNN"  # Name of scope which includes network models (arbitrary)
    Output_test = network(Input, scope=Name, test=True)  # Network & Output
    Loss_test = F.mean(F.absolute_error(
        Output_test, Trues))  # Loss Function (Squared Error)

    #   Load data
    with nn.parameter_scope(Name):
        nn.load_parameters(
            os.path.join(args.model_save_path,
                         "network_param_{:04}.h5".format(args.epoch)))

    # Training Data Setting
    image_data, mos_data = dt.data_loader(test=True)
    batches = dt.create_batch(image_data, mos_data, 1)
    del image_data, mos_data

    truth = []
    result = []
    for j in range(batches.iter_n):
        Input.d, tures = next(batches)
        Loss_test.forward(clear_no_need_grad=True)
        result.append(Loss_test.d)
        truth.append(tures)

    result = np.array(result)
    truth = np.squeeze(np.array(truth))

    # Evaluation of performance
    mae = np.average(np.abs(result - truth))
    SRCC, p1 = stats.spearmanr(truth,
                               result)  # Spearman's Correlation Coefficient
    PLCC, p2 = stats.pearsonr(truth, result)

    #   Display
    print("\n Model Parameter [epoch={0}]".format(args.epoch))
    print(" Mean Absolute Error with Truth: {0:.4f}".format(mae))
    print(" Speerman's Correlation Coefficient: {0:.3f}".format(SRCC))
    print(" Pearson's Linear Correlation Coefficient: {0:.3f}".format(PLCC))
Ejemplo n.º 6
0
def mae(x, y, mask=None, eps=1e-5):
    # l1 distance and reduce mean
    ae = F.absolute_error(x, y)

    if mask is not None:
        assert ae.shape[:2] == mask.shape[:2]

        ae *= F.reshape(mask, ae.shape)

        return F.sum(ae) / (F.sum(mask) + eps)

    return F.mean(ae)
Ejemplo n.º 7
0
    def update_graph(self, key='train'):
        r"""Builds the graph and update the placeholder.
        Args:
            key (str, optional): Type of computational graph.
                Defaults to 'train'.
        """
        assert key in ('train', 'valid')
        hp = self.hp

        self.gen.training = key == 'train'
        self.dis.training = key == 'train'

        # define input variables
        x_real = nn.Variable((hp.batch_size, 1, hp.segment_length))
        x_real_mel = compute_mel(x_real, self.mel_basis, hp)

        x_fake = self.gen(x_real_mel)
        x_fake_mel = compute_mel(x_fake, self.mel_basis, hp)

        dis_real_x = self.dis(x_real)
        dis_fake_x = self.dis(x_fake)

        # ------------------------------ Discriminator -----------------------
        d_loss = (discriminator_loss(dis_real_x, 1.0) +
                  discriminator_loss(dis_fake_x, 0.0))
        # -------------------------------- Generator -------------------------
        g_loss_avd = discriminator_loss(dis_fake_x, 1.0)
        g_loss_mel = F.mean(F.absolute_error(x_real_mel, x_fake_mel))
        g_loss_fea = feature_loss(dis_real_x, dis_fake_x)
        g_loss = g_loss_avd + 45 * g_loss_mel + 2 * g_loss_fea

        set_persistent_all(
            g_loss_mel,
            g_loss_avd,
            g_loss_fea,
            d_loss,
            x_fake,
            g_loss,
        )

        self.placeholder[key] = dict(
            x_real=x_real,
            x_fake=x_fake,
            d_loss=d_loss,
            g_loss_avd=g_loss_avd,
            g_loss_mel=g_loss_mel,
            g_loss_fea=g_loss_fea,
            g_loss=g_loss,
        )
Ejemplo n.º 8
0
def vgg16_perceptual_loss(fake, real):
    '''VGG perceptual loss based on VGG-16 network.

    Assuming the values in fake and real are in [0, 255].

    Features are obtained from all ReLU activations of the first convolution
    after each downsampling (maxpooling) layer
    (including the first convolution applied to an image).
    '''
    from nnabla.models.imagenet import VGG16

    class VisitFeatures(object):
        def __init__(self):
            self.features = []
            self.relu_counter = 0
            self.features_at = set([0, 2, 4, 7, 10])

        def __call__(self, f):
            # print(f.name, end='')
            if not f.name.startswith('ReLU'):
                # print('')
                return
            if self.relu_counter in self.features_at:
                self.features.append(f.outputs[0])
                # print('*', end='')
            # print('')
            self.relu_counter += 1

    # We use VGG16 model instead of VGG19 because VGG19
    # is not in nnabla.models.
    vgg = VGG16()

    def get_features(x):
        o = vgg(x, use_up_to='lastconv')
        f = VisitFeatures()
        o.visit(f)
        return f

    with nn.parameter_scope("vgg16_loss"):
        fake_features = get_features(fake)
        real_features = get_features(real)

    volumes = np.array([np.prod(f.shape)
                        for f in fake_features.features], dtype=np.float32)
    weights = volumes[-1] / volumes
    return sum([w * F.mean(F.absolute_error(ff, fr)) for w, ff, fr in zip(weights, fake_features.features, real_features.features)])
Ejemplo n.º 9
0
def perceptual_loss(pyramide_real, pyramide_fake, scales, weights, vgg_param_path):
    """
        Compute Perceptual Loss using VGG19 as a feature extractor.
    """
    vgg19 = PretrainedVgg19(param_path=vgg_param_path)
    variable_not_exist = True
    for scale in scales:
        x_vgg = vgg19(pyramide_fake[f'prediction_{scale}'])
        y_vgg = vgg19(pyramide_real[f'prediction_{scale}'])

        for i, weight in enumerate(weights):
            value = F.mean(F.absolute_error(x_vgg[i], y_vgg[i]))
            if variable_not_exist:
                loss = weight * value
                variable_not_exist = False
            else:
                loss += weight * value
    return loss
Ejemplo n.º 10
0
def equivariance_jacobian_loss(kp_driving_jacobian, arithmetic_jacobian,
                               trans_kp_jacobian, weight):
    jacobian_transformed = F.batch_matmul(arithmetic_jacobian,
                                          trans_kp_jacobian)

    normed_driving = F.reshape(
        F.batch_inv(
            F.reshape(kp_driving_jacobian,
                      (-1, ) + kp_driving_jacobian.shape[-2:])),
        kp_driving_jacobian.shape)

    normed_transformed = jacobian_transformed
    value = F.batch_matmul(normed_driving, normed_transformed)

    eye = nn.Variable.from_numpy_array(np.reshape(np.eye(2), (1, 1, 2, 2)))

    jacobian_loss = F.mean(F.absolute_error(eye, value))
    loss = weight * jacobian_loss
    return loss
Ejemplo n.º 11
0
 def forward(self, output, inds, gt, reg_mask, channel_last=False):
     # TODO refactor loss implementation for channel_last without transposing
     if channel_last:
         output = F.transpose(output, (0, 3, 1, 2))
     b = inds.shape[0]
     c = output.shape[1]
     max_objs = inds.shape[1]
     # divide by number of :
     num_objs = F.sum(reg_mask) * 2
     f_map_size = output.shape[2] * output.shape[3]
     output = F.reshape(output, (-1, f_map_size))
     inds = F.broadcast(inds.reshape((b, 1, max_objs)), (b, c, max_objs))
     inds = inds.reshape((-1, max_objs))
     y = output[F.broadcast(F.reshape(F.arange(0, b * c), (b * c, 1)),
                            (b * c, max_objs)), inds].reshape(
                                (b, c, max_objs))
     y = F.transpose(y, (0, 2, 1))
     loss = F.sum(reg_mask * F.absolute_error(y, gt))
     loss = loss / (num_objs + 1e-4)
     return loss
Ejemplo n.º 12
0
def context_preserving_loss(xa, yb):
    def mask_weight(a, b):
        # much different from definition in the paper
        merged_mask = F.concatenate(a, b, axis=1)
        summed_mask = F.sum((merged_mask + 1) / 2, axis=1, keepdims=True)
        clipped = F.clip_by_value(summed_mask,
                                  F.constant(0, shape=summed_mask.shape),
                                  F.constant(1, shape=summed_mask.shape))
        z = clipped * 2 - 1
        mask = (1 - z) / 2
        return mask

    x = xa[:, :3, :, :]
    a = xa[:, 3:, :, :]
    y = yb[:, :3, :, :]
    b = yb[:, 3:, :, :]

    assert x.shape == y.shape and a.shape == b.shape
    W = mask_weight(a, b)
    return F.mean(F.mul2(F.absolute_error(x, y), W))
Ejemplo n.º 13
0
def define_loss(real_out,
                real_feats,
                fake_out,
                fake_feats,
                use_fm=True,
                fm_lambda=10.,
                gan_loss_type="ls"):
    g_gan = 0
    g_feat = 0 if use_fm else F.constant(0)
    d_real = 0
    d_fake = 0

    gan_loss = get_gan_loss(gan_loss_type)

    n_disc = len(real_out)

    for disc_id in real_out.keys():
        r_out = real_out[disc_id]
        r_feats = real_feats[disc_id]
        f_out = fake_out[disc_id]
        f_feats = fake_feats[disc_id]

        # define GAN loss
        _d_real, _d_fake, _g_gan = gan_loss(r_out, f_out)

        d_real += _d_real
        d_fake += _d_fake
        g_gan += _g_gan

        # feature matching
        if use_fm:
            assert r_out.shape == f_out.shape

            for layer_id, r_feat in r_feats.items():
                g_feat += F.mean(F.absolute_error(
                    r_feat, f_feats[layer_id])) * fm_lambda / n_disc

    return g_gan, g_feat, d_real, d_fake
Ejemplo n.º 14
0
 def criteria(x, t):
     return F.mean(F.absolute_error(x, t))
Ejemplo n.º 15
0
def train(args):
    """
    Training
    """

    ##  ~~~~~~~~~~~~~~~~~~~
    ##   Initial settings
    ##  ~~~~~~~~~~~~~~~~~~~

    #   Input Variable       args. -> setting.
    M = 64

    nn.clear_parameters()  # Clear
    Input = nn.Variable([args.batch_size, 6, 128, 128])
    Trues = nn.Variable([args.batch_size, 1])  # True Value

    #   Network Definition
    Name = "CNN"  # Name of scope which includes network models (arbitrary)
    Name2 = "CNN"
    preOutput = network(input=Input, feature_num=M,
                        scope=Name)  # Network & Output #add
    # preOutput = F.reshape(preOutput, (args.batch_size, 1, M))      # (B*N, M) > (B, N, M)
    # preOutput = F.mean(preOutput, axis=1, keepdims=True)  # (B, N, M) > (B, 1, M) N個のシフト画像の特徴量を1つにする keepdims->次元を保持
    Output = network2(input=preOutput, scope=Name2)  # fullconnect

    #   Loss Definition
    Loss = F.mean(F.absolute_error(
        Output,
        Trues))  # Loss Function (Squared Error) 誤差関数(差の絶対値の平均) -> 交差エントロピーはだめ?

    #   Solver Setting
    solver = S.Adam(args.learning_rate)  # Adam is used for solver 学習率の最適化
    solver2 = S.Adam(args.learning_rate)  # Adam is used for solver 学習率の最適化
    solver.weight_decay(0.00001)  # Weight Decay for stable update
    solver2.weight_decay(0.00001)

    with nn.parameter_scope(Name):  # Get updating parameters included in scope
        solver.set_parameters(nn.get_parameters())

    with nn.parameter_scope(
            Name2):  # Get updating parameters included in scope
        solver2.set_parameters(nn.get_parameters())

    #   Training Data Setting
    #image_data, mos_data, image_files = dt.data_loader(test = False)
    image_data, mos_data, similarity = dt.data_loader(test=False)

    #batches = dt.create_batch(image_data, mos_data, args.batch_size, image_files)
    batches = dt.create_batch(image_data, mos_data, args.batch_size)
    del image_data, mos_data

    ##  ~~~~~~~~~~~~~~~~~~~
    ##   Learning
    ##  ~~~~~~~~~~~~~~~~~~~
    print('== Start Training ==')

    bar = tqdm(total=(args.epoch - args.retrain) * batches.iter_n, leave=False)
    bar.clear()
    cnt = 0
    loss_disp = True

    #   Load data
    if args.retrain > 0:  # 途中のエポック(retrain)から再学習
        with nn.parameter_scope(Name):
            print('Retrain from {0} Epoch'.format(args.retrain))
            nn.load_parameters(
                os.path.join(args.model_save_path,
                             "network_param_{:04}.h5".format(args.retrain)))
            solver.set_learning_rate(args.learning_rate /
                                     np.sqrt(args.retrain))

    ##  Training
    for i in range(args.retrain,
                   args.epoch):  # args.retrain → args.epoch まで繰り返し学習

        bar.set_description_str('Epoch {0}/{1}:'.format(i + 1, args.epoch),
                                refresh=False)  # プログレスバーに説明文を加える

        #   Shuffling
        batches.shuffle()

        ##  Batch iteration
        for j in range(batches.iter_n):  # バッチ学習

            cnt += 1

            #  Load Batch Data from Training data
            Input_npy, Trues_npy = next(batches)
            size_ = Input_npy.shape
            Input.d = Input_npy.reshape(
                [size_[0] * size_[1], size_[2], size_[3], size_[4]])
            Trues.d = Trues_npy

            #  Update
            solver.zero_grad()  # Initialize  #   Initialize #勾配をリセット
            #solver2.zero_grad()
            Loss.forward(clear_no_need_grad=True)  # Forward path #順伝播
            loss_scale = 8
            Loss.backward(loss_scale,
                          clear_buffer=True)  # Backward path #誤差逆伝播法
            #solver2.update()
            solver.scale_grad(1. / loss_scale)
            solver.update()

            # Progress
            if cnt % 10 == 0:
                bar.update(10)  # プログレスバーの進捗率を1あげる
                if loss_disp is not None:
                    bar.set_postfix_str('Loss={0:.3e}'.format(Loss.d),
                                        refresh=False)  # 実行中にloss_dispとSRCCを表示

        ## Save parameters
        if ((i + 1) % args.model_save_cycle) == 0 or (i + 1) == args.epoch:
            bar.clear()
            with nn.parameter_scope(Name):
                nn.save_parameters(
                    os.path.join(args.model_save_path,
                                 'network_param_{:04}.h5'.format(i + 1)))
            with nn.parameter_scope(Name2):
                nn.save_parameters(
                    os.path.join(args.model_save_path2,
                                 'network_param_{:04}.h5'.format(i + 1)))
Ejemplo n.º 16
0
def test(args):
    """
    Training
    """
    M = 64
    ##  ~~~~~~~~~~~~~~~~~~~
    ##   Initial settings
    ##  ~~~~~~~~~~~~~~~~~~~

    #   Input Variable 変数定義
    nn.clear_parameters()  # Clear
    Input = nn.Variable([1, 6, 256, 256])  # Input
    Trues = nn.Variable([1, 1])  # True Value

    #   Network Definition
    Name = "CNN"  # Name of scope which includes network models (arbitrary)
    Name2 = "CNN"

    preOutput = network(input=Input, feature_num=M,
                        scope=Name)  # Network & Output #add
    preOutput = F.reshape(preOutput, (1, N, M))  # (B*N, M) > (B, N, M)
    preOutput_mean = F.mean(
        preOutput, axis=1, keepdims=True
    )  # (B, N, M) > (B, 1, M) N個のシフト画像の特徴量を1つにする keepdims->次元を保持
    Output_test = network2(input=preOutput_mean, scope=Name2)  # fullconnect

    Loss_test = F.mean(F.absolute_error(
        Output_test, Trues))  # Loss Function (Squared Error) #誤差関数

    #   Load data 保存した学習パラメータの読み込み
    with nn.parameter_scope(Name):
        nn.load_parameters(
            os.path.join(args.model_save_path,
                         "network_param_{:04}.h5".format(args.epoch)))
    with nn.parameter_scope(Name2):
        nn.load_parameters(
            os.path.join(args.model_save_path2,
                         "network_param_{:04}.h5".format(args.epoch)))

    # Test Data Setting
    #image_data, mos_data, image_files = dt.data_loader(test=True)
    image_data, mos_data = dt.data_loader(test=True)
    #batches = dt.create_batch(image_data, mos_data, 1, image_files)
    batches = dt.create_batch(image_data, mos_data, 1)

    del image_data, mos_data

    truth = []
    result = []

    for j in range(batches.iter_n):
        #Input_npy, Trues_npy, image_files = next(batches)
        Input_npy, Trues_npy = next(batches)
        size_ = Input_npy.shape
        # print("Input Image:" +  str(image_files) + " Trues:" + str(Trues_npy))
        Input.d = Input_npy.reshape(
            [size_[0] * size_[1], size_[2], size_[3], size_[4]])
        Trues.d = Trues_npy[0][0]

        Loss_test.forward(clear_no_need_grad=True)
        result.append(Loss_test.d)
        truth.append(Trues.d)

    result = np.array(result)
    mean = np.mean(result)
    truth = np.squeeze(np.array(truth))  # delete

    # Evaluation of performance
    mae = np.average(np.abs(result - truth))
    SRCC, p1 = stats.spearmanr(truth,
                               result)  # Spearman's Correlation Coefficient
    PLCC, p2 = stats.pearsonr(truth, result)

    np.set_printoptions(threshold=np.inf)
    print("result: {}".format(result))
    print("Trues: {}".format(truth))
    print(np.average(result))
    print("\n Model Parameter [epoch={0}]".format(args.epoch))
    print(" Mean Absolute Error with Truth: {0:.4f}".format(mae))
    print(" Speerman's Correlation Coefficient: {0:.5f}".format(SRCC))
    print(" Pearson's Linear Correlation Coefficient: {0:.5f}".format(PLCC))
Ejemplo n.º 17
0
def train(args):
    """
    Training
    """

    ##  ~~~~~~~~~~~~~~~~~~~
    ##   Initial settings
    ##  ~~~~~~~~~~~~~~~~~~~

    #   Input Variable
    nn.clear_parameters()  #   Clear
    Input = nn.Variable([args.batch_size, 3, 64, 64])  #   Input
    Trues = nn.Variable([args.batch_size, 1])  #   True Value

    #   Network Definition
    Name = "CNN"  #   Name of scope which includes network models (arbitrary)
    Output = network(Input, scope=Name)  #   Network & Output
    Output_test = network(Input, scope=Name, test=True)

    #   Loss Definition
    Loss = F.mean(F.absolute_error(Output,
                                   Trues))  #   Loss Function (Squared Error)
    Loss_test = F.mean(F.absolute_error(Output_test, Trues))

    #   Solver Setting
    solver = S.AMSBound(args.learning_rate)  #   Adam is used for solver
    with nn.parameter_scope(
            Name):  #   Get updating parameters included in scope
        solver.set_parameters(nn.get_parameters())

    #   Training Data Setting
    image_data, mos_data = dt.data_loader()
    batches = dt.create_batch(image_data, mos_data, args.batch_size)
    del image_data, mos_data

    #   Test Data Setting
    image_data, mos_data = dt.data_loader(test=True)
    batches_test = dt.create_batch(image_data, mos_data, args.batch_size)
    del image_data, mos_data

    ##  ~~~~~~~~~~~~~~~~~~~
    ##   Learning
    ##  ~~~~~~~~~~~~~~~~~~~
    print('== Start Training ==')

    bar = tqdm(total=args.epoch - args.retrain, leave=False)
    bar.clear()
    loss_disp = None
    SRCC = None

    #   Load data
    if args.retrain > 0:
        with nn.parameter_scope(Name):
            print('Retrain from {0} Epoch'.format(args.retrain))
            nn.load_parameters(
                os.path.join(args.model_save_path,
                             "network_param_{:04}.h5".format(args.retrain)))
            solver.set_learning_rate(args.learning_rate /
                                     np.sqrt(args.retrain))

    ##  Training
    for i in range(args.retrain, args.epoch):

        bar.set_description_str('Epoch {0}:'.format(i + 1), refresh=False)
        if (loss_disp is not None) and (SRCC is not None):
            bar.set_postfix_str('Loss={0:.5f},  SRCC={1:.4f}'.format(
                loss_disp, SRCC),
                                refresh=False)
        bar.update(1)

        #   Shuffling
        batches.shuffle()
        batches_test.shuffle()

        ##  Batch iteration
        for j in range(batches.iter_n):

            #  Load Batch Data from Training data
            Input.d, Trues.d = next(batches)

            #  Update
            solver.zero_grad()  #   Initialize
            Loss.forward(clear_no_need_grad=True)  #   Forward path
            Loss.backward(clear_buffer=True)  #   Backward path
            solver.weight_decay(0.00001)  #   Weight Decay for stable update
            solver.update()

        ## Progress
        # Get result for Display
        Input.d, Trues.d = next(batches_test)
        Loss_test.forward(clear_no_need_grad=True)
        Output_test.forward()
        loss_disp = Loss_test.d
        SRCC, _ = stats.spearmanr(Output_test.d, Trues.d)

        # Display text
        # disp(i, batches.iter_n, Loss_test.d)

        ## Save parameters
        if ((i + 1) % args.model_save_cycle) == 0 or (i + 1) == args.epoch:
            bar.clear()
            with nn.parameter_scope(Name):
                nn.save_parameters(
                    os.path.join(args.model_save_path,
                                 'network_param_{:04}.h5'.format(i + 1)))
Ejemplo n.º 18
0
def idr_loss(camloc, raydir, alpha, color_gt, mask_obj, conf):
    # Setting
    B, R, _ = raydir.shape
    L = conf.layers
    D = conf.depth
    feature_size = conf.feature_size

    # Ray trace (visibility)
    x_hit, mask_hit, dists, mask_pin, mask_pout = \
        ray_trace(partial(sdf_net, conf=conf),
                  camloc, raydir, mask_obj, t_near=conf.t_near, t_far=conf.t_far,
                  sphere_trace_itr=conf.sphere_trace_itr,
                  ray_march_points=conf.ray_march_points,
                  n_chunks=conf.n_chunks,
                  max_post_itr=conf.max_post_itr,
                  post_method=conf.post_method, eps=conf.eps)

    x_hit = x_hit.apply(need_grad=False)
    mask_hit = mask_hit.apply(need_grad=False, persistent=True)
    dists = dists.apply(need_grad=False)
    mask_pin = mask_pin.apply(need_grad=False)
    mask_pout = mask_pout.apply(need_grad=False)
    mask_us = mask_pin + mask_pout
    P = F.sum(mask_us)

    # Current points
    x_curr = (camloc.reshape((B, 1, 3)) + dists * raydir).apply(need_grad=True)

    # Eikonal loss
    bounding_box_size = conf.bounding_box_size
    x_free = F.rand(-bounding_box_size,
                    bounding_box_size,
                    shape=(B, R // 2, 3))
    x_point = F.concatenate(*[x_curr, x_free], axis=1)
    sdf_xp, _, grad_xp = sdf_feature_grad(implicit_network, x_point, conf)
    gp = (F.norm(grad_xp, axis=[grad_xp.ndim - 1], keepdims=True) - 1.0)**2.0
    loss_eikonal = F.sum(gp[:, :R, :] * mask_us) + F.sum(gp[:, R:, :])
    loss_eikonal = loss_eikonal / (P + B * R // 2)
    loss_eikonal = loss_eikonal.apply(persistent=True)

    sdf_curr = sdf_xp[:, :R, :]
    grad_curr = grad_xp[:, :R, :]

    # Mask loss
    logit = -alpha.reshape([1 for _ in range(sdf_curr.ndim)]) * sdf_curr
    loss_mask = F.sigmoid_cross_entropy(logit, mask_obj)
    loss_mask = loss_mask * mask_pout
    loss_mask = F.sum(loss_mask) / P / alpha
    loss_mask = loss_mask.apply(persistent=True)

    # Lighting
    x_hat = sample_network(x_curr, sdf_curr, raydir, grad_curr)
    _, feature, grad = sdf_feature_grad(implicit_network, x_hat, conf)
    normal = grad
    color_pred = lighting_network(x_hat, normal, feature, -raydir, D)

    # Color loss
    loss_color = F.absolute_error(color_gt, color_pred)
    loss_color = loss_color * mask_pin
    loss_color = F.sum(loss_color) / P
    loss_color = loss_color.apply(persistent=True)

    # Total loss
    loss = loss_color + conf.mask_weight * \
        loss_mask + conf.eikonal_weight * loss_eikonal

    return loss, loss_color, loss_mask, loss_eikonal, mask_hit
Ejemplo n.º 19
0
def equivariance_value_loss(kp_driving_value, warped_kp_value, weight):
    value_loss = F.mean(F.absolute_error(kp_driving_value, warped_kp_value))
    loss = weight * value_loss
    return loss
Ejemplo n.º 20
0
def main():
    conf = get_config()
    train_gt_path = sorted(glob.glob(conf.DIV2K.gt_train + "/*.png"))
    train_lq_path = sorted(glob.glob(conf.DIV2K.lq_train + "/*.png"))
    val_gt_path = sorted(glob.glob(conf.SET14.gt_val + "/*.png"))
    val_lq_path = sorted(glob.glob(conf.SET14.lq_val + "/*.png"))
    train_samples = len(train_gt_path)
    val_samples = len(val_gt_path)
    lr_g = conf.hyperparameters.lr_g
    lr_d = conf.hyperparameters.lr_d
    lr_steps = conf.train.lr_steps

    random.seed(conf.train.seed)
    np.random.seed(conf.train.seed)

    extension_module = conf.nnabla_context.context
    ctx = get_extension_context(
        extension_module, device_id=conf.nnabla_context.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)

    # data iterators for train and val data
    from data_loader import data_iterator_sr
    data_iterator_train = data_iterator_sr(
        train_samples, conf.train.batch_size, train_gt_path, train_lq_path, train=True, shuffle=True)
    data_iterator_val = data_iterator_sr(
        val_samples, conf.val.batch_size, val_gt_path, val_lq_path, train=False, shuffle=False)

    if comm.n_procs > 1:
        data_iterator_train = data_iterator_train.slice(
            rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank)

    train_gt = nn.Variable(
        (conf.train.batch_size, 3, conf.train.gt_size, conf.train.gt_size))
    train_lq = nn.Variable(
        (conf.train.batch_size, 3, conf.train.gt_size // conf.train.scale, conf.train.gt_size // conf.train.scale))

    # setting up monitors for logging
    monitor_path = './nnmonitor' + str(datetime.now().strftime("%Y%m%d%H%M%S"))
    monitor = Monitor(monitor_path)
    monitor_pixel_g = MonitorSeries(
        'l_g_pix per iteration', monitor, interval=100)
    monitor_val = MonitorSeries(
        'Validation loss per epoch', monitor, interval=1)
    monitor_time = MonitorTimeElapsed(
        "Training time per epoch", monitor, interval=1)

    with nn.parameter_scope("gen"):
        nn.load_parameters(conf.train.gen_pretrained)
        fake_h = rrdb_net(train_lq, 64, 23)
        fake_h.persistent = True
    pixel_loss = F.mean(F.absolute_error(fake_h, train_gt))
    pixel_loss.persistent = True
    gen_loss = pixel_loss

    if conf.model.esrgan:
        from esrgan_model import get_esrgan_gen, get_esrgan_dis, get_esrgan_monitors
        gen_model = get_esrgan_gen(conf, train_gt, train_lq, fake_h)
        gen_loss = conf.hyperparameters.eta_pixel_loss * pixel_loss + conf.hyperparameters.feature_loss_weight * gen_model.feature_loss + \
            conf.hyperparameters.lambda_gan_loss * gen_model.loss_gan_gen
        dis_model = get_esrgan_dis(fake_h, gen_model.pred_d_real)
        # Set Discriminator parameters
        solver_dis = S.Adam(lr_d, beta1=0.9, beta2=0.99)
        with nn.parameter_scope("dis"):
            solver_dis.set_parameters(nn.get_parameters())
        esr_mon = get_esrgan_monitors()

    # Set generator Parameters
    solver_gen = S.Adam(alpha=lr_g, beta1=0.9, beta2=0.99)
    with nn.parameter_scope("gen"):
        solver_gen.set_parameters(nn.get_parameters())

    train_size = int(
        train_samples / conf.train.batch_size / comm.n_procs)
    total_epochs = conf.train.n_epochs
    start_epoch = 0
    current_iter = 0
    if comm.rank == 0:
        print("total_epochs", total_epochs)
        print("train_samples", train_samples)
        print("val_samples", val_samples)
        print("train_size", train_size)

    for epoch in range(start_epoch + 1, total_epochs + 1):
        index = 0
        # Training loop for psnr rrdb model
        while index < train_size:
            current_iter += comm.n_procs
            train_gt.d, train_lq.d = data_iterator_train.next()

            if not conf.model.esrgan:
                lr_g = get_repeated_cosine_annealing_learning_rate(
                    current_iter, conf.hyperparameters.eta_max, conf.hyperparameters.eta_min, conf.train.cosine_period,
                    conf.train.cosine_num_period)

            if conf.model.esrgan:
                lr_g = get_multistep_learning_rate(
                    current_iter, lr_steps, lr_g)
                gen_model.var_ref.d = train_gt.d
                gen_model.pred_d_real.grad.zero()
                gen_model.pred_d_real.forward(clear_no_need_grad=True)
                gen_model.pred_d_real.need_grad = False

            # Generator update
            gen_loss.forward(clear_no_need_grad=True)
            solver_gen.zero_grad()
            # All-reduce gradients every 2MiB parameters during backward computation
            if comm.n_procs > 1:
                with nn.parameter_scope('gen'):
                    all_reduce_callback = comm.get_all_reduce_callback()
                    gen_loss.backward(clear_buffer=True,
                                      communicator_callbacks=all_reduce_callback)
            else:
                gen_loss.backward(clear_buffer=True)
            solver_gen.set_learning_rate(lr_g)
            solver_gen.update()

            # Discriminator Upate
            if conf.model.esrgan:
                gen_model.pred_d_real.need_grad = True
                lr_d = get_multistep_learning_rate(
                    current_iter, lr_steps, lr_d)
                solver_dis.zero_grad()
                dis_model.l_d_total.forward(clear_no_need_grad=True)
                if comm.n_procs > 1:
                    with nn.parameter_scope('dis'):
                        all_reduce_callback = comm.get_all_reduce_callback()
                    dis_model.l_d_total.backward(
                        clear_buffer=True, communicator_callbacks=all_reduce_callback)
                else:
                    dis_model.l_d_total.backward(clear_buffer=True)
                solver_dis.set_learning_rate(lr_d)
                solver_dis.update()

            index += 1
            if comm.rank == 0:
                monitor_pixel_g.add(
                    current_iter, pixel_loss.d.copy())
                monitor_time.add(epoch * comm.n_procs)
            if comm.rank == 0 and conf.model.esrgan:
                esr_mon.monitor_feature_g.add(
                    current_iter, gen_model.feature_loss.d.copy())
                esr_mon.monitor_gan_g.add(
                    current_iter, gen_model.loss_gan_gen.d.copy())
                esr_mon.monitor_gan_d.add(
                    current_iter, dis_model.l_d_total.d.copy())
                esr_mon.monitor_d_real.add(current_iter, F.mean(
                    gen_model.pred_d_real.data).data)
                esr_mon.monitor_d_fake.add(current_iter, F.mean(
                    gen_model.pred_g_fake.data).data)

        # Validation Loop
        if comm.rank == 0:
            avg_psnr = 0.0
            for idx in range(val_samples):
                val_gt_im, val_lq_im = data_iterator_val.next()
                val_gt = nn.NdArray.from_numpy_array(val_gt_im)
                val_lq = nn.NdArray.from_numpy_array(val_lq_im)
                with nn.parameter_scope("gen"):
                    avg_psnr = val_save(
                        val_gt, val_lq, val_lq_path, idx, epoch, avg_psnr)
            avg_psnr = avg_psnr / val_samples
            monitor_val.add(epoch, avg_psnr)

        # Save generator weights
        if comm.rank == 0:
            if not os.path.exists(conf.train.savemodel):
                os.makedirs(conf.train.savemodel)
            with nn.parameter_scope("gen"):
                nn.save_parameters(os.path.join(
                    conf.train.savemodel, "generator_param_%06d.h5" % epoch))
       # Save discriminator weights
        if comm.rank == 0 and conf.model.esrgan:
            with nn.parameter_scope("dis"):
                nn.save_parameters(os.path.join(
                    conf.train.savemodel, "discriminator_param_%06d.h5" % epoch))
Ejemplo n.º 21
0
def Loss_reconstruction(wave_fake, wave_true, beta_in, beta_clean):
    E_wave = F.mean(F.absolute_error(wave_fake, wave_true))  # 再構成性能の向上
    B_wave = F.mean(F.absolute_error(beta_in, beta_clean))
    return E_wave + 0.01 * B_wave
Ejemplo n.º 22
0
def train(args):

    ##  Sub-functions
    ## ---------------------------------
    ## Save Models
    def save_models(epoch_num, cle_disout, fake_disout, losses_gen, losses_dis, losses_ae):

        # save generator parameter
        with nn.parameter_scope("gen"):
            nn.save_parameters(os.path.join(args.model_save_path, 'generator_param_{:04}.h5'.format(epoch_num + 1)))

        # save discriminator parameter
        with nn.parameter_scope("dis"):
            nn.save_parameters(os.path.join(args.model_save_path, 'discriminator_param_{:04}.h5'.format(epoch_num + 1)))

        # save results
        np.save(os.path.join(args.model_save_path, 'disout_his_{:04}.npy'.format(epoch_num + 1)), np.array([cle_disout, fake_disout]))
        np.save(os.path.join(args.model_save_path, 'losses_gen_{:04}.npy'.format(epoch_num + 1)), np.array(losses_gen))
        np.save(os.path.join(args.model_save_path, 'losses_dis_{:04}.npy'.format(epoch_num + 1)), np.array(losses_dis))
        np.save(os.path.join(args.model_save_path, 'losses_ae_{:04}.npy'.format(epoch_num + 1)), np.array(losses_ae))

    ## Load Models
    def load_models(epoch_num, gen=True, dis=True):

        # load generator parameter
        with nn.parameter_scope("gen"):
            nn.load_parameters(os.path.join(args.model_save_path, 'generator_param_{:04}.h5'.format(args.epoch_from)))

        # load discriminator parameter
        with nn.parameter_scope("dis"):
            nn.load_parameters(os.path.join(args.model_save_path, 'discriminator_param_{:04}.h5'.format(args.epoch_from)))

    ## Update parameters
    class updating:

        def __init__(self):
            self.scale = 8 if args.halfprec else 1

        def __call__(self, solver, loss):
            solver.zero_grad()                                  # initialize
            loss.forward(clear_no_need_grad=True)               # calculate forward
            loss.backward(self.scale, clear_buffer=True)      # calculate backward
            solver.scale_grad(1. / self.scale)                # scaling
            solver.weight_decay(args.weight_decay * self.scale) # decay
            solver.update()                                     # update


    ##  Inital Settings
    ## ---------------------------------

    ##  Create network
    #   Clear
    nn.clear_parameters()
    #   Variables
    noisy 		= nn.Variable([args.batch_size, 1, 16384], need_grad=False)  # Input
    clean 		= nn.Variable([args.batch_size, 1, 16384], need_grad=False)  # Desire
    z           = nn.Variable([args.batch_size, 1024, 8], need_grad=False)   # Random Latent Variable
    #   Generator
    genout = Generator(noisy, z)                       # Predicted Clean
    genout.persistent = True                # Not to clear at backward
    loss_gen 	= Loss_gen(genout, clean, Discriminator(noisy, genout))
    loss_ae     = F.mean(F.absolute_error(genout, clean))
    #   Discriminator
    fake_dis 	= genout.get_unlinked_variable(need_grad=True)
    cle_disout  = Discriminator(noisy, clean)
    fake_disout  = Discriminator(noisy, fake_dis)
    loss_dis    = Loss_dis(Discriminator(noisy, clean),Discriminator(noisy, fake_dis))

    ##  Solver
    # RMSprop.
    # solver_gen = S.RMSprop(args.learning_rate_gen)
    # solver_dis = S.RMSprop(args.learning_rate_dis)
    # Adam
    solver_gen = S.Adam(args.learning_rate_gen)
    solver_dis = S.Adam(args.learning_rate_dis)
    # set parameter
    with nn.parameter_scope("gen"):
        solver_gen.set_parameters(nn.get_parameters())
    with nn.parameter_scope("dis"):
        solver_dis.set_parameters(nn.get_parameters())

    ##  Load data & Create batch
    clean_data, noisy_data = dt.data_loader()
    batches     = dt.create_batch(clean_data, noisy_data, args.batch_size)
    del clean_data, noisy_data

    ##  Initial settings for sub-functions
    fig     = figout()
    disp    = display(args.epoch_from, args.epoch, batches.batch_num)
    upd     = updating()

    ##  Train
    ##----------------------------------------------------

    print('== Start Training ==')

    ##  Load "Pre-trained" parameters
    if args.epoch_from > 0:
        print(' Retrain parameter from pre-trained network')
        load_models(args.epoch_from, dis=False)
        losses_gen  = np.load(os.path.join(args.model_save_path, 'losses_gen_{:04}.npy'.format(args.epoch_from)))
        losses_dis  = np.load(os.path.join(args.model_save_path, 'losses_dis_{:04}.npy'.format(args.epoch_from)))
        losses_ae   = np.load(os.path.join(args.model_save_path, 'losses_ae_{:04}.npy'.format(args.epoch_from)))
    else:
        losses_gen  = []
        losses_ae   = []
        losses_dis  = []

    ## Create loss loggers
    point       = len(losses_gen)
    loss_len    = (args.epoch - args.epoch_from) * ((batches.batch_num+1)//10)
    losses_gen  = np.append(losses_gen, np.zeros(loss_len))
    losses_ae   = np.append(losses_ae, np.zeros(loss_len))
    losses_dis  = np.append(losses_dis, np.zeros(loss_len))

    ##  Training
    for i in range(args.epoch_from, args.epoch):

        print('')
        print(' =========================================================')
        print('  Epoch :: {0}/{1}'.format(i + 1, args.epoch))
        print(' =========================================================')
        print('')

        #  Batch iteration
        for j in range(batches.batch_num):
            print('  Train (Epoch. {0}) - {1}/{2}'.format(i+1, j+1, batches.batch_num))

            ##  Batch setting
            clean.d, noisy.d = batches.next(j)
            #z.d = np.random.randn(*z.shape)
            z.d = np.zeros(z.shape)

            ##  Updating
            upd(solver_gen, loss_gen)       # update Generator
            upd(solver_dis, loss_dis)       # update Discriminator

            ##  Display
            if (j+1) % 10 == 0:
                # Get result for Display
                cle_disout.forward()
                fake_disout.forward()
                loss_ae.forward(clear_no_need_grad=True)

                # Display text
                disp(i, j, loss_gen.d, loss_dis.d, loss_ae.d)

                # Data logger
                losses_gen[point] = loss_gen.d
                losses_ae[point]  = loss_ae.d
                losses_dis[point] = loss_dis.d
                point = point + 1

                # Plot
                fig.waveform(noisy.d[0,0,:], genout.d[0,0,:], clean.d[0,0,:])
                fig.loss(losses_gen[0:point-1], losses_ae[0:point-1], losses_dis[0:point-1])
                fig.histogram(cle_disout.d, fake_disout.d)
                pg.QtGui.QApplication.processEvents()


        ## Save parameters
        if ((i+1) % args.model_save_cycle) == 0:
            save_models(i, cle_disout.d, fake_disout.d, losses_gen[0:point-1], losses_dis[0:point-1], losses_ae[0:point-1])  # save model
            exporter = pg.exporters.ImageExporter(fig.win.scene())  # Call pg.QtGui.QApplication.processEvents() before exporters!!
            exporter.export(os.path.join(args.model_save_path, 'plot_{:04}.png'.format(i + 1))) # save fig

    ## Save parameters (Last)
    save_models(args.epoch-1, cle_disout.d, fake_disout.d, losses_gen, losses_dis, losses_ae)
def Loss_reconstruction(beta_in,beta_clean):
   B_wave = F.mean( F.absolute_error(beta_in, beta_clean) )
   return 0.001*B_wave                                          #係数が分布推定ネットワークの重みになる
Ejemplo n.º 24
0
def recon_loss(x, y):
    return F.mean(F.absolute_error(x, y))
Ejemplo n.º 25
0
def feature_loss(fea_real, fea_fake):
    loss = list()
    for o1, o2 in zip(fea_real, fea_fake):
        for f1, f2 in zip(o1[:-1], o2[:-1]):
            loss.append(F.mean(F.absolute_error(f1, f2)))
    return sum(loss)