def my_less_scalar(_x, _scalar): # input # _x : type=nn.Variable # _scalar : type=float # output # flags : type=nn.Variable, same shape with _x temp = F.r_sub_scalar(_x, _scalar) temp = F.sign(temp, alpha=0) flags = F.relu(temp) return flags
def __rsub__(self, other): """ Element-wise subtraction. Part of the implementation of the subtraction operator. Args: other (float or ~nnabla.Variable): Internally calling :func:`~nnabla.functions.sub2` or :func:`~nnabla.functions.r_sub_scalar` according to the type. Returns: :class:`nnabla.Variable` """ import nnabla.functions as F if isinstance(other, Variable): return F.sub2(other, self) return F.r_sub_scalar(self, other)
def train(args): if args.c_dim != len(args.selected_attrs): print("c_dim must be the same as the num of selected attributes. Modified c_dim.") args.c_dim = len(args.selected_attrs) # Dump the config information. config = dict() print("Used config:") for k in args.__dir__(): if not k.startswith("_"): config[k] = getattr(args, k) print("'{}' : {}".format(k, getattr(args, k))) # Prepare Generator and Discriminator based on user config. generator = functools.partial( model.generator, conv_dim=args.g_conv_dim, c_dim=args.c_dim, num_downsample=args.num_downsample, num_upsample=args.num_upsample, repeat_num=args.g_repeat_num) discriminator = functools.partial(model.discriminator, image_size=args.image_size, conv_dim=args.d_conv_dim, c_dim=args.c_dim, repeat_num=args.d_repeat_num) x_real = nn.Variable( [args.batch_size, 3, args.image_size, args.image_size]) label_org = nn.Variable([args.batch_size, args.c_dim, 1, 1]) label_trg = nn.Variable([args.batch_size, args.c_dim, 1, 1]) with nn.parameter_scope("dis"): dis_real_img, dis_real_cls = discriminator(x_real) with nn.parameter_scope("gen"): x_fake = generator(x_real, label_trg) x_fake.persistent = True # to retain its value during computation. # get an unlinked_variable of x_fake x_fake_unlinked = x_fake.get_unlinked_variable() with nn.parameter_scope("dis"): dis_fake_img, dis_fake_cls = discriminator(x_fake_unlinked) # ---------------- Define Loss for Discriminator ----------------- d_loss_real = (-1) * loss.gan_loss(dis_real_img) d_loss_fake = loss.gan_loss(dis_fake_img) d_loss_cls = loss.classification_loss(dis_real_cls, label_org) d_loss_cls.persistent = True # Gradient Penalty. alpha = F.rand(shape=(args.batch_size, 1, 1, 1)) x_hat = F.mul2(alpha, x_real) + \ F.mul2(F.r_sub_scalar(alpha, 1), x_fake_unlinked) with nn.parameter_scope("dis"): dis_for_gp, _ = discriminator(x_hat) grads = nn.grad([dis_for_gp], [x_hat]) l2norm = F.sum(grads[0] ** 2.0, axis=(1, 2, 3)) ** 0.5 d_loss_gp = F.mean((l2norm - 1.0) ** 2.0) # total discriminator loss. d_loss = d_loss_real + d_loss_fake + args.lambda_cls * \ d_loss_cls + args.lambda_gp * d_loss_gp # ---------------- Define Loss for Generator ----------------- g_loss_fake = (-1) * loss.gan_loss(dis_fake_img) g_loss_cls = loss.classification_loss(dis_fake_cls, label_trg) g_loss_cls.persistent = True # Reconstruct Images. with nn.parameter_scope("gen"): x_recon = generator(x_fake_unlinked, label_org) x_recon.persistent = True g_loss_rec = loss.recon_loss(x_real, x_recon) g_loss_rec.persistent = True # total generator loss. g_loss = g_loss_fake + args.lambda_rec * \ g_loss_rec + args.lambda_cls * g_loss_cls # -------------------- Solver Setup --------------------- d_lr = args.d_lr # initial learning rate for Discriminator g_lr = args.g_lr # initial learning rate for Generator solver_dis = S.Adam(alpha=args.d_lr, beta1=args.beta1, beta2=args.beta2) solver_gen = S.Adam(alpha=args.g_lr, beta1=args.beta1, beta2=args.beta2) # register parameters to each solver. with nn.parameter_scope("dis"): solver_dis.set_parameters(nn.get_parameters()) with nn.parameter_scope("gen"): solver_gen.set_parameters(nn.get_parameters()) # -------------------- Create Monitors -------------------- monitor = Monitor(args.monitor_path) monitor_d_cls_loss = MonitorSeries( 'real_classification_loss', monitor, args.log_step) monitor_g_cls_loss = MonitorSeries( 'fake_classification_loss', monitor, args.log_step) monitor_loss_dis = MonitorSeries( 'discriminator_loss', monitor, args.log_step) monitor_recon_loss = MonitorSeries( 'reconstruction_loss', monitor, args.log_step) monitor_loss_gen = MonitorSeries('generator_loss', monitor, args.log_step) monitor_time = MonitorTimeElapsed("Training_time", monitor, args.log_step) # -------------------- Prepare / Split Dataset -------------------- using_attr = args.selected_attrs dataset, attr2idx, idx2attr = get_data_dict(args.attr_path, using_attr) random.seed(313) # use fixed seed. random.shuffle(dataset) # shuffle dataset. test_dataset = dataset[-2000:] # extract 2000 images for test if args.num_data: # Use training data partially. training_dataset = dataset[:min(args.num_data, len(dataset) - 2000)] else: training_dataset = dataset[:-2000] print("Use {} images for training.".format(len(training_dataset))) # create data iterators. load_func = functools.partial(stargan_load_func, dataset=training_dataset, image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size) data_iterator = data_iterator_simple(load_func, len( training_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False) load_func_test = functools.partial(stargan_load_func, dataset=test_dataset, image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size) test_data_iterator = data_iterator_simple(load_func_test, len( test_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False) # Keep fixed test images for intermediate translation visualization. test_real_ndarray, test_label_ndarray = test_data_iterator.next() test_label_ndarray = test_label_ndarray.reshape( test_label_ndarray.shape + (1, 1)) # -------------------- Training Loop -------------------- one_epoch = data_iterator.size // args.batch_size num_max_iter = args.max_epoch * one_epoch for i in range(num_max_iter): # Get real images and labels. real_ndarray, label_ndarray = data_iterator.next() label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1)) label_ndarray = label_ndarray.astype(float) x_real.d, label_org.d = real_ndarray, label_ndarray # Generate target domain labels randomly. rand_idx = np.random.permutation(label_org.shape[0]) label_trg.d = label_ndarray[rand_idx] # ---------------- Train Discriminator ----------------- # generate fake image. x_fake.forward(clear_no_need_grad=True) d_loss.forward(clear_no_need_grad=True) solver_dis.zero_grad() d_loss.backward(clear_buffer=True) solver_dis.update() monitor_loss_dis.add(i, d_loss.d.item()) monitor_d_cls_loss.add(i, d_loss_cls.d.item()) monitor_time.add(i) # -------------- Train Generator -------------- if (i + 1) % args.n_critic == 0: g_loss.forward(clear_no_need_grad=True) solver_dis.zero_grad() solver_gen.zero_grad() x_fake_unlinked.grad.zero() g_loss.backward(clear_buffer=True) x_fake.backward(grad=None) solver_gen.update() monitor_loss_gen.add(i, g_loss.d.item()) monitor_g_cls_loss.add(i, g_loss_cls.d.item()) monitor_recon_loss.add(i, g_loss_rec.d.item()) monitor_time.add(i) if (i + 1) % args.sample_step == 0: # save image. save_results(i, args, x_real, x_fake, label_org, label_trg, x_recon) if args.test_during_training: # translate images from test dataset. x_real.d, label_org.d = test_real_ndarray, test_label_ndarray label_trg.d = test_label_ndarray[rand_idx] x_fake.forward(clear_no_need_grad=True) save_results(i, args, x_real, x_fake, label_org, label_trg, None, is_training=False) # Learning rates get decayed if (i + 1) > int(0.5 * num_max_iter) and (i + 1) % args.lr_update_step == 0: g_lr = max(0, g_lr - (args.lr_update_step * args.g_lr / float(0.5 * num_max_iter))) d_lr = max(0, d_lr - (args.lr_update_step * args.d_lr / float(0.5 * num_max_iter))) solver_gen.set_learning_rate(g_lr) solver_dis.set_learning_rate(d_lr) print('learning rates decayed, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) # Save parameters and training config. param_name = 'trained_params_{}.h5'.format( datetime.datetime.today().strftime("%m%d%H%M")) param_path = os.path.join(args.model_save_path, param_name) nn.save_parameters(param_path) config["pretrained_params"] = param_name with open(os.path.join(args.model_save_path, "training_conf_{}.json".format(datetime.datetime.today().strftime("%m%d%H%M"))), "w") as f: json.dump(config, f) # -------------------- Translation on test dataset -------------------- for i in range(args.num_test): real_ndarray, label_ndarray = test_data_iterator.next() label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1)) label_ndarray = label_ndarray.astype(float) x_real.d, label_org.d = real_ndarray, label_ndarray rand_idx = np.random.permutation(label_org.shape[0]) label_trg.d = label_ndarray[rand_idx] x_fake.forward(clear_no_need_grad=True) save_results(i, args, x_real, x_fake, label_org, label_trg, None, is_training=False)
def ssd_loss(_ssd_confs, _ssd_locs, _label, _alpha=1): # input # _ssd_confs : type=nn.Variable, prediction of class. shape=(batch_size, default boxes, class num + 1) # _ssd_locs : type=nn.Variable, prediction of location. shape=(batch_size, default boxes, 4) # _label : type=nn.Variable, shape=(batch_size, default boxes, class num + 1 + 4) # _alpha : type=float, hyperparameter. this is weight of loc_loss. # output # loss : type=nn.Variable def smooth_L1(__pred_locs, __label_locs): # input # __pred_locs : type=nn.Variable, # __label_locs : type=nn.Variable, # output # _loss : type=nn.Variable, loss of location. return F.mul_scalar(F.huber_loss(__pred_locs, __label_locs), 0.5) # _label_conf : type=nn.Variable, label of class. shape=(batch_size, default boxes, class num + 1) (after one_hot) # _label_loc : type=nn.Variable, label of location. shape=(batch_size, default boxes, 4) label_conf = F.slice( _label, start=(0,0,4), stop=_label.shape, step=(1,1,1) ) label_loc = F.slice( _label, start=(0,0,0), stop=(_label.shape[0], _label.shape[1], 4), step=(1,1,1) ) # conf ssd_pos_conf, ssd_neg_conf = ssd_separate_conf_pos_neg(_ssd_confs) label_conf_pos, _ = ssd_separate_conf_pos_neg(label_conf) # pos pos_loss = F.sum( F.mul2( F.softmax(ssd_pos_conf, axis=2), label_conf_pos ) , axis=2 ) # neg neg_loss = F.sum(F.log(ssd_neg_conf), axis=2) conf_loss = F.sum(F.sub2(pos_loss, neg_loss), axis=1) # loc pos_label = F.sum(label_conf_pos, axis=2) # =1 (if there is sonething), =0 (if there is nothing) loc_loss = F.sum(F.mul2(F.sum(smooth_L1(_ssd_locs, label_loc), axis=2), pos_label), axis=1) # [2019/07/18] label_match_default_box_num = F.slice( _label, start=(0,0,_label.shape[2] - 1), stop=_label.shape, step=(1,1,1) ) label_match_default_box_num = F.sum(label_match_default_box_num, axis=1) label_match_default_box_num = F.r_sub_scalar(label_match_default_box_num, _label.shape[1]) label_match_default_box_num = F.reshape(label_match_default_box_num, (label_match_default_box_num.shape[0],), inplace=False) # label_match_default_box_num : type=nn.Variable, inverse number of default boxes that matches with pos. # loss loss = F.mul2(F.add2(conf_loss, F.mul_scalar(loc_loss, _alpha)), label_match_default_box_num) loss = F.mean(loss) return loss