Exemplo n.º 1
0
def evaluate_all(epoch_str, protos, models, train_loader, test_unseen_loader, \
                         test_seen_loader, cls_loss, best_accs, distance, logger):
    logger.print('Evaluate [{:}] with distance={:}'.format(
        epoch_str, distance))
    # calculate zero shot setting
    target_wnid_indexes, _, _ = test_unseen_loader.dataset.get_all_wnid_index()
    test_unseen_loader.dataset.set_return_label_mode('new')
    test_loss, test_acc, test_per_cls_acc = evaluate(
        models, test_unseen_loader, protos[target_wnid_indexes], cls_loss,
        distance)
    if best_accs['zs'] < test_per_cls_acc: best_accs['zs'] = test_per_cls_acc
    logger.print(
        'Test {:} [zero-zero-zero-zero-shot----] {:} done, loss={:.3f}, accuracy={:5.2f}%, per-class-acc={:5.2f}% (TTEST-best={:5.2f}%).'
        .format(time_string(), epoch_str, test_loss, test_acc,
                test_per_cls_acc, best_accs['zs']))
    # calculate generalized zero-shot setting
    _, target_wnid_indexes, _ = train_loader.dataset.get_all_wnid_index()
    target_protos = protos[target_wnid_indexes]
    train_loader.dataset.set_return_label_mode('original')
    train_loss, train_acc, train_per_cls_acc = evaluate(
        models, train_loader, target_protos, cls_loss, distance)
    if best_accs['xtrain'] < train_per_cls_acc:
        best_accs['xtrain'] = train_per_cls_acc
    logger.print(
        'Test {:} [train-train-train-train-----] {:} done, loss={:.3f}, accuracy={:5.2f}%, per-class-acc={:5.2f}% (TRAIN-best={:5.2f}%).'
        .format(time_string(), epoch_str, train_loss, train_acc,
                train_per_cls_acc, best_accs['xtrain']))
    _, target_wnid_indexes, _ = test_unseen_loader.dataset.get_all_wnid_index()
    target_protos = protos[target_wnid_indexes]
    test_unseen_loader.dataset.set_return_label_mode('original')
    test_loss_unseen, test_acc_unseen, test_per_cls_acc_unseen = evaluate(
        models, test_unseen_loader, target_protos, cls_loss, distance)
    if best_accs['gzs-unseen'] < test_per_cls_acc_unseen:
        best_accs['gzs-unseen'] = test_per_cls_acc_unseen
    logger.print(
        'Test {:} [generalized-zero-shot-unseen] {:} done, loss={:.3f}, accuracy={:5.2f}%, per-class-acc={:5.2f}% (TUNSN-best={:5.2f}%).'
        .format(time_string(), epoch_str, test_loss_unseen, test_acc_unseen,
                test_per_cls_acc_unseen, best_accs['gzs-unseen']))
    # for test data with seen classes
    test_seen_loader.dataset.set_return_label_mode('original')
    test_loss_seen, test_acc_seen, test_per_cls_acc_seen = evaluate(
        models, test_seen_loader, target_protos, cls_loss, distance)
    if best_accs['gzs-seen'] < test_per_cls_acc_seen:
        best_accs['gzs-seen'] = test_per_cls_acc_seen
    logger.print(
        'Test {:} [generalized-zero-shot---seen] {:} done, loss={:.3f}, accuracy={:5.2f}%, per-class-acc={:5.2f}% (TSEEN-best={:5.2f}%).'
        .format(time_string(), epoch_str, test_loss_seen, test_acc_seen,
                test_per_cls_acc_seen, best_accs['gzs-seen']))
    harmonic_mean = (2 * test_per_cls_acc_seen * test_per_cls_acc_unseen) / (
        test_per_cls_acc_seen + test_per_cls_acc_unseen + 1e-8)
    if best_accs['gzs-H'] < harmonic_mean:
        best_accs['gzs-H'] = harmonic_mean
        best_accs[
            'best-info'] = '[{:}] seen={:5.2f}% unseen={:5.2f}%, H={:5.2f}%'.format(
                epoch_str, test_per_cls_acc_seen, test_per_cls_acc_unseen,
                harmonic_mean)
    logger.print(
        'Test [generalized-zero-shot-h-mean] {:} H={:.3f}% (HH-best={:.3f}%). ||| Best comes from {:}'
        .format(epoch_str, harmonic_mean, best_accs['gzs-H'],
                best_accs['best-info']))
