def lr_step(self): """ Learning rate scheduler """ if self.args.optimizer == 'SGD': self.tgt_opt = utils.inv_lr_scheduler(self.param_lr_c, self.tgt_opt, self.current_step, init_lr=self.args.lr)
def train_source(config, base_network, classifier_gnn, dset_loaders): # define loss functions criterion_gedge = nn.BCELoss(reduction='mean') ce_criterion = nn.CrossEntropyLoss() # configure optimizer optimizer_config = config['optimizer'] parameter_list = base_network.get_parameters() +\ [{'params': classifier_gnn.parameters(), 'lr_mult': 10, 'decay_mult': 2}] optimizer = optimizer_config['type'](parameter_list, **(optimizer_config['optim_params'])) # configure learning rates param_lr = [] for param_group in optimizer.param_groups: param_lr.append(param_group['lr']) schedule_param = optimizer_config['lr_param'] # start train loop base_network.train() classifier_gnn.train() len_train_source = len(dset_loaders["source"]) for i in range(config['source_iters']): optimizer = utils.inv_lr_scheduler(optimizer, i, **schedule_param) optimizer.zero_grad() # get input data if i % len_train_source == 0: iter_source = iter(dset_loaders["source"]) batch_source = iter_source.next() inputs_source, labels_source = batch_source['img'].to(DEVICE), batch_source['target'].to(DEVICE) # make forward pass for encoder and mlp head features_source, logits_mlp = base_network(inputs_source) mlp_loss = ce_criterion(logits_mlp, labels_source) # make forward pass for gnn head logits_gnn, edge_sim = classifier_gnn(features_source) gnn_loss = ce_criterion(logits_gnn, labels_source) # compute edge loss edge_gt, edge_mask = classifier_gnn.label2edge(labels_source.unsqueeze(dim=0)) edge_loss = criterion_gedge(edge_sim.masked_select(edge_mask), edge_gt.masked_select(edge_mask)) # total loss and backpropagation loss = mlp_loss + config['lambda_node'] * gnn_loss + config['lambda_edge'] * edge_loss loss.backward() optimizer.step() # printout train loss if i % 20 == 0 or i == config['source_iters'] - 1: log_str = 'Iters:(%4d/%d)\tMLP loss:%.4f\tGNN loss:%.4f\tEdge loss:%.4f' % (i, config['source_iters'], mlp_loss.item(), gnn_loss.item(), edge_loss.item()) utils.write_logs(config, log_str) # evaluate network every test_interval if i % config['test_interval'] == config['test_interval'] - 1: evaluate(i, config, base_network, classifier_gnn, dset_loaders['target_test']) return base_network, classifier_gnn
def main(): set_seed(args.seed) torch.cuda.set_device(args.gpu) encoder = Encoder().cuda() encoder_group_0 = Encoder().cuda() encoder_group_1 = Encoder().cuda() dfc = DFC(cluster_number=args.k, hidden_dimension=64).cuda() dfc_group_0 = DFC(cluster_number=args.k, hidden_dimension=64).cuda() dfc_group_1 = DFC(cluster_number=args.k, hidden_dimension=64).cuda() critic = AdversarialNetwork(in_feature=args.k, hidden_size=32, max_iter=args.iters, lr_mult=args.adv_mult).cuda() # encoder pre-trained with self-reconstruction encoder.load_state_dict(torch.load("./save/encoder_pretrain.pth")) # encoder and clustering model trained by DEC encoder_group_0.load_state_dict(torch.load("./save/encoder_mnist.pth")) encoder_group_1.load_state_dict(torch.load("./save/encoder_usps.pth")) dfc_group_0.load_state_dict(torch.load("./save/dec_mnist.pth")) dfc_group_1.load_state_dict(torch.load("./save/dec_usps.pth")) # load clustering centroids given by k-means centers = np.loadtxt("./save/centers.txt") cluster_centers = torch.tensor(centers, dtype=torch.float, requires_grad=True).cuda() with torch.no_grad(): print("loading clustering centers...") dfc.state_dict()['assignment.cluster_centers'].copy_(cluster_centers) optimizer = torch.optim.Adam(dfc.get_parameters() + encoder.get_parameters() + critic.get_parameters(), lr=args.lr, weight_decay=5e-4) criterion_c = nn.KLDivLoss(reduction="sum") criterion_p = nn.MSELoss(reduction="sum") C_LOSS = AverageMeter() F_LOSS = AverageMeter() P_LOSS = AverageMeter() encoder_group_0.eval(), encoder_group_1.eval() dfc_group_0.eval(), dfc_group_1.eval() data_loader = mnist_usps(args) len_image_0 = len(data_loader[0]) len_image_1 = len(data_loader[1]) for step in range(args.iters): encoder.train() dfc.train() if step % len_image_0 == 0: iter_image_0 = iter(data_loader[0]) if step % len_image_1 == 0: iter_image_1 = iter(data_loader[1]) image_0, _ = iter_image_0.__next__() image_1, _ = iter_image_1.__next__() image_0, image_1 = image_0.cuda(), image_1.cuda() image = torch.cat((image_0, image_1), dim=0) predict_0, predict_1 = dfc_group_0( encoder_group_0(image_0)[0]), dfc_group_1( encoder_group_1(image_1)[0]) z, _, _ = encoder(image) output = dfc(z) output_0, output_1 = output[0:args.bs, :], output[args.bs:args.bs * 2, :] target_0, target_1 = target_distribution( output_0).detach(), target_distribution(output_1).detach() clustering_loss = 0.5 * criterion_c(output_0.log( ), target_0) + 0.5 * criterion_c(output_1.log(), target_1) fair_loss = adv_loss(output, critic) partition_loss = 0.5 * criterion_p(aff(output_0), aff(predict_0).detach()) \ + 0.5 * criterion_p(aff(output_1), aff(predict_1).detach()) total_loss = clustering_loss + args.coeff_fair * fair_loss + args.coeff_par * partition_loss optimizer = inv_lr_scheduler(optimizer, args.lr, step, args.iters) optimizer.zero_grad() total_loss.backward() optimizer.step() C_LOSS.update(clustering_loss) F_LOSS.update(fair_loss) P_LOSS.update(partition_loss) if step % args.test_interval == args.test_interval - 1 or step == 0: predicted, labels = predict(data_loader, encoder, dfc) predicted, labels = predicted.cpu().numpy(), labels.numpy() _, accuracy = cluster_accuracy(predicted, labels, 10) nmi = normalized_mutual_info_score(labels, predicted, average_method="arithmetic") bal, en_0, en_1 = balance(predicted, 60000) print("Step:[{:03d}/{:03d}] " "Acc:{:2.3f};" "NMI:{:1.3f};" "Bal:{:1.3f};" "En:{:1.3f}/{:1.3f};" "C.loss:{C_Loss.avg:3.2f};" "F.loss:{F_Loss.avg:3.2f};" "P.loss:{P_Loss.avg:3.2f};".format(step + 1, args.iters, accuracy, nmi, bal, en_0, en_1, C_Loss=C_LOSS, F_Loss=F_LOSS, P_Loss=P_LOSS)) return
def adapt_target(config, base_network, classifier_gnn, dset_loaders, max_inherit_domain): # define loss functions criterion_gedge = nn.BCELoss(reduction='mean') ce_criterion = nn.CrossEntropyLoss() # add random layer and adversarial network class_num = config['encoder']['params']['class_num'] random_layer = networks.RandomLayer([base_network.output_num(), class_num], config['random_dim'], DEVICE) adv_net = networks.AdversarialNetwork(config['random_dim'], config['random_dim'], config['ndomains']) random_layer.to(DEVICE) adv_net = adv_net.to(DEVICE) # configure optimizer optimizer_config = config['optimizer'] parameter_list = base_network.get_parameters() + adv_net.get_parameters() \ + [{'params': classifier_gnn.parameters(), 'lr_mult': 10, 'decay_mult': 2}] optimizer = optimizer_config['type'](parameter_list, **(optimizer_config['optim_params'])) # configure learning rates param_lr = [] for param_group in optimizer.param_groups: param_lr.append(param_group['lr']) schedule_param = optimizer_config['lr_param'] # start train loop len_train_source = len(dset_loaders['source']) len_train_target = len(dset_loaders['target_train'][max_inherit_domain]) # set nets in train mode base_network.train() classifier_gnn.train() adv_net.train() random_layer.train() for i in range(config['adapt_iters']): optimizer = utils.inv_lr_scheduler(optimizer, i, **schedule_param) optimizer.zero_grad() # get input data if i % len_train_source == 0: iter_source = iter(dset_loaders['source']) if i % len_train_target == 0: iter_target = iter( dset_loaders['target_train'][max_inherit_domain]) batch_source = iter_source.next() batch_target = iter_target.next() inputs_source, inputs_target = batch_source['img'].to( DEVICE), batch_target['img'].to(DEVICE) labels_source = batch_source['target'].to(DEVICE) domain_source, domain_target = batch_source['domain'].to( DEVICE), batch_target['domain'].to(DEVICE) domain_input = torch.cat([domain_source, domain_target], dim=0) # make forward pass for encoder and mlp head features_source, logits_mlp_source = base_network(inputs_source) features_target, logits_mlp_target = base_network(inputs_target) features = torch.cat((features_source, features_target), dim=0) logits_mlp = torch.cat((logits_mlp_source, logits_mlp_target), dim=0) softmax_mlp = nn.Softmax(dim=1)(logits_mlp) mlp_loss = ce_criterion(logits_mlp_source, labels_source) # *** GNN at work *** # make forward pass for gnn head logits_gnn, edge_sim = classifier_gnn(features) gnn_loss = ce_criterion(logits_gnn[:labels_source.size(0)], labels_source) # compute pseudo-labels for affinity matrix by mlp classifier out_target_class = torch.softmax(logits_mlp_target, dim=1) target_score, target_pseudo_labels = out_target_class.max(1, keepdim=True) idx_pseudo = target_score > config['threshold'] target_pseudo_labels[~idx_pseudo] = classifier_gnn.mask_val # combine source labels and target pseudo labels for edge_net node_labels = torch.cat( (labels_source, target_pseudo_labels.squeeze(dim=1)), dim=0).unsqueeze(dim=0) # compute source-target mask and ground truth for edge_net edge_gt, edge_mask = classifier_gnn.label2edge(node_labels) # compute edge loss edge_loss = criterion_gedge(edge_sim.masked_select(edge_mask), edge_gt.masked_select(edge_mask)) # *** Adversarial net at work *** if config['method'] == 'CDAN+E': entropy = transfer_loss.Entropy(softmax_mlp) trans_loss = transfer_loss.CDAN(config['ndomains'], [features, softmax_mlp], adv_net, entropy, networks.calc_coeff(i), random_layer, domain_input) elif config['method'] == 'CDAN': trans_loss = transfer_loss.CDAN(config['ndomains'], [features, softmax_mlp], adv_net, None, None, random_layer, domain_input) else: raise ValueError('Method cannot be recognized.') # total loss and backpropagation loss = config['lambda_adv'] * trans_loss + mlp_loss +\ config['lambda_node'] * gnn_loss + config['lambda_edge'] * edge_loss loss.backward() optimizer.step() # printout train loss if i % 20 == 0 or i == config['adapt_iters'] - 1: log_str = 'Iters:(%4d/%d)\tMLP loss: %.4f\t GNN Loss: %.4f\t Edge Loss: %.4f\t Transfer loss:%.4f' % ( i, config["adapt_iters"], mlp_loss.item(), config['lambda_node'] * gnn_loss.item(), config['lambda_edge'] * edge_loss.item(), config['lambda_adv'] * trans_loss.item()) utils.write_logs(config, log_str) # evaluate network every test_interval if i % config['test_interval'] == config['test_interval'] - 1: evaluate(i, config, base_network, classifier_gnn, dset_loaders['target_test']) return base_network, classifier_gnn
def train(args, dataloader_list, encoder, encoder_group_0=None, encoder_group_1=None, dfc_group_0=None, dfc_group_1=None, device='cpu', centers=None, get_loss_trade_off=lambda step: (10, 10, 10), save_name='DFC'): """Trains DFC and optionally the critic, automatically saves when finished training Args: args: Namespace object which contains config set from argument parser { lr, seed, iters, log_dir, test_interval, adv_multiplier, dfc_hidden_dim } dataloader_list (list): this list may consist of only 1 dataloader or multiple encoder: Encoder to use encoder_group_0: Optional pre-trained golden standard model encoder_group_1: Optional pre-trained golden standard model dfc_group_0: Optional cluster centers file obtained with encoder_group_0 dfc_group_1: Optional cluster centers file obtained with encoder_group_1 device: Device configuration centers: Initial centers clusters if available get_loss_trade_off: Proportional importance of individual loss functions save_name: Prefix for save files Returns: DFC: A trained DFC model """ set_seed(args.seed) if args.half_tensor: torch.set_default_tensor_type('torch.HalfTensor') dfc = DFC(cluster_number=args.cluster_number, hidden_dimension=args.dfc_hidden_dim).to(device) wandb.watch(dfc) critic = AdversarialNetwork(in_feature=args.cluster_number, hidden_size=32, max_iter=args.iters, lr_mult=args.adv_multiplier).to(device) wandb.watch(critic) if not (centers is None): cluster_centers = centers.clone().detach().requires_grad_(True).to( device) with torch.no_grad(): print("loading clustering centers...") dfc.state_dict()['assignment.cluster_centers'].copy_( cluster_centers) encoder_param = encoder.get_parameters( ) if args.encoder_type == 'vae' else [{ "params": encoder.parameters(), "lr_mult": 1 }] optimizer = torch.optim.Adam(dfc.get_parameters() + encoder_param + critic.get_parameters(), lr=args.dec_lr, weight_decay=5e-4) criterion_c = nn.KLDivLoss(reduction="sum") criterion_p = nn.MSELoss(reduction="sum") C_LOSS = AverageMeter() F_LOSS = AverageMeter() P_LOSS = AverageMeter() partition_loss_enabled = True if not encoder_group_0 or not encoder_group_1 or not dfc_group_0 or not dfc_group_1: print( "Missing Golden Standard models, switching to DEC mode instead of DFC." ) partition_loss_enabled = False if partition_loss_enabled: encoder_group_0.eval(), encoder_group_1.eval() dfc_group_0.eval(), dfc_group_1.eval() print("Start training") assert 0 < len(dataloader_list) < 3 len_image_0 = len(dataloader_list[0]) len_image_1 = len( dataloader_list[1]) if len(dataloader_list) == 2 else None for step in range(args.iters): encoder.train() dfc.train() if step % len_image_0 == 0: iter_image_0 = iter(dataloader_list[0]) if len_image_1 and step % len_image_1 == 0: iter_image_1 = iter(dataloader_list[1]) image_0, _ = iter_image_0.__next__() image_0 = image_0.to(device) if not (len_image_1 is None): image_1, _ = iter_image_1.__next__() image_1 = image_1.to(device) image = torch.cat((image_0, image_1), dim=0) else: image_1 = None image = torch.cat((image_0, ), dim=0) if args.encoder_type == 'vae': z, _, _ = encoder(image) elif args.encoder_type == 'resnet50': z = encoder(image) else: raise Exception( 'Wrong encoder type, how did you get this far in running the code?' ) output = dfc(z) features_enc_0 = encoder_group_0( image_0)[0] if args.encoder_type == 'vae' else encoder_group_0( image_0) predict_0 = dfc_group_0(features_enc_0) features_enc_1 = encoder_group_1( image_1)[0] if args.encoder_type == 'vae' else encoder_group_1( image_1) predict_1 = dfc_group_1(features_enc_1) if not ( image_1 is None) else None output_0, output_1 = output[0:args.bs, :], output[ args.bs:args.bs * 2, :] if not (predict_1 is None) else None target_0, target_1 = target_distribution( output_0).detach(), target_distribution(output_1).detach() if not ( predict_1 is None) else None # Equaition (5) in the paper # output_0 and output_1 are probability distribution P of samples being assinged to a class in k # target_0 and target_1 are auxiliary distribuion Q calculated based on P. Eqation (4) in the paper if not (output_1 is None): clustering_loss = 0.5 * criterion_c(output_0.log( ), target_0) + 0.5 * criterion_c(output_1.log(), target_1) else: clustering_loss = criterion_c(output_0.log(), target_0) # Equation (2) in the paper # output = D(A(F(X))) # critic is the distribuition of categorical sensitive subgroup variable G (?) if len(dataloader_list) > 1: fair_loss, critic_acc = adv_loss(output, critic, device=device) else: fair_loss, critic_acc = 0, 0 if partition_loss_enabled: # Equation (3) in the paper # output_0 and output_1 are the output of the pretrained encoder # predict_0 and predict_1 are the soft cluster assignments of the DFC. # loss is high if the outputs and predictions (and this the cluster structures) differ. if not (predict_1 is None): partition_loss = 0.5 * criterion_p(aff(output_0), aff(predict_0).detach()) \ + 0.5 * criterion_p(aff(output_1), aff(predict_1).detach()) else: partition_loss = criterion_p(aff(output_0), aff(predict_0).detach()) else: partition_loss = 0 loss_trade_off = get_loss_trade_off(step) if args.encoder_type == 'resnet50' and args.dataset == 'office_31': # alpha_s loss_trade_off = list(loss_trade_off) loss_trade_off[1] = ((512 / 128)**2) * (31 / 10) total_loss = loss_trade_off[0] * fair_loss + loss_trade_off[ 1] * partition_loss + loss_trade_off[2] * clustering_loss optimizer = inv_lr_scheduler(optimizer, args.lr, step, args.iters) optimizer.zero_grad() total_loss.backward() optimizer.step() C_LOSS.update(clustering_loss) F_LOSS.update(fair_loss) P_LOSS.update(partition_loss) wandb.log({ f"{save_name} Train C Loss Avg": C_LOSS.avg, f"{save_name} Train F Loss Avg": F_LOSS.avg, f"{save_name} Train P Loss Avg": P_LOSS.avg, f"{save_name} step": step, f"{save_name} Critic ACC": critic_acc }) wandb.log({ f"{save_name} Train C Loss Cur": C_LOSS.val, f"{save_name} Train F Loss Cur": F_LOSS.val, f"{save_name} Train P Loss Cur": P_LOSS.val, f"{save_name} step": step }) if step % args.test_interval == args.test_interval - 1 or step == 0: predicted, labels = predict(dataloader_list, encoder, dfc, device=device, encoder_type=args.encoder_type) predicted, labels = predicted.cpu().numpy(), labels.numpy() _, accuracy = cluster_accuracy(predicted, labels, args.cluster_number) nmi = normalized_mutual_info_score(labels, predicted, average_method="arithmetic") bal, en_0, en_1 = balance(predicted, len_image_0, k=args.cluster_number) wandb.log({ f"{save_name} Train Accuracy": accuracy, f"{save_name} Train NMI": nmi, f"{save_name} Train Bal": bal, f"{save_name} Train Entropy 0": en_0, f"{save_name} Train Entropy 1": en_1, f"{save_name} step": step }) print("Step:[{:03d}/{:03d}] " "Acc:{:2.3f};" "NMI:{:1.3f};" "Bal:{:1.3f};" "En:{:1.3f}/{:1.3f};" "Clustering.loss:{C_Loss.avg:3.2f};" "Fairness.loss:{F_Loss.avg:3.2f};" "Partition.loss:{P_Loss.avg:3.2f};".format(step + 1, args.iters, accuracy, nmi, bal, en_0, en_1, C_Loss=C_LOSS, F_Loss=F_LOSS, P_Loss=P_LOSS)) # log tsne visualisation if args.encoder_type == "vae": tsne_img = tsne_visualization(dataloader_list, encoder, args.cluster_number, encoder_type=args.encoder_type, device=device) if not (tsne_img is None): wandb.log({ f"{save_name} TSNE": plt, f"{save_name} step": step }) torch.save(dfc.state_dict(), f'{args.log_dir}DFC_{save_name}.pth') if len(dataloader_list) > 1: torch.save(critic.state_dict(), f'{args.log_dir}CRITIC_{save_name}.pth') return dfc