Ejemplo n.º 1
0
def evaluate(models, loader, prototypes, criterion, distance):
    losses, acc1, accuracies = AverageMeter(), AverageMeter(
    ), collections.defaultdict(list)
    #prototypes = F.normalize(prototypes, dim=1)
    cnn_model, model = models
    with torch.no_grad():
        for batch_idx, (images, emb, continous, target) in enumerate(loader):
            batch, target = images.size(0), target.cuda()
            tensors = cnn_model(images)
            vectors = model(tensors)
            # target is 0, 1, 2, ..., test-classes-1
            logits = -distance_func(vectors, prototypes, distance)
            #logits = torch.nn.functional.cosine_similarity(data.view(-1,1,2048), prototypes.view(1,-1,2048), dim=1)
            cls_loss = criterion(logits, target)
            losses.update(cls_loss.item(), batch)
            # log
            [accuracy] = obtain_accuracy(logits.data, target.data, (1, ))
            acc1.update(accuracy.item(), batch)
            corrects = (logits.argmax(dim=1) == target).cpu().tolist()
            target = target.cpu().tolist()
            for cls, ok in zip(target, corrects):
                accuracies[cls].append(ok)
    acc_per_class = []
    for cls in accuracies.keys():
        acc_per_class.append(np.mean(accuracies[cls]))
    acc_per_class = float(np.mean(acc_per_class))
    return losses.avg, acc1.avg, acc_per_class * 100
Ejemplo n.º 2
0
 def get_attention(self, attributes):
     att_prop_g = torch.mm(attributes, self.att_g)
     att_prop_h = torch.mm(attributes, self.att_g)
     distances = distance_func(att_prop_g, att_prop_h, 'cosine')
     zero_vec = -9e15 * torch.ones_like(distances)
     raw_attss = torch.where(distances > self.thresh, distances, zero_vec)
     attention = F.softmax(raw_attss * self.T, dim=1)
     return raw_attss, attention
Ejemplo n.º 3
0
 def get_attention(self, attributes, choice="attribute"):
     if choice == "attribute": att_g = self.att_g_att
     elif choice == "img": att_g = self.att_g_img
     else: raise ValueError("invalid choice {:}".format(choice))
     att_prop_g = torch.mm(attributes, att_g)
     att_prop_h = torch.mm(attributes, att_g)
     distances = distance_func(att_prop_g, att_prop_h, 'cosine')
     zero_vec = -9e15 * torch.ones_like(distances)
     raw_attss = torch.where(distances > self.thresh, distances, zero_vec)
     attention = F.softmax(raw_attss * self.T, dim=1)
     #return raw_attss, attention
     return distances, attention
Ejemplo n.º 4
0
 def get_new_attribute(self, attributes):
     if self.n_hop == 0: return attributes
     for ihop in range(self.n_hop):
         att_prop_g = torch.mm(attributes, self.att_g)
         att_prop_h = torch.mm(attributes, self.att_g)
         distances = distance_func(att_prop_g, att_prop_h, 'cosine')
         zero_vec = -9e15 * torch.ones_like(distances)
         raw_attss = torch.where(distances > self.thresh, distances,
                                 zero_vec)
         attention = F.softmax(raw_attss * self.T, dim=1)
         att_outs = torch.mm(attention, attributes)
         # update attributes
         attributes = att_outs
     return att_outs, distances > self.thresh