Exemplo n.º 2
0
def get_train_protos(network, attributes, train_classes, unseen_classes, all_class_loader, xargs):
  train_proto_path = '{:}/../train_proto_lists-{:}.pth'.format(xargs.log_dir, xargs.dataset)
  if os.path.exists(train_proto_path):
    train_proto_lists = torch.load(train_proto_path)
  else:
    # get the training protos over all images
    train_proto_lists = dict()
    num_per_class     = defaultdict(lambda: 0)
    data_time, xend   = AverageMeter(), time.time()
    all_class_sampler = all_class_loader.batch_sampler
    for ibatch, (feats, labels) in enumerate(all_class_loader):
      assert len(set(labels.tolist())) == 1
      label = labels[0].item()
      num_per_class[label] += feats.size(0)
      if label not in train_proto_lists:
        train_proto_lists[label] = torch.sum(feats, dim=0) / len(all_class_sampler.label2index[label])
      else:
        train_proto_lists[label]+= torch.sum(feats, dim=0) / len(all_class_sampler.label2index[label])

      data_time.update(time.time() - xend)
      xend = time.time()
      if ibatch % 100 == 0 or ibatch + 1 == len(all_class_loader):
        Tstring = '{:} [{:03d}/{:03d}] AVG=({:.2f}, {:.2f})'.format(time_string(), ibatch, len(all_class_loader), data_time.val, data_time.avg)
        Tstring+= ' :: {:}'.format(convert_secs2time(data_time.avg * (len(all_class_loader)-ibatch), True))
        print('***extract features*** : {:}'.format(Tstring))
    # check numbers
    for key, item in num_per_class.items():
      assert item == len(all_class_sampler.label2index[key]), '[{:}] : {:} vs {:} \n:::{:}'.format(key, item, len(all_class_sampler.label2index[label]), num_per_class)
    torch.save(train_proto_lists, train_proto_path)

  train_protos = [ train_proto_lists[cls] for cls in train_classes ]
  train_protos = torch.stack(train_protos).cuda()
  with torch.no_grad():
    network.eval()
    raw_atts, attention = network.get_attention(attributes.cuda())
    # get seen protos
    #seen_att    = F.softmax(raw_atts[train_classes,:][:,train_classes], dim=1)
    # get unseen protos
    unseen_att = raw_atts[unseen_classes,:][:,train_classes]
  return train_protos, unseen_att
