Ejemplo n.º 1
0
def collect_leakage(device,
                    base_folder='./models/BAM/',
                    specific="bowling_alley", 
                    seed=0, 
                    module="layer3",
                    experiment="sgd_finetuned",
                    ratios=["0.0","0.1","0.2","0.3","0.4","0.5","0.6","0.7","0.8","0.9", "1.0"],
                    adv=False,
                    baseline=False,
                    epoch=None,
                    multiple=True,
                    force=False,
                    dataset='bam',
                    args=None):
    results = {}
    if dataset == 'bam':
        _, testloader = dataload.get_data_loader_SceneBAM(seed=seed,ratio=float(0.5), specific=specific)
    elif dataset != 'coco':
        _, testloader = dataload.get_data_loader_idenProf('idenprof',train_shuffle=True,
                                                                   train_batch_size=64,
                                                                   test_batch_size=64,
                                                                   exclusive=True)
    for ratio in ratios:
        model, net, net_forward, activation_probe = load_models(
            device,
            base_folder=base_folder,
            specific=specific, 
            seed=seed, 
            module=module,
            experiment=experiment,
            ratio=ratio,
            adv=adv,
            baseline=baseline,
            epoch=epoch,
            post=True,
            multiple=multiple,
            leakage=True,
            force=force,
            dataset=dataset,
            args=args
        )
        model.eval()
        net.eval()
        if dataset == 'coco':
            tmp_args = copy.deepcopy(args)
            tmp_args.ratio = ratio
            tmp_args.gender_balanced = True
            if int(ratio) > 0:
                tmp_args.balanced = True
            _, testloader = coco_dataload.get_data_loader_coco(
                tmp_args
            )

        results[ratio],_ = utils.net2vec_accuracy(
            testloader, 
            net_forward, 
            device, 
            train_labels=[-2,-1]
        )
    return results
