Exemple #1
0
def max_z_for_gen(args,
                  virtual_gen,
                  prev_gen,
                  prev_cls,
                  z_mu,
                  z_var,
                  z_t,
                  gradient_steps=10,
                  budget=10):
    """
    retrieve most interfered samples for the vae branch
    :param new_net:
    :param mu:
    :param logvar:
    :param z_t:
    :param gradient_steps:
    :param budget:
    :return:
    """
    z_new_max = None
    for i in range(budget):

        with torch.no_grad():

            if args.mir_init_prior:
                z_new = prev_gen.prior.sample(
                    (z_mu.shape[0], )).to(args.device)
            else:
                z_new = prev_gen.reparameterize(z_mu, z_var)

        for j in range(gradient_steps):
            z_new.requires_grad = True

            x_new = prev_gen.decode(z_new)


            prev_x_mean, prev_z_mu, prev_z_var, prev_ldj, prev_z0, prev_zk = \
                    prev_gen(x_new)
            _, prev_rec, prev_kl, _ = calculate_loss(prev_x_mean, x_new, prev_z_mu, \
                    prev_z_var, prev_z0, prev_zk, prev_ldj, args, beta=1)

            virtual_x_mean, virtual_z_mu, virtual_z_var, virtual_ldj, virtual_z0, virtual_zk = \
                    virtual_gen(x_new)
            _, virtual_rec, virtual_kl, _ = calculate_loss(virtual_x_mean, x_new, virtual_z_mu, \
                    virtual_z_var, virtual_z0, virtual_zk, virtual_ldj, args, beta=1)

            #TODO(warning, KL can explode)

            # maximise the interference
            KL = 0
            if args.gen_kl_coeff > 0.:
                KL = virtual_kl - prev_kl

            REC = 0
            if args.gen_rec_coeff > 0.:
                REC = virtual_rec - prev_rec

            # the predictions from the two models should be confident
            ENT = 0
            if args.gen_ent_coeff > 0.:
                y_pre = prev_cls(x_new)
                ENT = cross_entropy(y_pre, y_pre)
            #TODO(should we do the args.curr_entropy thing?)

            DIV = 0
            # the new found samples samples should be differnt from each others
            if args.gen_div_coeff > 0.:
                for found_z_i in range(i):
                    DIV += F.mse_loss(
                        z_new, z_new_max[found_z_i * z_new.size(0):found_z_i *
                                         z_new.size(0) + z_new.size(0)]) / (i)

            # (NEW) stay on gaussian shell loss:
            SHELL = 0
            if args.gen_shell_coeff > 0.:
                SHELL = mse(
                    torch.norm(z_new, 2, dim=1),
                    torch.ones_like(torch.norm(z_new, 2, dim=1)) *
                    np.sqrt(args.z_size))

            gain = args.gen_kl_coeff * KL + \
                   args.gen_rec_coeff * REC + \
                   -args.gen_ent_coeff * ENT + \
                   args.gen_div_coeff * DIV + \
                   -args.gen_shell_coeff * SHELL

            z_g = torch.autograd.grad(gain, z_new)[0]
            z_new = (z_new + 1 * z_g).detach()

        if z_new_max is None:
            z_new_max = z_new.clone()
        else:
            z_new_max = torch.cat([z_new_max, z_new.clone()])

    z_new_max.require_grad = False

    return z_new_max
                args.beta = min([(sample_amt) / max([args.warmup, 1.]),
                                 args.max_beta])

                #------ Train Generator ------#

                #-------------------------------
                # Begin Generator Iteration Loop
                for it in range(args.gen_iters):

                    x_mean, z_mu, z_var, ldj, z0, zk = gen(data)
                    gen_loss, rec, kl, _ = calculate_loss(x_mean,
                                                          data,
                                                          z_mu,
                                                          z_var,
                                                          z0,
                                                          zk,
                                                          ldj,
                                                          args,
                                                          beta=args.beta)

                    tot_gen_loss = 0 + gen_loss

                    if task > 0 and args.gen_method != 'no_rehearsal':

                        if it == 0 or not args.reuse_samples:

                            if args.gen_method == 'rand_gen':
                                mem_x = prev_gen.generate(args.batch_size *
                                                          args.n_mem).detach()