Exemplo n.º 3
0
def evaluate_all(epoch_str, train_loader, test_unseen_loader, \
                         test_seen_loader, features, adj_dis, network, \
                         info, best_accs, logger):
    train_classes, unseen_classes = info['train_classes'], info[
        'unseen_classes']
    logger.print('Evaluate [{:}]'.format(epoch_str))
    # calculate zero shot setting
    test_unseen_loader.dataset.set_return_label_mode('new')
    target_semantics = features[unseen_classes, :]
    target_adj_dis = adj_dis[unseen_classes, :][:, unseen_classes]
    test_loss, _, test_per_cls_acc = evaluate(test_unseen_loader,
                                              target_semantics, target_adj_dis,
                                              network)
    if best_accs['zs'] < test_per_cls_acc: best_accs['zs'] = test_per_cls_acc
    logger.print(
        'Test {:} [zero-zero-zero-zero-shot----] {:} done, loss={:.3f}, per-class-acc={:5.2f}% (TTEST-best={:5.2f}%).'
        .format(time_string(), epoch_str, test_loss, test_per_cls_acc,
                best_accs['zs']))
    # calculate generalized zero-shot setting
    train_loader.dataset.set_return_label_mode('original')
    train_loss, _, train_per_cls_acc = evaluate(train_loader, features,
                                                adj_dis, network)
    if best_accs['xtrain'] < train_per_cls_acc:
        best_accs['xtrain'] = train_per_cls_acc
    logger.print(
        'Test {:} [train-train-train-train-----] {:} done, loss={:.3f}, per-class-acc={:5.2f}% (TRAIN-best={:5.2f}%).'
        .format(time_string(), epoch_str, train_loss, train_per_cls_acc,
                best_accs['xtrain']))
    test_unseen_loader.dataset.set_return_label_mode('original')
    test_loss_unseen, test_unsn_accs, test_per_cls_acc_unseen = evaluate(
        test_unseen_loader, features, adj_dis, network)
    if best_accs['gzs-unseen'] < test_per_cls_acc_unseen:
        best_accs['gzs-unseen'] = test_per_cls_acc_unseen
    logger.print(
        'Test {:} [generalized-zero-shot-unseen] {:} done, loss={:.3f}, per-class-acc={:5.2f}% (TUNSN-best={:5.2f}%).'
        .format(time_string(), epoch_str, test_loss_unseen,
                test_per_cls_acc_unseen, best_accs['gzs-unseen']))
    #logger.print('Test {:} [generalized-zero-shot-unseen] {:} ::: {:}.'.format(time_string(), epoch_str, test_unsn_accs))
    # for test data with seen classes
    test_seen_loader.dataset.set_return_label_mode('original')
    test_loss_seen, test_seen_accs, test_per_cls_acc_seen = evaluate(
        test_seen_loader, features, adj_dis, network)
    if best_accs['gzs-seen'] < test_per_cls_acc_seen:
        best_accs['gzs-seen'] = test_per_cls_acc_seen
    logger.print(
        'Test {:} [generalized-zero-shot---seen] {:} done, loss={:.3f}, per-class-acc={:5.2f}% (TSEEN-best={:5.2f}%).'
        .format(time_string(), epoch_str, test_loss_seen,
                test_per_cls_acc_seen, best_accs['gzs-seen']))
    #logger.print('Test {:} [generalized-zero-shot---seen] {:} ::: {:}.'.format(time_string(), epoch_str, test_seen_accs))
    harmonic_mean = (2 * test_per_cls_acc_seen * test_per_cls_acc_unseen) / (
        test_per_cls_acc_seen + test_per_cls_acc_unseen + 1e-8)
    if best_accs['gzs-H'] < harmonic_mean:
        best_accs['gzs-H'] = harmonic_mean
        best_accs[
            'best-info'] = '[{:}] seen={:5.2f}% unseen={:5.2f}%, H={:5.2f}%'.format(
                epoch_str, test_per_cls_acc_seen, test_per_cls_acc_unseen,
                harmonic_mean)
    logger.print(
        'Test [generalized-zero-shot-h-mean] {:} H={:.3f}% (HH-best={:.3f}%). ||| Best comes from {:}'
        .format(epoch_str, harmonic_mean, best_accs['gzs-H'],
                best_accs['best-info']))
