Ejemplo n.º 1
0
def my_less_scalar(_x, _scalar):
    # input
    # _x : type=nn.Variable
    # _scalar : type=float
    # output
    # flags : type=nn.Variable, same shape with _x

    temp = F.r_sub_scalar(_x, _scalar)
    temp = F.sign(temp, alpha=0)
    flags = F.relu(temp)
    return flags
Ejemplo n.º 2
0
    def __rsub__(self, other):
        """
        Element-wise subtraction.
        Part of the implementation of the subtraction operator.

        Args:
            other (float or ~nnabla.Variable): Internally calling
                :func:`~nnabla.functions.sub2` or
                :func:`~nnabla.functions.r_sub_scalar` according to the
                type.

        Returns: :class:`nnabla.Variable`

        """
        import nnabla.functions as F
        if isinstance(other, Variable):
            return F.sub2(other, self)
        return F.r_sub_scalar(self, other)
Ejemplo n.º 3
0
def train(args):
    if args.c_dim != len(args.selected_attrs):
        print("c_dim must be the same as the num of selected attributes. Modified c_dim.")
        args.c_dim = len(args.selected_attrs)

    # Dump the config information.
    config = dict()
    print("Used config:")
    for k in args.__dir__():
        if not k.startswith("_"):
            config[k] = getattr(args, k)
            print("'{}' : {}".format(k, getattr(args, k)))

    # Prepare Generator and Discriminator based on user config.
    generator = functools.partial(
        model.generator, conv_dim=args.g_conv_dim, c_dim=args.c_dim, num_downsample=args.num_downsample, num_upsample=args.num_upsample, repeat_num=args.g_repeat_num)
    discriminator = functools.partial(model.discriminator, image_size=args.image_size,
                                      conv_dim=args.d_conv_dim, c_dim=args.c_dim, repeat_num=args.d_repeat_num)

    x_real = nn.Variable(
        [args.batch_size, 3, args.image_size, args.image_size])
    label_org = nn.Variable([args.batch_size, args.c_dim, 1, 1])
    label_trg = nn.Variable([args.batch_size, args.c_dim, 1, 1])

    with nn.parameter_scope("dis"):
        dis_real_img, dis_real_cls = discriminator(x_real)

    with nn.parameter_scope("gen"):
        x_fake = generator(x_real, label_trg)
    x_fake.persistent = True  # to retain its value during computation.

    # get an unlinked_variable of x_fake
    x_fake_unlinked = x_fake.get_unlinked_variable()

    with nn.parameter_scope("dis"):
        dis_fake_img, dis_fake_cls = discriminator(x_fake_unlinked)

    # ---------------- Define Loss for Discriminator -----------------
    d_loss_real = (-1) * loss.gan_loss(dis_real_img)
    d_loss_fake = loss.gan_loss(dis_fake_img)
    d_loss_cls = loss.classification_loss(dis_real_cls, label_org)
    d_loss_cls.persistent = True

    # Gradient Penalty.
    alpha = F.rand(shape=(args.batch_size, 1, 1, 1))
    x_hat = F.mul2(alpha, x_real) + \
        F.mul2(F.r_sub_scalar(alpha, 1), x_fake_unlinked)

    with nn.parameter_scope("dis"):
        dis_for_gp, _ = discriminator(x_hat)
    grads = nn.grad([dis_for_gp], [x_hat])

    l2norm = F.sum(grads[0] ** 2.0, axis=(1, 2, 3)) ** 0.5
    d_loss_gp = F.mean((l2norm - 1.0) ** 2.0)

    # total discriminator loss.
    d_loss = d_loss_real + d_loss_fake + args.lambda_cls * \
        d_loss_cls + args.lambda_gp * d_loss_gp

    # ---------------- Define Loss for Generator -----------------
    g_loss_fake = (-1) * loss.gan_loss(dis_fake_img)
    g_loss_cls = loss.classification_loss(dis_fake_cls, label_trg)
    g_loss_cls.persistent = True

    # Reconstruct Images.
    with nn.parameter_scope("gen"):
        x_recon = generator(x_fake_unlinked, label_org)
    x_recon.persistent = True

    g_loss_rec = loss.recon_loss(x_real, x_recon)
    g_loss_rec.persistent = True

    # total generator loss.
    g_loss = g_loss_fake + args.lambda_rec * \
        g_loss_rec + args.lambda_cls * g_loss_cls

    # -------------------- Solver Setup ---------------------
    d_lr = args.d_lr  # initial learning rate for Discriminator
    g_lr = args.g_lr  # initial learning rate for Generator
    solver_dis = S.Adam(alpha=args.d_lr, beta1=args.beta1, beta2=args.beta2)
    solver_gen = S.Adam(alpha=args.g_lr, beta1=args.beta1, beta2=args.beta2)

    # register parameters to each solver.
    with nn.parameter_scope("dis"):
        solver_dis.set_parameters(nn.get_parameters())

    with nn.parameter_scope("gen"):
        solver_gen.set_parameters(nn.get_parameters())

    # -------------------- Create Monitors --------------------
    monitor = Monitor(args.monitor_path)
    monitor_d_cls_loss = MonitorSeries(
        'real_classification_loss', monitor, args.log_step)
    monitor_g_cls_loss = MonitorSeries(
        'fake_classification_loss', monitor, args.log_step)
    monitor_loss_dis = MonitorSeries(
        'discriminator_loss', monitor, args.log_step)
    monitor_recon_loss = MonitorSeries(
        'reconstruction_loss', monitor, args.log_step)
    monitor_loss_gen = MonitorSeries('generator_loss', monitor, args.log_step)
    monitor_time = MonitorTimeElapsed("Training_time", monitor, args.log_step)

    # -------------------- Prepare / Split Dataset --------------------
    using_attr = args.selected_attrs
    dataset, attr2idx, idx2attr = get_data_dict(args.attr_path, using_attr)
    random.seed(313)  # use fixed seed.
    random.shuffle(dataset)  # shuffle dataset.
    test_dataset = dataset[-2000:]  # extract 2000 images for test

    if args.num_data:
        # Use training data partially.
        training_dataset = dataset[:min(args.num_data, len(dataset) - 2000)]
    else:
        training_dataset = dataset[:-2000]
    print("Use {} images for training.".format(len(training_dataset)))

    # create data iterators.
    load_func = functools.partial(stargan_load_func, dataset=training_dataset,
                                  image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size)
    data_iterator = data_iterator_simple(load_func, len(
        training_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False)

    load_func_test = functools.partial(stargan_load_func, dataset=test_dataset,
                                       image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size)
    test_data_iterator = data_iterator_simple(load_func_test, len(
        test_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False)

    # Keep fixed test images for intermediate translation visualization.
    test_real_ndarray, test_label_ndarray = test_data_iterator.next()
    test_label_ndarray = test_label_ndarray.reshape(
        test_label_ndarray.shape + (1, 1))

    # -------------------- Training Loop --------------------
    one_epoch = data_iterator.size // args.batch_size
    num_max_iter = args.max_epoch * one_epoch

    for i in range(num_max_iter):
        # Get real images and labels.
        real_ndarray, label_ndarray = data_iterator.next()
        label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1))
        label_ndarray = label_ndarray.astype(float)
        x_real.d, label_org.d = real_ndarray, label_ndarray

        # Generate target domain labels randomly.
        rand_idx = np.random.permutation(label_org.shape[0])
        label_trg.d = label_ndarray[rand_idx]

        # ---------------- Train Discriminator -----------------
        # generate fake image.
        x_fake.forward(clear_no_need_grad=True)
        d_loss.forward(clear_no_need_grad=True)
        solver_dis.zero_grad()
        d_loss.backward(clear_buffer=True)
        solver_dis.update()

        monitor_loss_dis.add(i, d_loss.d.item())
        monitor_d_cls_loss.add(i, d_loss_cls.d.item())
        monitor_time.add(i)

        # -------------- Train Generator --------------
        if (i + 1) % args.n_critic == 0:
            g_loss.forward(clear_no_need_grad=True)
            solver_dis.zero_grad()
            solver_gen.zero_grad()
            x_fake_unlinked.grad.zero()
            g_loss.backward(clear_buffer=True)
            x_fake.backward(grad=None)
            solver_gen.update()
            monitor_loss_gen.add(i, g_loss.d.item())
            monitor_g_cls_loss.add(i, g_loss_cls.d.item())
            monitor_recon_loss.add(i, g_loss_rec.d.item())
            monitor_time.add(i)

            if (i + 1) % args.sample_step == 0:
                # save image.
                save_results(i, args, x_real, x_fake,
                             label_org, label_trg, x_recon)
                if args.test_during_training:
                    # translate images from test dataset.
                    x_real.d, label_org.d = test_real_ndarray, test_label_ndarray
                    label_trg.d = test_label_ndarray[rand_idx]
                    x_fake.forward(clear_no_need_grad=True)
                    save_results(i, args, x_real, x_fake, label_org,
                                 label_trg, None, is_training=False)

        # Learning rates get decayed
        if (i + 1) > int(0.5 * num_max_iter) and (i + 1) % args.lr_update_step == 0:
            g_lr = max(0, g_lr - (args.lr_update_step *
                                  args.g_lr / float(0.5 * num_max_iter)))
            d_lr = max(0, d_lr - (args.lr_update_step *
                                  args.d_lr / float(0.5 * num_max_iter)))
            solver_gen.set_learning_rate(g_lr)
            solver_dis.set_learning_rate(d_lr)
            print('learning rates decayed, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))

    # Save parameters and training config.
    param_name = 'trained_params_{}.h5'.format(
        datetime.datetime.today().strftime("%m%d%H%M"))
    param_path = os.path.join(args.model_save_path, param_name)
    nn.save_parameters(param_path)
    config["pretrained_params"] = param_name

    with open(os.path.join(args.model_save_path, "training_conf_{}.json".format(datetime.datetime.today().strftime("%m%d%H%M"))), "w") as f:
        json.dump(config, f)

    # -------------------- Translation on test dataset --------------------
    for i in range(args.num_test):
        real_ndarray, label_ndarray = test_data_iterator.next()
        label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1))
        label_ndarray = label_ndarray.astype(float)
        x_real.d, label_org.d = real_ndarray, label_ndarray

        rand_idx = np.random.permutation(label_org.shape[0])
        label_trg.d = label_ndarray[rand_idx]

        x_fake.forward(clear_no_need_grad=True)
        save_results(i, args, x_real, x_fake, label_org,
                     label_trg, None, is_training=False)
