def main(): args = parser.parse_args() print(args) config = configure(args.config) os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu print(colored(f"Model directory: {args.model_dir}", 'green')) assert os.path.isfile(args.model_dir) dataset_name = get_dataset_name(args.src, args.tgt) dataset_config = config.data.dataset[dataset_name] src_file = os.path.join(args.dataset_root, dataset_name, args.src + '_list.txt') tgt_file = os.path.join(args.dataset_root, dataset_name, args.tgt + '_list.txt') model = Model(base_net=args.network, num_classes=dataset_config.num_classes, frozen_layer='') del model.classifier_layer del model.contrast_layer model_state_dict = model.state_dict() trained_state_dict = torch.load(args.model_dir)['weights'] keys = set(model_state_dict.keys()) trained_keys = set(trained_state_dict.keys()) shared_keys = keys.intersection(trained_keys) to_load_state_dict = {key: trained_state_dict[key] for key in shared_keys} model.load_state_dict(to_load_state_dict) model = model.cuda() # source classifier and domain classifier src_classifier = nn.Sequential( nn.Dropout(0.5), nn.Linear(model.base_network.out_dim, dataset_config.num_classes)) initialize_layer(src_classifier) parameter_list = [{"params": src_classifier.parameters(), "lr": 1}] src_classifier = src_classifier.cuda() domain_classifier = nn.Sequential(nn.Dropout(0.5), nn.Linear(model.base_network.out_dim, 2)) initialize_layer(domain_classifier) parameter_list += [{"params": domain_classifier.parameters(), "lr": 1}] domain_classifier = domain_classifier.cuda() group_ratios = [parameter['lr'] for parameter in parameter_list] optimizer = torch.optim.SGD(parameter_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov) assert args.lr_scheduler == 'inv' lr_scheduler = InvScheduler(gamma=args.gamma, decay_rate=args.decay_rate, group_ratios=group_ratios, init_lr=args.lr) # split into train and validation sets src_size = len(open(src_file).readlines()) src_train_size = int(args.train_portion * src_size) src_train_indices, src_test_indices = np.split( np.random.permutation(src_size), [src_train_size]) tgt_size = len(open(tgt_file).readlines()) tgt_train_size = int(args.train_portion * tgt_size) tgt_train_indices, tgt_test_indices = np.split( np.random.permutation(tgt_size), [tgt_train_size]) # define data loaders train_data_loader_kwargs = { 'shuffle': True, 'drop_last': True, 'batch_size': args.batch_size, 'num_workers': args.num_workers } test_data_loader_kwargs = { 'shuffle': False, 'drop_last': False, 'batch_size': args.eval_batch_size, 'num_workers': args.num_workers } train_transformer = get_transform(training=True) test_transformer = get_transform(training=False) data_loader = {} data_iterator = {} src_train_dataset = IndicesDataset(src_file, list(src_train_indices), transform=train_transformer) data_loader['src_train'] = DataLoader(src_train_dataset, **train_data_loader_kwargs) src_test_dataset = IndicesDataset(src_file, list(src_test_indices), transform=test_transformer) data_loader['src_test'] = DataLoader(src_test_dataset, **test_data_loader_kwargs) tgt_train_dataset = IndicesDataset(tgt_file, list(tgt_train_indices), transform=train_transformer) data_loader['tgt_train'] = DataLoader(tgt_train_dataset, **train_data_loader_kwargs) tgt_test_dataset = IndicesDataset(tgt_file, list(tgt_test_indices), transform=test_transformer) data_loader['tgt_test'] = DataLoader(tgt_test_dataset, **test_data_loader_kwargs) for key in data_loader: data_iterator[key] = iter(data_loader[key]) # start training total_progress_bar = tqdm.tqdm(desc='Iterations', total=args.max_iterations, ascii=True, smoothing=0.01) class_criterion = nn.CrossEntropyLoss() model.base_network.eval() src_classifier.train() domain_classifier.train() iteration = 0 while iteration < args.max_iterations: lr_scheduler.adjust_learning_rate(optimizer, iteration) optimizer.zero_grad() src_data = get_sample(data_loader, data_iterator, 'src_train') src_inputs, src_labels = src_data['image_1'].cuda( ), src_data['true_label'].cuda() tgt_data = get_sample(data_loader, data_iterator, 'tgt_train') tgt_inputs = tgt_data['image_1'].cuda() model.set_bn_domain(domain=0) with torch.no_grad(): src_features = model.base_network(src_inputs) src_features = F.normalize(src_features, p=2, dim=1) src_class_logits = src_classifier(src_features) src_domain_logits = domain_classifier(src_features) model.set_bn_domain(domain=1) with torch.no_grad(): tgt_features = model.base_network(tgt_inputs) tgt_features = F.normalize(tgt_features, p=2, dim=1) tgt_domain_logits = domain_classifier(tgt_features) src_classification_loss = class_criterion(src_class_logits, src_labels) domain_logits = torch.cat([src_domain_logits, tgt_domain_logits], dim=0) domain_labels = torch.tensor([0] * src_inputs.size(0) + [1] * tgt_inputs.size(0)).cuda() domain_classification_loss = class_criterion(domain_logits, domain_labels) if iteration % args.print_acc_interval == 0: compute_accuracy(src_class_logits, src_labels, acc_metric=dataset_config.acc_metric, print_result=True) compute_accuracy(domain_logits, domain_labels, acc_metric='total_mean', print_result=True) total_loss = src_classification_loss + domain_classification_loss total_loss.backward() optimizer.step() iteration += 1 total_progress_bar.update(1) # test model.base_network.eval() src_classifier.eval() domain_classifier.eval() with torch.no_grad(): src_all_class_logits = [] src_all_labels = [] src_all_domain_logits = [] model.set_bn_domain(domain=0) for src_test_data in tqdm.tqdm(data_loader['src_test'], desc='src_test', leave=False, ascii=True): src_test_inputs, src_test_labels = src_test_data['image_1'].cuda( ), src_test_data['true_label'].cuda() src_test_features = model.base_network(src_test_inputs) src_test_features = F.normalize(src_test_features, p=2, dim=1) src_test_class_logits = src_classifier(src_test_features) src_test_domain_logits = domain_classifier(src_test_features) src_all_class_logits += [src_test_class_logits] src_all_labels += [src_test_labels] src_all_domain_logits += [src_test_domain_logits] src_all_class_logits = torch.cat(src_all_class_logits, dim=0) src_all_labels = torch.cat(src_all_labels, dim=0) src_all_domain_logits = torch.cat(src_all_domain_logits, dim=0) src_test_class_acc = compute_accuracy( src_all_class_logits, src_all_labels, acc_metric=dataset_config.acc_metric, print_result=True) src_test_domain_acc = compute_accuracy( src_all_domain_logits, torch.zeros(src_all_domain_logits.size(0)).cuda(), acc_metric='total_mean', print_result=True) tgt_all_domain_logits = [] model.set_bn_domain(domain=1) for tgt_test_data in tqdm.tqdm(data_loader['tgt_test'], desc='tgt_test', leave=False, ascii=True): tgt_test_inputs = tgt_test_data['image_1'].cuda() tgt_test_features = model.base_network(tgt_test_inputs) tgt_test_features = F.normalize(tgt_test_features, p=2, dim=1) tgt_test_domain_logits = domain_classifier(tgt_test_features) tgt_all_domain_logits += [tgt_test_domain_logits] tgt_all_domain_logits = torch.cat(tgt_all_domain_logits, dim=0) tgt_test_domain_acc = compute_accuracy( tgt_all_domain_logits, torch.ones(tgt_all_domain_logits.size(0)).cuda(), acc_metric='total_mean', print_result=True) write_list = [ args.model_dir, src_test_class_acc, src_test_domain_acc, tgt_test_domain_acc ] # with open('hyper_search_office_home.csv', 'a') as f: with open(args.output_file, 'a') as f: csv_writer = csv.writer(f) csv_writer.writerow(write_list)
def main(): args = parser.parse_args() print(args) config = configure(args.config) os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu print(colored(f"Model directory: {args.model_dir}", 'green')) assert os.path.isfile(args.model_dir) dataset_name = get_dataset_name(args.src, args.tgt) dataset_config = config.data.dataset[dataset_name] tgt_file = os.path.join(args.dataset_root, dataset_name, args.tgt + '_list.txt') # tgt classification model = Model(base_net=args.network, num_classes=dataset_config.num_classes, frozen_layer='') del model.classifier_layer del model.contrast_layer model_state_dict = model.state_dict() trained_state_dict = torch.load(args.model_dir)['weights'] keys = set(model_state_dict.keys()) trained_keys = set(trained_state_dict.keys()) shared_keys = keys.intersection(trained_keys) to_load_state_dict = {key: trained_state_dict[key] for key in shared_keys} model.load_state_dict(to_load_state_dict) model = model.cuda() # define data loader test_data_loader_kwargs = { 'shuffle': False, 'drop_last': False, 'batch_size': args.batch_size, 'num_workers': args.num_workers } test_transformer = get_transform(training=False) data_loader = {} data_iterator = {} tgt_test_dataset = DefaultDataset(tgt_file, transform=test_transformer) data_loader['tgt_test'] = DataLoader(tgt_test_dataset, **test_data_loader_kwargs) for key in data_loader: data_iterator[key] = iter(data_loader[key]) # test model.base_network.eval() with torch.no_grad(): tgt_all_features = [] tgt_all_labels = [] model.set_bn_domain(domain=1) for tgt_test_data in tqdm.tqdm(data_loader['tgt_test'], desc='tgt_test', leave=False, ascii=True): tgt_test_inputs, tgt_test_labels = tgt_test_data['image_1'].cuda( ), tgt_test_data['true_label'].cuda() tgt_test_features = model.base_network(tgt_test_inputs) # tgt_test_features = F.normalize(tgt_test_features, p=2, dim=1) tgt_all_features += [tgt_test_features] tgt_all_labels += [tgt_test_labels] tgt_all_features = torch.cat(tgt_all_features, dim=0) tgt_all_labels = torch.cat(tgt_all_labels, dim=0) tgt_all_features = tgt_all_features.cpu().numpy() tgt_all_labels = tgt_all_labels.cpu().numpy() features_pickle_file = os.path.join('features', args.src + '_' + args.tgt, 'tgt_features.pkl') labels_pickle_file = os.path.join('features', args.src + '_' + args.tgt, 'tgt_labels.pkl') pickle.dump(tgt_all_features, open(features_pickle_file, 'wb')) pickle.dump(tgt_all_labels, open(labels_pickle_file, 'wb'))