Exemplo n.º 4
0
def extract_rotate_feats(dataset):
    data_root = '{:}/info-files/x-{:}-data-image.pth'.format(root_dir, dataset)
    xdata = torch.load(data_root)
    files = xdata['image_files']
    save_dir = root_dir / 'rotate-infos' / dataset
    save_dir.mkdir(parents=True, exist_ok=True)
    imagepath2featpath = dict()
    avoid_duplicate = set()

    cnn_name = 'resnet101'
    backbone = obtain_backbone(cnn_name).cuda()
    backbone = torch.nn.DataParallel(backbone)
    backbone.eval()
    #print("CNN-Backbone ----> \n {:}".format(backbone))

    # 3 is the number of augmentations
    simple_data = SIMPLE_DATA(files, 3, 'imagenet')
    #simple_loader = torch.utils.data.DataLoader(simple_data, batch_size=128, shuffle=False, num_workers=8, pin_memory=True)
    simple_loader = torch.utils.data.DataLoader(simple_data,
                                                batch_size=32,
                                                shuffle=False,
                                                num_workers=8,
                                                pin_memory=True)

    batch_time, xend = AverageMeter(), time.time()
    for idx, (indexes, tensor_000, tensor_090, tensor_180,
              tensor_270) in enumerate(simple_loader):
        with torch.no_grad():
            feats_000 = [backbone(x) for x in tensor_000]
            feats_090 = [backbone(x) for x in tensor_090]
            feats_180 = [backbone(x) for x in tensor_180]
            feats_270 = [backbone(x) for x in tensor_270]

            for ii, image_idx in enumerate(indexes):
                x_feats_000 = torch.stack([x[ii] for x in feats_000]).cpu()
                x_feats_090 = torch.stack([x[ii] for x in feats_090]).cpu()
                x_feats_180 = torch.stack([x[ii] for x in feats_180]).cpu()
                x_feats_270 = torch.stack([x[ii] for x in feats_270]).cpu()
                ori_file_p = Path(files[image_idx.item()])
                save_dir_xx = save_dir / ori_file_p.parent.name
                save_dir_xx.mkdir(parents=True, exist_ok=True)
                save_f_path = save_dir_xx / (ori_file_p.name.split('.')[0] +
                                             '.pth')
                torch.save(
                    {
                        'feats-000': x_feats_000,
                        'feats-090': x_feats_090,
                        'feats-180': x_feats_180,
                        'feats-270': x_feats_270
                    }, save_f_path)
                imagepath2featpath[files[image_idx.item()]] = str(save_f_path)
                assert str(
                    save_f_path
                ) not in avoid_duplicate, 'invalid path : {:}'.format(
                    save_f_path)
                avoid_duplicate.add(str(save_f_path))
        need_time = convert_secs2time(
            batch_time.val * (len(simple_loader) - idx), True)
        print('{:} : {:5d} / {:5d} : {:} : {:}'.format(time_string(), idx,
                                                       len(simple_loader),
                                                       need_time, save_f_path))
        batch_time.update(time.time() - xend)
        xend = time.time()
    xdata['image2feat'] = imagepath2featpath
    torch.save(xdata, data_root)
    print('Update all-info in {:}, file size : {:.2f} GB'.format(
        data_root,
        os.path.getsize(data_root) / 1e9))