Ejemplo n.º 5
0
def main(xargs):
  # your main function
  # print some necessary informations
  # create logger
  if not os.path.exists(xargs.log_dir):
    os.makedirs(xargs.log_dir)
  logger = Logger(xargs.log_dir, xargs.manual_seed)
  logger.print('args :\n{:}'.format(xargs))
  logger.print('PyTorch: {:}'.format(torch.__version__))

  assert torch.cuda.is_available(), 'You must have at least one GPU'

  # set random seed
  #torch.backends.cudnn.benchmark = True
  torch.backends.cudnn.deterministic = True
  random.seed(xargs.manual_seed)
  np.random.seed(xargs.manual_seed)
  torch.manual_seed(xargs.manual_seed)
  torch.cuda.manual_seed(xargs.manual_seed)

  logger.print('Start Main with this file : {:}'.format(__file__))
  graph_info = torch.load(Path(xargs.data_root))
  unseen_classes = graph_info['unseen_classes']
  train_classes  = graph_info['train_classes']

  # All labels return original value between 0-49
  train_dataset       = AwA2_IMG_Rotate_Save(graph_info, 'train')
  batch_size          = xargs.class_per_it * xargs.num_shot
  total_episode       = ((len(train_dataset) / batch_size) // 100 + 1) * 100
  #train_sampler       = MetaSampler(train_dataset, total_episode, xargs.class_per_it, xargs.num_shot)
  train_sampler       = DualMetaSampler(train_dataset, total_episode, xargs.class_per_it, xargs.num_shot) 
  train_loader        = torch.utils.data.DataLoader(train_dataset, batch_sampler=train_sampler, num_workers=xargs.num_workers)
  #train_loader        = torch.utils.data.DataLoader(train_dataset,       batch_size=batch_size, shuffle=True , num_workers=xargs.num_workers, drop_last=True)
  test_seen_dataset   = AwA2_IMG_Rotate_Save(graph_info, 'test-seen')
  test_seen_dataset.set_return_img_mode('original')
  test_seen_loader    = torch.utils.data.DataLoader(test_seen_dataset,   batch_size=batch_size, shuffle=False, num_workers=xargs.num_workers)
  test_unseen_dataset = AwA2_IMG_Rotate_Save(graph_info, 'test-unseen')
  test_unseen_dataset.set_return_img_mode('original')
  test_unseen_loader  = torch.utils.data.DataLoader(test_unseen_dataset, batch_size=batch_size, shuffle=False, num_workers=xargs.num_workers)
  all_class_sampler   = AllClassSampler(train_dataset)
  all_class_loader    = torch.utils.data.DataLoader(train_dataset, batch_sampler=all_class_sampler, num_workers=xargs.num_workers, pin_memory=True)
  logger.print('train-dataset       : {:}'.format(train_dataset))
  #logger.print('train_sampler       : {:}'.format(train_sampler))
  logger.print('test-seen-dataset   : {:}'.format(test_seen_dataset))
  logger.print('test-unseen-dataset : {:}'.format(test_unseen_dataset))
  logger.print('all-class-train-sam : {:}'.format(all_class_sampler))

  features       = graph_info['ori_attributes'].float().cuda()
  train_features = features[graph_info['train_classes'], :]
  logger.print('feature-shape={:}, train-feature-shape={:}'.format(list(features.shape), list(train_features.shape)))

  kmeans = KMeans(n_clusters=xargs.clusters, random_state=1337).fit(train_features.cpu().numpy())
  att_centers = torch.tensor(kmeans.cluster_centers_).float().cuda()
  for cls in range(xargs.clusters):
    logger.print('[cluster : {:}] has {:} elements.'.format(cls, (kmeans.labels_ == cls).sum()))
  logger.print('Train-Feature-Shape={:}, use {:} clusters, shape={:}'.format(train_features.shape, xargs.clusters, att_centers.shape))

  # build adjacent matrix
  distances     = distance_func(graph_info['attributes'], graph_info['attributes'], 'euclidean-pow').float().cuda()
  xallx_adj_dis = distances.clone()
  train_adj_dis = distances[graph_info['train_classes'],:][:,graph_info['train_classes']]

  network = obtain_combine_models_v2(xargs.semantic_name, xargs.relation_name, att_centers, 2048)
  network = network.cuda()

  #parameters = [{'params': list(C_Net.parameters()), 'lr': xargs.lr*5, 'weight_decay': xargs.weight_decay*0.1},
  #              {'params': list(R_Net.parameters()), 'lr': xargs.lr  , 'weight_decay': xargs.weight_decay}]
  parameters = network.parameters()
  optimizer  = torch.optim.Adam(parameters, lr=xargs.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=xargs.weight_decay, amsgrad=False)
  #optimizer = torch.optim.SGD(parameters, lr=xargs.lr, momentum=0.9, weight_decay=xargs.weight_decay, nesterov=True)
  lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, gamma=0.1, step_size=xargs.epochs*2//3)
  logger.print('network : {:.2f} MB =>>>\n{:}'.format(count_parameters_in_MB(network), network))
  logger.print('optimizer : {:}'.format(optimizer))
  
  #import pdb; pdb.set_trace()
  model_lst_path  = logger.checkpoint('ckp-last-{:}.pth'.format(xargs.manual_seed))
  if os.path.isfile(model_lst_path):
    checkpoint  = torch.load(model_lst_path)
    start_epoch = checkpoint['epoch'] + 1
    best_accs   = checkpoint['best_accs']
    network.load_state_dict(checkpoint['network'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    lr_scheduler.load_state_dict(checkpoint['scheduler'])
    logger.print('load checkpoint from {:}'.format(model_lst_path))
  else:
    start_epoch, best_accs = 0, {'train': -1, 'xtrain': -1, 'zs': -1, 'gzs-seen': -1, 'gzs-unseen': -1, 'gzs-H':-1, 'best-info': None}
  
  epoch_time, start_time = AverageMeter(), time.time()
  # training
  for iepoch in range(start_epoch, xargs.epochs):
    # set some classes as fake zero-shot classes
    time_str = convert_secs2time(epoch_time.val * (xargs.epochs- iepoch), True) 
    epoch_str= '{:03d}/{:03d}'.format(iepoch, xargs.epochs)
    # last_lr  = lr_scheduler.get_last_lr()
    last_lr  = lr_scheduler.get_lr()
    logger.print('Train the {:}-th epoch, {:}, LR={:1.6f} ~ {:1.6f}'.format(epoch_str, time_str, min(last_lr), max(last_lr)))
  
    config_train = load_configure(None, {'epoch_str': epoch_str, 'log_interval': xargs.log_interval,
                                         'loss_type': xargs.loss_type,
                                         'consistency_coef': xargs.consistency_coef,
                                         'consistency_type': xargs.consistency_type}, None)

    train_cls_loss, train_acc = train_model(train_loader, train_features, train_adj_dis, network, optimizer, config_train, logger)
    
    lr_scheduler.step()
    if train_acc > best_accs['train']: best_accs['train'] = train_acc
    logger.print('Train {:} done, cls-loss={:.3f}, accuracy={:.2f}%, (best={:.2f}).\n'.format(epoch_str, train_cls_loss, train_acc, best_accs['train']))

    if iepoch % xargs.test_interval == 0 or iepoch == xargs.epochs -1:
      with torch.no_grad():
        xinfo = {'train_classes' : graph_info['train_classes'], 'unseen_classes': graph_info['unseen_classes']}
        train_loader.dataset.set_return_img_mode('original')
        all_class_loader.dataset.set_return_label_mode('original')
        all_class_loader.dataset.set_return_img_mode('original')
        seen_protos, unseen_att = get_train_protos(network, features, train_classes, unseen_classes, all_class_loader, xargs)
        for test_topK in range(1, 2):
          logger.print('-----test--init with top-{:} seen protos-------'.format(test_topK))
          topkATT, topkIDX = torch.topk(unseen_att, test_topK, dim=1)
          norm_att      = F.softmax(topkATT, dim=1)
          unseen_protos = norm_att.view(len(unseen_classes), test_topK, 1) * seen_protos[topkIDX]
          unseen_protos = unseen_protos.mean(dim=1)
          protos = []
          for icls in range(features.size(0)):
            if icls in train_classes: protos.append( seen_protos[ train_classes.index(icls) ] )
            else                    : protos.append( unseen_protos[ unseen_classes.index(icls) ] )
          protos = torch.stack(protos)
          train_loader.dataset.set_return_img_mode('original')
          evaluate_all_dual(epoch_str, train_loader, test_unseen_loader, test_seen_loader, features, protos, xallx_adj_dis, network, xinfo, best_accs, logger)

    semantic_lists = network.get_semantic_list(features)
    # save the info
    info = {'epoch'           : iepoch,
            'args'            : deepcopy(xargs),
            'finish'          : iepoch+1==xargs.epochs,
            'best_accs'       : best_accs,
            'semantic_lists'  : semantic_lists,
            'adj_distances'   : xallx_adj_dis,
            'network'         : network.state_dict(),
            'optimizer'       : optimizer.state_dict(),
            'scheduler'       : lr_scheduler.state_dict(),
            }
    try:
      torch.save(info, model_lst_path)
      logger.print('--->>> joint-arch :: save into {:}.\n'.format(model_lst_path))
    except PermmisionError:
      print('unsuccessful write log')

    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()
  if 'info' in locals() or 'checkpoint' in locals():
    if 'checkpoint' in locals():
      semantic_lists = checkpoint['semantic_lists']
    else:
      semantic_lists = info['semantic_lists']
  '''
  # the final evaluation
  logger.print('final evaluation --->>>')
  with torch.no_grad():
    xinfo = {'train_classes' : graph_info['train_classes'], 'unseen_classes': graph_info['unseen_classes']}
    train_loader.dataset.set_return_img_mode('original')
    evaluate_all('final-eval', train_loader, test_unseen_loader, test_seen_loader, features, xallx_adj_dis, network, xinfo, best_accs, logger)
  logger.print('-'*200)
  '''
  logger.close()