def test_basic(self): """ Basic test to check that the calculation is sensible and conforms to the formula. """ test_tensor = torch.Tensor([[0.5, 0.5], [0.0, 1.0]]) output = target_distribution(test_tensor) self.assertAlmostEqual(tuple(output[0]), (0.75, 0.25)) self.assertAlmostEqual(tuple(output[1]), (0.0, 1.0))
def activeUnsL(model, node, label, features, adj_lists, num_features, num_hidden, num_cls, filetime,labels, xi=1e-6, eps=2.5, num_iters=10): #obtain the adj matrix and find the best perturbation direction then add perturbation to the attention matrix encSpc = model(node, actE = True) dec_ae = DEC_AE(50, 100, num_hidden) dec = DEC(num_cls, 50, dec_ae) kmeans = KMeans(n_clusters=dec.cluster_number, n_init=20) features = [] # form initial cluster centres dec.pretrain(encSpc.data) features = dec.ae.encoder(encSpc).detach() predicted = kmeans.fit_predict(features) predicted_previous = torch.tensor(np.copy(predicted), dtype=torch.long) _, accuracy, _, _ = cluster_accuracy( label, predicted) print("ACCU", accuracy) cluster_centers = torch.tensor(kmeans.cluster_centers_, dtype=torch.float) #print(features) dec.assignment.cluster_centers = torch.nn.Parameter(cluster_centers) loss_function = nn.KLDivLoss(size_average=False) delta_label = None optimizer = torch.optim.SGD(dec.parameters(), lr = 0.01, momentum=0.9) for epoch in range(250): dec.train() output = dec(encSpc) target = target_distribution(output).detach() loss = loss_function(output.log(), target) / output.shape[0] optimizer.zero_grad() loss.backward(retain_graph=True) optimizer.step(closure=None) features = dec.ae.encoder(encSpc).detach() #predicted = dec(test) predicted = output.argmax(dim = 1) delta_label = float((predicted != predicted_previous ).float().sum().item()) / predicted_previous.shape[0] predicted_previous = predicted _, accuracy, _, _ = cluster_accuracy(np.array(predicted), np.array(label)) if epoch % 50 == 49: count_matrix = np.zeros((num_cls, num_cls), dtype=np.int64) for i in range(len(predicted)): count_matrix[np.array(predicted)[i], np.array(label)[i]] += 1 for i in range(num_cls): print(count_matrix[i]) summary(node, labels, np.array(predicted), num_cls, filetime, outlog = False, output=None) print(loss) print("ACCU", accuracy)
def Train(epoch, model, data, adj, label, lr, pre_model_save_path, final_model_save_path, n_clusters, original_acc, gamma_value, lambda_value, device): optimizer = Adam(model.parameters(), lr=lr) model.load_state_dict(torch.load(pre_model_save_path, map_location='cpu')) with torch.no_grad(): x_hat, z_hat, adj_hat, z_ae, z_igae, _, _, _, z_tilde = model( data, adj) kmeans = KMeans(n_clusters=n_clusters, n_init=20) cluster_id = kmeans.fit_predict(z_tilde.data.cpu().numpy()) model.cluster_layer.data = torch.tensor(kmeans.cluster_centers_).to(device) eva(label, cluster_id, 'Initialization') for epoch in range(epoch): # if opt.args.name in use_adjust_lr: # adjust_learning_rate(optimizer, epoch) x_hat, z_hat, adj_hat, z_ae, z_igae, q, q1, q2, z_tilde = model( data, adj) tmp_q = q.data p = target_distribution(tmp_q) loss_ae = F.mse_loss(x_hat, data) loss_w = F.mse_loss(z_hat, torch.spmm(adj, data)) loss_a = F.mse_loss(adj_hat, adj.to_dense()) loss_igae = loss_w + gamma_value * loss_a loss_kl = F.kl_div((q.log() + q1.log() + q2.log()) / 3, p, reduction='batchmean') loss = loss_ae + loss_igae + lambda_value * loss_kl print('{} loss: {}'.format(epoch, loss)) optimizer.zero_grad() loss.backward() optimizer.step() kmeans = KMeans(n_clusters=n_clusters, n_init=20).fit(z_tilde.data.cpu().numpy()) acc, nmi, ari, f1 = eva(label, kmeans.labels_, epoch) acc_reuslt.append(acc) nmi_result.append(nmi) ari_result.append(ari) f1_result.append(f1) if acc > original_acc: original_acc = acc torch.save(model.state_dict(), final_model_save_path)
def main(): set_seed(args.seed) torch.cuda.set_device(args.gpu) encoder = Encoder().cuda() encoder_group_0 = Encoder().cuda() encoder_group_1 = Encoder().cuda() dfc = DFC(cluster_number=args.k, hidden_dimension=64).cuda() dfc_group_0 = DFC(cluster_number=args.k, hidden_dimension=64).cuda() dfc_group_1 = DFC(cluster_number=args.k, hidden_dimension=64).cuda() critic = AdversarialNetwork(in_feature=args.k, hidden_size=32, max_iter=args.iters, lr_mult=args.adv_mult).cuda() # encoder pre-trained with self-reconstruction encoder.load_state_dict(torch.load("./save/encoder_pretrain.pth")) # encoder and clustering model trained by DEC encoder_group_0.load_state_dict(torch.load("./save/encoder_mnist.pth")) encoder_group_1.load_state_dict(torch.load("./save/encoder_usps.pth")) dfc_group_0.load_state_dict(torch.load("./save/dec_mnist.pth")) dfc_group_1.load_state_dict(torch.load("./save/dec_usps.pth")) # load clustering centroids given by k-means centers = np.loadtxt("./save/centers.txt") cluster_centers = torch.tensor(centers, dtype=torch.float, requires_grad=True).cuda() with torch.no_grad(): print("loading clustering centers...") dfc.state_dict()['assignment.cluster_centers'].copy_(cluster_centers) optimizer = torch.optim.Adam(dfc.get_parameters() + encoder.get_parameters() + critic.get_parameters(), lr=args.lr, weight_decay=5e-4) criterion_c = nn.KLDivLoss(reduction="sum") criterion_p = nn.MSELoss(reduction="sum") C_LOSS = AverageMeter() F_LOSS = AverageMeter() P_LOSS = AverageMeter() encoder_group_0.eval(), encoder_group_1.eval() dfc_group_0.eval(), dfc_group_1.eval() data_loader = mnist_usps(args) len_image_0 = len(data_loader[0]) len_image_1 = len(data_loader[1]) for step in range(args.iters): encoder.train() dfc.train() if step % len_image_0 == 0: iter_image_0 = iter(data_loader[0]) if step % len_image_1 == 0: iter_image_1 = iter(data_loader[1]) image_0, _ = iter_image_0.__next__() image_1, _ = iter_image_1.__next__() image_0, image_1 = image_0.cuda(), image_1.cuda() image = torch.cat((image_0, image_1), dim=0) predict_0, predict_1 = dfc_group_0( encoder_group_0(image_0)[0]), dfc_group_1( encoder_group_1(image_1)[0]) z, _, _ = encoder(image) output = dfc(z) output_0, output_1 = output[0:args.bs, :], output[args.bs:args.bs * 2, :] target_0, target_1 = target_distribution( output_0).detach(), target_distribution(output_1).detach() clustering_loss = 0.5 * criterion_c(output_0.log( ), target_0) + 0.5 * criterion_c(output_1.log(), target_1) fair_loss = adv_loss(output, critic) partition_loss = 0.5 * criterion_p(aff(output_0), aff(predict_0).detach()) \ + 0.5 * criterion_p(aff(output_1), aff(predict_1).detach()) total_loss = clustering_loss + args.coeff_fair * fair_loss + args.coeff_par * partition_loss optimizer = inv_lr_scheduler(optimizer, args.lr, step, args.iters) optimizer.zero_grad() total_loss.backward() optimizer.step() C_LOSS.update(clustering_loss) F_LOSS.update(fair_loss) P_LOSS.update(partition_loss) if step % args.test_interval == args.test_interval - 1 or step == 0: predicted, labels = predict(data_loader, encoder, dfc) predicted, labels = predicted.cpu().numpy(), labels.numpy() _, accuracy = cluster_accuracy(predicted, labels, 10) nmi = normalized_mutual_info_score(labels, predicted, average_method="arithmetic") bal, en_0, en_1 = balance(predicted, 60000) print("Step:[{:03d}/{:03d}] " "Acc:{:2.3f};" "NMI:{:1.3f};" "Bal:{:1.3f};" "En:{:1.3f}/{:1.3f};" "C.loss:{C_Loss.avg:3.2f};" "F.loss:{F_Loss.avg:3.2f};" "P.loss:{P_Loss.avg:3.2f};".format(step + 1, args.iters, accuracy, nmi, bal, en_0, en_1, C_Loss=C_LOSS, F_Loss=F_LOSS, P_Loss=P_LOSS)) return
# logging file if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) logfile = open(args.save_dir + '/dec_log_{}.csv'.format(i), 'w') logwriter = csv.DictWriter( logfile, fieldnames=['iter', 'acc', 'nmi', 'ari', 'L']) logwriter.writeheader() loss = 0 idx = 0 t0 = time() for ite in range(int(args.maxiter)): if ite % args.update_interval == 0: q = model.predict_generator(AE_generator, verbose=1) p = target_distribution( q) # update the auxiliary target distribution p print(p.shape) # evaluate the clustering performance y_pred = q.argmax(1) if y_true is not None: acc = np.round(cluster_acc(y_true, y_pred), 5) nmi = np.round( metrics.normalized_mutual_info_score( y_true, y_pred), 5) ari = np.round( metrics.adjusted_rand_score(y_true, y_pred), 5) loss = np.round(loss, 5) logwriter.writerow( dict(iter=ite, acc=acc, nmi=nmi, ari=ari, L=loss)) print( 'Iter-%d: ACC= %.4f, NMI= %.4f, ARI= %.4f; L= %.5f'
def clustering(self, x, y=None, update_interval=100, maxiter=200, save_dir='./results/dec'): print('Updating auxilliary distribution after %d iterations' % update_interval) save_interval = 10 # 10 epochs print('Saving models after %d iterations' % save_interval) # initialize cluster centers using k-means print('Initializing cluster centers with k-means.') k_means = KMeans(n_clusters=self.n_clusters, n_init=20, random_state=42) y_pred = k_means.fit_predict(self.encoder.predict(x)) self.model.get_layer(name='clustering').set_weights( [k_means.cluster_centers_]) # logging file if not os.path.exists(save_dir): os.makedirs(save_dir) log_file = open(save_dir + '/dec_log.csv', 'w') log_writer = csv.DictWriter( log_file, fieldnames=['iter', 'acc', 'nmi', 'ari', 'L']) log_writer.writeheader() loss = 0 for ite in range(int(maxiter)): if ite % update_interval == 0: q = self.model.predict(x, verbose=0) p = target_distribution( q) # update the auxiliary target distribution p # evaluate the clustering performance y_pred = q.argmax(1) acc, nmi, ari = get_acc_nmi_ari(y, y_pred) loss = np.round(loss, 5) log_dict = dict(iter=ite, acc=acc, nmi=nmi, ari=ari, L=loss) log_writer.writerow(log_dict) print('Iter', ite, ': Acc', acc, ', nmi', nmi, ', ari', ari, '; loss=', loss) # training on whole data loss = self.model.train_on_batch(x=x, y=p) # save intermediate model if ite % save_interval == 0: # save DEC model checkpoints print('saving model to:', save_dir + '/DEC_model_' + str(ite) + '.h5') self.model.save(save_dir + '/DEC_model_' + str(ite) + '.h5') ite += 1 # save the trained model log_file.close() print('saving model to:', save_dir + '/DEC_model_final.h5') self.model.save(save_dir + '/DEC_model_final.h5') return y_pred
def train(args, dataloader_list, encoder, device='cpu', centers=None, save_name='DEC'): """ Trains DFC and optionally the critic, automatically saves when finished training Args: args: Namespace object which contains config set from argument parser { lr, seed, iters, log_dir, test_interval, adv_multiplier, dfc_hidden_dim } dataloader_list (list): this list may consist of only 1 dataloader or multiple encoder: Encoder to use encoder_group_0: Optional pre-trained golden standard model encoder_group_1: Optional pre-trained golden standard model dfc_group_0: Optional cluster centers file obtained with encoder_group_0 dfc_group_1: Optional cluster centers file obtained with encoder_group_1 device: Device configuration centers: Initial centers clusters if available get_loss_trade_off: Proportional importance of individual loss functions save_name: Prefix for save files Returns: DFC: A trained DFC model """ # """ # Function for training and testing a VAE model. # Inputs: # args - # """ set_seed(args.seed) if args.half_tensor: torch.set_default_tensor_type('torch.HalfTensor') dec = DFC(cluster_number=args.cluster_number, hidden_dimension=args.dfc_hidden_dim).to(device) wandb.watch(dec) if not (centers is None): cluster_centers = centers.clone().detach().requires_grad_(True).to(device) with torch.no_grad(): print("loading clustering centers...") dec.state_dict()['assignment.cluster_centers'].copy_(cluster_centers) # depending on the encoder we get the params diff so we have to use this if encoder_param = encoder.get_parameters() if args.encoder_type == 'vae' else [ {"params": get_update_param(encoder), "lr_mult": 1}] optimizer = torch.optim.Adam(dec.get_parameters() + encoder_param, lr=args.dec_lr) # criterion_c = nn.KLDivLoss(reduction="sum") # following dec code more closely criterion_c = nn.KLDivLoss(size_average=False) C_LOSS = AverageMeter() print("Start training") assert 0 < len(dataloader_list) < 3 concat_dataset = torch.utils.data.ConcatDataset([x.dataset for x in dataloader_list]) training_dataloader = torch.utils.data.DataLoader( dataset=concat_dataset, batch_size=args.dec_batch_size, shuffle=True, num_workers=4 ) for step in range(args.dec_iters): encoder.train() dec.train() if step % len(training_dataloader) == 0: iterator = iter(training_dataloader) image, _ = iterator.__next__() image = image.to(device) if args.encoder_type == 'vae': z, _, _ = encoder(image) elif args.encoder_type == 'resnet50': z = encoder(image) else: raise Exception('Wrong encoder type, how did you get this far in running the code?') output = dec(z) target = target_distribution(output).detach() clustering_loss = criterion_c(output.log(), target) / output.shape[0] optimizer.zero_grad() clustering_loss.backward() optimizer.step() C_LOSS.update(clustering_loss) wandb.log({f"{save_name} Train C Loss Avg": C_LOSS.avg, f"{save_name} step": step}) wandb.log({f"{save_name} Train C Loss Cur": C_LOSS.val, f"{save_name} step": step}) if step % args.test_interval == args.test_interval - 1 or step == 0: predicted, labels = predict(dataloader_list, encoder, dec, device=device, encoder_type=args.encoder_type) predicted, labels = predicted.cpu().numpy(), labels.numpy() _, accuracy = cluster_accuracy(predicted, labels, args.cluster_number) nmi = normalized_mutual_info_score(labels, predicted, average_method="arithmetic") bal, en_0, en_1 = balance(predicted, len(dataloader_list[0]), k=args.cluster_number) wandb.log( {f"{save_name} Train Accuracy": accuracy, f"{save_name} Train NMI": nmi, f"{save_name} Train Bal": bal, f"{save_name} Train Entropy 0": en_0, f"{save_name} Train Entropy 1": en_1, f"{save_name} step": step}) print("Step:[{:03d}/{:03d}] " "Acc:{:2.3f};" "NMI:{:1.3f};" "Bal:{:1.3f};" "En:{:1.3f}/{:1.3f};" "Clustering.loss:{C_Loss.avg:3.2f};".format(step + 1, args.dec_iters, accuracy, nmi, bal, en_0, en_1, C_Loss=C_LOSS)) # log tsne visualisation if args.encoder_type == "vae": tsne_img = tsne_visualization(dataloader_list, encoder, args.cluster_number, encoder_type=args.encoder_type, device=device) if not (tsne_img is None): wandb.log({f"{save_name} TSNE": plt, f"{save_name} step": step}) torch.save(dec.state_dict(), f'{args.log_dir}DFC_{save_name}.pth') return dec
def train(args, dataloader_list, encoder, encoder_group_0=None, encoder_group_1=None, dfc_group_0=None, dfc_group_1=None, device='cpu', centers=None, get_loss_trade_off=lambda step: (10, 10, 10), save_name='DFC'): """Trains DFC and optionally the critic, automatically saves when finished training Args: args: Namespace object which contains config set from argument parser { lr, seed, iters, log_dir, test_interval, adv_multiplier, dfc_hidden_dim } dataloader_list (list): this list may consist of only 1 dataloader or multiple encoder: Encoder to use encoder_group_0: Optional pre-trained golden standard model encoder_group_1: Optional pre-trained golden standard model dfc_group_0: Optional cluster centers file obtained with encoder_group_0 dfc_group_1: Optional cluster centers file obtained with encoder_group_1 device: Device configuration centers: Initial centers clusters if available get_loss_trade_off: Proportional importance of individual loss functions save_name: Prefix for save files Returns: DFC: A trained DFC model """ set_seed(args.seed) if args.half_tensor: torch.set_default_tensor_type('torch.HalfTensor') dfc = DFC(cluster_number=args.cluster_number, hidden_dimension=args.dfc_hidden_dim).to(device) wandb.watch(dfc) critic = AdversarialNetwork(in_feature=args.cluster_number, hidden_size=32, max_iter=args.iters, lr_mult=args.adv_multiplier).to(device) wandb.watch(critic) if not (centers is None): cluster_centers = centers.clone().detach().requires_grad_(True).to( device) with torch.no_grad(): print("loading clustering centers...") dfc.state_dict()['assignment.cluster_centers'].copy_( cluster_centers) encoder_param = encoder.get_parameters( ) if args.encoder_type == 'vae' else [{ "params": encoder.parameters(), "lr_mult": 1 }] optimizer = torch.optim.Adam(dfc.get_parameters() + encoder_param + critic.get_parameters(), lr=args.dec_lr, weight_decay=5e-4) criterion_c = nn.KLDivLoss(reduction="sum") criterion_p = nn.MSELoss(reduction="sum") C_LOSS = AverageMeter() F_LOSS = AverageMeter() P_LOSS = AverageMeter() partition_loss_enabled = True if not encoder_group_0 or not encoder_group_1 or not dfc_group_0 or not dfc_group_1: print( "Missing Golden Standard models, switching to DEC mode instead of DFC." ) partition_loss_enabled = False if partition_loss_enabled: encoder_group_0.eval(), encoder_group_1.eval() dfc_group_0.eval(), dfc_group_1.eval() print("Start training") assert 0 < len(dataloader_list) < 3 len_image_0 = len(dataloader_list[0]) len_image_1 = len( dataloader_list[1]) if len(dataloader_list) == 2 else None for step in range(args.iters): encoder.train() dfc.train() if step % len_image_0 == 0: iter_image_0 = iter(dataloader_list[0]) if len_image_1 and step % len_image_1 == 0: iter_image_1 = iter(dataloader_list[1]) image_0, _ = iter_image_0.__next__() image_0 = image_0.to(device) if not (len_image_1 is None): image_1, _ = iter_image_1.__next__() image_1 = image_1.to(device) image = torch.cat((image_0, image_1), dim=0) else: image_1 = None image = torch.cat((image_0, ), dim=0) if args.encoder_type == 'vae': z, _, _ = encoder(image) elif args.encoder_type == 'resnet50': z = encoder(image) else: raise Exception( 'Wrong encoder type, how did you get this far in running the code?' ) output = dfc(z) features_enc_0 = encoder_group_0( image_0)[0] if args.encoder_type == 'vae' else encoder_group_0( image_0) predict_0 = dfc_group_0(features_enc_0) features_enc_1 = encoder_group_1( image_1)[0] if args.encoder_type == 'vae' else encoder_group_1( image_1) predict_1 = dfc_group_1(features_enc_1) if not ( image_1 is None) else None output_0, output_1 = output[0:args.bs, :], output[ args.bs:args.bs * 2, :] if not (predict_1 is None) else None target_0, target_1 = target_distribution( output_0).detach(), target_distribution(output_1).detach() if not ( predict_1 is None) else None # Equaition (5) in the paper # output_0 and output_1 are probability distribution P of samples being assinged to a class in k # target_0 and target_1 are auxiliary distribuion Q calculated based on P. Eqation (4) in the paper if not (output_1 is None): clustering_loss = 0.5 * criterion_c(output_0.log( ), target_0) + 0.5 * criterion_c(output_1.log(), target_1) else: clustering_loss = criterion_c(output_0.log(), target_0) # Equation (2) in the paper # output = D(A(F(X))) # critic is the distribuition of categorical sensitive subgroup variable G (?) if len(dataloader_list) > 1: fair_loss, critic_acc = adv_loss(output, critic, device=device) else: fair_loss, critic_acc = 0, 0 if partition_loss_enabled: # Equation (3) in the paper # output_0 and output_1 are the output of the pretrained encoder # predict_0 and predict_1 are the soft cluster assignments of the DFC. # loss is high if the outputs and predictions (and this the cluster structures) differ. if not (predict_1 is None): partition_loss = 0.5 * criterion_p(aff(output_0), aff(predict_0).detach()) \ + 0.5 * criterion_p(aff(output_1), aff(predict_1).detach()) else: partition_loss = criterion_p(aff(output_0), aff(predict_0).detach()) else: partition_loss = 0 loss_trade_off = get_loss_trade_off(step) if args.encoder_type == 'resnet50' and args.dataset == 'office_31': # alpha_s loss_trade_off = list(loss_trade_off) loss_trade_off[1] = ((512 / 128)**2) * (31 / 10) total_loss = loss_trade_off[0] * fair_loss + loss_trade_off[ 1] * partition_loss + loss_trade_off[2] * clustering_loss optimizer = inv_lr_scheduler(optimizer, args.lr, step, args.iters) optimizer.zero_grad() total_loss.backward() optimizer.step() C_LOSS.update(clustering_loss) F_LOSS.update(fair_loss) P_LOSS.update(partition_loss) wandb.log({ f"{save_name} Train C Loss Avg": C_LOSS.avg, f"{save_name} Train F Loss Avg": F_LOSS.avg, f"{save_name} Train P Loss Avg": P_LOSS.avg, f"{save_name} step": step, f"{save_name} Critic ACC": critic_acc }) wandb.log({ f"{save_name} Train C Loss Cur": C_LOSS.val, f"{save_name} Train F Loss Cur": F_LOSS.val, f"{save_name} Train P Loss Cur": P_LOSS.val, f"{save_name} step": step }) if step % args.test_interval == args.test_interval - 1 or step == 0: predicted, labels = predict(dataloader_list, encoder, dfc, device=device, encoder_type=args.encoder_type) predicted, labels = predicted.cpu().numpy(), labels.numpy() _, accuracy = cluster_accuracy(predicted, labels, args.cluster_number) nmi = normalized_mutual_info_score(labels, predicted, average_method="arithmetic") bal, en_0, en_1 = balance(predicted, len_image_0, k=args.cluster_number) wandb.log({ f"{save_name} Train Accuracy": accuracy, f"{save_name} Train NMI": nmi, f"{save_name} Train Bal": bal, f"{save_name} Train Entropy 0": en_0, f"{save_name} Train Entropy 1": en_1, f"{save_name} step": step }) print("Step:[{:03d}/{:03d}] " "Acc:{:2.3f};" "NMI:{:1.3f};" "Bal:{:1.3f};" "En:{:1.3f}/{:1.3f};" "Clustering.loss:{C_Loss.avg:3.2f};" "Fairness.loss:{F_Loss.avg:3.2f};" "Partition.loss:{P_Loss.avg:3.2f};".format(step + 1, args.iters, accuracy, nmi, bal, en_0, en_1, C_Loss=C_LOSS, F_Loss=F_LOSS, P_Loss=P_LOSS)) # log tsne visualisation if args.encoder_type == "vae": tsne_img = tsne_visualization(dataloader_list, encoder, args.cluster_number, encoder_type=args.encoder_type, device=device) if not (tsne_img is None): wandb.log({ f"{save_name} TSNE": plt, f"{save_name} step": step }) torch.save(dfc.state_dict(), f'{args.log_dir}DFC_{save_name}.pth') if len(dataloader_list) > 1: torch.save(critic.state_dict(), f'{args.log_dir}CRITIC_{save_name}.pth') return dfc