Ejemplo n.º 2
0
def train(
        trainloader,
        testloader,
        device,
        seed,
        debias_=True,
        specific=None,
        ratio=0.5,  # bias ratio in dataset
        n_epochs=5,
        model_lr=1e-3,
        n2v_lr=1e-3,
        combined_n2v_lr=1e-3,  # metalearning rate for n2v
        alpha=100,  # for debias,
        beta=0.1,  # for adversarial loss
        out_file=None,
        base_folder="",
        results_folder="",
        experiment="sgd",
        momentum=0,
        module="layer4",
        finetuned=False,
        adversarial=False,
        nonlinear=False,
        subset=False,
        subset_ratio=0.1,
        save_every=False,
        model_momentum=0,
        n2v_momentum=0,
        experimental=False,
        multiple=False,
        debias_multiple=False,
        reset=False,
        reset_counter=1,
        n2v_start=False,
        experiment2=None,
        adaptive_alpha=False,
        n2v_adam=False,
        single=False,
        imagenet=False,
        train_batch_size=64,
        constant_resize=False,
        adaptive_resize=False,
        no_class=False,
        gamma=0,
        partial_projection=False,
        norm='l2',
        constant_alpha=False,
        jump_alpha=False,
        linear_alpha=False,
        mean_debias=False,
        no_limit=False,
        dataset='bam',
        parallel=False,
        gpu_ids=[],
        switch_modes=True):
    print("mu", momentum, "debias", debias_, "alpha", alpha, " | ratio:",
          ratio)

    def get_vg(W):
        if single:
            return W[-2, :]
        else:
            return W[-2, :] - W[-1, :]

    if dataset == 'bam' or dataset == 'coco':
        model_init_path, n2v_init_path = utils.get_paths(
            base_folder,
            seed,
            specific,
            model_end="resnet_init" + '.pt',
            n2v_end="resnet_n2v_init" + '.pt',
            n2v_module=module,
            experiment=experiment,
            with_n2v=False)
    else:
        model_init_path = os.path.join(base_folder, str(seed), experiment,
                                       'resnet_init.pt')
        n2v_init_path = os.path.join(base_folder, str(seed), experiment,
                                     module, 'resnet_n2v_init.pt')
    if finetuned:
        if dataset == 'bam' or dataset == 'coco':
            model_init_path = utils.get_model_path(
                base_folder,
                seed,
                specific,
                "resnet_" + str(ratio) + ".pt",
                experiment='post_train'
                if not n2v_start else experiment.split('_finetuned')[0])
        else:
            model_init_path = os.path.join(
                base_folder, str(seed), 'post_train' if not n2v_start else
                experiment.split('_finetuned')[0], 'resnet.pt')
        assert (debias_ and not adversarial) or (
            adversarial and not debias_) or (not adversarial and not debias_)
        if debias_ and n2v_start:
            ext = "_n2v_" if not nonlinear else "_mlp_"
            if dataset == 'bam' or dataset == 'coco':
                n2v_init_path = utils.get_net2vec_path(
                    base_folder,
                    seed,
                    specific,
                    module,
                    "resnet" + str(ext) + str(ratio) + ".pt",
                    experiment=experiment.split('_finetuned')[0])
            else:
                n2v_init_path = os.path.join(base_folder, str(seed),
                                             experiment.split('_finetuned')[0],
                                             module,
                                             'resnet' + ext[:-1] + '.pt')
        # if we're also doing adversarial, make sure to load the matching n2v as init...
        if adversarial:
            ext = "_n2v_" if not nonlinear else "_mlp_"
            if dataset == 'bam' or dataset == 'coco':
                n2v_init_path = utils.get_net2vec_path(base_folder,
                                                       seed,
                                                       specific,
                                                       module,
                                                       "resnet" + str(ext) +
                                                       str(ratio) + ".pt",
                                                       experiment='post_train')
            else:
                n2v_init_path = os.path.join(base_folder, str(seed),
                                             'post_train', module,
                                             'resnet' + ext[:-1] + '.pt')
    num_classes = 10
    num_attributes = 12
    if nonlinear:
        num_attributes = 2
    if multiple:
        num_attributes = 10 + 9 + 2 * 10
    if dataset == 'coco':
        num_classes = 79
        num_attributes = 81
    model, net, net_forward, activation_probe = models.load_models(
        device,
        lambda x, y, z: models.resnet_(pretrained=True,
                                       custom_path=x,
                                       device=y,
                                       initialize=z,
                                       num_classes=num_classes,
                                       size=50 if (dataset == 'bam' or dataset
                                                   == 'coco') else 34),
        model_path=model_init_path,
        net2vec_pretrained=True,
        net2vec_path=n2v_init_path,
        module=module,
        num_attributes=num_attributes,
        # we want to make sure to save the inits if not finetuned...
        model_init=True if not finetuned else False,
        n2v_init=True if not (finetuned and
                              (adversarial or
                               (debias_ and n2v_start))) else False,
        loader=trainloader,
        nonlinear=nonlinear,
        # parameters if we want to initially project probes to have a certain amount of bias
        partial_projection=partial_projection,
        t=gamma)
    print(model_init_path, n2v_init_path)
    model_n2v_combined = models.ProbedModel(model,
                                            net,
                                            module,
                                            switch_modes=switch_modes)
    if n2v_adam:
        combined_optim = torch.optim.Adam(
            [{
                'params': model_n2v_combined.model.parameters()
            }, {
                'params': model_n2v_combined.net.parameters()
            }],
            lr=n2v_lr)
        # TODO: allow for momentum training as well
        n2v_optim = torch.optim.Adam(net.parameters(), lr=n2v_lr)
    else:
        combined_optim = torch.optim.SGD(
            [{
                'params': model_n2v_combined.model.parameters()
            }, {
                'params': model_n2v_combined.net.parameters(),
                'lr': combined_n2v_lr,
                'momentum': n2v_momentum
            }],
            lr=model_lr,
            momentum=model_momentum)

        # TODO: allow for momentum training as well
        n2v_optim = torch.optim.SGD(net.parameters(),
                                    lr=n2v_lr,
                                    momentum=n2v_momentum)
    model_optim = torch.optim.SGD(model.parameters(),
                                  lr=model_lr,
                                  momentum=model_momentum)

    d_losses = []
    adv_losses = []
    n2v_train_losses = []
    n2v_accs = []
    n2v_val_losses = []
    class_train_losses = []
    class_accs = []
    class_val_losses = []
    alpha_log = []
    magnitudes = []
    magnitudes2 = []
    unreduced = []
    bias_grads = []
    loss_shapes = []
    loss_shapes2 = []

    results = {
        "debias_losses": d_losses,
        "n2v_train_losses": n2v_train_losses,
        "n2v_val_losses": n2v_val_losses,
        "n2v_accs": n2v_accs,
        "class_train_losses": class_train_losses,
        "class_val_losses": class_val_losses,
        "class_accs": class_accs,
        "adv_losses": adv_losses,
        "alphas": alpha_log,
        "magnitudes": magnitudes,
        "magnitudes2": magnitudes2,
        "unreduced": unreduced,
        "bias_grads": bias_grads,
        "loss_shapes": loss_shapes,
        "loss_shapes2": loss_shapes2
    }
    if debias_:
        results_end = str(ratio) + "_debias.pck"
    elif adversarial:
        results_end = str(ratio) + "_adv.pck"
        if nonlinear:
            results_end = str(ratio) + "_mlp_adv.pck"
    else:
        results_end = str(ratio) + "_base.pck"

    if dataset == 'bam' or dataset == 'coco':
        results_path = utils.get_net2vec_path(
            results_folder, seed, specific, module, results_end,
            experiment if experiment2 is None else experiment2)
    else:
        results_path = os.path.join(
            results_folder, str(seed),
            experiment if experiment2 is None else experiment2, module,
            results_end)
    if debias_:
        model_end = "resnet_debias_" + str(ratio) + '.pt'
        n2v_end = "resnet_n2v_debias_" + str(ratio) + '.pt'
    elif adversarial:
        if not nonlinear:
            model_end = "resnet_adv_" + str(ratio) + '.pt'
        else:
            model_end = "resnet_adv_nonlinear_" + str(ratio) + '.pt'
        if not nonlinear:
            n2v_end = "resnet_n2v_adv_" + str(ratio) + '.pt'
        else:
            n2v_end = "resnet_mlp_adv_" + str(ratio) + '.pt'
    else:
        model_end = "resnet_base_" + str(ratio) + '.pt'
        n2v_end = "resnet_n2v_base_" + str(ratio) + '.pt'

    if dataset != 'bam' and dataset != 'coco':
        model_end = model_end.replace('_' + str(ratio), '')
        n2v_end = n2v_end.replace('_' + str(ratio), '')

    if dataset == 'bam' or dataset == 'coco':
        model_path, n2v_path = utils.get_paths(
            base_folder,
            seed,
            specific,
            model_end=model_end,
            n2v_end=n2v_end,
            n2v_module=module,
            experiment=experiment if experiment2 is None else experiment2,
            with_n2v=True,
        )
    else:
        model_path = os.path.join(
            base_folder, str(seed),
            experiment if experiment2 is None else experiment2, module,
            model_end)
        n2v_path = os.path.join(
            base_folder, str(seed),
            experiment if experiment2 is None else experiment2, module,
            n2v_end)
    if hasattr(trainloader.dataset, 'idx_to_class'):
        for key in trainloader.dataset.idx_to_class:
            if specific is not None and trainloader.dataset.idx_to_class[
                    key] in specific:
                specific_idx = int(key)
            else:
                specific_idx = 0
    train_labels = None if not nonlinear else [-2, -1]
    d_last = 0
    resize = constant_resize or adaptive_resize
    if imagenet:
        imagenet_trainloaders, _ = dataload.get_imagenet_tz(
            './datasets/imagenet',
            workers=8,
            train_batch_size=train_batch_size // 8,
            resize=resize,
            constant=constant_resize)
        imagenet_trainloader = dataload.process_imagenet_loaders(
            imagenet_trainloaders)

    params = list(model_n2v_combined.parameters())[:-2]
    init_alpha = alpha
    last_e = 0

    # setup training criteria
    if dataset == 'coco':
        object_weights = torch.FloatTensor(
            trainloader.dataset.getObjectWeights())
        gender_weights = torch.FloatTensor(
            trainloader.dataset.getGenderWeights())
        all_weights = torch.cat([object_weights, gender_weights])
        probe_criterion = nn.BCEWithLogitsLoss(weight=all_weights.to(device),
                                               reduction='elementwise_mean')
        downstream_criterion = nn.BCEWithLogitsLoss(
            weight=object_weights.to(device), reduction='elementwise_mean')
    else:
        probe_criterion = None
        downstream_criterion = nn.CrossEntropyLoss()

    for e in range(n_epochs):
        # save results every epoch...
        with open(results_path, 'wb') as f:
            print("saving results", e)
            print(results_path)
            pickle.dump(results, f)

        model.eval()

        with torch.no_grad():
            n2v_acc, n2v_val_loss = utils.net2vec_accuracy(
                testloader, net_forward, device, train_labels)
            n2v_accs.append(n2v_acc)
            n2v_val_losses.append(n2v_val_loss)

            if dataset != 'coco':
                class_acc, class_val_loss = utils.classification_accuracy(
                    testloader, model, device)
                class_accs.append(class_acc)
                class_val_losses.append(class_val_loss)
            else:
                f1, mAP = utils.detection_results(testloader, model, device)
                print("Epoch", e, "| f1:", f1, "| mAP:", mAP)
                class_accs.append([f1, mAP])

        d_initial = 0
        if not adversarial:
            curr_W = net.weight.data.clone()
            if not multiple:
                vg = get_vg(curr_W).reshape(-1, 1)
                d_initial = debias.debias_loss(curr_W[:-2], vg, t=0).item()
                print("Epoch", e, "bias", str(d_initial), " | debias: ",
                      debias_)
            else:
                ds = np.zeros(10)
                for i in range(10):
                    if i == 0:
                        vg = (curr_W[10, :] - curr_W[11, :]).reshape(-1, 1)
                    else:
                        vg = (curr_W[20 + i, :] - curr_W[29 + i, :]).reshape(
                            -1, 1)
                    ds[i] = debias.debias_loss(curr_W[:10], vg, t=0).item()
                print("Epoch", e, "bias", ds, " | debias: ", debias_)
                print("Accuracies:", n2v_acc)
                d_initial = ds[0]
        else:
            print("Epoch", e, "Adversarial", n2v_accs[-1])
        if adaptive_alpha and (e == 0 or ((d_last / d_initial) >=
                                          (5 / 2**(e - 1)) or
                                          (0.8 < (d_last / d_initial) < 1.2))):
            #alpha = alpha
            old_alpha = alpha
            # we don't want to increase too much if it's already decreasing
            if (e == 0 or (d_last / d_initial) >= (5 / 2**(e - 1))):
                alpha = min(
                    alpha * 2, (15 / (2**e)) / (d_initial + 1e-10)
                )  # numerical stability just in case d_initial gets really low
                #if e > 0 and old_alpha >= alpha:
                #    alpha = old_alpha # don't update if we're decreasing...
                print("Option 1")
            if e > 0 and alpha < old_alpha:
                # we want to increase if plateaud
                alpha = max(
                    old_alpha * 1.5, alpha
                )  # numerical stability just in case d_initial gets really low
                print("Option 2")
            # don't want to go over 1000...
            if alpha > 1000:
                alpha = 1000
            d_last = d_initial
        elif not adaptive_alpha and not constant_alpha:
            if dataset == 'coco' and jump_alpha:
                if e < 2:
                    alpha = 5e3
                elif e >= 2 and e < 4:
                    alpha = 1e4
                else:
                    alpha = init_alpha
            elif jump_alpha and (e - last_e) > 2:
                if not mean_debias:
                    if alpha < 100:
                        alpha = min(alpha * 2, 100)
                        last_e = e
                    else:
                        # two jumps
                        # if (e-last_e) >= ((n_epochs - last_e) // 2):
                        #     alpha = 1000
                        # else:
                        alpha = 1000
                else:
                    if alpha < 1000:
                        alpha = min(alpha * 2, 1000)
                        last_e = e
                    else:
                        alpha = 10000
            elif linear_alpha and (e - last_e) > 2:
                if alpha < 100:
                    alpha = min(alpha * 2, 100)
                    last_e = e
                else:
                    alpha += (1000 - 100) / (n_epochs - last_e)
            elif not jump_alpha and not linear_alpha:
                if (e + 1) % 3 == 0:
                    # apply alpha schedule?
                    # alpha = min(alpha * 1.2, max(init_alpha,1000))
                    alpha = alpha * 1.5
        alpha_log.append(alpha)
        print("Current Alpha:,", alpha)
        if save_every and e % 10 == 0 and e > 0 and seed == 0 and debias_:
            torch.save(net.state_dict(),
                       n2v_path.split('.pt')[0] + '_' + str(e) + '.pt')
            torch.save(model.state_dict(),
                       model_path.split('.pt')[0] + '_' + str(e) + '.pt')
        if reset and (e + 1) % reset_counter == 0 and e > 0:
            print("resetting")
            net, net_forward, activation_probe = net2vec.create_net2vec(
                model,
                module,
                num_attributes,
                device,
                pretrained=False,
                initialize=True,
                nonlinear=nonlinear)
            n2v_optim = torch.optim.SGD(net.parameters(),
                                        lr=n2v_lr,
                                        momentum=n2v_momentum)

        model.train()
        ct = 0
        for X, y, genders in trainloader:
            ids = None
            ##### Part 1: Update the Embeddings #####
            model_optim.zero_grad()
            n2v_optim.zero_grad()
            labels = utils.merge_labels(y, genders, device)
            logits = net_forward(X.to(device), switch_modes=switch_modes)
            # Now actually update net2vec embeddings, making sure to use the same batch
            if train_labels is not None:
                if logits.shape[1] == labels.shape[1]:
                    logits = logits[:, train_labels]
                labels = labels[:, train_labels]
            shapes = []
            shapes2 = []
            if dataset == 'coco':
                prelim_loss = probe_criterion(logits, labels)
            else:
                prelim_loss, ids = utils.balanced_loss(logits,
                                                       labels,
                                                       device,
                                                       0.5,
                                                       ids=ids,
                                                       multiple=multiple,
                                                       specific=specific_idx,
                                                       shapes=shapes)
            #print("prelim_loss:", prelim_loss.item())
            prelim_loss.backward()
            # we don't want to update these parameters, just in case
            model_optim.zero_grad()
            n2v_train_losses.append(prelim_loss.item())
            n2v_optim.step()
            try:
                magnitudes.append(
                    torch.norm(net.weight.data, dim=1).data.cpu().numpy())
            except:
                pass

            ##### Part 2: Update Conv parameters for classification #####
            model_optim.zero_grad()
            n2v_optim.zero_grad()
            class_logits = model(X.to(device))
            class_loss = downstream_criterion(class_logits, y.to(device))
            class_train_losses.append(class_loss.item())

            if debias_:
                W_curr = net.weight.data
                vg = get_vg(W_curr).reshape(-1, 1)
                unreduced.append(
                    debias.debias_loss(W_curr[:-2], vg, t=0,
                                       unreduced=True).data.cpu().numpy())

            loss = class_loss
            #### Part 2a: Debias Loss
            if debias_:
                model_optim.zero_grad()
                n2v_optim.zero_grad()

                labels = utils.merge_labels(y, genders, device)
                o = net.weight.clone()
                combined_optim.zero_grad()
                with higher.innerloop_ctx(model_n2v_combined,
                                          combined_optim) as (fn2v,
                                                              diffopt_n2v):
                    models.update_probe(fn2v)
                    logits = fn2v(X.to(device))
                    if dataset == 'coco':
                        prelim_loss = probe_criterion(logits, labels)
                    else:
                        prelim_loss, ids = utils.balanced_loss(
                            logits,
                            labels,
                            device,
                            0.5,
                            ids=ids,
                            multiple=False,
                            specific=specific_idx,
                            shapes=shapes2)
                    diffopt_n2v.step(prelim_loss)
                    weights = list(fn2v.parameters())[-2]
                    vg = get_vg(weights).reshape(-1, 1)
                    d_loss = debias.debias_loss(weights[:-2],
                                                vg,
                                                t=gamma,
                                                norm=norm,
                                                mean=mean_debias)
                    # only want to save the actual bias...
                    d_losses.append(d_loss.item())
                    grad_of_grads = torch.autograd.grad(
                        alpha * d_loss,
                        list(fn2v.parameters(time=0))[:-2],
                        allow_unused=True)

                    del prelim_loss
                    del logits
                    del vg
                    del fn2v
                    del diffopt_n2v
            #### Part 2b: Adversarial Loss
            if adversarial:
                logits = net_forward(
                    None, forward=True)[:, -2:]  # just use activation probe
                labels = genders.type(torch.FloatTensor).reshape(
                    genders.shape[0], -1).to(device)
                adv_loss, _ = utils.balanced_loss(logits,
                                                  labels,
                                                  device,
                                                  0.5,
                                                  ids=ids,
                                                  stable=True)
                adv_losses.append(adv_loss.item())
                # getting too strong, let it retrain...
                if adv_loss < 2:
                    adv_loss = -beta * adv_loss
                    loss += adv_loss
            loss.backward()
            if debias_:
                # custom backward to include the bias regularization....
                max_norm_grad = -1
                param_idx = -1
                for ii in range(len(grad_of_grads)):
                    if (grad_of_grads[ii] is not None
                            and params[ii].grad is not None
                            and torch.isnan(grad_of_grads[ii]).long().sum() <
                            grad_of_grads[ii].reshape(-1).shape[0]):
                        # just in case some or nan for some reason?
                        not_nan = ~torch.isnan(grad_of_grads[ii])
                        params[ii].grad[not_nan] += grad_of_grads[ii][not_nan]
                        if grad_of_grads[ii][not_nan].norm().item(
                        ) > max_norm_grad:
                            max_norm_grad = grad_of_grads[ii][not_nan].norm(
                            ).item()
                            param_idx = ii
                bias_grads.append((param_idx, max_norm_grad))
                # undo the last step and apply a smaller alpha to prevent stability issues
                if not no_limit and ((not mean_debias and max_norm_grad > 100)
                                     or (mean_debias and max_norm_grad > 100)):
                    for ii in range(len(grad_of_grads)):
                        if (grad_of_grads[ii] is not None
                                and params[ii].grad is not None and
                                torch.isnan(grad_of_grads[ii]).long().sum() <
                                grad_of_grads[ii].reshape(-1).shape[0]):
                            # just in case some or nan for some reason?
                            not_nan = ~torch.isnan(grad_of_grads[ii])
                            params[ii].grad[not_nan] -= grad_of_grads[ii][
                                not_nan]
                            # scale accordingly
                            # params[ii].grad[not_nan] += grad_of_grads[ii][not_nan] / max_norm_grad

            loss_shapes.append(shapes)
            loss_shapes2.append(shapes2)
            model_optim.step()
            #magnitudes2.append(
            #    torch.norm(net.weight.data, dim=1).data.cpu().numpy()
            #)
            ct += 1

    # save results every epoch...
    with open(results_path, 'wb') as f:
        print("saving results", e)
        print(results_path)
        pickle.dump(results, f)
    torch.save(net.state_dict(), n2v_path)
    torch.save(model.state_dict(), model_path)
