def train(config): ## set up summary writer writer = SummaryWriter(config['output_path']) class_num = config["network"]["params"]["class_num"] loss_params = config["loss"] class_criterion = nn.CrossEntropyLoss() transfer_criterion = loss.PADA center_criterion = loss_params["loss_type"]( num_classes=class_num, feat_dim=config["network"]["params"]["bottleneck_dim"]) ## prepare data dsets = {} dset_loaders = {} #sampling WOR, i guess we leave the 10 in the middle to validate? pristine_indices = torch.randperm(len(pristine_x)) #train pristine_x_train = pristine_x[ pristine_indices[:int(np.floor(.7 * len(pristine_x)))]] pristine_y_train = pristine_y[ pristine_indices[:int(np.floor(.7 * len(pristine_x)))]] #validate --- gets passed into test functions in train file pristine_x_valid = pristine_x[pristine_indices[ int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 * len(pristine_x)))]] pristine_y_valid = pristine_y[pristine_indices[ int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 * len(pristine_x)))]] #test for evaluation file pristine_x_test = pristine_x[ pristine_indices[int(np.floor(.8 * len(pristine_x))):]] pristine_y_test = pristine_y[ pristine_indices[int(np.floor(.8 * len(pristine_x))):]] noisy_indices = torch.randperm(len(noisy_x)) #train noisy_x_train = noisy_x[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]] noisy_y_train = noisy_y[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]] #validate --- gets passed into test functions in train file noisy_x_valid = noisy_x[noisy_indices[int(np.floor(.7 * len(noisy_x)) ):int(np.floor(.8 * len(noisy_x)))]] noisy_y_valid = noisy_y[noisy_indices[int(np.floor(.7 * len(noisy_x)) ):int(np.floor(.8 * len(noisy_x)))]] #test for evaluation file noisy_x_test = noisy_x[noisy_indices[int(np.floor(.8 * len(noisy_x))):]] noisy_y_test = noisy_y[noisy_indices[int(np.floor(.8 * len(noisy_x))):]] dsets["source"] = TensorDataset(pristine_x_train, pristine_y_train) dsets["target"] = TensorDataset(noisy_x_train, noisy_y_train) dsets["source_valid"] = TensorDataset(pristine_x_valid, pristine_y_valid) dsets["target_valid"] = TensorDataset(noisy_x_valid, noisy_y_valid) dsets["source_test"] = TensorDataset(pristine_x_test, pristine_y_test) dsets["target_test"] = TensorDataset(noisy_x_test, noisy_y_test) #put your dataloaders here #i stole batch size numbers from below dset_loaders["source"] = DataLoader(dsets["source"], batch_size=128, shuffle=True, num_workers=1) dset_loaders["target"] = DataLoader(dsets["target"], batch_size=128, shuffle=True, num_workers=1) #guessing batch size based on what was done for testing in the original file dset_loaders["source_valid"] = DataLoader(dsets["source_valid"], batch_size=64, shuffle=True, num_workers=1) dset_loaders["target_valid"] = DataLoader(dsets["target_valid"], batch_size=64, shuffle=True, num_workers=1) dset_loaders["source_test"] = DataLoader(dsets["source_test"], batch_size=64, shuffle=True, num_workers=1) dset_loaders["target_test"] = DataLoader(dsets["target_test"], batch_size=64, shuffle=True, num_workers=1) config['out_file'].write("dataset sizes: source={}, target={}\n".format( len(dsets["source"]), len(dsets["target"]))) config["num_iterations"] = len( dset_loaders["source"]) * config["epochs"] + 1 config["test_interval"] = len(dset_loaders["source"]) config["snapshot_interval"] = len( dset_loaders["source"]) * config["epochs"] * .25 config["log_iter"] = len(dset_loaders["source"]) #print the configuration you are using config["out_file"].write("config: {}\n".format(config)) config["out_file"].flush() # set up early stop early_stop_engine = EarlyStopping(config["early_stop_patience"]) ## set base network net_config = config["network"] base_network = net_config["name"](**net_config["params"]) use_gpu = torch.cuda.is_available() if use_gpu: base_network = base_network.cuda() ## add additional network for some methods ad_net = network.AdversarialNetwork(base_network.output_num()) gradient_reverse_layer = network.AdversarialLayer( high_value=config["high"]) #, #max_iter_value=config["num_iterations"]) if use_gpu: ad_net = ad_net.cuda() ## collect parameters if "DeepMerge" in args.net: parameter_list = [{ "params": base_network.parameters(), "lr_mult": 1, 'decay_mult': 2 }] parameter_list.append({ "params": ad_net.parameters(), "lr_mult": .1, 'decay_mult': 2 }) parameter_list.append({ "params": center_criterion.parameters(), "lr_mult": 10, 'decay_mult': 1 }) elif "ResNet18" in args.net: parameter_list = [{ "params": base_network.parameters(), "lr_mult": 1, 'decay_mult': 2 }] parameter_list.append({ "params": ad_net.parameters(), "lr_mult": .1, 'decay_mult': 2 }) parameter_list.append({ "params": center_criterion.parameters(), "lr_mult": 10, 'decay_mult': 1 }) if net_config["params"]["new_cls"]: if net_config["params"]["use_bottleneck"]: parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \ {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \ {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}] parameter_list.append({ "params": ad_net.parameters(), "lr_mult": config["ad_net_mult_lr"], 'decay_mult': 2 }) parameter_list.append({ "params": center_criterion.parameters(), "lr_mult": 10, 'decay_mult': 1 }) else: parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \ {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}] parameter_list.append({ "params": ad_net.parameters(), "lr_mult": config["ad_net_mult_lr"], 'decay_mult': 2 }) parameter_list.append({ "params": center_criterion.parameters(), "lr_mult": 10, 'decay_mult': 1 }) else: parameter_list = [{ "params": base_network.parameters(), "lr_mult": 1, 'decay_mult': 2 }] parameter_list.append({ "params": ad_net.parameters(), "lr_mult": config["ad_net_mult_lr"], 'decay_mult': 2 }) parameter_list.append({ "params": center_criterion.parameters(), "lr_mult": 10, 'decay_mult': 1 }) #Should I put lr_mult here as 1 for DeepMerge too? Probably! ## set optimizer optimizer_config = config["optimizer"] optimizer = optim_dict[optimizer_config["type"]](parameter_list, \ **(optimizer_config["optim_params"])) param_lr = [] for param_group in optimizer.param_groups: param_lr.append(param_group["lr"]) schedule_param = optimizer_config["lr_param"] lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]] scan_lr = [] scan_loss = [] ## train len_train_source = len(dset_loaders["source"]) len_train_target = len(dset_loaders["target"]) len_valid_source = len(dset_loaders["source_valid"]) len_valid_target = len(dset_loaders["target_valid"]) transfer_loss_value = classifier_loss_value = total_loss_value = 0.0 best_acc = 0.0 for i in range(config["num_iterations"]): if i % config["test_interval"] == 0: base_network.train(False) if config['loss']['ly_type'] == "cosine": temp_acc, _ = image_classification_test(dset_loaders, 'source_valid', \ base_network, gpu=use_gpu, verbose = False, save_where = None) train_acc, _ = image_classification_test(dset_loaders, 'source', \ base_network, gpu=use_gpu, verbose = False, save_where = None) elif config['loss']['ly_type'] == "euclidean": temp_acc, _ = distance_classification_test(dset_loaders, 'source_valid', \ base_network, center_criterion.centers.detach(), gpu=use_gpu, verbose = False, save_where = None) train_acc, _ = distance_classification_test( dset_loaders, 'source', base_network, center_criterion.centers.detach(), gpu=use_gpu, verbose=False, save_where=None) else: raise ValueError("no test method for cls loss: {}".format( config['loss']['ly_type'])) snapshot_obj = { 'epoch': i / len(dset_loaders["source"]), "base_network": base_network.state_dict(), 'valid accuracy': temp_acc, 'train accuracy': train_acc, } if (i + 1) % config["snapshot_interval"] == 0: torch.save( snapshot_obj, osp.join( config["output_path"], "epoch_{}_model.pth.tar".format( i / len(dset_loaders["source"])))) if config["loss"]["loss_name"] != "laplacian" and config["loss"][ "ly_type"] == "euclidean": snapshot_obj['center_criterion'] = center_criterion.state_dict( ) if temp_acc > best_acc: best_acc = temp_acc # save best model torch.save( snapshot_obj, osp.join(config["output_path"], "best_model.pth.tar")) log_str = "epoch: {}, {} validation accuracy: {:.5f}, {} training accuracy: {:.5f}\n".format( i / len(dset_loaders["source"]), config['loss']['ly_type'], temp_acc, config['loss']['ly_type'], train_acc) config["out_file"].write(log_str) config["out_file"].flush() writer.add_scalar("validation accuracy", temp_acc, i / len(dset_loaders["source"])) writer.add_scalar("training accuracy", train_acc, i / len(dset_loaders["source"])) ## train one iter base_network.train(True) if i % config["log_iter"] == 0: optimizer = lr_scheduler(param_lr, optimizer, i, config["log_iter"], config["frozen lr"], config["cycle_length"], **schedule_param) if config["optimizer"]["lr_type"] == "one-cycle": optimizer = lr_scheduler(param_lr, optimizer, i, config["log_iter"], config["frozen lr"], config["cycle_length"], **schedule_param) if config["optimizer"]["lr_type"] == "linear": optimizer = lr_scheduler(param_lr, optimizer, i, config["log_iter"], config["frozen lr"], config["cycle_length"], **schedule_param) optim = optimizer.state_dict() scan_lr.append(optim['param_groups'][0]['lr']) optimizer.zero_grad() try: inputs_source, labels_source = iter(dset_loaders["source"]).next() inputs_target, labels_target = iter(dset_loaders["target"]).next() except StopIteration: iter(dset_loaders["source"]) iter(dset_loaders["target"]) if use_gpu: inputs_source, inputs_target, labels_source = \ Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \ Variable(labels_source).cuda() else: inputs_source, inputs_target, labels_source = Variable(inputs_source), \ Variable(inputs_target), Variable(labels_source) inputs = torch.cat((inputs_source, inputs_target), dim=0) source_batch_size = inputs_source.size(0) if config['loss']['ly_type'] == 'cosine': features, logits = base_network(inputs) source_logits = logits.narrow(0, 0, source_batch_size) elif config['loss']['ly_type'] == 'euclidean': features, _ = base_network(inputs) logits = -1.0 * loss.distance_to_centroids( features, center_criterion.centers.detach()) source_logits = logits.narrow(0, 0, source_batch_size) ad_net.train(True) weight_ad = torch.ones(inputs.size(0)) transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \ weight_ad, use_gpu) ad_out, _ = ad_net(features.detach()) ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(ad_out) # source domain classification task loss classifier_loss = class_criterion(source_logits, labels_source.long()) # fisher loss on labeled source domain if config["fisher_or_no"] == 'no': total_loss = loss_params["trade_off"] * transfer_loss \ + classifier_loss scan_loss.append(total_loss.cpu().float().item()) total_loss.backward() ###################################### # Plot embeddings periodically. if args.blobs is not None and i / len( dset_loaders["source"]) % 50 == 0: visualizePerformance(base_network, dset_loaders["source"], dset_loaders["target"], batch_size=128, domain_classifier=ad_net, num_of_samples=100, imgName='embedding_' + str(i / len(dset_loaders["source"])), save_dir=osp.join(config["output_path"], "blobs")) ########################################## # if center_grad is not None: # # clear mmc_loss # center_criterion.centers.grad.zero_() # # Manually assign centers gradients other than using autograd # center_criterion.centers.backward(center_grad) optimizer.step() if i % config["log_iter"] == 0: if config['lr_scan'] != 'no': if not osp.exists( osp.join(config["output_path"], "learning_rate_scan")): os.makedirs( osp.join(config["output_path"], "learning_rate_scan")) plot_learning_rate_scan( scan_lr, scan_loss, i / len(dset_loaders["source"]), osp.join(config["output_path"], "learning_rate_scan")) if config['grad_vis'] != 'no': if not osp.exists( osp.join(config["output_path"], "gradients")): os.makedirs( osp.join(config["output_path"], "gradients")) plot_grad_flow( osp.join(config["output_path"], "gradients"), i / len(dset_loaders["source"]), base_network.named_parameters()) config['out_file'].write( 'epoch {}: train total loss={:0.4f}, train transfer loss={:0.4f}, train classifier loss={:0.4f},' 'train source+target domain accuracy={:0.4f}, train source domain accuracy={:0.4f}, train target domain accuracy={:0.4f}\n' .format( i / len(dset_loaders["source"]), total_loss.data.cpu().float().item(), transfer_loss.data.cpu().float().item(), classifier_loss.data.cpu().float().item(), ad_acc, source_acc_ad, target_acc_ad, )) config['out_file'].flush() writer.add_scalar("training total loss", total_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar("training classifier loss", classifier_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar("training transfer loss", transfer_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar("training source+target domain accuracy", ad_acc, i / len(dset_loaders["source"])) writer.add_scalar("training source domain accuracy", source_acc_ad, i / len(dset_loaders["source"])) writer.add_scalar("training target domain accuracy", target_acc_ad, i / len(dset_loaders["source"])) #attempted validation step for j in range(0, len(dset_loaders["source_valid"])): base_network.train(False) with torch.no_grad(): try: inputs_valid_source, labels_valid_source = iter( dset_loaders["source_valid"]).next() inputs_valid_target, labels_valid_target = iter( dset_loaders["target_valid"]).next() except StopIteration: iter(dset_loaders["source_valid"]) iter(dset_loaders["target_valid"]) if use_gpu: inputs_valid_source, inputs_valid_target, labels_valid_source = \ Variable(inputs_valid_source).cuda(), Variable(inputs_valid_target).cuda(), \ Variable(labels_valid_source).cuda() else: inputs_valid_source, inputs_valid_target, labels_valid_source = Variable(inputs_valid_source), \ Variable(inputs_valid_target), Variable(labels_valid_source) valid_inputs = torch.cat( (inputs_valid_source, inputs_valid_target), dim=0) valid_source_batch_size = inputs_valid_source.size(0) if config['loss']['ly_type'] == 'cosine': features, logits = base_network(valid_inputs) source_logits = logits.narrow( 0, 0, valid_source_batch_size) elif config['loss']['ly_type'] == 'euclidean': features, _ = base_network(valid_inputs) logits = -1.0 * loss.distance_to_centroids( features, center_criterion.centers.detach()) source_logits = logits.narrow( 0, 0, valid_source_batch_size) ad_net.train(False) weight_ad = torch.ones(valid_inputs.size(0)) transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \ weight_ad, use_gpu) ad_out, _ = ad_net(features.detach()) ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy( ad_out) # source domain classification task loss classifier_loss = class_criterion( source_logits, labels_valid_source.long()) #if config["fisher_or_no"] == 'no': total_loss = loss_params["trade_off"] * transfer_loss \ + classifier_loss if j % len(dset_loaders["source_valid"]) == 0: config['out_file'].write( 'epoch {}: valid total loss={:0.4f}, valid transfer loss={:0.4f}, valid classifier loss={:0.4f},' 'valid source+target domain accuracy={:0.4f}, valid source domain accuracy={:0.4f}, valid target domain accuracy={:0.4f}\n' .format( i / len(dset_loaders["source"]), total_loss.data.cpu().float().item(), transfer_loss.data.cpu().float().item(), classifier_loss.data.cpu().float().item(), ad_acc, source_acc_ad, target_acc_ad, )) config['out_file'].flush() writer.add_scalar("validation total loss", total_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar( "validation classifier loss", classifier_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar( "validation transfer loss", transfer_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar( "validation source+target domain accuracy", ad_acc, i / len(dset_loaders["source"])) writer.add_scalar("validation source domain accuracy", source_acc_ad, i / len(dset_loaders["source"])) writer.add_scalar("validation target domain accuracy", target_acc_ad, i / len(dset_loaders["source"])) if early_stop_engine.is_stop_training( classifier_loss.cpu().float().item()): config["out_file"].write( "no improvement after {}, stop training at step {}\n" .format(config["early_stop_patience"], i / len(dset_loaders["source"]))) sys.exit() else: fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion( features.narrow(0, 0, int(inputs.size(0) / 2)), labels_source, inter_class=config["loss"]["inter_type"], intra_loss_weight=loss_params["intra_loss_coef"], inter_loss_weight=loss_params["inter_loss_coef"]) # entropy minimization loss em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits)) # final loss total_loss = loss_params["trade_off"] * transfer_loss \ + fisher_loss \ + loss_params["em_loss_coef"] * em_loss \ + classifier_loss scan_loss.append(total_loss.cpu().float().item()) total_loss.backward() ###################################### # Plot embeddings periodically. if args.blobs is not None and i / len( dset_loaders["source"]) % 50 == 0: visualizePerformance(base_network, dset_loaders["source"], dset_loaders["target"], batch_size=128, num_of_samples=50, imgName='embedding_' + str(i / len(dset_loaders["source"])), save_dir=osp.join(config["output_path"], "blobs")) ########################################## if center_grad is not None: # clear mmc_loss center_criterion.centers.grad.zero_() # Manually assign centers gradients other than using autograd center_criterion.centers.backward(center_grad) optimizer.step() if i % config["log_iter"] == 0: if config['lr_scan'] != 'no': if not osp.exists( osp.join(config["output_path"], "learning_rate_scan")): os.makedirs( osp.join(config["output_path"], "learning_rate_scan")) plot_learning_rate_scan( scan_lr, scan_loss, i / len(dset_loaders["source"]), osp.join(config["output_path"], "learning_rate_scan")) if config['grad_vis'] != 'no': if not osp.exists( osp.join(config["output_path"], "gradients")): os.makedirs( osp.join(config["output_path"], "gradients")) plot_grad_flow( osp.join(config["output_path"], "gradients"), i / len(dset_loaders["source"]), base_network.named_parameters()) config['out_file'].write( 'epoch {}: train total loss={:0.4f}, train transfer loss={:0.4f}, train classifier loss={:0.4f}, ' 'train entropy min loss={:0.4f}, ' 'train fisher loss={:0.4f}, train intra-group fisher loss={:0.4f}, train inter-group fisher loss={:0.4f}, ' 'train source+target domain accuracy={:0.4f}, train source domain accuracy={:0.4f}, train target domain accuracy={:0.4f}\n' .format( i / len(dset_loaders["source"]), total_loss.data.cpu().float().item(), transfer_loss.data.cpu().float().item(), classifier_loss.data.cpu().float().item(), em_loss.data.cpu().float().item(), fisher_loss.cpu().float().item(), fisher_intra_loss.cpu().float().item(), fisher_inter_loss.cpu().float().item(), ad_acc, source_acc_ad, target_acc_ad, )) config['out_file'].flush() writer.add_scalar("training total loss", total_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar("training classifier loss", classifier_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar("training transfer loss", transfer_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar("training entropy minimization loss", em_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar("training total fisher loss", fisher_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar("training intra-group fisher", fisher_intra_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar("training inter-group fisher", fisher_inter_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar("training source+target domain accuracy", ad_acc, i / len(dset_loaders["source"])) writer.add_scalar("training source domain accuracy", source_acc_ad, i / len(dset_loaders["source"])) writer.add_scalar("training target domain accuracy", target_acc_ad, i / len(dset_loaders["source"])) #attempted validation step for j in range(0, len(dset_loaders["source_valid"])): base_network.train(False) with torch.no_grad(): try: inputs_valid_source, labels_valid_source = iter( dset_loaders["source_valid"]).next() inputs_valid_target, labels_valid_target = iter( dset_loaders["target_valid"]).next() except StopIteration: iter(dset_loaders["source_valid"]) iter(dset_loaders["target_valid"]) if use_gpu: inputs_valid_source, inputs_valid_target, labels_valid_source = \ Variable(inputs_valid_source).cuda(), Variable(inputs_valid_target).cuda(), \ Variable(labels_valid_source).cuda() else: inputs_valid_source, inputs_valid_target, labels_valid_source = Variable(inputs_valid_source), \ Variable(inputs_valid_target), Variable(labels_valid_source) valid_inputs = torch.cat( (inputs_valid_source, inputs_valid_target), dim=0) valid_source_batch_size = inputs_valid_source.size(0) if config['loss']['ly_type'] == 'cosine': features, logits = base_network(valid_inputs) source_logits = logits.narrow( 0, 0, valid_source_batch_size) elif config['loss']['ly_type'] == 'euclidean': features, _ = base_network(valid_inputs) logits = -1.0 * loss.distance_to_centroids( features, center_criterion.centers.detach()) source_logits = logits.narrow( 0, 0, valid_source_batch_size) ad_net.train(False) weight_ad = torch.ones(valid_inputs.size(0)) transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \ weight_ad, use_gpu) ad_out, _ = ad_net(features.detach()) ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy( ad_out) # source domain classification task loss classifier_loss = class_criterion( source_logits, labels_valid_source.long()) # fisher loss on labeled source domain fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion( features.narrow(0, 0, int(valid_inputs.size(0) / 2)), labels_valid_source, inter_class=loss_params["inter_type"], intra_loss_weight=loss_params["intra_loss_coef"], inter_loss_weight=loss_params["inter_loss_coef"]) # entropy minimization loss em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits)) # final loss total_loss = loss_params["trade_off"] * transfer_loss \ + fisher_loss \ + loss_params["em_loss_coef"] * em_loss \ + classifier_loss if j % len(dset_loaders["source_valid"]) == 0: config['out_file'].write( 'epoch {}: valid total loss={:0.4f}, valid transfer loss={:0.4f}, valid classifier loss={:0.4f}, ' 'valid entropy min loss={:0.4f}, ' 'valid fisher loss={:0.4f}, valid intra-group fisher loss={:0.4f}, valid inter-group fisher loss={:0.4f}, ' 'valid source+target domain accuracy={:0.4f}, valid source domain accuracy={:0.4f}, valid target domain accuracy={:0.4f}\n' .format( i / len(dset_loaders["source"]), total_loss.data.cpu().float().item(), transfer_loss.data.cpu().float().item(), classifier_loss.data.cpu().float().item(), em_loss.data.cpu().float().item(), fisher_loss.cpu().float().item(), fisher_intra_loss.cpu().float().item(), fisher_inter_loss.cpu().float().item(), ad_acc, source_acc_ad, target_acc_ad, )) config['out_file'].flush() writer.add_scalar("validation total loss", total_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar( "validation classifier loss", classifier_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar( "validation entropy minimization loss", em_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar( "validation transfer loss", transfer_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar( "validation entropy minimization loss", em_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar( "validation total fisher loss", fisher_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar( "validation intra-group fisher", fisher_intra_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar( "validation inter-group fisher", fisher_inter_loss.data.cpu().float().item(), i / len(dset_loaders["source"])) writer.add_scalar( "validation source+target domain accuracy", ad_acc, i / len(dset_loaders["source"])) writer.add_scalar("validation source domain accuracy", source_acc_ad, i / len(dset_loaders["source"])) writer.add_scalar("validation target domain accuracy", target_acc_ad, i / len(dset_loaders["source"])) if early_stop_engine.is_stop_training( classifier_loss.cpu().float().item()): config["out_file"].write( "no improvement after {}, stop training at step {}\n" .format(config["early_stop_patience"], i / len(dset_loaders["source"]))) sys.exit() return best_acc
def train(config, data_import): class_num = config["network"]["params"]["class_num"] loss_params = config["loss"] class_criterion = nn.CrossEntropyLoss() transfer_criterion = loss.PADA center_criterion = config["loss"]["discriminant_loss"]( num_classes=class_num, feat_dim=config["network"]["params"]["bottleneck_dim"]) ## prepare data pristine_x, pristine_y, noisy_x, noisy_y = data_import dsets = {} dset_loaders = {} #sampling WOR pristine_indices = torch.randperm(len(pristine_x)) pristine_x_train = pristine_x[pristine_indices] pristine_y_train = pristine_y[pristine_indices] noisy_indices = torch.randperm(len(noisy_x)) noisy_x_train = noisy_x[noisy_indices] noisy_y_train = noisy_y[noisy_indices] dsets["source"] = TensorDataset(pristine_x_train, pristine_y_train) dsets["target"] = TensorDataset(noisy_x_train, noisy_y_train) dset_loaders["source"] = DataLoader(dsets["source"], batch_size=128, shuffle=True, num_workers=1) dset_loaders["target"] = DataLoader(dsets["target"], batch_size=128, shuffle=True, num_workers=1) config["num_iterations"] = len( dset_loaders["source"]) * config["epochs"] + 1 config["test_interval"] = len(dset_loaders["source"]) config["snapshot_interval"] = len( dset_loaders["source"]) * config["epochs"] * .25 config["log_iter"] = len(dset_loaders["source"]) ## set base network net_config = config["network"] base_network = net_config["name"](**net_config["params"]) if config["ckpt_path"] is not None: ckpt = torch.load(config['ckpt_path'] + '/best_model.pth.tar', map_location=torch.device('cpu')) base_network.load_state_dict(ckpt['base_network']) use_gpu = torch.cuda.is_available() if use_gpu: base_network = base_network.cuda() ## add additional network for some methods ad_net = network.AdversarialNetwork(base_network.output_num()) gradient_reverse_layer = network.AdversarialLayer( high_value=config["high"]) if use_gpu: ad_net = ad_net.cuda() ## collect parameters if "DeepMerge" in config["net"]: parameter_list = [{ "params": base_network.parameters(), "lr_mult": 1, 'decay_mult': 2 }] parameter_list.append({ "params": ad_net.parameters(), "lr_mult": .1, 'decay_mult': 2 }) parameter_list.append({ "params": center_criterion.parameters(), "lr_mult": 10, 'decay_mult': 1 }) elif "ResNet18" in config["net"]: parameter_list = [{ "params": base_network.parameters(), "lr_mult": 1, 'decay_mult': 2 }] parameter_list.append({ "params": ad_net.parameters(), "lr_mult": .1, 'decay_mult': 2 }) parameter_list.append({ "params": center_criterion.parameters(), "lr_mult": 10, 'decay_mult': 1 }) if net_config["params"]["new_cls"]: if net_config["params"]["use_bottleneck"]: parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \ {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \ {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}] parameter_list.append({ "params": ad_net.parameters(), "lr_mult": config["ad_net_mult_lr"], 'decay_mult': 2 }) parameter_list.append({ "params": center_criterion.parameters(), "lr_mult": 10, 'decay_mult': 1 }) else: parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \ {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}] parameter_list.append({ "params": ad_net.parameters(), "lr_mult": config["ad_net_mult_lr"], 'decay_mult': 2 }) parameter_list.append({ "params": center_criterion.parameters(), "lr_mult": 10, 'decay_mult': 1 }) else: parameter_list = [{ "params": base_network.parameters(), "lr_mult": 1, 'decay_mult': 2 }] parameter_list.append({ "params": ad_net.parameters(), "lr_mult": config["ad_net_mult_lr"], 'decay_mult': 2 }) parameter_list.append({ "params": center_criterion.parameters(), "lr_mult": 10, 'decay_mult': 1 }) ## set optimizer optimizer_config = config["optimizer"] optimizer = optim_dict[optimizer_config["type"]](parameter_list, \ **(optimizer_config["optim_params"])) param_lr = [] for param_group in optimizer.param_groups: param_lr.append(param_group["lr"]) schedule_param = optimizer_config["lr_param"] lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]] ## train len_train_source = len(dset_loaders["source"]) len_train_target = len(dset_loaders["target"]) transfer_loss_value = classifier_loss_value = total_loss_value = 0.0 best_acc = 0.0 for i in range(config["num_iterations"]): ## train one iter base_network.train(True) if i % config["log_iter"] == 0: optimizer = lr_scheduler(param_lr, optimizer, i, config["log_iter"], config["frozen lr"], config["cycle_length"], **schedule_param) if config["optimizer"]["lr_type"] == "one-cycle": optimizer = lr_scheduler(param_lr, optimizer, i, config["log_iter"], config["frozen lr"], config["cycle_length"], **schedule_param) elif config["optimizer"]["lr_type"] == "linear": optimizer = lr_scheduler(param_lr, optimizer, i, config["log_iter"], config["frozen lr"], config["cycle_length"], **schedule_param) optim = optimizer.state_dict() optimizer.zero_grad() try: inputs_source, labels_source = iter(dset_loaders["source"]).next() inputs_target, labels_target = iter(dset_loaders["target"]).next() except StopIteration: iter(dset_loaders["source"]) iter(dset_loaders["target"]) if use_gpu: inputs_source, inputs_target, labels_source = \ Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \ Variable(labels_source).cuda() else: inputs_source, inputs_target, labels_source = Variable(inputs_source), \ Variable(inputs_target), Variable(labels_source) inputs = torch.cat((inputs_source, inputs_target), dim=0) source_batch_size = inputs_source.size(0) if config['loss']['ly_type'] == 'cosine': features, logits = base_network(inputs) source_logits = logits.narrow(0, 0, source_batch_size) elif config['loss']['ly_type'] == 'euclidean': features, _ = base_network(inputs) logits = -1.0 * loss.distance_to_centroids( features, center_criterion.centers.detach()) source_logits = logits.narrow(0, 0, source_batch_size) ad_net.train(True) weight_ad = torch.ones(inputs.size(0)) transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \ weight_ad, use_gpu) ad_out, _ = ad_net(features.detach()) ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(ad_out) # source domain classification task loss classifier_loss = class_criterion(source_logits, labels_source.long()) # fisher loss on labeled source domain if config["fisher_or_no"] == 'no': total_loss = loss_params["trade_off"] * transfer_loss \ + classifier_loss total_loss = classifier_loss + classifier_loss * ( 0.5 - source_acc_ad)**2 + classifier_loss * (0.5 - target_acc_ad)**2 total_loss.backward() optimizer.step() else: fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion( features.narrow(0, 0, int(inputs.size(0) / 2)), labels_source, inter_class=config["loss"]["inter_type"], intra_loss_weight=loss_params["intra_loss_coef"], inter_loss_weight=loss_params["inter_loss_coef"]) # entropy minimization loss em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits)) # final loss total_loss = loss_params["trade_off"] * transfer_loss \ + fisher_loss \ + loss_params["em_loss_coef"] * em_loss \ + classifier_loss total_loss = classifier_loss + fisher_loss + loss_params[ "em_loss_coef"] * em_loss + classifier_loss * ( 0.5 - source_acc_ad)**2 + classifier_loss * ( 0.5 - target_acc_ad)**2 total_loss.backward() if center_grad is not None: # clear mmc_loss center_criterion.centers.grad.zero_() # Manually assign centers gradients other than using autograd center_criterion.centers.backward(center_grad) optimizer.step() return (-1 * total_loss.cpu().float().item())
def train(config): ## set up summary writer writer = SummaryWriter(config['output_path']) # set up early stop early_stop_engine = EarlyStopping(config["early_stop_patience"]) ## set pre-process # prep_dict = {} # prep_config = config["prep"] # prep_dict["source"] = prep.image_train( \ # resize_size=prep_config["resize_size"], \ # crop_size=prep_config["crop_size"]) # prep_dict["target"] = prep.image_train( \ # resize_size=prep_config["resize_size"], \ # crop_size=prep_config["crop_size"]) # if prep_config["test_10crop"]: # prep_dict["test"] = prep.image_test_10crop( \ # resize_size=prep_config["resize_size"], \ # crop_size=prep_config["crop_size"]) # else: # prep_dict["test"] = prep.image_test( \ # resize_size=prep_config["resize_size"], \ # crop_size=prep_config["crop_size"]) ## set loss class_num = config["network"]["params"]["class_num"] loss_params = config["loss"] class_criterion = nn.CrossEntropyLoss() transfer_criterion = loss.PADA center_criterion = loss_params["loss_type"]( num_classes=class_num, feat_dim=config["network"]["params"]["bottleneck_dim"]) ## prepare data dsets = {} dset_loaders = {} #sampling WOR, i guess we leave the 10 in the middle to validate? pristine_indices = torch.randperm(len(pristine_x)) #train pristine_x_train = pristine_x[ pristine_indices[:int(np.floor(.7 * len(pristine_x)))]] pristine_y_train = pristine_y[ pristine_indices[:int(np.floor(.7 * len(pristine_x)))]] #validate --- gets passed into test functions in train file pristine_x_valid = pristine_x[pristine_indices[ int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 * len(pristine_x)))]] pristine_y_valid = pristine_y[pristine_indices[ int(np.floor(.7 * len(pristine_x))):int(np.floor(.8 * len(pristine_x)))]] #test for evaluation file pristine_x_test = pristine_x[ pristine_indices[int(np.floor(.8 * len(pristine_x))):]] pristine_y_test = pristine_y[ pristine_indices[int(np.floor(.8 * len(pristine_x))):]] noisy_indices = torch.randperm(len(noisy_x)) #train noisy_x_train = noisy_x[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]] noisy_y_train = noisy_y[noisy_indices[:int(np.floor(.7 * len(noisy_x)))]] #validate --- gets passed into test functions in train file noisy_x_valid = noisy_x[noisy_indices[int(np.floor(.7 * len(noisy_x)) ):int(np.floor(.8 * len(noisy_x)))]] noisy_y_valid = noisy_y[noisy_indices[int(np.floor(.7 * len(noisy_x)) ):int(np.floor(.8 * len(noisy_x)))]] #test for evaluation file noisy_x_test = noisy_x[noisy_indices[int(np.floor(.8 * len(noisy_x))):]] noisy_y_test = noisy_y[noisy_indices[int(np.floor(.8 * len(noisy_x))):]] dsets["source"] = TensorDataset(pristine_x_train, pristine_y_train) dsets["target"] = TensorDataset(noisy_x_train, noisy_y_train) dsets["source_valid"] = TensorDataset(pristine_x_valid, pristine_y_valid) dsets["target_valid"] = TensorDataset(noisy_x_valid, noisy_y_valid) dsets["source_test"] = TensorDataset(pristine_x_test, pristine_y_test) dsets["target_test"] = TensorDataset(noisy_x_test, noisy_y_test) #put your dataloaders here #i stole batch size numbers from below dset_loaders["source"] = DataLoader(dsets["source"], batch_size=36, shuffle=True, num_workers=1) dset_loaders["target"] = DataLoader(dsets["target"], batch_size=36, shuffle=True, num_workers=1) #guessing batch size based on what was done for testing in the original file dset_loaders["source_valid"] = DataLoader(dsets["source_valid"], batch_size=4, shuffle=True, num_workers=1) dset_loaders["target_valid"] = DataLoader(dsets["target_valid"], batch_size=4, shuffle=True, num_workers=1) dset_loaders["source_test"] = DataLoader(dsets["source_test"], batch_size=4, shuffle=True, num_workers=1) dset_loaders["target_test"] = DataLoader(dsets["target_test"], batch_size=4, shuffle=True, num_workers=1) # data_config = config["data"] # dsets["source"] = ImageList(stratify_sampling(open(data_config["source"]["list_path"]).readlines(), prep_config["source_size"]), \ # transform=prep_dict["source"]) # dset_loaders["source"] = util_data.DataLoader(dsets["source"], \ # batch_size=data_config["source"]["batch_size"], \ # shuffle=True, num_workers=1) # dsets["target"] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), prep_config["target_size"]), \ # transform=prep_dict["target"]) # dset_loaders["target"] = util_data.DataLoader(dsets["target"], \ # batch_size=data_config["target"]["batch_size"], \ # shuffle=True, num_workers=1) # if prep_config["test_10crop"]: # for i in range(10): # dsets["test"+str(i)] = ImageList(stratify_sampling(open(data_config["test"]["list_path"]).readlines(), ratio=prep_config['target_size']), \ # transform=prep_dict["test"]["val"+str(i)]) # dset_loaders["test"+str(i)] = util_data.DataLoader(dsets["test"+str(i)], \ # batch_size=data_config["test"]["batch_size"], \ # shuffle=False, num_workers=1) # dsets["target"+str(i)] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), ratio=prep_config['target_size']), \ # transform=prep_dict["test"]["val"+str(i)]) # dset_loaders["target"+str(i)] = util_data.DataLoader(dsets["target"+str(i)], \ # batch_size=data_config["test"]["batch_size"], \ # shuffle=False, num_workers=1) # else: # dsets["test"] = ImageList(stratify_sampling(open(data_config["test"]["list_path"]).readlines(), ratio=prep_config['target_size']), \ # transform=prep_dict["test"]) # dset_loaders["test"] = util_data.DataLoader(dsets["test"], \ # batch_size=data_config["test"]["batch_size"], \ # shuffle=False, num_workers=1) # dsets["target_test"] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), ratio=prep_config['target_size']), \ # transform=prep_dict["test"]) # dset_loaders["target_test"] = MyDataLoader(dsets["target_test"], \ # batch_size=data_config["test"]["batch_size"], \ # shuffle=False, num_workers=1) config['out_file'].write("dataset sizes: source={}, target={}\n".format( len(dsets["source"]), len(dsets["target"]))) ## set base network net_config = config["network"] base_network = net_config["name"](**net_config["params"]) use_gpu = torch.cuda.is_available() if use_gpu: base_network = base_network.cuda() ## collect parameters if net_config["params"]["new_cls"]: if net_config["params"]["use_bottleneck"]: parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \ {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \ {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}] else: parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \ {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}] else: parameter_list = [{ "params": base_network.parameters(), "lr_mult": 1, 'decay_mult': 2 }] ## add additional network for some methods ad_net = network.AdversarialNetwork(base_network.output_num()) gradient_reverse_layer = network.AdversarialLayer( high_value=config["high"]) #, #max_iter_value=config["num_iterations"]) if use_gpu: ad_net = ad_net.cuda() parameter_list.append({ "params": ad_net.parameters(), "lr_mult": 10, 'decay_mult': 2 }) parameter_list.append({ "params": center_criterion.parameters(), "lr_mult": 10, 'decay_mult': 1 }) ## set optimizer optimizer_config = config["optimizer"] optimizer = optim_dict[optimizer_config["type"]](parameter_list, \ **(optimizer_config["optim_params"])) param_lr = [] for param_group in optimizer.param_groups: param_lr.append(param_group["lr"]) schedule_param = optimizer_config["lr_param"] lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]] ## train len_train_source = len(dset_loaders["source"]) - 1 len_train_target = len(dset_loaders["target"]) - 1 transfer_loss_value = classifier_loss_value = total_loss_value = 0.0 best_acc = 0.0 for i in range(config["num_iterations"]): if i % config["test_interval"] == 0: base_network.train(False) if config['loss']['ly_type'] == "cosine": temp_acc, _ = image_classification_test(dset_loaders, 'source_valid', \ base_network, \ gpu=use_gpu) train_acc, _ = image_classification_test(dset_loaders, 'source', \ base_network, \ gpu=use_gpu) elif config['loss']['ly_type'] == "euclidean": temp_acc, _ = distance_classification_test(dset_loaders, 'source_valid', \ base_network, center_criterion.centers.detach(), \ gpu=use_gpu) train_acc, _ = distance_classification_test(dset_loaders, 'source', \ base_network, \ gpu=use_gpu) else: raise ValueError("no test method for cls loss: {}".format( config['loss']['ly_type'])) snapshot_obj = { 'step': i, "base_network": base_network.state_dict(), 'precision': temp_acc, 'train accuracy': train_acc, } if config["loss"]["loss_name"] != "laplacian" and config["loss"][ "ly_type"] == "euclidean": snapshot_obj['center_criterion'] = center_criterion.state_dict( ) if temp_acc > best_acc: best_acc = temp_acc # save best model torch.save( snapshot_obj, osp.join(config["output_path"], "best_model.pth.tar")) log_str = "iter: {:05d}, {} validation accuracy: {:.5f}, {} training accuracy: {:.5f}\n".format( i, config['loss']['ly_type'], temp_acc, config['loss']['ly_type'], train_acc) config["out_file"].write(log_str) config["out_file"].flush() writer.add_scalar("validation accuracy", temp_acc, i) writer.add_scalar("training accuracy", train_acc, i) if early_stop_engine.is_stop_training(temp_acc): config["out_file"].write( "no improvement after {}, stop training at step {}\n". format(config["early_stop_patience"], i)) # config["out_file"].write("finish training! \n") break if (i + 1) % config["snapshot_interval"] == 0: torch.save( snapshot_obj, osp.join(config["output_path"], "iter_{:05d}_model.pth.tar".format(i))) ## train one iter base_network.train(True) optimizer = lr_scheduler(param_lr, optimizer, i, **schedule_param) optimizer.zero_grad() if i % len_train_source == 0: iter_source = iter(dset_loaders["source"]) if i % len_train_target == 0: iter_target = iter(dset_loaders["target"]) inputs_source, labels_source = iter_source.next() inputs_target, labels_target = iter_target.next() if use_gpu: inputs_source, inputs_target, labels_source = \ Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \ Variable(labels_source).cuda() else: inputs_source, inputs_target, labels_source = Variable(inputs_source), \ Variable(inputs_target), Variable(labels_source) inputs = torch.cat((inputs_source, inputs_target), dim=0) source_batch_size = inputs_source.size(0) if config['loss']['ly_type'] == 'cosine': features, logits = base_network(inputs) source_logits = logits.narrow(0, 0, source_batch_size) elif config['loss']['ly_type'] == 'euclidean': features, _ = base_network(inputs) logits = -1.0 * loss.distance_to_centroids( features, center_criterion.centers.detach()) source_logits = logits.narrow(0, 0, source_batch_size) ad_net.train(True) weight_ad = torch.ones(inputs.size(0)) transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \ weight_ad, use_gpu) ad_out, _ = ad_net(features.detach()) ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(ad_out) # source domain classification task loss classifier_loss = class_criterion(source_logits, labels_source) # fisher loss on labeled source domain fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion( features.narrow(0, 0, int(inputs.size(0) / 2)), labels_source, inter_class=config["loss"]["inter_type"], intra_loss_weight=loss_params["intra_loss_coef"], inter_loss_weight=loss_params["inter_loss_coef"]) # entropy minimization loss em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits)) # final loss total_loss = loss_params["trade_off"] * transfer_loss \ + fisher_loss \ + loss_params["em_loss_coef"] * em_loss \ + classifier_loss total_loss.backward() if center_grad is not None: # clear mmc_loss center_criterion.centers.grad.zero_() # Manually assign centers gradients other than using autograd center_criterion.centers.backward(center_grad) optimizer.step() if i % config["log_iter"] == 0: config['out_file'].write( 'iter {}: train total loss={:0.4f}, train transfer loss={:0.4f}, train classifier loss={:0.4f}, ' 'train entropy min loss={:0.4f}, ' 'train fisher loss={:0.4f}, train intra-group fisher loss={:0.4f}, train inter-group fisher loss={:0.4f}, ' 'source+target domain accuracy={:0.4f}, source domain accuracy={:0.4f}, target domain accuracy={:0.4f}\n' .format( i, total_loss.data.cpu().float().item(), transfer_loss.data.cpu().float().item(), classifier_loss.data.cpu().float().item(), em_loss.data.cpu().float().item(), fisher_loss.cpu().float().item(), fisher_intra_loss.cpu().float().item(), fisher_inter_loss.cpu().float().item(), ad_acc, source_acc_ad, target_acc_ad, )) config['out_file'].flush() writer.add_scalar("training total loss", total_loss.data.cpu().float().item(), i) writer.add_scalar("training classifier loss", classifier_loss.data.cpu().float().item(), i) writer.add_scalar("training transfer_loss", transfer_loss.data.cpu().float().item(), i) writer.add_scalar("training total fisher loss", fisher_loss.data.cpu().float().item(), i) writer.add_scalar("training intra-group fisher", fisher_intra_loss.data.cpu().float().item(), i) writer.add_scalar("training inter-group fisher", fisher_inter_loss.data.cpu().float().item(), i) writer.add_scalar("source+target domain accuracy", ad_acc, i) writer.add_scalar("source domain accuracy", source_acc_ad, i) writer.add_scalar("target domain accuracy", target_acc_ad, i) #attempted validation step #if i < len_valid_source: base_network.eval() with torch.no_grad(): if i % len_valid_source == 0: iter_source = iter(dset_loaders["source_valid"]) if i % len_valid_target == 0: iter_target = iter(dset_loaders["target_valid"]) try: inputs_source, labels_source = iter_source.next() inputs_target, labels_target = iter_target.next() except StopIteration: iter_source = iter(dset_loaders["source_valid"]) iter_target = iter(dset_loaders["target_valid"]) if use_gpu: inputs_source, inputs_target, labels_source = \ Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \ Variable(labels_source).cuda() else: inputs_source, inputs_target, labels_source = Variable(inputs_source), \ Variable(inputs_target), Variable(labels_source) inputs = torch.cat((inputs_source, inputs_target), dim=0) source_batch_size = inputs_source.size(0) if config['loss']['ly_type'] == 'cosine': features, logits = base_network(inputs) source_logits = logits.narrow(0, 0, source_batch_size) elif config['loss']['ly_type'] == 'euclidean': features, _ = base_network(inputs) logits = -1.0 * loss.distance_to_centroids( features, center_criterion.centers.detach()) source_logits = logits.narrow(0, 0, source_batch_size) ad_net.train(False) weight_ad = torch.ones(inputs.size(0)) transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \ weight_ad, use_gpu) ad_out, _ = ad_net(features.detach()) ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(ad_out) transfer_loss = transfer_criterion(features[:source_batch_size], features[source_batch_size:]) # source domain classification task loss classifier_loss = class_criterion(source_logits, labels_source.long()) # fisher loss on labeled source domain fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion( features.narrow(0, 0, int(inputs.size(0) / 2)), labels_source, inter_class=loss_params["inter_type"], intra_loss_weight=loss_params["intra_loss_coef"], inter_loss_weight=loss_params["inter_loss_coef"]) # entropy minimization loss em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits)) # final loss total_loss = loss_params["trade_off"] * transfer_loss \ + fisher_loss \ + loss_params["em_loss_coef"] * em_loss \ + classifier_loss #total_loss.backward() no backprop on the eval mode if i % config["log_iter"] == 0: config['out_file'].write( 'iter {}: valid total loss={:0.4f}, valid transfer loss={:0.4f}, valid classifier loss={:0.4f}, ' 'valid entropy min loss={:0.4f}, ' 'valid fisher loss={:0.4f}, valid intra-group fisher loss={:0.4f}, valid inter-group fisher loss={:0.4f}, ' 'source+target domain accuracy={:0.4f}, source domain accuracy={:0.4f}, target domain accuracy={:0.4f}\n' .format( i, total_loss.data.cpu().float().item(), transfer_loss.data.cpu().float().item(), classifier_loss.data.cpu().float().item(), em_loss.data.cpu().float().item(), fisher_loss.cpu().float().item(), fisher_intra_loss.cpu().float().item(), fisher_inter_loss.cpu().float().item(), ad_acc, source_acc_ad, target_acc_ad, )) config['out_file'].flush() writer.add_scalar("validation total loss", total_loss.data.cpu().float().item(), i) writer.add_scalar("validation classifier loss", classifier_loss.data.cpu().float().item(), i) writer.add_scalar("validation transfer_loss", transfer_loss.data.cpu().float().item(), i) writer.add_scalar("validation total fisher loss", fisher_loss.data.cpu().float().item(), i) writer.add_scalar("validation intra-group fisher", fisher_intra_loss.data.cpu().float().item(), i) writer.add_scalar("validation inter-group fisher", fisher_inter_loss.data.cpu().float().item(), i) writer.add_scalar("source+target domain accuracy", ad_acc, i) writer.add_scalar("source domain accuracy", source_acc_ad, i) writer.add_scalar("target domain accuracy", target_acc_ad, i) return best_acc