예제 #1
0
def retrieve_replay_update(args,
                           model,
                           opt,
                           input_x,
                           input_y,
                           buffer,
                           task,
                           loader=None,
                           rehearse=True):
    """ finds buffer samples with maxium interference """
    '''
    ER - MIR and regular ER
    '''

    updated_inds = None

    hid = model.return_hidden(input_x)

    logits = model.linear(hid)
    if args.multiple_heads:
        logits = logits.masked_fill(loader.dataset.mask == 0, -1e9)
    loss_a = F.cross_entropy(logits, input_y, reduction='none')
    loss = (loss_a).sum() / loss_a.size(0)

    opt.zero_grad()
    loss.backward()

    if not rehearse:
        opt.step()
        return model

    if args.method == 'mir_replay':
        bx, by, bt, subsample = buffer.sample(args.subsample,
                                              exclude_task=task,
                                              ret_ind=True)
        grad_dims = []
        for param in model.parameters():
            grad_dims.append(param.data.numel())
        grad_vector = get_grad_vector(args, model.parameters, grad_dims)
        model_temp = get_future_step_parameters(model,
                                                grad_vector,
                                                grad_dims,
                                                lr=args.lr)

        with torch.no_grad():
            logits_track_pre = model(bx)
            buffer_hid = model_temp.return_hidden(bx)
            logits_track_post = model_temp.linear(buffer_hid)

            if args.multiple_heads:
                mask = torch.zeros_like(logits_track_post)
                mask.scatter_(1, loader.dataset.task_ids[bt], 1)
                assert mask.nelement() // mask.sum() == args.n_tasks
                logits_track_post = logits_track_post.masked_fill(
                    mask == 0, -1e9)
                logits_track_pre = logits_track_pre.masked_fill(
                    mask == 0, -1e9)

            pre_loss = F.cross_entropy(logits_track_pre, by, reduction="none")
            post_loss = F.cross_entropy(logits_track_post,
                                        by,
                                        reduction="none")
            scores = post_loss - pre_loss
            EN_logits = entropy_fn(logits_track_pre)
            if args.compare_to_old_logits:
                old_loss = F.cross_entropy(buffer.logits[subsample],
                                           by,
                                           reduction="none")

                updated_mask = pre_loss < old_loss
                updated_inds = updated_mask.data.nonzero().squeeze(1)
                scores = post_loss - torch.min(pre_loss, old_loss)

            all_logits = scores
            big_ind = all_logits.sort(
                descending=True)[1][:args.buffer_batch_size]

            idx = subsample[big_ind]

        mem_x, mem_y, logits_y, b_task_ids = bx[big_ind], by[
            big_ind], buffer.logits[idx], bt[big_ind]
    else:
        mem_x, mem_y, bt = buffer.sample(args.buffer_batch_size,
                                         exclude_task=task)

    logits_buffer = model(mem_x)
    if args.multiple_heads:
        mask = torch.zeros_like(logits_buffer)
        mask.scatter_(1, loader.dataset.task_ids[b_task_ids], 1)
        assert mask.nelement() // mask.sum() == args.n_tasks
        logits_buffer = logits_buffer.masked_fill(mask == 0, -1e9)
    F.cross_entropy(logits_buffer, mem_y).backward()

    if updated_inds is not None:
        buffer.logits[subsample[updated_inds]] = deepcopy(
            logits_track_pre[updated_inds])
    opt.step()
    return model
