def collect_leakage(device, base_folder='./models/BAM/', specific="bowling_alley", seed=0, module="layer3", experiment="sgd_finetuned", ratios=["0.0","0.1","0.2","0.3","0.4","0.5","0.6","0.7","0.8","0.9", "1.0"], adv=False, baseline=False, epoch=None, multiple=True, force=False, dataset='bam', args=None): results = {} if dataset == 'bam': _, testloader = dataload.get_data_loader_SceneBAM(seed=seed,ratio=float(0.5), specific=specific) elif dataset != 'coco': _, testloader = dataload.get_data_loader_idenProf('idenprof',train_shuffle=True, train_batch_size=64, test_batch_size=64, exclusive=True) for ratio in ratios: model, net, net_forward, activation_probe = load_models( device, base_folder=base_folder, specific=specific, seed=seed, module=module, experiment=experiment, ratio=ratio, adv=adv, baseline=baseline, epoch=epoch, post=True, multiple=multiple, leakage=True, force=force, dataset=dataset, args=args ) model.eval() net.eval() if dataset == 'coco': tmp_args = copy.deepcopy(args) tmp_args.ratio = ratio tmp_args.gender_balanced = True if int(ratio) > 0: tmp_args.balanced = True _, testloader = coco_dataload.get_data_loader_coco( tmp_args ) results[ratio],_ = utils.net2vec_accuracy( testloader, net_forward, device, train_labels=[-2,-1] ) return results
def train( trainloader, testloader, device, seed, debias_=True, specific=None, ratio=0.5, # bias ratio in dataset n_epochs=5, model_lr=1e-3, n2v_lr=1e-3, combined_n2v_lr=1e-3, # metalearning rate for n2v alpha=100, # for debias, beta=0.1, # for adversarial loss out_file=None, base_folder="", results_folder="", experiment="sgd", momentum=0, module="layer4", finetuned=False, adversarial=False, nonlinear=False, subset=False, subset_ratio=0.1, save_every=False, model_momentum=0, n2v_momentum=0, experimental=False, multiple=False, debias_multiple=False, reset=False, reset_counter=1, n2v_start=False, experiment2=None, adaptive_alpha=False, n2v_adam=False, single=False, imagenet=False, train_batch_size=64, constant_resize=False, adaptive_resize=False, no_class=False, gamma=0, partial_projection=False, norm='l2', constant_alpha=False, jump_alpha=False, linear_alpha=False, mean_debias=False, no_limit=False, dataset='bam', parallel=False, gpu_ids=[], switch_modes=True): print("mu", momentum, "debias", debias_, "alpha", alpha, " | ratio:", ratio) def get_vg(W): if single: return W[-2, :] else: return W[-2, :] - W[-1, :] if dataset == 'bam' or dataset == 'coco': model_init_path, n2v_init_path = utils.get_paths( base_folder, seed, specific, model_end="resnet_init" + '.pt', n2v_end="resnet_n2v_init" + '.pt', n2v_module=module, experiment=experiment, with_n2v=False) else: model_init_path = os.path.join(base_folder, str(seed), experiment, 'resnet_init.pt') n2v_init_path = os.path.join(base_folder, str(seed), experiment, module, 'resnet_n2v_init.pt') if finetuned: if dataset == 'bam' or dataset == 'coco': model_init_path = utils.get_model_path( base_folder, seed, specific, "resnet_" + str(ratio) + ".pt", experiment='post_train' if not n2v_start else experiment.split('_finetuned')[0]) else: model_init_path = os.path.join( base_folder, str(seed), 'post_train' if not n2v_start else experiment.split('_finetuned')[0], 'resnet.pt') assert (debias_ and not adversarial) or ( adversarial and not debias_) or (not adversarial and not debias_) if debias_ and n2v_start: ext = "_n2v_" if not nonlinear else "_mlp_" if dataset == 'bam' or dataset == 'coco': n2v_init_path = utils.get_net2vec_path( base_folder, seed, specific, module, "resnet" + str(ext) + str(ratio) + ".pt", experiment=experiment.split('_finetuned')[0]) else: n2v_init_path = os.path.join(base_folder, str(seed), experiment.split('_finetuned')[0], module, 'resnet' + ext[:-1] + '.pt') # if we're also doing adversarial, make sure to load the matching n2v as init... if adversarial: ext = "_n2v_" if not nonlinear else "_mlp_" if dataset == 'bam' or dataset == 'coco': n2v_init_path = utils.get_net2vec_path(base_folder, seed, specific, module, "resnet" + str(ext) + str(ratio) + ".pt", experiment='post_train') else: n2v_init_path = os.path.join(base_folder, str(seed), 'post_train', module, 'resnet' + ext[:-1] + '.pt') num_classes = 10 num_attributes = 12 if nonlinear: num_attributes = 2 if multiple: num_attributes = 10 + 9 + 2 * 10 if dataset == 'coco': num_classes = 79 num_attributes = 81 model, net, net_forward, activation_probe = models.load_models( device, lambda x, y, z: models.resnet_(pretrained=True, custom_path=x, device=y, initialize=z, num_classes=num_classes, size=50 if (dataset == 'bam' or dataset == 'coco') else 34), model_path=model_init_path, net2vec_pretrained=True, net2vec_path=n2v_init_path, module=module, num_attributes=num_attributes, # we want to make sure to save the inits if not finetuned... model_init=True if not finetuned else False, n2v_init=True if not (finetuned and (adversarial or (debias_ and n2v_start))) else False, loader=trainloader, nonlinear=nonlinear, # parameters if we want to initially project probes to have a certain amount of bias partial_projection=partial_projection, t=gamma) print(model_init_path, n2v_init_path) model_n2v_combined = models.ProbedModel(model, net, module, switch_modes=switch_modes) if n2v_adam: combined_optim = torch.optim.Adam( [{ 'params': model_n2v_combined.model.parameters() }, { 'params': model_n2v_combined.net.parameters() }], lr=n2v_lr) # TODO: allow for momentum training as well n2v_optim = torch.optim.Adam(net.parameters(), lr=n2v_lr) else: combined_optim = torch.optim.SGD( [{ 'params': model_n2v_combined.model.parameters() }, { 'params': model_n2v_combined.net.parameters(), 'lr': combined_n2v_lr, 'momentum': n2v_momentum }], lr=model_lr, momentum=model_momentum) # TODO: allow for momentum training as well n2v_optim = torch.optim.SGD(net.parameters(), lr=n2v_lr, momentum=n2v_momentum) model_optim = torch.optim.SGD(model.parameters(), lr=model_lr, momentum=model_momentum) d_losses = [] adv_losses = [] n2v_train_losses = [] n2v_accs = [] n2v_val_losses = [] class_train_losses = [] class_accs = [] class_val_losses = [] alpha_log = [] magnitudes = [] magnitudes2 = [] unreduced = [] bias_grads = [] loss_shapes = [] loss_shapes2 = [] results = { "debias_losses": d_losses, "n2v_train_losses": n2v_train_losses, "n2v_val_losses": n2v_val_losses, "n2v_accs": n2v_accs, "class_train_losses": class_train_losses, "class_val_losses": class_val_losses, "class_accs": class_accs, "adv_losses": adv_losses, "alphas": alpha_log, "magnitudes": magnitudes, "magnitudes2": magnitudes2, "unreduced": unreduced, "bias_grads": bias_grads, "loss_shapes": loss_shapes, "loss_shapes2": loss_shapes2 } if debias_: results_end = str(ratio) + "_debias.pck" elif adversarial: results_end = str(ratio) + "_adv.pck" if nonlinear: results_end = str(ratio) + "_mlp_adv.pck" else: results_end = str(ratio) + "_base.pck" if dataset == 'bam' or dataset == 'coco': results_path = utils.get_net2vec_path( results_folder, seed, specific, module, results_end, experiment if experiment2 is None else experiment2) else: results_path = os.path.join( results_folder, str(seed), experiment if experiment2 is None else experiment2, module, results_end) if debias_: model_end = "resnet_debias_" + str(ratio) + '.pt' n2v_end = "resnet_n2v_debias_" + str(ratio) + '.pt' elif adversarial: if not nonlinear: model_end = "resnet_adv_" + str(ratio) + '.pt' else: model_end = "resnet_adv_nonlinear_" + str(ratio) + '.pt' if not nonlinear: n2v_end = "resnet_n2v_adv_" + str(ratio) + '.pt' else: n2v_end = "resnet_mlp_adv_" + str(ratio) + '.pt' else: model_end = "resnet_base_" + str(ratio) + '.pt' n2v_end = "resnet_n2v_base_" + str(ratio) + '.pt' if dataset != 'bam' and dataset != 'coco': model_end = model_end.replace('_' + str(ratio), '') n2v_end = n2v_end.replace('_' + str(ratio), '') if dataset == 'bam' or dataset == 'coco': model_path, n2v_path = utils.get_paths( base_folder, seed, specific, model_end=model_end, n2v_end=n2v_end, n2v_module=module, experiment=experiment if experiment2 is None else experiment2, with_n2v=True, ) else: model_path = os.path.join( base_folder, str(seed), experiment if experiment2 is None else experiment2, module, model_end) n2v_path = os.path.join( base_folder, str(seed), experiment if experiment2 is None else experiment2, module, n2v_end) if hasattr(trainloader.dataset, 'idx_to_class'): for key in trainloader.dataset.idx_to_class: if specific is not None and trainloader.dataset.idx_to_class[ key] in specific: specific_idx = int(key) else: specific_idx = 0 train_labels = None if not nonlinear else [-2, -1] d_last = 0 resize = constant_resize or adaptive_resize if imagenet: imagenet_trainloaders, _ = dataload.get_imagenet_tz( './datasets/imagenet', workers=8, train_batch_size=train_batch_size // 8, resize=resize, constant=constant_resize) imagenet_trainloader = dataload.process_imagenet_loaders( imagenet_trainloaders) params = list(model_n2v_combined.parameters())[:-2] init_alpha = alpha last_e = 0 # setup training criteria if dataset == 'coco': object_weights = torch.FloatTensor( trainloader.dataset.getObjectWeights()) gender_weights = torch.FloatTensor( trainloader.dataset.getGenderWeights()) all_weights = torch.cat([object_weights, gender_weights]) probe_criterion = nn.BCEWithLogitsLoss(weight=all_weights.to(device), reduction='elementwise_mean') downstream_criterion = nn.BCEWithLogitsLoss( weight=object_weights.to(device), reduction='elementwise_mean') else: probe_criterion = None downstream_criterion = nn.CrossEntropyLoss() for e in range(n_epochs): # save results every epoch... with open(results_path, 'wb') as f: print("saving results", e) print(results_path) pickle.dump(results, f) model.eval() with torch.no_grad(): n2v_acc, n2v_val_loss = utils.net2vec_accuracy( testloader, net_forward, device, train_labels) n2v_accs.append(n2v_acc) n2v_val_losses.append(n2v_val_loss) if dataset != 'coco': class_acc, class_val_loss = utils.classification_accuracy( testloader, model, device) class_accs.append(class_acc) class_val_losses.append(class_val_loss) else: f1, mAP = utils.detection_results(testloader, model, device) print("Epoch", e, "| f1:", f1, "| mAP:", mAP) class_accs.append([f1, mAP]) d_initial = 0 if not adversarial: curr_W = net.weight.data.clone() if not multiple: vg = get_vg(curr_W).reshape(-1, 1) d_initial = debias.debias_loss(curr_W[:-2], vg, t=0).item() print("Epoch", e, "bias", str(d_initial), " | debias: ", debias_) else: ds = np.zeros(10) for i in range(10): if i == 0: vg = (curr_W[10, :] - curr_W[11, :]).reshape(-1, 1) else: vg = (curr_W[20 + i, :] - curr_W[29 + i, :]).reshape( -1, 1) ds[i] = debias.debias_loss(curr_W[:10], vg, t=0).item() print("Epoch", e, "bias", ds, " | debias: ", debias_) print("Accuracies:", n2v_acc) d_initial = ds[0] else: print("Epoch", e, "Adversarial", n2v_accs[-1]) if adaptive_alpha and (e == 0 or ((d_last / d_initial) >= (5 / 2**(e - 1)) or (0.8 < (d_last / d_initial) < 1.2))): #alpha = alpha old_alpha = alpha # we don't want to increase too much if it's already decreasing if (e == 0 or (d_last / d_initial) >= (5 / 2**(e - 1))): alpha = min( alpha * 2, (15 / (2**e)) / (d_initial + 1e-10) ) # numerical stability just in case d_initial gets really low #if e > 0 and old_alpha >= alpha: # alpha = old_alpha # don't update if we're decreasing... print("Option 1") if e > 0 and alpha < old_alpha: # we want to increase if plateaud alpha = max( old_alpha * 1.5, alpha ) # numerical stability just in case d_initial gets really low print("Option 2") # don't want to go over 1000... if alpha > 1000: alpha = 1000 d_last = d_initial elif not adaptive_alpha and not constant_alpha: if dataset == 'coco' and jump_alpha: if e < 2: alpha = 5e3 elif e >= 2 and e < 4: alpha = 1e4 else: alpha = init_alpha elif jump_alpha and (e - last_e) > 2: if not mean_debias: if alpha < 100: alpha = min(alpha * 2, 100) last_e = e else: # two jumps # if (e-last_e) >= ((n_epochs - last_e) // 2): # alpha = 1000 # else: alpha = 1000 else: if alpha < 1000: alpha = min(alpha * 2, 1000) last_e = e else: alpha = 10000 elif linear_alpha and (e - last_e) > 2: if alpha < 100: alpha = min(alpha * 2, 100) last_e = e else: alpha += (1000 - 100) / (n_epochs - last_e) elif not jump_alpha and not linear_alpha: if (e + 1) % 3 == 0: # apply alpha schedule? # alpha = min(alpha * 1.2, max(init_alpha,1000)) alpha = alpha * 1.5 alpha_log.append(alpha) print("Current Alpha:,", alpha) if save_every and e % 10 == 0 and e > 0 and seed == 0 and debias_: torch.save(net.state_dict(), n2v_path.split('.pt')[0] + '_' + str(e) + '.pt') torch.save(model.state_dict(), model_path.split('.pt')[0] + '_' + str(e) + '.pt') if reset and (e + 1) % reset_counter == 0 and e > 0: print("resetting") net, net_forward, activation_probe = net2vec.create_net2vec( model, module, num_attributes, device, pretrained=False, initialize=True, nonlinear=nonlinear) n2v_optim = torch.optim.SGD(net.parameters(), lr=n2v_lr, momentum=n2v_momentum) model.train() ct = 0 for X, y, genders in trainloader: ids = None ##### Part 1: Update the Embeddings ##### model_optim.zero_grad() n2v_optim.zero_grad() labels = utils.merge_labels(y, genders, device) logits = net_forward(X.to(device), switch_modes=switch_modes) # Now actually update net2vec embeddings, making sure to use the same batch if train_labels is not None: if logits.shape[1] == labels.shape[1]: logits = logits[:, train_labels] labels = labels[:, train_labels] shapes = [] shapes2 = [] if dataset == 'coco': prelim_loss = probe_criterion(logits, labels) else: prelim_loss, ids = utils.balanced_loss(logits, labels, device, 0.5, ids=ids, multiple=multiple, specific=specific_idx, shapes=shapes) #print("prelim_loss:", prelim_loss.item()) prelim_loss.backward() # we don't want to update these parameters, just in case model_optim.zero_grad() n2v_train_losses.append(prelim_loss.item()) n2v_optim.step() try: magnitudes.append( torch.norm(net.weight.data, dim=1).data.cpu().numpy()) except: pass ##### Part 2: Update Conv parameters for classification ##### model_optim.zero_grad() n2v_optim.zero_grad() class_logits = model(X.to(device)) class_loss = downstream_criterion(class_logits, y.to(device)) class_train_losses.append(class_loss.item()) if debias_: W_curr = net.weight.data vg = get_vg(W_curr).reshape(-1, 1) unreduced.append( debias.debias_loss(W_curr[:-2], vg, t=0, unreduced=True).data.cpu().numpy()) loss = class_loss #### Part 2a: Debias Loss if debias_: model_optim.zero_grad() n2v_optim.zero_grad() labels = utils.merge_labels(y, genders, device) o = net.weight.clone() combined_optim.zero_grad() with higher.innerloop_ctx(model_n2v_combined, combined_optim) as (fn2v, diffopt_n2v): models.update_probe(fn2v) logits = fn2v(X.to(device)) if dataset == 'coco': prelim_loss = probe_criterion(logits, labels) else: prelim_loss, ids = utils.balanced_loss( logits, labels, device, 0.5, ids=ids, multiple=False, specific=specific_idx, shapes=shapes2) diffopt_n2v.step(prelim_loss) weights = list(fn2v.parameters())[-2] vg = get_vg(weights).reshape(-1, 1) d_loss = debias.debias_loss(weights[:-2], vg, t=gamma, norm=norm, mean=mean_debias) # only want to save the actual bias... d_losses.append(d_loss.item()) grad_of_grads = torch.autograd.grad( alpha * d_loss, list(fn2v.parameters(time=0))[:-2], allow_unused=True) del prelim_loss del logits del vg del fn2v del diffopt_n2v #### Part 2b: Adversarial Loss if adversarial: logits = net_forward( None, forward=True)[:, -2:] # just use activation probe labels = genders.type(torch.FloatTensor).reshape( genders.shape[0], -1).to(device) adv_loss, _ = utils.balanced_loss(logits, labels, device, 0.5, ids=ids, stable=True) adv_losses.append(adv_loss.item()) # getting too strong, let it retrain... if adv_loss < 2: adv_loss = -beta * adv_loss loss += adv_loss loss.backward() if debias_: # custom backward to include the bias regularization.... max_norm_grad = -1 param_idx = -1 for ii in range(len(grad_of_grads)): if (grad_of_grads[ii] is not None and params[ii].grad is not None and torch.isnan(grad_of_grads[ii]).long().sum() < grad_of_grads[ii].reshape(-1).shape[0]): # just in case some or nan for some reason? not_nan = ~torch.isnan(grad_of_grads[ii]) params[ii].grad[not_nan] += grad_of_grads[ii][not_nan] if grad_of_grads[ii][not_nan].norm().item( ) > max_norm_grad: max_norm_grad = grad_of_grads[ii][not_nan].norm( ).item() param_idx = ii bias_grads.append((param_idx, max_norm_grad)) # undo the last step and apply a smaller alpha to prevent stability issues if not no_limit and ((not mean_debias and max_norm_grad > 100) or (mean_debias and max_norm_grad > 100)): for ii in range(len(grad_of_grads)): if (grad_of_grads[ii] is not None and params[ii].grad is not None and torch.isnan(grad_of_grads[ii]).long().sum() < grad_of_grads[ii].reshape(-1).shape[0]): # just in case some or nan for some reason? not_nan = ~torch.isnan(grad_of_grads[ii]) params[ii].grad[not_nan] -= grad_of_grads[ii][ not_nan] # scale accordingly # params[ii].grad[not_nan] += grad_of_grads[ii][not_nan] / max_norm_grad loss_shapes.append(shapes) loss_shapes2.append(shapes2) model_optim.step() #magnitudes2.append( # torch.norm(net.weight.data, dim=1).data.cpu().numpy() #) ct += 1 # save results every epoch... with open(results_path, 'wb') as f: print("saving results", e) print(results_path) pickle.dump(results, f) torch.save(net.state_dict(), n2v_path) torch.save(model.state_dict(), model_path)
def train_net2vec(model, net, net2vec, epochs, trainloader, testloader, device, lr=0.01, save_path='default.pt', train_labels=None, balanced=True, p=0.5, repeat=False, n=None, f=None, multiple=False, specific=None, adam=False, save_best=False, criterion=None, leakage=False): if adam: optim = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=1e-5) scheduler = None else: optim = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim) results = { 'train_losses': [], 'test_losses': [], 'test_accs': [] } train_losses = results['train_losses'] test_losses = results['test_losses'] test_accs = results['test_accs'] best_acc = -1 best_state = None model.eval() for e in range(epochs): tmp_train_loss = [] tmp_test_loss = [] tmp_test_acc = [] #model.train() net.train() k = 0 for X,y,genders in trainloader: optim.zero_grad() if repeat: assert n is not None labels = utils.repeat_labels(genders[:,0:1], n, device) else: labels = utils.merge_labels(y, genders, device) logits = net2vec(X.to(device), switch_modes=False) if train_labels is not None: if logits.shape[1] == labels.shape[1]: logits = logits[:, train_labels] labels = labels[:, train_labels] if balanced: loss,_ = utils.balanced_loss(logits, labels, device, p=p) else: assert criterion is not None loss = criterion(logits, labels) if k % 10 == 0: print(loss.item()) loss.backward() tmp_train_loss.append(loss.item()) optim.step() k += 1 train_losses.append(np.mean(tmp_train_loss)) model.eval() net.eval() with torch.no_grad(): tmp_test_acc, (tmp_test_f1, tmp_test_mAP) = utils.net2vec_accuracy( testloader, net2vec, device, train_labels, repeat, n, leakage=leakage ) if leakage: tmp_test_acc = 0.5 + abs(tmp_test_acc - 0.5) if np.max(tmp_test_acc) > best_acc: best_acc = np.max(tmp_test_acc) best_state = net.state_dict() #if scheduler is not None: # scheduler.step(np.mean(tmp_test_acc)) test_accs.append(tmp_test_acc) print("Epoch", e, " :", tmp_test_acc, "f1/mAP:", tmp_test_f1, "/", tmp_test_mAP, file=f) if isinstance(net, nn.Linear): W = net.weight.data vg = W[-2] - W[-1] vg = vg / vg.norm() v = W[0] v = v / v.norm() print("projection:", vg.reshape(1,-1) @ v.reshape(-1,1)) if save_best: torch.save(best_state, save_path.split('.pt')[0] + '_' + str(best_acc) + '.pt') else: torch.save(net.state_dict(), save_path) return results
def train_net2vec(model, net, net2vec, epochs, trainloader, testloader, device, lr=0.01, save_path='default.pt', train_labels=None, balanced=True, p=0.5, repeat=False, n=None, f=None, multiple=False, specific=None, adam=False, save_best=False, criterion=None, leakage=False): if adam: optim = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=1e-5) scheduler = None else: optim = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim) results = {'train_losses': [], 'test_losses': [], 'test_accs': []} train_losses = results['train_losses'] test_losses = results['test_losses'] test_accs = results['test_accs'] best_acc = -1 best_state = None model.eval() best_proj = 2 best_proj_epoch = -1 with open('projection_results.txt', 'a') as fp: print("Starting: " + str(epochs) + " " + str(lr), file=fp) for e in range(epochs): tmp_train_loss = [] tmp_test_loss = [] tmp_test_acc = [] # model.train() net.train() k = 0 for X, y, genders in trainloader: optim.zero_grad() if repeat: assert n is not None labels = utils.repeat_labels(genders[:, 0:1], n, device) else: labels = utils.merge_labels(y, genders, device) logits = net2vec(X.to(device), switch_modes=False) if train_labels is not None: if logits.shape[1] == labels.shape[1]: logits = logits[:, train_labels] labels = labels[:, train_labels] if balanced: loss, _ = utils.balanced_loss(logits, labels, device, p=p) else: assert criterion is not None loss = criterion(logits, labels) if k % 10 == 0: print(loss.item()) loss.backward() tmp_train_loss.append(loss.item()) optim.step() k += 1 train_losses.append(np.mean(tmp_train_loss)) model.eval() net.eval() with torch.no_grad(): tmp_test_acc, (tmp_test_f1, tmp_test_mAP) = utils.net2vec_accuracy( testloader, net2vec, device, train_labels, repeat, n, leakage=leakage) if leakage: tmp_test_acc = 0.5 + abs(tmp_test_acc - 0.5) if np.max(tmp_test_acc) > best_acc: best_acc = np.max(tmp_test_acc) best_state = net.state_dict() # if scheduler is not None: # scheduler.step(np.mean(tmp_test_acc)) test_accs.append(tmp_test_acc) print("Epoch", e, " :", tmp_test_acc, "f1/mAP:", tmp_test_f1, "/", tmp_test_mAP, file=f) if isinstance(net, nn.Linear): W = net.weight.data vg = W[-2] - W[-1] vg = vg / vg.norm() normalized_W = normalize()(W, p=2, dim=1) mean_proj = (normalized_W @ vg.reshape(-1, 1)).mean().item() var_proj = (((normalized_W @ vg.reshape(-1, 1)) - mean_proj)** 2).sum().item() / (W.shape[0] - 1) proj = (mean_proj, ((W[0] / W[0].norm()).reshape(1, -1) @ vg.reshape(-1, 1)).item(), var_proj) print("projection:", proj, " |", save_path) with open('projection_results.txt', 'a') as fp: print("projection:", proj, " |", save_path, file=fp) if abs(proj[0]) < abs(best_proj): best_proj = proj[0] best_proj_epoch = e if e % 5 == 0 and e > 0: pass #torch.save(net.state_dict(), save_path.split('.pt')[ #0] + '_EPOCH_{}_'.format(e) + str(proj[0]) + '_' + str(lr) + '_var_{}.pt'.format(proj[2])) with open('projection_results.txt', 'a') as fp: print("", file=fp) if save_best: torch.save( best_state, save_path.split('.pt')[0] + '_' + str(best_acc) + '_' + str(lr) + '.pt') else: torch.save( net.state_dict(), save_path.split('.pt')[0] + '{}_{}_'.format(best_proj, best_proj_epoch) + str(proj[0]) + '_' + str(lr) + '_var_{}.pt'.format(proj[2])) torch.save(net.state_dict(), save_path) return results