Exemplo n.º 5
0
def train_model(loader, semantics, adj_distances, network, optimizer, config, logger):
  
  batch_time, Xlosses, CLSFlosses, Rlosses, accs, end = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), time.time()
  labelMeter, eps = AverageMeter(), 1e-7
  network.train()

  loader.dataset.set_return_label_mode('new')
  #loader.dataset.set_return_img_mode('original')
  loader.dataset.set_return_img_mode('original_augment')

  logger.print('[TRAIN---{:}], semantics-shape={:}, adj_distances-shape={:}, config={:}'.format(config.epoch_str, semantics.shape, adj_distances.shape, config))

  for batch_idx, (img_feat, targets) in enumerate(loader):
    
    batch_label_set = set(targets.tolist())
    batch_label_lst = list(batch_label_set)
    class_num       = len(batch_label_lst)
    batch           = targets.shape[0] // 2 # assume train and val has the same amount 
    batch_attrs     = semantics[batch_label_lst, :]
    batch_adj_dis   = adj_distances[batch_label_lst,:][:,batch_label_lst]

    train_targets, targets = torch.split(targets, (batch, batch))

    img_feat_train, img_feat_val = torch.split(img_feat, (batch, batch), dim=0)
    img_proto_list = []
    for lab in batch_label_lst:
      proto = img_feat_train[train_targets == lab].mean(dim=0)
      img_proto_list.append(proto)
    img_proto = torch.stack(img_proto_list)

    relations  = network(img_feat_val.cuda(), img_proto.cuda(), batch_attrs.cuda(), batch_adj_dis.cuda())
    raw_att_att, att_att = network.relation_module.get_attention_attribute()
    raw_att_img, att_img = network.relation_module.get_attention_img()
    if config.consistency_type == 'mse':
      consistency_loss = F.mse_loss(raw_att_att, raw_att_img) 
    elif config.consistency_type == 'kla2i':
      consistency_loss = F.kl_div((att_att + eps).log(), att_img + eps, reduction='batchmean')
    elif config.consistency_type == 'kli2a':
      consistency_loss = F.kl_div((att_img + eps).log(), att_att + eps, reduction='batchmean')
    else:
      raise ValueError('Unknown consistency type: {:}'.format(config.consistency_type))

    new_target_idxs = [batch_label_lst.index(x) for x in targets.tolist()]
    new_target_idxs = torch.LongTensor(new_target_idxs)
    one_hot_labels  = torch.zeros(batch, class_num).scatter_(1, new_target_idxs.view(-1,1), 1).cuda()
    target__labels  = new_target_idxs.cuda()

    if config.loss_type == 'sigmoid-mse':
      prediction = torch.sigmoid(relations)
      cls_loss = F.mse_loss(prediction, one_hot_labels, reduction='mean')
    elif re.match('softmax-*-*', config.loss_type, re.I):
      _, tempreture, epsilon = config.loss_type.split('-')
      tempreture, epsilon = float(tempreture), float(epsilon)
      if epsilon <= 0:
        cls_loss = F.cross_entropy(relations / tempreture, target__labels, weight=None, reduction='mean')
      else:
        log_probs = F.log_softmax(relations / tempreture, dim=1)
        _targets  = torch.zeros_like(log_probs).scatter_(1, target__labels.unsqueeze(1), 1)
        _targets  = (1-epsilon) * _targets + epsilon / relations.size(1)
        cls_loss  = (-_targets * log_probs).sum(dim=1).mean()
    elif config.loss_type == 'softmax':
      cls_loss = F.cross_entropy(relations, target__labels, weight=None, reduction='mean')
    elif config.loss_type == 'mse':
      cls_loss = F.mse_loss(torch.sigmoid(relations), one_hot_labels, reduction='mean')
    elif config.loss_type == 'none':
      positive = -torch.masked_select(relations, one_hot_labels == 1)
      negative = torch.masked_select(relations, one_hot_labels == 0)
      losses   = torch.cat([positive, negative])
      cls_loss = losses.mean()
    else:
      raise ValueError('invalid loss type : {:}'.format(config.loss_type))

    if config.consistency_coef > 0:
      loss = cls_loss + config.consistency_coef * consistency_loss
    else:
      loss = cls_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  
    # analysis
    Xlosses.update(loss.item(), batch)
    CLSFlosses.update(cls_loss.item(), batch)
    Rlosses.update(consistency_loss.item(), batch)
    predict_labels = torch.argmax(relations, dim=1)
    with torch.no_grad():
      accuracy = (predict_labels.cpu() == new_target_idxs.cpu()).float().mean().item()
      accs.update(accuracy*100, batch)
      labelMeter.update(class_num, 1)

    # measure elapsed time
    batch_time.update(time.time() - end)
    end = time.time()
  
    if batch_idx % config.log_interval == 0 or batch_idx + 1 == len(loader):
      Tstring = 'TIME[{batch_time.val:.2f} ({batch_time.avg:.2f})]'.format(batch_time=batch_time)
      Sstring = '{:} [{:}] [{:03d}/{:03d}]'.format(time_string(), config.epoch_str, batch_idx, len(loader))
      Astring = 'loss={:.7f} ({:.5f}), cls_loss={:.7f} ({:.5f}), consistency_loss={:.7f} ({:.5f}), acc@1={:.1f} ({:.1f})'.format(Xlosses.val, Xlosses.avg, CLSFlosses.val, CLSFlosses.avg, Rlosses.val, Rlosses.avg, accs.val, accs.avg)
      logger.print('{:} {:} {:} B={:}, L={:} ({:.1f}) : {:}'.format(Sstring, Tstring, Astring, batch, class_num, labelMeter.avg, batch_label_lst[:3]))
  return Xlosses.avg, accs.avg