예제 #2
0
def retrieve_gen_for_gen(args, x, gen, prev_gen, prev_cls):

    grad_vector = get_grad_vector(args, gen.parameters, gen.grad_dims)

    virtual_gen = get_future_step_parameters(gen, grad_vector, gen.grad_dims,
                                             args.lr)

    _, z_mu, z_var, _, _, _ = prev_gen(x)

    z_new_max = None
    for i in range(args.n_mem):

        with torch.no_grad():

            if args.mir_init_prior:
                z_new = prev_gen.prior.sample(
                    (z_mu.shape[0], )).to(args.device)
            else:
                z_new = prev_gen.reparameterize(z_mu, z_var)

        for j in range(args.mir_iters):
            z_new.requires_grad = True

            x_new = prev_gen.decode(z_new)


            prev_x_mean, prev_z_mu, prev_z_var, prev_ldj, prev_z0, prev_zk = \
                    prev_gen(x_new)
            _, prev_rec, prev_kl, _ = calculate_loss(prev_x_mean, x_new, prev_z_mu, \
                    prev_z_var, prev_z0, prev_zk, prev_ldj, args, beta=1)

            virtual_x_mean, virtual_z_mu, virtual_z_var, virtual_ldj, virtual_z0, virtual_zk = \
                    virtual_gen(x_new)
            _, virtual_rec, virtual_kl, _ = calculate_loss(virtual_x_mean, x_new, virtual_z_mu, \
                    virtual_z_var, virtual_z0, virtual_zk, virtual_ldj, args, beta=1)

            #TODO(warning, KL can explode)

            # maximise the interference
            KL = 0
            if args.gen_kl_coeff > 0.:
                KL = virtual_kl - prev_kl

            REC = 0
            if args.gen_rec_coeff > 0.:
                REC = virtual_rec - prev_rec

            # the predictions from the two models should be confident
            ENT = 0
            if args.gen_ent_coeff > 0.:
                y_pre = prev_cls(x_new)
                ENT = cross_entropy(y_pre, y_pre)
            #TODO(should we do the args.curr_entropy thing?)

            DIV = 0
            # the new found samples samples should be differnt from each others
            if args.gen_div_coeff > 0.:
                for found_z_i in range(i):
                    DIV += F.mse_loss(
                        z_new, z_new_max[found_z_i * z_new.size(0):found_z_i *
                                         z_new.size(0) + z_new.size(0)]) / (i)

            # (NEW) stay on gaussian shell loss:
            SHELL = 0
            if args.gen_shell_coeff > 0.:
                SHELL = mse(
                    torch.norm(z_new, 2, dim=1),
                    torch.ones_like(torch.norm(z_new, 2, dim=1)) *
                    np.sqrt(args.z_size))

            gain = args.gen_kl_coeff * KL + \
                   args.gen_rec_coeff * REC + \
                   -args.gen_ent_coeff * ENT + \
                   args.gen_div_coeff * DIV + \
                   -args.gen_shell_coeff * SHELL

            z_g = torch.autograd.grad(gain, z_new)[0]
            z_new = (z_new + 1 * z_g).detach()

        if z_new_max is None:
            z_new_max = z_new.clone()
        else:
            z_new_max = torch.cat([z_new_max, z_new.clone()])

    z_new_max.require_grad = False

    if np.isnan(z_new_max.to('cpu').numpy()).any():
        mir_worked = 0
        mem_x = prev_gen.generate(args.batch_size * args.n_mem).detach()
    else:
        mem_x = prev_gen.decode(z_new_max).detach()
        mir_worked = 1

    return mem_x, mir_worked
