def feature_matching_loss(dis_maps_real, dis_maps_fake, model_params, weights): variable_not_exist = True for j, scale in enumerate(model_params['discriminator_params']['scales']): key = f"feature_maps_{scale}".replace('.', '-') for i, (a, b) in enumerate(zip(dis_maps_real[key], dis_maps_fake[key])): if weights[i] == 0: continue if variable_not_exist: loss = F.mean(F.absolute_error(a, b)) * weights[i] variable_not_exist = False else: loss += F.mean(F.absolute_error(a, b)) * weights[i] return loss
def get_esrgan_gen(conf, train_gt, train_lq, fake_h): """ Create computation graph and variables for ESRGAN Generator. """ var_ref = nn.Variable( (conf.train.batch_size, 3, conf.train.gt_size, conf.train.gt_size)) # Feature Loss (L1 Loss) load_vgg19 = PretrainedVgg19() real_fea = load_vgg19(train_gt) # need_grad set to False, to avoid BP to vgg19 network real_fea.need_grad = False fake_fea = load_vgg19(fake_h) feature_loss = F.mean(F.absolute_error(fake_fea, real_fea)) feature_loss.persistent = True # Gan Loss Generator with nn.parameter_scope("dis"): pred_g_fake = discriminator(fake_h) pred_d_real = discriminator(var_ref) pred_d_real.persistent = True pred_g_fake.persistent = True unlinked_pred_d_real = pred_d_real.get_unlinked_variable() gan_loss = RelativisticAverageGanLoss(GanLoss()) gan_loss_gen_out = gan_loss(unlinked_pred_d_real, pred_g_fake) loss_gan_gen = gan_loss_gen_out.generator_loss loss_gan_gen.persistent = True Model_gen = namedtuple('Model_gen', [ 'train_gt', 'train_lq', 'var_ref', 'feature_loss', 'loss_gan_gen', 'pred_d_real', 'pred_g_fake' ]) return Model_gen(train_gt, train_lq, var_ref, feature_loss, loss_gan_gen, pred_d_real, pred_g_fake)
def Loss_gen(wave_fake, wave_true, dval_fake, lmd=100): def SquaredError_Scalor(x, val=1): return F.squared_error(x, F.constant(val, x.shape)) E_fake = F.mean( SquaredError_Scalor(dval_fake, val=1) ) # fake E_wave = F.mean( F.absolute_error(wave_fake, wave_true) ) # Reconstruction Performance return E_fake / 2 + lmd * E_wave
def perceptual_loss(self, x, target): r"""Returns perceptual loss.""" loss = [] out_x, out_t = self(x, None), self(target, None) for (a, t) in zip(out_x, out_t): for la, lt in zip(a[:-1], t[:-1]): lt.need_grad = False # avoid grads flowing though targets loss.append(F.mean(F.absolute_error(la, lt))) return sum(loss) / self.hp.num_D
def test(args): """ Training """ ## ~~~~~~~~~~~~~~~~~~~ ## Initial settings ## ~~~~~~~~~~~~~~~~~~~ # Input Variable nn.clear_parameters() # Clear Input = nn.Variable([1, 3, 64, 64]) # Input Trues = nn.Variable([1, 1]) # True Value # Network Definition Name = "CNN" # Name of scope which includes network models (arbitrary) Output_test = network(Input, scope=Name, test=True) # Network & Output Loss_test = F.mean(F.absolute_error( Output_test, Trues)) # Loss Function (Squared Error) # Load data with nn.parameter_scope(Name): nn.load_parameters( os.path.join(args.model_save_path, "network_param_{:04}.h5".format(args.epoch))) # Training Data Setting image_data, mos_data = dt.data_loader(test=True) batches = dt.create_batch(image_data, mos_data, 1) del image_data, mos_data truth = [] result = [] for j in range(batches.iter_n): Input.d, tures = next(batches) Loss_test.forward(clear_no_need_grad=True) result.append(Loss_test.d) truth.append(tures) result = np.array(result) truth = np.squeeze(np.array(truth)) # Evaluation of performance mae = np.average(np.abs(result - truth)) SRCC, p1 = stats.spearmanr(truth, result) # Spearman's Correlation Coefficient PLCC, p2 = stats.pearsonr(truth, result) # Display print("\n Model Parameter [epoch={0}]".format(args.epoch)) print(" Mean Absolute Error with Truth: {0:.4f}".format(mae)) print(" Speerman's Correlation Coefficient: {0:.3f}".format(SRCC)) print(" Pearson's Linear Correlation Coefficient: {0:.3f}".format(PLCC))
def mae(x, y, mask=None, eps=1e-5): # l1 distance and reduce mean ae = F.absolute_error(x, y) if mask is not None: assert ae.shape[:2] == mask.shape[:2] ae *= F.reshape(mask, ae.shape) return F.sum(ae) / (F.sum(mask) + eps) return F.mean(ae)
def update_graph(self, key='train'): r"""Builds the graph and update the placeholder. Args: key (str, optional): Type of computational graph. Defaults to 'train'. """ assert key in ('train', 'valid') hp = self.hp self.gen.training = key == 'train' self.dis.training = key == 'train' # define input variables x_real = nn.Variable((hp.batch_size, 1, hp.segment_length)) x_real_mel = compute_mel(x_real, self.mel_basis, hp) x_fake = self.gen(x_real_mel) x_fake_mel = compute_mel(x_fake, self.mel_basis, hp) dis_real_x = self.dis(x_real) dis_fake_x = self.dis(x_fake) # ------------------------------ Discriminator ----------------------- d_loss = (discriminator_loss(dis_real_x, 1.0) + discriminator_loss(dis_fake_x, 0.0)) # -------------------------------- Generator ------------------------- g_loss_avd = discriminator_loss(dis_fake_x, 1.0) g_loss_mel = F.mean(F.absolute_error(x_real_mel, x_fake_mel)) g_loss_fea = feature_loss(dis_real_x, dis_fake_x) g_loss = g_loss_avd + 45 * g_loss_mel + 2 * g_loss_fea set_persistent_all( g_loss_mel, g_loss_avd, g_loss_fea, d_loss, x_fake, g_loss, ) self.placeholder[key] = dict( x_real=x_real, x_fake=x_fake, d_loss=d_loss, g_loss_avd=g_loss_avd, g_loss_mel=g_loss_mel, g_loss_fea=g_loss_fea, g_loss=g_loss, )
def vgg16_perceptual_loss(fake, real): '''VGG perceptual loss based on VGG-16 network. Assuming the values in fake and real are in [0, 255]. Features are obtained from all ReLU activations of the first convolution after each downsampling (maxpooling) layer (including the first convolution applied to an image). ''' from nnabla.models.imagenet import VGG16 class VisitFeatures(object): def __init__(self): self.features = [] self.relu_counter = 0 self.features_at = set([0, 2, 4, 7, 10]) def __call__(self, f): # print(f.name, end='') if not f.name.startswith('ReLU'): # print('') return if self.relu_counter in self.features_at: self.features.append(f.outputs[0]) # print('*', end='') # print('') self.relu_counter += 1 # We use VGG16 model instead of VGG19 because VGG19 # is not in nnabla.models. vgg = VGG16() def get_features(x): o = vgg(x, use_up_to='lastconv') f = VisitFeatures() o.visit(f) return f with nn.parameter_scope("vgg16_loss"): fake_features = get_features(fake) real_features = get_features(real) volumes = np.array([np.prod(f.shape) for f in fake_features.features], dtype=np.float32) weights = volumes[-1] / volumes return sum([w * F.mean(F.absolute_error(ff, fr)) for w, ff, fr in zip(weights, fake_features.features, real_features.features)])
def perceptual_loss(pyramide_real, pyramide_fake, scales, weights, vgg_param_path): """ Compute Perceptual Loss using VGG19 as a feature extractor. """ vgg19 = PretrainedVgg19(param_path=vgg_param_path) variable_not_exist = True for scale in scales: x_vgg = vgg19(pyramide_fake[f'prediction_{scale}']) y_vgg = vgg19(pyramide_real[f'prediction_{scale}']) for i, weight in enumerate(weights): value = F.mean(F.absolute_error(x_vgg[i], y_vgg[i])) if variable_not_exist: loss = weight * value variable_not_exist = False else: loss += weight * value return loss
def equivariance_jacobian_loss(kp_driving_jacobian, arithmetic_jacobian, trans_kp_jacobian, weight): jacobian_transformed = F.batch_matmul(arithmetic_jacobian, trans_kp_jacobian) normed_driving = F.reshape( F.batch_inv( F.reshape(kp_driving_jacobian, (-1, ) + kp_driving_jacobian.shape[-2:])), kp_driving_jacobian.shape) normed_transformed = jacobian_transformed value = F.batch_matmul(normed_driving, normed_transformed) eye = nn.Variable.from_numpy_array(np.reshape(np.eye(2), (1, 1, 2, 2))) jacobian_loss = F.mean(F.absolute_error(eye, value)) loss = weight * jacobian_loss return loss
def forward(self, output, inds, gt, reg_mask, channel_last=False): # TODO refactor loss implementation for channel_last without transposing if channel_last: output = F.transpose(output, (0, 3, 1, 2)) b = inds.shape[0] c = output.shape[1] max_objs = inds.shape[1] # divide by number of : num_objs = F.sum(reg_mask) * 2 f_map_size = output.shape[2] * output.shape[3] output = F.reshape(output, (-1, f_map_size)) inds = F.broadcast(inds.reshape((b, 1, max_objs)), (b, c, max_objs)) inds = inds.reshape((-1, max_objs)) y = output[F.broadcast(F.reshape(F.arange(0, b * c), (b * c, 1)), (b * c, max_objs)), inds].reshape( (b, c, max_objs)) y = F.transpose(y, (0, 2, 1)) loss = F.sum(reg_mask * F.absolute_error(y, gt)) loss = loss / (num_objs + 1e-4) return loss
def context_preserving_loss(xa, yb): def mask_weight(a, b): # much different from definition in the paper merged_mask = F.concatenate(a, b, axis=1) summed_mask = F.sum((merged_mask + 1) / 2, axis=1, keepdims=True) clipped = F.clip_by_value(summed_mask, F.constant(0, shape=summed_mask.shape), F.constant(1, shape=summed_mask.shape)) z = clipped * 2 - 1 mask = (1 - z) / 2 return mask x = xa[:, :3, :, :] a = xa[:, 3:, :, :] y = yb[:, :3, :, :] b = yb[:, 3:, :, :] assert x.shape == y.shape and a.shape == b.shape W = mask_weight(a, b) return F.mean(F.mul2(F.absolute_error(x, y), W))
def define_loss(real_out, real_feats, fake_out, fake_feats, use_fm=True, fm_lambda=10., gan_loss_type="ls"): g_gan = 0 g_feat = 0 if use_fm else F.constant(0) d_real = 0 d_fake = 0 gan_loss = get_gan_loss(gan_loss_type) n_disc = len(real_out) for disc_id in real_out.keys(): r_out = real_out[disc_id] r_feats = real_feats[disc_id] f_out = fake_out[disc_id] f_feats = fake_feats[disc_id] # define GAN loss _d_real, _d_fake, _g_gan = gan_loss(r_out, f_out) d_real += _d_real d_fake += _d_fake g_gan += _g_gan # feature matching if use_fm: assert r_out.shape == f_out.shape for layer_id, r_feat in r_feats.items(): g_feat += F.mean(F.absolute_error( r_feat, f_feats[layer_id])) * fm_lambda / n_disc return g_gan, g_feat, d_real, d_fake
def criteria(x, t): return F.mean(F.absolute_error(x, t))
def train(args): """ Training """ ## ~~~~~~~~~~~~~~~~~~~ ## Initial settings ## ~~~~~~~~~~~~~~~~~~~ # Input Variable args. -> setting. M = 64 nn.clear_parameters() # Clear Input = nn.Variable([args.batch_size, 6, 128, 128]) Trues = nn.Variable([args.batch_size, 1]) # True Value # Network Definition Name = "CNN" # Name of scope which includes network models (arbitrary) Name2 = "CNN" preOutput = network(input=Input, feature_num=M, scope=Name) # Network & Output #add # preOutput = F.reshape(preOutput, (args.batch_size, 1, M)) # (B*N, M) > (B, N, M) # preOutput = F.mean(preOutput, axis=1, keepdims=True) # (B, N, M) > (B, 1, M) N個のシフト画像の特徴量を1つにする keepdims->次元を保持 Output = network2(input=preOutput, scope=Name2) # fullconnect # Loss Definition Loss = F.mean(F.absolute_error( Output, Trues)) # Loss Function (Squared Error) 誤差関数(差の絶対値の平均) -> 交差エントロピーはだめ? # Solver Setting solver = S.Adam(args.learning_rate) # Adam is used for solver 学習率の最適化 solver2 = S.Adam(args.learning_rate) # Adam is used for solver 学習率の最適化 solver.weight_decay(0.00001) # Weight Decay for stable update solver2.weight_decay(0.00001) with nn.parameter_scope(Name): # Get updating parameters included in scope solver.set_parameters(nn.get_parameters()) with nn.parameter_scope( Name2): # Get updating parameters included in scope solver2.set_parameters(nn.get_parameters()) # Training Data Setting #image_data, mos_data, image_files = dt.data_loader(test = False) image_data, mos_data, similarity = dt.data_loader(test=False) #batches = dt.create_batch(image_data, mos_data, args.batch_size, image_files) batches = dt.create_batch(image_data, mos_data, args.batch_size) del image_data, mos_data ## ~~~~~~~~~~~~~~~~~~~ ## Learning ## ~~~~~~~~~~~~~~~~~~~ print('== Start Training ==') bar = tqdm(total=(args.epoch - args.retrain) * batches.iter_n, leave=False) bar.clear() cnt = 0 loss_disp = True # Load data if args.retrain > 0: # 途中のエポック(retrain)から再学習 with nn.parameter_scope(Name): print('Retrain from {0} Epoch'.format(args.retrain)) nn.load_parameters( os.path.join(args.model_save_path, "network_param_{:04}.h5".format(args.retrain))) solver.set_learning_rate(args.learning_rate / np.sqrt(args.retrain)) ## Training for i in range(args.retrain, args.epoch): # args.retrain → args.epoch まで繰り返し学習 bar.set_description_str('Epoch {0}/{1}:'.format(i + 1, args.epoch), refresh=False) # プログレスバーに説明文を加える # Shuffling batches.shuffle() ## Batch iteration for j in range(batches.iter_n): # バッチ学習 cnt += 1 # Load Batch Data from Training data Input_npy, Trues_npy = next(batches) size_ = Input_npy.shape Input.d = Input_npy.reshape( [size_[0] * size_[1], size_[2], size_[3], size_[4]]) Trues.d = Trues_npy # Update solver.zero_grad() # Initialize # Initialize #勾配をリセット #solver2.zero_grad() Loss.forward(clear_no_need_grad=True) # Forward path #順伝播 loss_scale = 8 Loss.backward(loss_scale, clear_buffer=True) # Backward path #誤差逆伝播法 #solver2.update() solver.scale_grad(1. / loss_scale) solver.update() # Progress if cnt % 10 == 0: bar.update(10) # プログレスバーの進捗率を1あげる if loss_disp is not None: bar.set_postfix_str('Loss={0:.3e}'.format(Loss.d), refresh=False) # 実行中にloss_dispとSRCCを表示 ## Save parameters if ((i + 1) % args.model_save_cycle) == 0 or (i + 1) == args.epoch: bar.clear() with nn.parameter_scope(Name): nn.save_parameters( os.path.join(args.model_save_path, 'network_param_{:04}.h5'.format(i + 1))) with nn.parameter_scope(Name2): nn.save_parameters( os.path.join(args.model_save_path2, 'network_param_{:04}.h5'.format(i + 1)))
def test(args): """ Training """ M = 64 ## ~~~~~~~~~~~~~~~~~~~ ## Initial settings ## ~~~~~~~~~~~~~~~~~~~ # Input Variable 変数定義 nn.clear_parameters() # Clear Input = nn.Variable([1, 6, 256, 256]) # Input Trues = nn.Variable([1, 1]) # True Value # Network Definition Name = "CNN" # Name of scope which includes network models (arbitrary) Name2 = "CNN" preOutput = network(input=Input, feature_num=M, scope=Name) # Network & Output #add preOutput = F.reshape(preOutput, (1, N, M)) # (B*N, M) > (B, N, M) preOutput_mean = F.mean( preOutput, axis=1, keepdims=True ) # (B, N, M) > (B, 1, M) N個のシフト画像の特徴量を1つにする keepdims->次元を保持 Output_test = network2(input=preOutput_mean, scope=Name2) # fullconnect Loss_test = F.mean(F.absolute_error( Output_test, Trues)) # Loss Function (Squared Error) #誤差関数 # Load data 保存した学習パラメータの読み込み with nn.parameter_scope(Name): nn.load_parameters( os.path.join(args.model_save_path, "network_param_{:04}.h5".format(args.epoch))) with nn.parameter_scope(Name2): nn.load_parameters( os.path.join(args.model_save_path2, "network_param_{:04}.h5".format(args.epoch))) # Test Data Setting #image_data, mos_data, image_files = dt.data_loader(test=True) image_data, mos_data = dt.data_loader(test=True) #batches = dt.create_batch(image_data, mos_data, 1, image_files) batches = dt.create_batch(image_data, mos_data, 1) del image_data, mos_data truth = [] result = [] for j in range(batches.iter_n): #Input_npy, Trues_npy, image_files = next(batches) Input_npy, Trues_npy = next(batches) size_ = Input_npy.shape # print("Input Image:" + str(image_files) + " Trues:" + str(Trues_npy)) Input.d = Input_npy.reshape( [size_[0] * size_[1], size_[2], size_[3], size_[4]]) Trues.d = Trues_npy[0][0] Loss_test.forward(clear_no_need_grad=True) result.append(Loss_test.d) truth.append(Trues.d) result = np.array(result) mean = np.mean(result) truth = np.squeeze(np.array(truth)) # delete # Evaluation of performance mae = np.average(np.abs(result - truth)) SRCC, p1 = stats.spearmanr(truth, result) # Spearman's Correlation Coefficient PLCC, p2 = stats.pearsonr(truth, result) np.set_printoptions(threshold=np.inf) print("result: {}".format(result)) print("Trues: {}".format(truth)) print(np.average(result)) print("\n Model Parameter [epoch={0}]".format(args.epoch)) print(" Mean Absolute Error with Truth: {0:.4f}".format(mae)) print(" Speerman's Correlation Coefficient: {0:.5f}".format(SRCC)) print(" Pearson's Linear Correlation Coefficient: {0:.5f}".format(PLCC))
def train(args): """ Training """ ## ~~~~~~~~~~~~~~~~~~~ ## Initial settings ## ~~~~~~~~~~~~~~~~~~~ # Input Variable nn.clear_parameters() # Clear Input = nn.Variable([args.batch_size, 3, 64, 64]) # Input Trues = nn.Variable([args.batch_size, 1]) # True Value # Network Definition Name = "CNN" # Name of scope which includes network models (arbitrary) Output = network(Input, scope=Name) # Network & Output Output_test = network(Input, scope=Name, test=True) # Loss Definition Loss = F.mean(F.absolute_error(Output, Trues)) # Loss Function (Squared Error) Loss_test = F.mean(F.absolute_error(Output_test, Trues)) # Solver Setting solver = S.AMSBound(args.learning_rate) # Adam is used for solver with nn.parameter_scope( Name): # Get updating parameters included in scope solver.set_parameters(nn.get_parameters()) # Training Data Setting image_data, mos_data = dt.data_loader() batches = dt.create_batch(image_data, mos_data, args.batch_size) del image_data, mos_data # Test Data Setting image_data, mos_data = dt.data_loader(test=True) batches_test = dt.create_batch(image_data, mos_data, args.batch_size) del image_data, mos_data ## ~~~~~~~~~~~~~~~~~~~ ## Learning ## ~~~~~~~~~~~~~~~~~~~ print('== Start Training ==') bar = tqdm(total=args.epoch - args.retrain, leave=False) bar.clear() loss_disp = None SRCC = None # Load data if args.retrain > 0: with nn.parameter_scope(Name): print('Retrain from {0} Epoch'.format(args.retrain)) nn.load_parameters( os.path.join(args.model_save_path, "network_param_{:04}.h5".format(args.retrain))) solver.set_learning_rate(args.learning_rate / np.sqrt(args.retrain)) ## Training for i in range(args.retrain, args.epoch): bar.set_description_str('Epoch {0}:'.format(i + 1), refresh=False) if (loss_disp is not None) and (SRCC is not None): bar.set_postfix_str('Loss={0:.5f}, SRCC={1:.4f}'.format( loss_disp, SRCC), refresh=False) bar.update(1) # Shuffling batches.shuffle() batches_test.shuffle() ## Batch iteration for j in range(batches.iter_n): # Load Batch Data from Training data Input.d, Trues.d = next(batches) # Update solver.zero_grad() # Initialize Loss.forward(clear_no_need_grad=True) # Forward path Loss.backward(clear_buffer=True) # Backward path solver.weight_decay(0.00001) # Weight Decay for stable update solver.update() ## Progress # Get result for Display Input.d, Trues.d = next(batches_test) Loss_test.forward(clear_no_need_grad=True) Output_test.forward() loss_disp = Loss_test.d SRCC, _ = stats.spearmanr(Output_test.d, Trues.d) # Display text # disp(i, batches.iter_n, Loss_test.d) ## Save parameters if ((i + 1) % args.model_save_cycle) == 0 or (i + 1) == args.epoch: bar.clear() with nn.parameter_scope(Name): nn.save_parameters( os.path.join(args.model_save_path, 'network_param_{:04}.h5'.format(i + 1)))
def idr_loss(camloc, raydir, alpha, color_gt, mask_obj, conf): # Setting B, R, _ = raydir.shape L = conf.layers D = conf.depth feature_size = conf.feature_size # Ray trace (visibility) x_hit, mask_hit, dists, mask_pin, mask_pout = \ ray_trace(partial(sdf_net, conf=conf), camloc, raydir, mask_obj, t_near=conf.t_near, t_far=conf.t_far, sphere_trace_itr=conf.sphere_trace_itr, ray_march_points=conf.ray_march_points, n_chunks=conf.n_chunks, max_post_itr=conf.max_post_itr, post_method=conf.post_method, eps=conf.eps) x_hit = x_hit.apply(need_grad=False) mask_hit = mask_hit.apply(need_grad=False, persistent=True) dists = dists.apply(need_grad=False) mask_pin = mask_pin.apply(need_grad=False) mask_pout = mask_pout.apply(need_grad=False) mask_us = mask_pin + mask_pout P = F.sum(mask_us) # Current points x_curr = (camloc.reshape((B, 1, 3)) + dists * raydir).apply(need_grad=True) # Eikonal loss bounding_box_size = conf.bounding_box_size x_free = F.rand(-bounding_box_size, bounding_box_size, shape=(B, R // 2, 3)) x_point = F.concatenate(*[x_curr, x_free], axis=1) sdf_xp, _, grad_xp = sdf_feature_grad(implicit_network, x_point, conf) gp = (F.norm(grad_xp, axis=[grad_xp.ndim - 1], keepdims=True) - 1.0)**2.0 loss_eikonal = F.sum(gp[:, :R, :] * mask_us) + F.sum(gp[:, R:, :]) loss_eikonal = loss_eikonal / (P + B * R // 2) loss_eikonal = loss_eikonal.apply(persistent=True) sdf_curr = sdf_xp[:, :R, :] grad_curr = grad_xp[:, :R, :] # Mask loss logit = -alpha.reshape([1 for _ in range(sdf_curr.ndim)]) * sdf_curr loss_mask = F.sigmoid_cross_entropy(logit, mask_obj) loss_mask = loss_mask * mask_pout loss_mask = F.sum(loss_mask) / P / alpha loss_mask = loss_mask.apply(persistent=True) # Lighting x_hat = sample_network(x_curr, sdf_curr, raydir, grad_curr) _, feature, grad = sdf_feature_grad(implicit_network, x_hat, conf) normal = grad color_pred = lighting_network(x_hat, normal, feature, -raydir, D) # Color loss loss_color = F.absolute_error(color_gt, color_pred) loss_color = loss_color * mask_pin loss_color = F.sum(loss_color) / P loss_color = loss_color.apply(persistent=True) # Total loss loss = loss_color + conf.mask_weight * \ loss_mask + conf.eikonal_weight * loss_eikonal return loss, loss_color, loss_mask, loss_eikonal, mask_hit
def equivariance_value_loss(kp_driving_value, warped_kp_value, weight): value_loss = F.mean(F.absolute_error(kp_driving_value, warped_kp_value)) loss = weight * value_loss return loss
def main(): conf = get_config() train_gt_path = sorted(glob.glob(conf.DIV2K.gt_train + "/*.png")) train_lq_path = sorted(glob.glob(conf.DIV2K.lq_train + "/*.png")) val_gt_path = sorted(glob.glob(conf.SET14.gt_val + "/*.png")) val_lq_path = sorted(glob.glob(conf.SET14.lq_val + "/*.png")) train_samples = len(train_gt_path) val_samples = len(val_gt_path) lr_g = conf.hyperparameters.lr_g lr_d = conf.hyperparameters.lr_d lr_steps = conf.train.lr_steps random.seed(conf.train.seed) np.random.seed(conf.train.seed) extension_module = conf.nnabla_context.context ctx = get_extension_context( extension_module, device_id=conf.nnabla_context.device_id) comm = CommunicatorWrapper(ctx) nn.set_default_context(comm.ctx) # data iterators for train and val data from data_loader import data_iterator_sr data_iterator_train = data_iterator_sr( train_samples, conf.train.batch_size, train_gt_path, train_lq_path, train=True, shuffle=True) data_iterator_val = data_iterator_sr( val_samples, conf.val.batch_size, val_gt_path, val_lq_path, train=False, shuffle=False) if comm.n_procs > 1: data_iterator_train = data_iterator_train.slice( rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank) train_gt = nn.Variable( (conf.train.batch_size, 3, conf.train.gt_size, conf.train.gt_size)) train_lq = nn.Variable( (conf.train.batch_size, 3, conf.train.gt_size // conf.train.scale, conf.train.gt_size // conf.train.scale)) # setting up monitors for logging monitor_path = './nnmonitor' + str(datetime.now().strftime("%Y%m%d%H%M%S")) monitor = Monitor(monitor_path) monitor_pixel_g = MonitorSeries( 'l_g_pix per iteration', monitor, interval=100) monitor_val = MonitorSeries( 'Validation loss per epoch', monitor, interval=1) monitor_time = MonitorTimeElapsed( "Training time per epoch", monitor, interval=1) with nn.parameter_scope("gen"): nn.load_parameters(conf.train.gen_pretrained) fake_h = rrdb_net(train_lq, 64, 23) fake_h.persistent = True pixel_loss = F.mean(F.absolute_error(fake_h, train_gt)) pixel_loss.persistent = True gen_loss = pixel_loss if conf.model.esrgan: from esrgan_model import get_esrgan_gen, get_esrgan_dis, get_esrgan_monitors gen_model = get_esrgan_gen(conf, train_gt, train_lq, fake_h) gen_loss = conf.hyperparameters.eta_pixel_loss * pixel_loss + conf.hyperparameters.feature_loss_weight * gen_model.feature_loss + \ conf.hyperparameters.lambda_gan_loss * gen_model.loss_gan_gen dis_model = get_esrgan_dis(fake_h, gen_model.pred_d_real) # Set Discriminator parameters solver_dis = S.Adam(lr_d, beta1=0.9, beta2=0.99) with nn.parameter_scope("dis"): solver_dis.set_parameters(nn.get_parameters()) esr_mon = get_esrgan_monitors() # Set generator Parameters solver_gen = S.Adam(alpha=lr_g, beta1=0.9, beta2=0.99) with nn.parameter_scope("gen"): solver_gen.set_parameters(nn.get_parameters()) train_size = int( train_samples / conf.train.batch_size / comm.n_procs) total_epochs = conf.train.n_epochs start_epoch = 0 current_iter = 0 if comm.rank == 0: print("total_epochs", total_epochs) print("train_samples", train_samples) print("val_samples", val_samples) print("train_size", train_size) for epoch in range(start_epoch + 1, total_epochs + 1): index = 0 # Training loop for psnr rrdb model while index < train_size: current_iter += comm.n_procs train_gt.d, train_lq.d = data_iterator_train.next() if not conf.model.esrgan: lr_g = get_repeated_cosine_annealing_learning_rate( current_iter, conf.hyperparameters.eta_max, conf.hyperparameters.eta_min, conf.train.cosine_period, conf.train.cosine_num_period) if conf.model.esrgan: lr_g = get_multistep_learning_rate( current_iter, lr_steps, lr_g) gen_model.var_ref.d = train_gt.d gen_model.pred_d_real.grad.zero() gen_model.pred_d_real.forward(clear_no_need_grad=True) gen_model.pred_d_real.need_grad = False # Generator update gen_loss.forward(clear_no_need_grad=True) solver_gen.zero_grad() # All-reduce gradients every 2MiB parameters during backward computation if comm.n_procs > 1: with nn.parameter_scope('gen'): all_reduce_callback = comm.get_all_reduce_callback() gen_loss.backward(clear_buffer=True, communicator_callbacks=all_reduce_callback) else: gen_loss.backward(clear_buffer=True) solver_gen.set_learning_rate(lr_g) solver_gen.update() # Discriminator Upate if conf.model.esrgan: gen_model.pred_d_real.need_grad = True lr_d = get_multistep_learning_rate( current_iter, lr_steps, lr_d) solver_dis.zero_grad() dis_model.l_d_total.forward(clear_no_need_grad=True) if comm.n_procs > 1: with nn.parameter_scope('dis'): all_reduce_callback = comm.get_all_reduce_callback() dis_model.l_d_total.backward( clear_buffer=True, communicator_callbacks=all_reduce_callback) else: dis_model.l_d_total.backward(clear_buffer=True) solver_dis.set_learning_rate(lr_d) solver_dis.update() index += 1 if comm.rank == 0: monitor_pixel_g.add( current_iter, pixel_loss.d.copy()) monitor_time.add(epoch * comm.n_procs) if comm.rank == 0 and conf.model.esrgan: esr_mon.monitor_feature_g.add( current_iter, gen_model.feature_loss.d.copy()) esr_mon.monitor_gan_g.add( current_iter, gen_model.loss_gan_gen.d.copy()) esr_mon.monitor_gan_d.add( current_iter, dis_model.l_d_total.d.copy()) esr_mon.monitor_d_real.add(current_iter, F.mean( gen_model.pred_d_real.data).data) esr_mon.monitor_d_fake.add(current_iter, F.mean( gen_model.pred_g_fake.data).data) # Validation Loop if comm.rank == 0: avg_psnr = 0.0 for idx in range(val_samples): val_gt_im, val_lq_im = data_iterator_val.next() val_gt = nn.NdArray.from_numpy_array(val_gt_im) val_lq = nn.NdArray.from_numpy_array(val_lq_im) with nn.parameter_scope("gen"): avg_psnr = val_save( val_gt, val_lq, val_lq_path, idx, epoch, avg_psnr) avg_psnr = avg_psnr / val_samples monitor_val.add(epoch, avg_psnr) # Save generator weights if comm.rank == 0: if not os.path.exists(conf.train.savemodel): os.makedirs(conf.train.savemodel) with nn.parameter_scope("gen"): nn.save_parameters(os.path.join( conf.train.savemodel, "generator_param_%06d.h5" % epoch)) # Save discriminator weights if comm.rank == 0 and conf.model.esrgan: with nn.parameter_scope("dis"): nn.save_parameters(os.path.join( conf.train.savemodel, "discriminator_param_%06d.h5" % epoch))
def Loss_reconstruction(wave_fake, wave_true, beta_in, beta_clean): E_wave = F.mean(F.absolute_error(wave_fake, wave_true)) # 再構成性能の向上 B_wave = F.mean(F.absolute_error(beta_in, beta_clean)) return E_wave + 0.01 * B_wave
def train(args): ## Sub-functions ## --------------------------------- ## Save Models def save_models(epoch_num, cle_disout, fake_disout, losses_gen, losses_dis, losses_ae): # save generator parameter with nn.parameter_scope("gen"): nn.save_parameters(os.path.join(args.model_save_path, 'generator_param_{:04}.h5'.format(epoch_num + 1))) # save discriminator parameter with nn.parameter_scope("dis"): nn.save_parameters(os.path.join(args.model_save_path, 'discriminator_param_{:04}.h5'.format(epoch_num + 1))) # save results np.save(os.path.join(args.model_save_path, 'disout_his_{:04}.npy'.format(epoch_num + 1)), np.array([cle_disout, fake_disout])) np.save(os.path.join(args.model_save_path, 'losses_gen_{:04}.npy'.format(epoch_num + 1)), np.array(losses_gen)) np.save(os.path.join(args.model_save_path, 'losses_dis_{:04}.npy'.format(epoch_num + 1)), np.array(losses_dis)) np.save(os.path.join(args.model_save_path, 'losses_ae_{:04}.npy'.format(epoch_num + 1)), np.array(losses_ae)) ## Load Models def load_models(epoch_num, gen=True, dis=True): # load generator parameter with nn.parameter_scope("gen"): nn.load_parameters(os.path.join(args.model_save_path, 'generator_param_{:04}.h5'.format(args.epoch_from))) # load discriminator parameter with nn.parameter_scope("dis"): nn.load_parameters(os.path.join(args.model_save_path, 'discriminator_param_{:04}.h5'.format(args.epoch_from))) ## Update parameters class updating: def __init__(self): self.scale = 8 if args.halfprec else 1 def __call__(self, solver, loss): solver.zero_grad() # initialize loss.forward(clear_no_need_grad=True) # calculate forward loss.backward(self.scale, clear_buffer=True) # calculate backward solver.scale_grad(1. / self.scale) # scaling solver.weight_decay(args.weight_decay * self.scale) # decay solver.update() # update ## Inital Settings ## --------------------------------- ## Create network # Clear nn.clear_parameters() # Variables noisy = nn.Variable([args.batch_size, 1, 16384], need_grad=False) # Input clean = nn.Variable([args.batch_size, 1, 16384], need_grad=False) # Desire z = nn.Variable([args.batch_size, 1024, 8], need_grad=False) # Random Latent Variable # Generator genout = Generator(noisy, z) # Predicted Clean genout.persistent = True # Not to clear at backward loss_gen = Loss_gen(genout, clean, Discriminator(noisy, genout)) loss_ae = F.mean(F.absolute_error(genout, clean)) # Discriminator fake_dis = genout.get_unlinked_variable(need_grad=True) cle_disout = Discriminator(noisy, clean) fake_disout = Discriminator(noisy, fake_dis) loss_dis = Loss_dis(Discriminator(noisy, clean),Discriminator(noisy, fake_dis)) ## Solver # RMSprop. # solver_gen = S.RMSprop(args.learning_rate_gen) # solver_dis = S.RMSprop(args.learning_rate_dis) # Adam solver_gen = S.Adam(args.learning_rate_gen) solver_dis = S.Adam(args.learning_rate_dis) # set parameter with nn.parameter_scope("gen"): solver_gen.set_parameters(nn.get_parameters()) with nn.parameter_scope("dis"): solver_dis.set_parameters(nn.get_parameters()) ## Load data & Create batch clean_data, noisy_data = dt.data_loader() batches = dt.create_batch(clean_data, noisy_data, args.batch_size) del clean_data, noisy_data ## Initial settings for sub-functions fig = figout() disp = display(args.epoch_from, args.epoch, batches.batch_num) upd = updating() ## Train ##---------------------------------------------------- print('== Start Training ==') ## Load "Pre-trained" parameters if args.epoch_from > 0: print(' Retrain parameter from pre-trained network') load_models(args.epoch_from, dis=False) losses_gen = np.load(os.path.join(args.model_save_path, 'losses_gen_{:04}.npy'.format(args.epoch_from))) losses_dis = np.load(os.path.join(args.model_save_path, 'losses_dis_{:04}.npy'.format(args.epoch_from))) losses_ae = np.load(os.path.join(args.model_save_path, 'losses_ae_{:04}.npy'.format(args.epoch_from))) else: losses_gen = [] losses_ae = [] losses_dis = [] ## Create loss loggers point = len(losses_gen) loss_len = (args.epoch - args.epoch_from) * ((batches.batch_num+1)//10) losses_gen = np.append(losses_gen, np.zeros(loss_len)) losses_ae = np.append(losses_ae, np.zeros(loss_len)) losses_dis = np.append(losses_dis, np.zeros(loss_len)) ## Training for i in range(args.epoch_from, args.epoch): print('') print(' =========================================================') print(' Epoch :: {0}/{1}'.format(i + 1, args.epoch)) print(' =========================================================') print('') # Batch iteration for j in range(batches.batch_num): print(' Train (Epoch. {0}) - {1}/{2}'.format(i+1, j+1, batches.batch_num)) ## Batch setting clean.d, noisy.d = batches.next(j) #z.d = np.random.randn(*z.shape) z.d = np.zeros(z.shape) ## Updating upd(solver_gen, loss_gen) # update Generator upd(solver_dis, loss_dis) # update Discriminator ## Display if (j+1) % 10 == 0: # Get result for Display cle_disout.forward() fake_disout.forward() loss_ae.forward(clear_no_need_grad=True) # Display text disp(i, j, loss_gen.d, loss_dis.d, loss_ae.d) # Data logger losses_gen[point] = loss_gen.d losses_ae[point] = loss_ae.d losses_dis[point] = loss_dis.d point = point + 1 # Plot fig.waveform(noisy.d[0,0,:], genout.d[0,0,:], clean.d[0,0,:]) fig.loss(losses_gen[0:point-1], losses_ae[0:point-1], losses_dis[0:point-1]) fig.histogram(cle_disout.d, fake_disout.d) pg.QtGui.QApplication.processEvents() ## Save parameters if ((i+1) % args.model_save_cycle) == 0: save_models(i, cle_disout.d, fake_disout.d, losses_gen[0:point-1], losses_dis[0:point-1], losses_ae[0:point-1]) # save model exporter = pg.exporters.ImageExporter(fig.win.scene()) # Call pg.QtGui.QApplication.processEvents() before exporters!! exporter.export(os.path.join(args.model_save_path, 'plot_{:04}.png'.format(i + 1))) # save fig ## Save parameters (Last) save_models(args.epoch-1, cle_disout.d, fake_disout.d, losses_gen, losses_dis, losses_ae)
def Loss_reconstruction(beta_in,beta_clean): B_wave = F.mean( F.absolute_error(beta_in, beta_clean) ) return 0.001*B_wave #係数が分布推定ネットワークの重みになる
def recon_loss(x, y): return F.mean(F.absolute_error(x, y))
def feature_loss(fea_real, fea_fake): loss = list() for o1, o2 in zip(fea_real, fea_fake): for f1, f2 in zip(o1[:-1], o2[:-1]): loss.append(F.mean(F.absolute_error(f1, f2))) return sum(loss)