Exemple #3
0
def retrieve_gen_for_gen(args, x, gen, prev_gen, prev_cls):

    grad_vector = get_grad_vector(args, gen.parameters, gen.grad_dims)

    virtual_gen = get_future_step_parameters(gen, grad_vector, gen.grad_dims,
                                             args.lr)

    _, z_mu, z_var, _, _, _ = prev_gen(x)

    z_new_max = None
    for i in range(args.n_mem):

        with torch.no_grad():

            if args.mir_init_prior:
                z_new = prev_gen.prior.sample(
                    (z_mu.shape[0], )).to(args.device)
            else:
                z_new = prev_gen.reparameterize(z_mu, z_var)

        for j in range(args.mir_iters):
            z_new.requires_grad = True

            x_new = prev_gen.decode(z_new)


            prev_x_mean, prev_z_mu, prev_z_var, prev_ldj, prev_z0, prev_zk = \
                    prev_gen(x_new)
            _, prev_rec, prev_kl, _ = calculate_loss(prev_x_mean, x_new, prev_z_mu, \
                    prev_z_var, prev_z0, prev_zk, prev_ldj, args, beta=1)

            virtual_x_mean, virtual_z_mu, virtual_z_var, virtual_ldj, virtual_z0, virtual_zk = \
                    virtual_gen(x_new)
            _, virtual_rec, virtual_kl, _ = calculate_loss(virtual_x_mean, x_new, virtual_z_mu, \
                    virtual_z_var, virtual_z0, virtual_zk, virtual_ldj, args, beta=1)

            #TODO(warning, KL can explode)

            # maximise the interference
            KL = 0
            if args.gen_kl_coeff > 0.:
                KL = virtual_kl - prev_kl

            REC = 0
            if args.gen_rec_coeff > 0.:
                REC = virtual_rec - prev_rec

            # the predictions from the two models should be confident
            ENT = 0
            if args.gen_ent_coeff > 0.:
                y_pre = prev_cls(x_new)
                ENT = cross_entropy(y_pre, y_pre)
            #TODO(should we do the args.curr_entropy thing?)

            DIV = 0
            # the new found samples samples should be differnt from each others
            if args.gen_div_coeff > 0.:
                for found_z_i in range(i):
                    DIV += F.mse_loss(
                        z_new, z_new_max[found_z_i * z_new.size(0):found_z_i *
                                         z_new.size(0) + z_new.size(0)]) / (i)

            # (NEW) stay on gaussian shell loss:
            SHELL = 0
            if args.gen_shell_coeff > 0.:
                SHELL = mse(
                    torch.norm(z_new, 2, dim=1),
                    torch.ones_like(torch.norm(z_new, 2, dim=1)) *
                    np.sqrt(args.z_size))

            gain = args.gen_kl_coeff * KL + \
                   args.gen_rec_coeff * REC + \
                   -args.gen_ent_coeff * ENT + \
                   args.gen_div_coeff * DIV + \
                   -args.gen_shell_coeff * SHELL

            z_g = torch.autograd.grad(gain, z_new)[0]
            z_new = (z_new + 1 * z_g).detach()

        if z_new_max is None:
            z_new_max = z_new.clone()
        else:
            z_new_max = torch.cat([z_new_max, z_new.clone()])

    z_new_max.require_grad = False

    if np.isnan(z_new_max.to('cpu').numpy()).any():
        mir_worked = 0
        mem_x = prev_gen.generate(args.batch_size * args.n_mem).detach()
    else:
        mem_x = prev_gen.decode(z_new_max).detach()
        mir_worked = 1

    return mem_x, mir_worked