예제 #3
0
def retrieve_gen_for_cls(args, x, cls, prev_cls, prev_gen):

    grad_vector = get_grad_vector(args, cls.parameters, cls.grad_dims)

    virtual_cls = get_future_step_parameters(cls, grad_vector, cls.grad_dims,
                                             args.lr)

    _, z_mu, z_var, _, _, _ = prev_gen(x)

    z_new_max = None

    for i in range(args.n_mem):

        with torch.no_grad():

            if args.mir_init_prior:
                z_new = prev_gen.prior.sample(
                    (z_mu.shape[0], )).to(args.device)
            else:
                z_new = prev_gen.reparameterize(z_mu, z_var)

        for j in range(args.mir_iters):

            z_new.requires_grad = True

            x_new = prev_gen.decode(z_new)
            y_pre = prev_cls(x_new)
            y_virtual = virtual_cls(x_new)

            # maximise the interference:
            XENT = 0
            if args.cls_xent_coeff > 0.:
                XENT = cross_entropy(y_virtual, y_pre)

            # the predictions from the two models should be confident
            ENT = 0
            if args.cls_ent_coeff > 0.:
                ENT = cross_entropy(y_pre, y_pre)
            #TODO(should we do the args.curr_entropy thing?)

            # the new found samples samples should be differnt from each others
            DIV = 0
            if args.cls_div_coeff > 0.:
                for found_z_i in range(i):
                    DIV += F.mse_loss(
                        z_new, z_new_max[found_z_i * z_new.size(0):found_z_i *
                                         z_new.size(0) + z_new.size(0)]) / (i)

            # (NEW) stay on gaussian shell loss:
            SHELL = 0
            if args.cls_shell_coeff > 0.:
                SHELL = mse(
                    torch.norm(z_new, 2, dim=1),
                    torch.ones_like(torch.norm(z_new, 2, dim=1)) *
                    np.sqrt(args.z_size))

            gain = args.cls_xent_coeff * XENT + \
                   -args.cls_ent_coeff * ENT + \
                   args.cls_div_coeff * DIV + \
                   -args.cls_shell_coeff * SHELL

            z_g = torch.autograd.grad(gain, z_new)[0]
            z_new = (z_new + 1 * z_g).detach()

        if z_new_max is None:
            z_new_max = z_new.clone()
        else:
            z_new_max = torch.cat([z_new_max, z_new.clone()])

    z_new_max.require_grad = False

    if np.isnan(z_new_max.to('cpu').numpy()).any():
        mir_worked = 0
        mem_x = prev_gen.generate(args.batch_size * args.n_mem).detach()
    else:
        mem_x = prev_gen.decode(z_new_max).detach()
        mir_worked = 1

    mem_y = torch.softmax(prev_cls(mem_x), dim=1).detach()

    return mem_x, mem_y, mir_worked
def retrieve_replay_update(args,
                           model,
                           opt,
                           input_x,
                           input_y,
                           buffer,
                           task,
                           loader=None,
                           rehearse=True):
    """ finds buffer samples with maxium interference """
    '''
    ER - MIR and regular ER
    '''

    updated_inds = None
    if args.imprint:

        imprint(model, input_x, input_y, args.imprint > 1)

    hid = model.return_hidden(input_x)

    logits = model.linear(hid)
    if args.multiple_heads:
        logits = logits.masked_fill(loader.dataset.mask == 0, -1e9)

    opt.zero_grad()

    if not rehearse:
        loss_a = F.cross_entropy(logits, input_y, reduction='none')
        loss = (loss_a).sum() / loss_a.size(0)
        loss.backward()

        if args.friction:
            get_weight_accumelated_gradient_norm(model)
        if args.kl_far == -1:  #it was a dummy parameter to try normalizing the last layer weights
            weight_loss = get_weight_norm_diff(model, task=task, cl=5)
            weight_loss.backward()
        opt.step()

        return model

#####################################################################################################
    if args.method == 'mir_replay':
        bx, by, bt, subsample = buffer.sample(args.subsample,
                                              exclude_task=task,
                                              ret_ind=True)
        if args.cuda:
            bx = bx.to(args.device)
            by = by.to(args.device)
            bt = bt.to(args.device)
        grad_dims = []
        for param in model.parameters():
            grad_dims.append(param.data.numel())
        grad_vector = get_grad_vector(args, model.parameters, grad_dims)
        model_temp = get_future_step_parameters(model,
                                                grad_vector,
                                                grad_dims,
                                                lr=args.lr)

        with torch.no_grad():
            logits_track_pre = model(bx)
            buffer_hid = model_temp.return_hidden(bx)
            logits_track_post = model_temp.linear(buffer_hid)

            if args.multiple_heads:
                mask = torch.zeros_like(logits_track_post)
                mask.scatter_(1, loader.dataset.task_ids[bt], 1)
                assert mask.nelement() // mask.sum() == args.n_tasks
                logits_track_post = logits_track_post.masked_fill(
                    mask == 0, -1e9)
                logits_track_pre = logits_track_pre.masked_fill(
                    mask == 0, -1e9)

            pre_loss = F.cross_entropy(logits_track_pre, by, reduction="none")
            post_loss = F.cross_entropy(logits_track_post,
                                        by,
                                        reduction="none")
            scores = post_loss - pre_loss
            EN_logits = entropy_fn(logits_track_pre)
            if args.compare_to_old_logits:
                old_loss = F.cross_entropy(buffer.logits[subsample],
                                           by,
                                           reduction="none")

                updated_mask = pre_loss < old_loss
                updated_inds = updated_mask.data.nonzero().squeeze(1)
                scores = post_loss - torch.min(pre_loss, old_loss)

            all_logits = scores
            big_ind = all_logits.sort(
                descending=True)[1][:args.buffer_batch_size]

            idx = subsample[big_ind]

        mem_x, mem_y, logits_y, b_task_ids = bx[big_ind], by[
            big_ind], buffer.logits[idx], bt[big_ind]

        logits_buffer = model(mem_x)

        if args.mask:
            yy = torch.cat((input_y, mem_y), dim=0)
            mask = get_mask_unused_memories(yy, logits.size(1))
            logits = logits.masked_fill(mask, 1e-9)
            logits_buffer = logits_buffer.masked_fill(mask, 1e-9)

        F.cross_entropy(logits, input_y).backward()

        (args.multiplier * F.cross_entropy(logits_buffer, mem_y)).backward()