Ejemplo n.º 3
0
def train_net2vec(model,
                  net,
                  net2vec,
                  epochs,
                  trainloader,
                  testloader,
                  device,
                  lr=0.01,
                  save_path='default.pt',
                  train_labels=None,
                  balanced=True,
                  p=0.5,
                  repeat=False,
                  n=None,
                  f=None,
                  multiple=False,
                  specific=None,
                  adam=False,
                  save_best=False,
                  criterion=None,
                  leakage=False):
    if adam:
        optim        = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=1e-5)
        scheduler    = None
    else:
        optim        = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9)
        scheduler    = torch.optim.lr_scheduler.ReduceLROnPlateau(optim)
    results = {
        'train_losses': [],
        'test_losses': [],
        'test_accs': []
    }
    train_losses = results['train_losses']
    test_losses  = results['test_losses']
    test_accs    = results['test_accs']
    best_acc = -1
    best_state = None
    model.eval()
    for e in range(epochs):
        tmp_train_loss = []
        tmp_test_loss  = []
        tmp_test_acc   = []
        #model.train()
        net.train()
        k = 0
        for X,y,genders in trainloader:
            optim.zero_grad()
            if repeat:
                assert n is not None
                labels = utils.repeat_labels(genders[:,0:1], n, device)
            else:
                labels = utils.merge_labels(y, genders, device)
            logits = net2vec(X.to(device), switch_modes=False)
            if train_labels is not None:
                if logits.shape[1] == labels.shape[1]:
                    logits = logits[:, train_labels]
                labels = labels[:, train_labels]
            if balanced:
                loss,_   = utils.balanced_loss(logits, labels, device, p=p)
            else:
                assert criterion is not None
                loss = criterion(logits, labels)
                if k % 10 == 0:
                    print(loss.item())
            loss.backward()
            tmp_train_loss.append(loss.item())
            optim.step()
            k += 1
        train_losses.append(np.mean(tmp_train_loss))
        model.eval()
        net.eval()
        with torch.no_grad():
            tmp_test_acc, (tmp_test_f1, tmp_test_mAP) = utils.net2vec_accuracy(
                testloader, 
                net2vec, 
                device, 
                train_labels,
                repeat,
                n,
                leakage=leakage
            )
            if leakage:
                tmp_test_acc = 0.5 + abs(tmp_test_acc - 0.5)
        if np.max(tmp_test_acc) > best_acc:
            best_acc = np.max(tmp_test_acc)
            best_state = net.state_dict()
        #if scheduler is not None:
        #    scheduler.step(np.mean(tmp_test_acc))
        test_accs.append(tmp_test_acc)
        print("Epoch", e, " :", tmp_test_acc, "f1/mAP:", tmp_test_f1, "/", tmp_test_mAP, file=f)
        if isinstance(net, nn.Linear):
            W = net.weight.data
            vg = W[-2] - W[-1]
            vg = vg / vg.norm()
            v  = W[0]
            v  = v / v.norm()
            print("projection:", vg.reshape(1,-1) @ v.reshape(-1,1))
    if save_best:
        torch.save(best_state, save_path.split('.pt')[0] + '_' + str(best_acc) + '.pt')
    else:
        torch.save(net.state_dict(), save_path)
    return results
