def evaluate(loader, features, adj, network): losses, all_predictions, all_labels = AverageMeter(), [], [] network.eval() with torch.no_grad(): #hidden_features = c_net(features.cuda()) class_num = features.size(0) for batch_idx, (image_features, target) in enumerate(loader): batch, feat_dim = image_features.shape #relation = r_net(image_features.cuda(), hidden_features, adj) relation = network(image_features.cuda(), features.cuda(), adj.cuda()) one_hot_labels = torch.zeros(batch, class_num).scatter_( 1, target.view(-1, 1), 1).cuda() loss = F.mse_loss(torch.sigmoid(relation), one_hot_labels, reduction='mean') losses.update(loss.item(), batch) predict_labels = torch.argmax(relation, dim=1).cpu() all_predictions.append(predict_labels) all_labels.append(target) predictions = torch.cat(all_predictions, dim=0) labels = torch.cat(all_labels, dim=0) all_classes = sorted(list(set(labels.tolist()))) acc_per_classes = [] for idx, cls in enumerate(all_classes): assert idx <= cls, 'invalid all-classes : {:}'.format(all_classes) #print ('dataset : {:}'.format(loader.dataset)) indexes = labels == cls xpreds, xlabels = predictions[indexes], labels[indexes] acc_per_classes.append((xpreds == xlabels).float().mean().item()) acc_per_class = float(np.mean(acc_per_classes)) return losses.avg, ['{:3.1f}'.format(x * 100) for x in acc_per_classes], acc_per_class * 100
def evaluate(models, loader, prototypes, criterion, distance): losses, acc1, accuracies = AverageMeter(), AverageMeter( ), collections.defaultdict(list) #prototypes = F.normalize(prototypes, dim=1) cnn_model, model = models with torch.no_grad(): for batch_idx, (images, emb, continous, target) in enumerate(loader): batch, target = images.size(0), target.cuda() tensors = cnn_model(images) vectors = model(tensors) # target is 0, 1, 2, ..., test-classes-1 logits = -distance_func(vectors, prototypes, distance) #logits = torch.nn.functional.cosine_similarity(data.view(-1,1,2048), prototypes.view(1,-1,2048), dim=1) cls_loss = criterion(logits, target) losses.update(cls_loss.item(), batch) # log [accuracy] = obtain_accuracy(logits.data, target.data, (1, )) acc1.update(accuracy.item(), batch) corrects = (logits.argmax(dim=1) == target).cpu().tolist() target = target.cpu().tolist() for cls, ok in zip(target, corrects): accuracies[cls].append(ok) acc_per_class = [] for cls in accuracies.keys(): acc_per_class.append(np.mean(accuracies[cls])) acc_per_class = float(np.mean(acc_per_class)) return losses.avg, acc1.avg, acc_per_class * 100
def get_train_protos(network, attributes, train_classes, unseen_classes, all_class_loader, xargs): train_proto_path = '{:}/../train_proto_lists-{:}.pth'.format(xargs.log_dir, xargs.dataset) if os.path.exists(train_proto_path): train_proto_lists = torch.load(train_proto_path) else: # get the training protos over all images train_proto_lists = dict() num_per_class = defaultdict(lambda: 0) data_time, xend = AverageMeter(), time.time() all_class_sampler = all_class_loader.batch_sampler for ibatch, (feats, labels) in enumerate(all_class_loader): assert len(set(labels.tolist())) == 1 label = labels[0].item() num_per_class[label] += feats.size(0) if label not in train_proto_lists: train_proto_lists[label] = torch.sum(feats, dim=0) / len(all_class_sampler.label2index[label]) else: train_proto_lists[label]+= torch.sum(feats, dim=0) / len(all_class_sampler.label2index[label]) data_time.update(time.time() - xend) xend = time.time() if ibatch % 100 == 0 or ibatch + 1 == len(all_class_loader): Tstring = '{:} [{:03d}/{:03d}] AVG=({:.2f}, {:.2f})'.format(time_string(), ibatch, len(all_class_loader), data_time.val, data_time.avg) Tstring+= ' :: {:}'.format(convert_secs2time(data_time.avg * (len(all_class_loader)-ibatch), True)) print('***extract features*** : {:}'.format(Tstring)) # check numbers for key, item in num_per_class.items(): assert item == len(all_class_sampler.label2index[key]), '[{:}] : {:} vs {:} \n:::{:}'.format(key, item, len(all_class_sampler.label2index[label]), num_per_class) torch.save(train_proto_lists, train_proto_path) train_protos = [ train_proto_lists[cls] for cls in train_classes ] train_protos = torch.stack(train_protos).cuda() with torch.no_grad(): network.eval() raw_atts, attention = network.get_attention(attributes.cuda()) # get seen protos #seen_att = F.softmax(raw_atts[train_classes,:][:,train_classes], dim=1) # get unseen protos unseen_att = raw_atts[unseen_classes,:][:,train_classes] return train_protos, unseen_att
def extract_rotate_feats(dataset): data_root = '{:}/info-files/x-{:}-data-image.pth'.format(root_dir, dataset) xdata = torch.load(data_root) files = xdata['image_files'] save_dir = root_dir / 'rotate-infos' / dataset save_dir.mkdir(parents=True, exist_ok=True) imagepath2featpath = dict() avoid_duplicate = set() cnn_name = 'resnet101' backbone = obtain_backbone(cnn_name).cuda() backbone = torch.nn.DataParallel(backbone) backbone.eval() #print("CNN-Backbone ----> \n {:}".format(backbone)) # 3 is the number of augmentations simple_data = SIMPLE_DATA(files, 3, 'imagenet') #simple_loader = torch.utils.data.DataLoader(simple_data, batch_size=128, shuffle=False, num_workers=8, pin_memory=True) simple_loader = torch.utils.data.DataLoader(simple_data, batch_size=32, shuffle=False, num_workers=8, pin_memory=True) batch_time, xend = AverageMeter(), time.time() for idx, (indexes, tensor_000, tensor_090, tensor_180, tensor_270) in enumerate(simple_loader): with torch.no_grad(): feats_000 = [backbone(x) for x in tensor_000] feats_090 = [backbone(x) for x in tensor_090] feats_180 = [backbone(x) for x in tensor_180] feats_270 = [backbone(x) for x in tensor_270] for ii, image_idx in enumerate(indexes): x_feats_000 = torch.stack([x[ii] for x in feats_000]).cpu() x_feats_090 = torch.stack([x[ii] for x in feats_090]).cpu() x_feats_180 = torch.stack([x[ii] for x in feats_180]).cpu() x_feats_270 = torch.stack([x[ii] for x in feats_270]).cpu() ori_file_p = Path(files[image_idx.item()]) save_dir_xx = save_dir / ori_file_p.parent.name save_dir_xx.mkdir(parents=True, exist_ok=True) save_f_path = save_dir_xx / (ori_file_p.name.split('.')[0] + '.pth') torch.save( { 'feats-000': x_feats_000, 'feats-090': x_feats_090, 'feats-180': x_feats_180, 'feats-270': x_feats_270 }, save_f_path) imagepath2featpath[files[image_idx.item()]] = str(save_f_path) assert str( save_f_path ) not in avoid_duplicate, 'invalid path : {:}'.format( save_f_path) avoid_duplicate.add(str(save_f_path)) need_time = convert_secs2time( batch_time.val * (len(simple_loader) - idx), True) print('{:} : {:5d} / {:5d} : {:} : {:}'.format(time_string(), idx, len(simple_loader), need_time, save_f_path)) batch_time.update(time.time() - xend) xend = time.time() xdata['image2feat'] = imagepath2featpath torch.save(xdata, data_root) print('Update all-info in {:}, file size : {:.2f} GB'.format( data_root, os.path.getsize(data_root) / 1e9))
def train_model(loader, semantics, adj_distances, network, optimizer, config, logger): batch_time, Xlosses, CLSFlosses, Rlosses, accs, end = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), time.time() labelMeter, eps = AverageMeter(), 1e-7 network.train() loader.dataset.set_return_label_mode('new') #loader.dataset.set_return_img_mode('original') loader.dataset.set_return_img_mode('original_augment') logger.print('[TRAIN---{:}], semantics-shape={:}, adj_distances-shape={:}, config={:}'.format(config.epoch_str, semantics.shape, adj_distances.shape, config)) for batch_idx, (img_feat, targets) in enumerate(loader): batch_label_set = set(targets.tolist()) batch_label_lst = list(batch_label_set) class_num = len(batch_label_lst) batch = targets.shape[0] // 2 # assume train and val has the same amount batch_attrs = semantics[batch_label_lst, :] batch_adj_dis = adj_distances[batch_label_lst,:][:,batch_label_lst] train_targets, targets = torch.split(targets, (batch, batch)) img_feat_train, img_feat_val = torch.split(img_feat, (batch, batch), dim=0) img_proto_list = [] for lab in batch_label_lst: proto = img_feat_train[train_targets == lab].mean(dim=0) img_proto_list.append(proto) img_proto = torch.stack(img_proto_list) relations = network(img_feat_val.cuda(), img_proto.cuda(), batch_attrs.cuda(), batch_adj_dis.cuda()) raw_att_att, att_att = network.relation_module.get_attention_attribute() raw_att_img, att_img = network.relation_module.get_attention_img() if config.consistency_type == 'mse': consistency_loss = F.mse_loss(raw_att_att, raw_att_img) elif config.consistency_type == 'kla2i': consistency_loss = F.kl_div((att_att + eps).log(), att_img + eps, reduction='batchmean') elif config.consistency_type == 'kli2a': consistency_loss = F.kl_div((att_img + eps).log(), att_att + eps, reduction='batchmean') else: raise ValueError('Unknown consistency type: {:}'.format(config.consistency_type)) new_target_idxs = [batch_label_lst.index(x) for x in targets.tolist()] new_target_idxs = torch.LongTensor(new_target_idxs) one_hot_labels = torch.zeros(batch, class_num).scatter_(1, new_target_idxs.view(-1,1), 1).cuda() target__labels = new_target_idxs.cuda() if config.loss_type == 'sigmoid-mse': prediction = torch.sigmoid(relations) cls_loss = F.mse_loss(prediction, one_hot_labels, reduction='mean') elif re.match('softmax-*-*', config.loss_type, re.I): _, tempreture, epsilon = config.loss_type.split('-') tempreture, epsilon = float(tempreture), float(epsilon) if epsilon <= 0: cls_loss = F.cross_entropy(relations / tempreture, target__labels, weight=None, reduction='mean') else: log_probs = F.log_softmax(relations / tempreture, dim=1) _targets = torch.zeros_like(log_probs).scatter_(1, target__labels.unsqueeze(1), 1) _targets = (1-epsilon) * _targets + epsilon / relations.size(1) cls_loss = (-_targets * log_probs).sum(dim=1).mean() elif config.loss_type == 'softmax': cls_loss = F.cross_entropy(relations, target__labels, weight=None, reduction='mean') elif config.loss_type == 'mse': cls_loss = F.mse_loss(torch.sigmoid(relations), one_hot_labels, reduction='mean') elif config.loss_type == 'none': positive = -torch.masked_select(relations, one_hot_labels == 1) negative = torch.masked_select(relations, one_hot_labels == 0) losses = torch.cat([positive, negative]) cls_loss = losses.mean() else: raise ValueError('invalid loss type : {:}'.format(config.loss_type)) if config.consistency_coef > 0: loss = cls_loss + config.consistency_coef * consistency_loss else: loss = cls_loss optimizer.zero_grad() loss.backward() optimizer.step() # analysis Xlosses.update(loss.item(), batch) CLSFlosses.update(cls_loss.item(), batch) Rlosses.update(consistency_loss.item(), batch) predict_labels = torch.argmax(relations, dim=1) with torch.no_grad(): accuracy = (predict_labels.cpu() == new_target_idxs.cpu()).float().mean().item() accs.update(accuracy*100, batch) labelMeter.update(class_num, 1) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % config.log_interval == 0 or batch_idx + 1 == len(loader): Tstring = 'TIME[{batch_time.val:.2f} ({batch_time.avg:.2f})]'.format(batch_time=batch_time) Sstring = '{:} [{:}] [{:03d}/{:03d}]'.format(time_string(), config.epoch_str, batch_idx, len(loader)) Astring = 'loss={:.7f} ({:.5f}), cls_loss={:.7f} ({:.5f}), consistency_loss={:.7f} ({:.5f}), acc@1={:.1f} ({:.1f})'.format(Xlosses.val, Xlosses.avg, CLSFlosses.val, CLSFlosses.avg, Rlosses.val, Rlosses.avg, accs.val, accs.avg) logger.print('{:} {:} {:} B={:}, L={:} ({:.1f}) : {:}'.format(Sstring, Tstring, Astring, batch, class_num, labelMeter.avg, batch_label_lst[:3])) return Xlosses.avg, accs.avg
def main(xargs): # your main function # print some necessary informations # create logger if not os.path.exists(xargs.log_dir): os.makedirs(xargs.log_dir) logger = Logger(xargs.log_dir, xargs.manual_seed) logger.print('args :\n{:}'.format(xargs)) logger.print('PyTorch: {:}'.format(torch.__version__)) assert torch.cuda.is_available(), 'You must have at least one GPU' # set random seed #torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = True random.seed(xargs.manual_seed) np.random.seed(xargs.manual_seed) torch.manual_seed(xargs.manual_seed) torch.cuda.manual_seed(xargs.manual_seed) logger.print('Start Main with this file : {:}'.format(__file__)) graph_info = torch.load(Path(xargs.data_root)) unseen_classes = graph_info['unseen_classes'] train_classes = graph_info['train_classes'] # All labels return original value between 0-49 train_dataset = AwA2_IMG_Rotate_Save(graph_info, 'train') batch_size = xargs.class_per_it * xargs.num_shot total_episode = ((len(train_dataset) / batch_size) // 100 + 1) * 100 #train_sampler = MetaSampler(train_dataset, total_episode, xargs.class_per_it, xargs.num_shot) train_sampler = DualMetaSampler(train_dataset, total_episode, xargs.class_per_it, xargs.num_shot) train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=train_sampler, num_workers=xargs.num_workers) #train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True , num_workers=xargs.num_workers, drop_last=True) test_seen_dataset = AwA2_IMG_Rotate_Save(graph_info, 'test-seen') test_seen_dataset.set_return_img_mode('original') test_seen_loader = torch.utils.data.DataLoader(test_seen_dataset, batch_size=batch_size, shuffle=False, num_workers=xargs.num_workers) test_unseen_dataset = AwA2_IMG_Rotate_Save(graph_info, 'test-unseen') test_unseen_dataset.set_return_img_mode('original') test_unseen_loader = torch.utils.data.DataLoader(test_unseen_dataset, batch_size=batch_size, shuffle=False, num_workers=xargs.num_workers) all_class_sampler = AllClassSampler(train_dataset) all_class_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=all_class_sampler, num_workers=xargs.num_workers, pin_memory=True) logger.print('train-dataset : {:}'.format(train_dataset)) #logger.print('train_sampler : {:}'.format(train_sampler)) logger.print('test-seen-dataset : {:}'.format(test_seen_dataset)) logger.print('test-unseen-dataset : {:}'.format(test_unseen_dataset)) logger.print('all-class-train-sam : {:}'.format(all_class_sampler)) features = graph_info['ori_attributes'].float().cuda() train_features = features[graph_info['train_classes'], :] logger.print('feature-shape={:}, train-feature-shape={:}'.format(list(features.shape), list(train_features.shape))) kmeans = KMeans(n_clusters=xargs.clusters, random_state=1337).fit(train_features.cpu().numpy()) att_centers = torch.tensor(kmeans.cluster_centers_).float().cuda() for cls in range(xargs.clusters): logger.print('[cluster : {:}] has {:} elements.'.format(cls, (kmeans.labels_ == cls).sum())) logger.print('Train-Feature-Shape={:}, use {:} clusters, shape={:}'.format(train_features.shape, xargs.clusters, att_centers.shape)) # build adjacent matrix distances = distance_func(graph_info['attributes'], graph_info['attributes'], 'euclidean-pow').float().cuda() xallx_adj_dis = distances.clone() train_adj_dis = distances[graph_info['train_classes'],:][:,graph_info['train_classes']] network = obtain_combine_models_v2(xargs.semantic_name, xargs.relation_name, att_centers, 2048) network = network.cuda() #parameters = [{'params': list(C_Net.parameters()), 'lr': xargs.lr*5, 'weight_decay': xargs.weight_decay*0.1}, # {'params': list(R_Net.parameters()), 'lr': xargs.lr , 'weight_decay': xargs.weight_decay}] parameters = network.parameters() optimizer = torch.optim.Adam(parameters, lr=xargs.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=xargs.weight_decay, amsgrad=False) #optimizer = torch.optim.SGD(parameters, lr=xargs.lr, momentum=0.9, weight_decay=xargs.weight_decay, nesterov=True) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, gamma=0.1, step_size=xargs.epochs*2//3) logger.print('network : {:.2f} MB =>>>\n{:}'.format(count_parameters_in_MB(network), network)) logger.print('optimizer : {:}'.format(optimizer)) #import pdb; pdb.set_trace() model_lst_path = logger.checkpoint('ckp-last-{:}.pth'.format(xargs.manual_seed)) if os.path.isfile(model_lst_path): checkpoint = torch.load(model_lst_path) start_epoch = checkpoint['epoch'] + 1 best_accs = checkpoint['best_accs'] network.load_state_dict(checkpoint['network']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['scheduler']) logger.print('load checkpoint from {:}'.format(model_lst_path)) else: start_epoch, best_accs = 0, {'train': -1, 'xtrain': -1, 'zs': -1, 'gzs-seen': -1, 'gzs-unseen': -1, 'gzs-H':-1, 'best-info': None} epoch_time, start_time = AverageMeter(), time.time() # training for iepoch in range(start_epoch, xargs.epochs): # set some classes as fake zero-shot classes time_str = convert_secs2time(epoch_time.val * (xargs.epochs- iepoch), True) epoch_str= '{:03d}/{:03d}'.format(iepoch, xargs.epochs) # last_lr = lr_scheduler.get_last_lr() last_lr = lr_scheduler.get_lr() logger.print('Train the {:}-th epoch, {:}, LR={:1.6f} ~ {:1.6f}'.format(epoch_str, time_str, min(last_lr), max(last_lr))) config_train = load_configure(None, {'epoch_str': epoch_str, 'log_interval': xargs.log_interval, 'loss_type': xargs.loss_type, 'consistency_coef': xargs.consistency_coef, 'consistency_type': xargs.consistency_type}, None) train_cls_loss, train_acc = train_model(train_loader, train_features, train_adj_dis, network, optimizer, config_train, logger) lr_scheduler.step() if train_acc > best_accs['train']: best_accs['train'] = train_acc logger.print('Train {:} done, cls-loss={:.3f}, accuracy={:.2f}%, (best={:.2f}).\n'.format(epoch_str, train_cls_loss, train_acc, best_accs['train'])) if iepoch % xargs.test_interval == 0 or iepoch == xargs.epochs -1: with torch.no_grad(): xinfo = {'train_classes' : graph_info['train_classes'], 'unseen_classes': graph_info['unseen_classes']} train_loader.dataset.set_return_img_mode('original') all_class_loader.dataset.set_return_label_mode('original') all_class_loader.dataset.set_return_img_mode('original') seen_protos, unseen_att = get_train_protos(network, features, train_classes, unseen_classes, all_class_loader, xargs) for test_topK in range(1, 2): logger.print('-----test--init with top-{:} seen protos-------'.format(test_topK)) topkATT, topkIDX = torch.topk(unseen_att, test_topK, dim=1) norm_att = F.softmax(topkATT, dim=1) unseen_protos = norm_att.view(len(unseen_classes), test_topK, 1) * seen_protos[topkIDX] unseen_protos = unseen_protos.mean(dim=1) protos = [] for icls in range(features.size(0)): if icls in train_classes: protos.append( seen_protos[ train_classes.index(icls) ] ) else : protos.append( unseen_protos[ unseen_classes.index(icls) ] ) protos = torch.stack(protos) train_loader.dataset.set_return_img_mode('original') evaluate_all_dual(epoch_str, train_loader, test_unseen_loader, test_seen_loader, features, protos, xallx_adj_dis, network, xinfo, best_accs, logger) semantic_lists = network.get_semantic_list(features) # save the info info = {'epoch' : iepoch, 'args' : deepcopy(xargs), 'finish' : iepoch+1==xargs.epochs, 'best_accs' : best_accs, 'semantic_lists' : semantic_lists, 'adj_distances' : xallx_adj_dis, 'network' : network.state_dict(), 'optimizer' : optimizer.state_dict(), 'scheduler' : lr_scheduler.state_dict(), } try: torch.save(info, model_lst_path) logger.print('--->>> joint-arch :: save into {:}.\n'.format(model_lst_path)) except PermmisionError: print('unsuccessful write log') # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() if 'info' in locals() or 'checkpoint' in locals(): if 'checkpoint' in locals(): semantic_lists = checkpoint['semantic_lists'] else: semantic_lists = info['semantic_lists'] ''' # the final evaluation logger.print('final evaluation --->>>') with torch.no_grad(): xinfo = {'train_classes' : graph_info['train_classes'], 'unseen_classes': graph_info['unseen_classes']} train_loader.dataset.set_return_img_mode('original') evaluate_all('final-eval', train_loader, test_unseen_loader, test_seen_loader, features, xallx_adj_dis, network, xinfo, best_accs, logger) logger.print('-'*200) ''' logger.close()