##############################Nearest Neighbours###############################################

#########################Nearest Neighbours and Making Updated Local#################################################
    if args.method == 'oth_cl_neib_replay_mid_fix':
        #subsampling
        bx, by, bt, subsample = buffer.sample(args.subsample,
                                              exclude_task=task,
                                              ret_ind=True)
        #get hidden representation
        model.eval()
        b_hidden = model.return_hidden(bx)
        input_hidden = model.return_hidden(input_x)
        model.train()
        # get nearby samples
        close_indices, all_dist = get_nearbysamples(args, input_x,
                                                    input_hidden, b_hidden, by)
        #Not neighbouring indices
        _, sorted_indices = torch.sort(all_dist / args.buffer_batch_size)
        not_neighnor_inds = [
            x for x in sorted_indices if x not in close_indices
        ]
        if args.far:  #fix the furthest points
            far_indices = torch.tensor(not_neighnor_inds[-len(close_indices):])
        else:  #fix the mid points
            far_indices = torch.tensor(
                not_neighnor_inds[int(len(not_neighnor_inds) / 2):min(
                    len(not_neighnor_inds),
                    int(len(not_neighnor_inds) / 2) + len(close_indices))])

        #get corresponding samples of neighbours and far samples
        close_indices = torch.tensor(close_indices)
        mem_x = bx[close_indices]
        mem_y = by[close_indices]
        logits_buffer = model(mem_x)
        far_indices = torch.tensor(far_indices)
        far_mem_x = bx[far_indices]
        if args.untilconvergence:
            utils.optimize(args, model, opt, input_x, input_y, mem_x, mem_y,
                           far_mem_x, task)
            return model
        else:
            utils.compute_lossgrad(args, model, input_x, input_y, mem_x, mem_y,
                                   far_mem_x)


####################################RAND########################################

    if args.method == 'rand_replay':

        mem_x, mem_y, bt = buffer.sample(args.buffer_batch_size,
                                         exclude_task=task)
        if args.untilconvergence:
            utils.rand_optimize(args, model, opt, input_x, input_y, mem_x,
                                mem_y, task)
            return model

        logits_buffer = model(mem_x)

        if args.mask:
            yy = torch.cat((input_y, mem_y), dim=0)
            mask = get_mask_unused_memories(yy, logits.size(1))
            logits = logits.masked_fill(mask, 1e-9)
            logits_buffer = logits_buffer.masked_fill(mask, 1e-9)

        F.cross_entropy(logits, input_y).backward()

        (args.multiplier * F.cross_entropy(logits_buffer, mem_y)).backward()

    if updated_inds is not None:
        buffer.logits[subsample[updated_inds]] = deepcopy(
            logits_track_pre[updated_inds])

    if args.friction:
        get_weight_accumelated_gradient_norm(model)
        weight_gradient_norm(model, args.friction)
    if args.kl_far == -1:
        weight_loss = get_weight_norm_diff(model, task=task, cl=5)
        weight_loss.backward()
    opt.step()

    return model