Ejemplo n.º 4
0
def ssd_loss(_ssd_confs, _ssd_locs, _label, _alpha=1):
    # input
    # _ssd_confs : type=nn.Variable, prediction of class. shape=(batch_size, default boxes, class num + 1)
    # _ssd_locs : type=nn.Variable, prediction of location. shape=(batch_size, default boxes, 4)
    # _label : type=nn.Variable, shape=(batch_size, default boxes, class num + 1 + 4)
    # _alpha : type=float, hyperparameter. this is weight of loc_loss.

    # output
    # loss : type=nn.Variable

    def smooth_L1(__pred_locs, __label_locs):
        # input
        # __pred_locs : type=nn.Variable, 
        # __label_locs : type=nn.Variable, 

        # output
        # _loss : type=nn.Variable, loss of location.

        return F.mul_scalar(F.huber_loss(__pred_locs, __label_locs), 0.5)

    # _label_conf : type=nn.Variable, label of class. shape=(batch_size, default boxes, class num + 1) (after one_hot)
    # _label_loc : type=nn.Variable, label of location. shape=(batch_size, default boxes, 4)
    label_conf = F.slice(
        _label, 
        start=(0,0,4), 
        stop=_label.shape, 
        step=(1,1,1)
    )
    label_loc = F.slice(
        _label, 
        start=(0,0,0), 
        stop=(_label.shape[0], _label.shape[1], 4), 
        step=(1,1,1)
    )

    # conf
    ssd_pos_conf, ssd_neg_conf = ssd_separate_conf_pos_neg(_ssd_confs)
    label_conf_pos, _ = ssd_separate_conf_pos_neg(label_conf)
    # pos
    pos_loss = F.sum(
                        F.mul2(
                            F.softmax(ssd_pos_conf, axis=2), 
                            label_conf_pos
                        )
                        , axis=2
                    )
    # neg
    neg_loss = F.sum(F.log(ssd_neg_conf), axis=2)
    conf_loss = F.sum(F.sub2(pos_loss, neg_loss), axis=1)

    # loc
    pos_label = F.sum(label_conf_pos, axis=2)      # =1 (if there is sonething), =0 (if there is nothing)
    loc_loss = F.sum(F.mul2(F.sum(smooth_L1(_ssd_locs, label_loc), axis=2), pos_label), axis=1)

    # [2019/07/18]
    label_match_default_box_num = F.slice(
        _label, 
        start=(0,0,_label.shape[2] - 1), 
        stop=_label.shape, 
        step=(1,1,1)
    )
    label_match_default_box_num = F.sum(label_match_default_box_num, axis=1)
    label_match_default_box_num = F.r_sub_scalar(label_match_default_box_num, _label.shape[1])
    label_match_default_box_num = F.reshape(label_match_default_box_num, (label_match_default_box_num.shape[0],), inplace=False)
    # label_match_default_box_num : type=nn.Variable, inverse number of default boxes that matches with pos.

    # loss
    loss = F.mul2(F.add2(conf_loss, F.mul_scalar(loc_loss, _alpha)), label_match_default_box_num)
    loss = F.mean(loss)
    return loss