Ejemplo n.º 4
0
def train_net2vec(model,
                  net,
                  net2vec,
                  epochs,
                  trainloader,
                  testloader,
                  device,
                  lr=0.01,
                  save_path='default.pt',
                  train_labels=None,
                  balanced=True,
                  p=0.5,
                  repeat=False,
                  n=None,
                  f=None,
                  multiple=False,
                  specific=None,
                  adam=False,
                  save_best=False,
                  criterion=None,
                  leakage=False):
    if adam:
        optim = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=1e-5)
        scheduler = None
    else:
        optim = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim)
    results = {'train_losses': [], 'test_losses': [], 'test_accs': []}
    train_losses = results['train_losses']
    test_losses = results['test_losses']
    test_accs = results['test_accs']
    best_acc = -1
    best_state = None
    model.eval()
    best_proj = 2
    best_proj_epoch = -1
    with open('projection_results.txt', 'a') as fp:
        print("Starting: " + str(epochs) + " " + str(lr), file=fp)
    for e in range(epochs):
        tmp_train_loss = []
        tmp_test_loss = []
        tmp_test_acc = []
        # model.train()
        net.train()
        k = 0
        for X, y, genders in trainloader:
            optim.zero_grad()
            if repeat:
                assert n is not None
                labels = utils.repeat_labels(genders[:, 0:1], n, device)
            else:
                labels = utils.merge_labels(y, genders, device)
            logits = net2vec(X.to(device), switch_modes=False)
            if train_labels is not None:
                if logits.shape[1] == labels.shape[1]:
                    logits = logits[:, train_labels]
                labels = labels[:, train_labels]
            if balanced:
                loss, _ = utils.balanced_loss(logits, labels, device, p=p)
            else:
                assert criterion is not None
                loss = criterion(logits, labels)
                if k % 10 == 0:
                    print(loss.item())
            loss.backward()
            tmp_train_loss.append(loss.item())
            optim.step()
            k += 1
        train_losses.append(np.mean(tmp_train_loss))
        model.eval()
        net.eval()
        with torch.no_grad():
            tmp_test_acc, (tmp_test_f1, tmp_test_mAP) = utils.net2vec_accuracy(
                testloader,
                net2vec,
                device,
                train_labels,
                repeat,
                n,
                leakage=leakage)
            if leakage:
                tmp_test_acc = 0.5 + abs(tmp_test_acc - 0.5)
        if np.max(tmp_test_acc) > best_acc:
            best_acc = np.max(tmp_test_acc)
            best_state = net.state_dict()
        # if scheduler is not None:
        #    scheduler.step(np.mean(tmp_test_acc))
        test_accs.append(tmp_test_acc)
        print("Epoch",
              e,
              " :",
              tmp_test_acc,
              "f1/mAP:",
              tmp_test_f1,
              "/",
              tmp_test_mAP,
              file=f)
        if isinstance(net, nn.Linear):
            W = net.weight.data
            vg = W[-2] - W[-1]
            vg = vg / vg.norm()
            normalized_W = normalize()(W, p=2, dim=1)
            mean_proj = (normalized_W @ vg.reshape(-1, 1)).mean().item()
            var_proj = (((normalized_W @ vg.reshape(-1, 1)) - mean_proj)**
                        2).sum().item() / (W.shape[0] - 1)
            proj = (mean_proj, ((W[0] / W[0].norm()).reshape(1, -1)
                                @ vg.reshape(-1, 1)).item(), var_proj)
            print("projection:", proj, " |", save_path)
            with open('projection_results.txt', 'a') as fp:
                print("projection:", proj, " |", save_path, file=fp)
            if abs(proj[0]) < abs(best_proj):
                best_proj = proj[0]
                best_proj_epoch = e
        if e % 5 == 0 and e > 0:
            pass  #torch.save(net.state_dict(), save_path.split('.pt')[
            #0] + '_EPOCH_{}_'.format(e) + str(proj[0]) + '_' + str(lr) + '_var_{}.pt'.format(proj[2]))
    with open('projection_results.txt', 'a') as fp:
        print("", file=fp)
    if save_best:
        torch.save(
            best_state,
            save_path.split('.pt')[0] + '_' + str(best_acc) + '_' + str(lr) +
            '.pt')
    else:
        torch.save(
            net.state_dict(),
            save_path.split('.pt')[0] +
            '{}_{}_'.format(best_proj, best_proj_epoch) + str(proj[0]) + '_' +
            str(lr) + '_var_{}.pt'.format(proj[2]))
        torch.save(net.state_dict(), save_path)
    return results