Beispiel #1
0
def run_filter(filter_net, box_filter_net , tree, boxes, w2i, CNN, precision_f1, instance_idx, gold):
  box_feats,spat_feats = get_box_feats(boxes[instance_idx], CNN)
  box_rep = torch.cat((box_feats,spat_feats),1)

  prediction = filter_net([0]*5, box_rep, tree)
  pred_np    = prediction.cpu().data.numpy()

  if precision_f1 >= 1:
    precision_f1 = int(precision_f1)
    f1_pred     = np.argsort(-pred_np)[0][:min(precision_f1, pred_np.shape[1])]
  else:
    f1_pred = []
    for kk,prediction in enumerate(pred_np[0]):
      if np.exp(prediction) > precision_f1:
        f1_pred.append(kk)
    f1_pred = np.array(f1_pred)

  bf_pred        = []
  for box_index in f1_pred:
    box_feats,spat_feats = get_box_feats([boxes[instance_idx][box_index]], CNN)
    box_rep = torch.cat((box_feats,spat_feats),1) ### assumes box_usage = 0 fir box_filter_net
    prediction = box_filter_net([w2i.get(node.label,0) for node in tree.leaves()], box_rep, tree).data
    bf_pred.append(prediction[0][0])

  return bf_pred, f1_pred
def run_filter(box_filter_net, tree, boxes, w2i, CNN, gold, instance_idx,
               precision_f2):

    bf_pred = []
    n_box = len(boxes[instance_idx])
    for box_index in xrange(n_box):
        box_feats, spat_feats = get_box_feats([boxes[instance_idx][box_index]],
                                              CNN)
        box_rep = torch.cat((box_feats, spat_feats),
                            1)  ### assumes box_usage = 0 for box_filter_net
        prediction = box_filter_net(
            [w2i.get(node.label, 0) for node in tree.leaves()], box_rep,
            tree).data
        bf_pred.append(prediction[0][0])

    f1_pred = np.array(range(n_box))
    if precision_f2 >= 1:
        precision_f2 = int(precision_f2)
        final_prediction = f1_pred[np.argsort(-np.array(bf_pred))
                                   [:min(precision_f2, len(bf_pred))]]
    else:
        final_prediction = []
        for kk, pred in enumerate(bf_pred):
            if pred > precision_f2:
                final_prediction.append(f1_pred[kk])
        final_prediction = np.array(final_prediction)
    return final_prediction, len(
        set(final_prediction).intersection(set(gold[instance_idx])))
def run_filter(filter_net, tree, boxes, CNN, gold, instance_idx, precision_f1,
               annid2catid):
    box_feats, spat_feats = get_box_feats(boxes[instance_idx], CNN)
    box_rep = torch.cat((box_feats, spat_feats), 1)

    prediction = filter_net([0] * 5, box_rep, tree)
    pred_np = prediction.cpu().data.numpy()

    if precision_f1 >= 1:
        precision_f1 = int(precision_f1)
        f1_pred = np.argsort(-pred_np)[0][:min(precision_f1, pred_np.shape[1])]
    else:
        f1_pred = []
        for kk, prediction in enumerate(pred_np[0]):
            if np.exp(prediction) > precision_f1:
                f1_pred.append(kk)
        f1_pred = np.array(f1_pred)

    gold_cats = set()
    for gold_box in gold[instance_idx]:
        ann_id = str(CNN['meta'][boxes[instance_idx][gold_box]])
        gold_cats.add(annid2catid[ann_id])

    cat_ids = [
        annid2catid[str(ann_id)]
        for ann_id in CNN['meta'][boxes[instance_idx]][f1_pred]
    ]
    filtered_cat_ids = [
        idx for idx, cat_id in enumerate(cat_ids) if cat_id in gold_cats
    ]
    final_pred = f1_pred[filtered_cat_ids]  ### filter by obj category

    return final_pred, len(
        set(final_pred).intersection(set(gold[instance_idx])))
Beispiel #4
0
def evaluate(net, split, CNN, config, experiment_log, box_usage = 0, verbose = False, tst_json = [], out_file = ''):

  box_usage = config['box_usage']
  model     = config['model']

  eval_start = time.time()
  n = correct = 0.0
  trees, boxes, iou, gold = split

  indexes = range(len(trees))
  if verbose:
    pbar = tqdm(indexes)
  else:
    pbar = indexes

  preds = []
  net.eval()
  for j in pbar:
    tree = trees[j]
    box_feats,spat_feats = get_box_feats(boxes[j], CNN)

    if box_usage == 0:
      box_rep = torch.cat((box_feats,spat_feats),1)
    elif box_usage == 1:
      box_rep = box_feats
    elif box_usage == 2:
      box_rep = spat_feats
    else:
      raise NotImplementedError()
    prediction = net([w2i.get(node.label,0) for node in tree.leaves()], box_rep, tree).data
    _,pred = torch.max(prediction,1)

    hit        = (1.0 if pred[0][0] in set(gold[j]) else 0.0)
    correct   += hit
    n += 1
    preds.append(int(pred[0][0]))
  eval_time = time.time() - eval_start

  log = []
  log.append("Total correct {}/{}".format(correct,n))
  log.append("="*20)

  for line in log:
    print(line)
    experiment_log.write("\n".join(line))

  if tst_json != []:
    for ii,inst in enumerate(tst_json):
      tst_json[ii]['predicted_bounding_boxes'] = [tst_json[ii]['box_names'][preds[ii]]]
    json.dump(tst_json,open(out_file,'w'))
  return correct/n , len(trees)/eval_time
Beispiel #5
0
def run_filter(filter_net, tree, boxes, CNN, gold, instance_idx, precision_f1):
    box_feats, spat_feats = get_box_feats(boxes[instance_idx], CNN)
    box_rep = torch.cat((box_feats, spat_feats), 1)

    prediction = filter_net([0] * 5, box_rep, tree)
    pred_np = prediction.cpu().data.numpy()

    if precision_f1 >= 1:
        precision_f1 = int(precision_f1)
        f1_pred = np.argsort(-pred_np)[0][:min(precision_f1, pred_np.shape[1])]
    else:
        f1_pred = []
        for kk, prediction in enumerate(pred_np[0]):
            if np.exp(prediction) > precision_f1:
                f1_pred.append(kk)
        f1_pred = np.array(f1_pred)

    return f1_pred, len(set(f1_pred).intersection(set(gold[instance_idx])))
def evaluate(net, split, CNN, config, experiment_log, box_usage = 0, verbose = False, tst_json = [], out_file = '', precision_k = 10, no_lang = False):

  box_usage = config['box_usage']
  model     = config['model']

  eval_start = time.time()
  n = correct = 0.0
  trees, boxes, iou, gold = split

  indexes = range(len(trees))
  if verbose:
    pbar = tqdm(indexes)
  else:
    pbar = indexes

  precision_count = [0]*precision_k
  precision_hit = [0]*precision_k

  preds = []
  net.eval()

  stats = { 'hit' : defaultdict(int), 'cnt' : defaultdict(int) }
  all_supporting = []
  for j in pbar:
    tree = trees[j]
    box_feats,spat_feats = get_box_feats(boxes[j], CNN)

    if box_usage == 0:
      box_rep = torch.cat((box_feats,spat_feats),1)
    elif box_usage == 1:
      box_rep = box_feats
    elif box_usage == 2:
      box_rep = spat_feats
    else:
      raise NotImplementedError()

    prediction = net([0]*5, box_rep, tree)
    supporting = []
    _,pred = torch.max(prediction.data,1)

    pred_np  = prediction.cpu().data.numpy()
    for idx,k in enumerate(np.argsort(-pred_np)[0][:min(precision_k, pred_np.shape[1])]):
      p_hit = (1.0 if k in set(gold[j]) else 0.0)

      if p_hit == 1.0:
        for jj in xrange(idx,min(precision_k, pred_np.shape[1])):
          precision_hit[jj]   += p_hit
          precision_count[jj] += 1.0
        break
      precision_hit[idx]   += p_hit
      precision_count[idx] += 1.0
    hit        = (1.0 if pred[0][0] in set(gold[j]) else 0.0)
    correct   += hit
    n += 1
    if len(tst_json) > 0:
      preds.append(np.array(tst_json[j]['box_names'])[np.argsort(-pred_np)[0]])
    all_supporting.append(supporting)
  eval_time = time.time() - eval_start

  print "_"*10
  for k in xrange(precision_k):
    print "precision@{} for {} instances: {:5.3f} and overall {:5.3f}".format(k+1,precision_count[k], precision_hit[k]/precision_count[k], precision_hit[k]/len(trees))
  if tst_json != []:
    for ii,inst in enumerate(tst_json):
      tst_json[ii]['predicted_bounding_boxes'] = [list(p) for p in preds[ii]]
      tst_json[ii]['context_box'] = all_supporting[ii]
    json.dump(tst_json,open(out_file,'w'))
  return correct/n , len(trees)/eval_time
Beispiel #7
0
def evaluate(net, split, CNN, config, experiment_log, box_usage = 0, verbose = False, tst_json = [], out_file = ''):

  box_usage = config['box_usage']
  model     = config['model']

  eval_start = time.time()
  n = correct = 0.0
  trees, boxes, iou, gold = split

  indexes = range(len(trees))
  if verbose:
    pbar = tqdm(indexes)
  else:
    pbar = indexes

  preds = []
  net.eval()
  stats = { 'hit' : defaultdict(int), 'cnt' : defaultdict(int) }
  for j in pbar:
    tree = trees[j]
    box_feats,spat_feats = get_box_feats(boxes[j], CNN)

    if box_usage == 0:
      box_rep = torch.cat((box_feats,spat_feats),1)
    elif box_usage == 1:
      box_rep = box_feats
    elif box_usage == 2:
      box_rep = spat_feats
    else:
      raise NotImplementedError()
    if config['model'] in set(["groundnet", "groundnetflexall", "groundnetflexrel","treernn"]):
      prediction = net(tree, box_rep, tree).data
    else:
      prediction = net([w2i.get(node.label,0) for node in tree.leaves()], box_rep, tree).data
    _,pred = torch.max(prediction,1)

    hit        = (1.0 if pred[0][0] in set(gold[j]) else 0.0)
    correct   += hit
    complexity = len(tree.nonterms())
    stats['hit'][complexity] += hit
    stats['cnt'][complexity] += 1
    n += 1
    preds.append(int(pred[0][0]))
  eval_time = time.time() - eval_start

  log = []
  log.append("")
  log.append("="*20)
  log.append("Stats")
  for complexity in stats['hit'].keys():
    log_line = "Complexity {:3d} acc {:5.3f} %{:5.3f}".format(complexity,
                                                         stats['hit'][complexity]*1.0 / stats['cnt'][complexity],
                                                         stats['cnt'][complexity]*1.0 / n)
    log.append(log_line)
  log_line = "Total correct {}/{}".format(correct,n)
  log.append(log_line)
  log.append("="*20)

  for line in log:
    print(line)
    experiment_log.write(line)

  if tst_json != []:
    for ii,inst in enumerate(tst_json):
      tst_json[ii]['predicted_bounding_boxes'] = [tst_json[ii]['box_names'][preds[ii]]]
    json.dump(tst_json,open(out_file,'w'))
  return correct/n , len(trees)/eval_time
Beispiel #8
0
  closs = 0.0
  cinst = 0
  correct = 0.0
  trn_start = time.time()

  if args.verbose and not args.debug_mode:
    pbar = tqdm(indexes, desc='trn_loss')
  else:
    pbar = indexes

  done = 1
  for ii in pbar:
    tree = Xtrn_tree[ii]

    # get cnn feat for boxes
    box_feats, spat_feats = get_box_feats(Xtrn_box[ii], CNN)

    if config['box_usage'] == 0:
      box_rep = torch.cat((box_feats,spat_feats),1)
    elif config['box_usage'] == 1:
      box_rep = box_feats
    elif config['box_usage'] == 2:
      box_rep = spat_feats
    else:
      raise NotImplementedError()

    if args.debug_mode:
      raise NotImplementedError()
    else:
      if config['model'] in set(["groundnet", "groundnetflexall", "groundnetflexrel","treernn"]):
        prediction = net(tree, box_rep, tree)
def evaluate(box_filter_net,
             net,
             split,
             CNN,
             config,
             precision_f2,
             experiment_log,
             box_usage=0,
             verbose=False,
             tst_json=[],
             out_file='',
             no_lang=False):

    box_usage = config['box_usage']
    model = config['model']

    eval_start = time.time()
    n = correct = 0.0
    trees, boxes, iou, gold = split

    indexes = range(len(trees))
    if verbose:
        pbar = tqdm(indexes)
    else:
        pbar = indexes

    preds = []
    net.eval()

    all_supporting = []

    for j in pbar:
        tree = trees[j]
        filter_pred, gold_predicted = run_filter(box_filter_net, tree, boxes,
                                                 w2i, CNN, gold, j,
                                                 precision_f2)
        if gold_predicted == 0:
            n += 1
            preds.append(None)
            all_supporting.append(None)
            continue

        gold_instance = []
        for g in gold[j]:
            if g in set(filter_pred):
                gold_instance.append((filter_pred == g).nonzero()[0][0])

        box_feats, spat_feats = get_box_feats(
            list(np.array(boxes[j])[filter_pred]), CNN)

        if box_usage == 0:
            box_rep = torch.cat((box_feats, spat_feats), 1)
        elif box_usage == 1:
            box_rep = box_feats
        elif box_usage == 2:
            box_rep = spat_feats
        else:
            raise NotImplementedError()
        prediction = net([w2i.get(node.label, 0) for node in tree.leaves()],
                         box_rep, tree)
        supporting = []
        _, pred = torch.max(prediction.data, 1)
        hit = (1.0 if pred[0][0] in set(gold_instance) else 0.0)
        correct += hit
        n += 1
        preds.append(filter_pred[int(pred[0][0])])
        all_supporting.append(supporting)
    eval_time = time.time() - eval_start

    if tst_json != []:
        for ii, inst in enumerate(tst_json):
            if preds[ii] == None:
                tst_json[ii]['predicted_bounding_boxes'] = [[-1, -1, -1, -1]]
                tst_json[ii]['context_box'] = [-1]
                continue
            tst_json[ii]['predicted_bounding_boxes'] = [
                tst_json[ii]['box_names'][preds[ii]]
            ]
            tst_json[ii]['context_box'] = all_supporting[ii]
        json.dump(tst_json, open(out_file, 'w'))

    return correct / n, len(trees) / eval_time
    done = 1
    for ii in pbar:
        tree = Xtrn_tree[ii]
        filter_pred, gold_predicted = run_filter(box_filter_net, tree,
                                                 Xtrn_box, w2i, CNN, Ytrn, ii,
                                                 args.precision_f2)
        if gold_predicted == 0:
            cinst += 1
            continue

        gold_instance = []
        for g in Ytrn[ii]:
            if g in set(filter_pred):
                gold_instance.append((filter_pred == g).nonzero()[0][0])

        box_feats, spat_feats = get_box_feats(
            list(np.array(Xtrn_box[ii])[filter_pred]), CNN)

        if config['box_usage'] == 0:
            box_rep = torch.cat((box_feats, spat_feats), 1)
        elif config['box_usage'] == 1:
            box_rep = box_feats
        elif config['box_usage'] == 2:
            box_rep = spat_feats
        else:
            raise NotImplementedError()

        if args.debug_mode:
            raise NotImplementedError()
        else:
            prediction = net([w2i.get(n.label, 0) for n in tree.leaves()],
                             box_rep, tree)
Beispiel #11
0
def getCategories(category,
                  trees,
                  boxes,
                  gold_labels,
                  CNN,
                  prediction,
                  cut_off={
                      'loc': 5,
                      'boxes': 20,
                      'words': 18,
                      'depth': 3
                  },
                  annid2catid={},
                  annid2imgid={},
                  imgid2catid2count={},
                  vocabulary={},
                  imgid2catid2bsize={},
                  annid2bsize={},
                  n_supporting=0):
    categories = {
        'loc': defaultdict(list),
        'boxes': defaultdict(list),
        'words': defaultdict(list),
        'depth': defaultdict(list),
        'obj_cat': defaultdict(list),
        'obj_box_count': defaultdict(list),
        'oov': defaultdict(list),
        'box_size': defaultdict(list),
        'box_order': defaultdict(list),
        'obj_box_order': defaultdict(list),
        'box_distance': defaultdict(list),
        'n_sup': defaultdict(list),
        'sup_acc': defaultdict(list),
        'missing': defaultdict(list),
        'box2d': defaultdict(list),
        'missing_all': defaultdict(list)
    }

    if category == 'loc':
        for i, tree in enumerate(trees):
            n_loc = Counter([n.label for n in tree.nonterms()])['loc']
            categories[category][min(n_loc, cut_off[category])].append(i)

    elif category == 'box_order':
        for i, box in enumerate(boxes):
            box_feats, spat_feats = get_box_feats(boxes[i],
                                                  CNN,
                                                  convert_spat=False,
                                                  convert_box=False)
            for j in gold_labels[i][:1]:  ###
                bracket = min(
                    list(np.argsort(spat_feats[:, -1])[::-1]).index(j), 9)
                categories[category][bracket].append(i)

    elif category == 'box_size':
        for i, box in enumerate(boxes):
            box_feats, spat_feats = get_box_feats(boxes[i],
                                                  CNN,
                                                  convert_spat=False,
                                                  convert_box=False)
            for j in gold_labels[i][:1]:  ###
                bracket = int(min(99, spat_feats[j, -1] * 100)) / 40
                categories[category][bracket].append(i)

    elif category == 'boxes':
        for i, box in enumerate(boxes):
            n_box = len(box)
            categories[category][min(n_box, cut_off[category])].append(i)

    elif category == 'words':
        for i, tree in enumerate(trees):
            n_words = len(tree.original_text.split(" "))
            categories[category][min(n_words, cut_off[category])].append(i)

    elif category == 'depth':
        for i, tree in enumerate(trees):
            depth = Counter(tree.getRaw().split(" ")[-1])[')']
            categories[category][min(depth, cut_off[category])].append(i)

    elif category == 'obj_cat':
        for i, pred in enumerate(prediction):
            obj_cat = annid2catid[str(pred['annotation_id'])]
            categories[category][obj_cat].append(i)

    elif category == 'obj_box_count':
        for i, pred in enumerate(prediction):
            imgid = annid2imgid[str(pred['annotation_id'])]
            catid = annid2catid[str(pred['annotation_id'])]
            count = imgid2catid2count[imgid][catid]
            categories[category][count].append(i)

    elif category == 'obj_box_order':
        for i, pred in enumerate(prediction):
            imgid = annid2imgid[str(pred['annotation_id'])]
            catid = annid2catid[str(pred['annotation_id'])]
            count = imgid2catid2count[imgid][catid]

            idx = imgid2catid2bsize[imgid][catid].index(annid2bsize[str(
                pred['annotation_id'])])
            bracket = list(np.argsort(
                imgid2catid2bsize[imgid][catid])[::-1]).index(idx)
            categories[category][bracket].append(i)

    elif category == 'box_distance':
        for i, box in enumerate(boxes):
            box_feats, spat_feats = get_box_feats(boxes[i],
                                                  CNN,
                                                  convert_spat=False,
                                                  convert_box=False)
            distances = ((spat_feats[:, 2] + spat_feats[:, 0]) / 2.0)**2 + (
                (spat_feats[:, 3] + spat_feats[:, 1]) / 2.0)**2
            for j in gold_labels[i][:1]:  ###
                bracket = int(min(15000, distances[j] * 100)) / 33
                categories[category][bracket].append(i)

    elif category == 'box2d':
        for i, box in enumerate(boxes):
            box_feats, spat_feats = get_box_feats(boxes[i],
                                                  CNN,
                                                  convert_spat=False,
                                                  convert_box=False)

            idx = gold_labels[i][0]
            density = 25
            x1 = int(min(spat_feats[idx, 0] * 100 + 100, 199) / 2) / density
            x2 = int(min(spat_feats[idx, 2] * 100 + 100, 199) / 2) / density
            y1 = int(min(spat_feats[idx, 1] * 100 + 100, 199) / 2) / density
            y2 = int(min(spat_feats[idx, 3] * 100 + 100, 199) / 2) / density

            for xx in xrange(x1, x2 + 1):
                for yy in xrange(y1, y2 + 1):
                    categories[category]['x' + str(xx) + 'y' +
                                         str(yy)].append(i)

    elif category == 'oov':
        for i, tree in enumerate(trees):
            oov = 0
            for n in tree.leaves():
                if n.label not in vocabulary:
                    oov += 1
            categories[category][oov].append(i)

    elif category == 'n_sup':
        n_total = 1.0
        n_same = 0.0

        n_total_all = 1.0
        n_same_all = 1.0

        for i in xrange(n_supporting):
            n_sup = supportingCheck(i,
                                    prediction[i]['context_box'],
                                    Ytst[i],
                                    mode=category)
            categories[category][n_sup].append(i)

            pred_box = "_".join(
                [str(v) for v in prediction[i]['predicted_bounding_boxes'][0]])
            boxes = [
                "_".join([str(v) for v in box])
                for box in prediction[i]['box_names']
            ]

            target_box = boxes.index(pred_box)

            if n_sup > 0:
                if len(prediction[i]['context_box']
                       ) > 0 and prediction[i]['context_box'][0] == target_box:
                    n_same += 1
                n_total += 1
                n_total_all += 1

        for key in categories[category]:
            print key, len(categories[category][key])

        print "=" * 20
        print "total {} the same {} percentage {}".format(
            n_total, n_same, n_same / n_total)
        print "all {} the same {} percentage {}".format(
            n_total_all, n_same_all, n_same_all / n_total_all)
        print "=" * 20

    elif category == 'sup_acc':
        for i in xrange(n_supporting):
            sup = supportingCheck(i,
                                  prediction[i]['context_box'],
                                  Ytst[i],
                                  mode=category)
            categories[category][sup].append(i)
        print "# of instances with no supporting objects:", len(
            categories[category][0])
        total_sup = (len(categories[category][1]) + len(
            categories[category][2]) + len(categories[category][3])) * 1.0
        print "# of instances with supporting objects:", total_sup
        print "# of instances all supporting objects are correct:", len(
            categories[category][1]) / total_sup
        print "# of instances at least one supporting objects is correct:", (
            len(categories[category][1]) +
            len(categories[category][2])) / total_sup

    elif category == 'missing':
        for i in xrange(n_supporting):
            missing = supportingCheck(i,
                                      prediction[i]['context_box'],
                                      Ytst[i],
                                      mode=category)
            categories[category][missing > 0].append(i)
        for key in categories[category]:
            print key, len(categories[category][key])

    elif category == 'missing_all':
        for i in xrange(n_supporting):
            missing_all = supportingCheck(i,
                                          prediction[i]['context_box'],
                                          Ytst[i],
                                          mode=category)
            categories[category][missing_all].append(i)
        for key in categories[category]:
            print key, len(categories[category][key])

    else:
        raise NotImplementedError()

    return categories
def evaluate(net,
             split,
             CNN,
             config,
             experiment_log,
             box_usage=0,
             verbose=False):

    box_usage = config['box_usage']
    model = config['model']

    eval_start = time.time()
    n = correct = 0.0
    trees, boxes, iou, gold = split

    indexes = range(len(trees))
    if verbose:
        pbar = tqdm(indexes)
    else:
        pbar = indexes

    preds = []
    net.eval()
    predict_label = {0: [], 1: []}
    for j in pbar:
        tree = trees[j]
        box_feats, spat_feats = get_box_feats(boxes[j], CNN)

        if box_usage == 0:
            box_rep = torch.cat((box_feats, spat_feats), 1)
        elif box_usage == 1:
            box_rep = box_feats
        elif box_usage == 2:
            box_rep = spat_feats
        else:
            raise NotImplementedError()

        prediction = net([w2i.get(node.label, 0) for node in tree.leaves()],
                         box_rep, tree).data
        pred = int(np.round(prediction[0][0]))

        hit = (1.0 if pred == gold[j] else 0.0)
        correct += hit
        n += 1
        predict_label[pred] += [j]
        preds.append(pred)
    eval_time = time.time() - eval_start

    log = []
    log.append("")
    log.append("=" * 20)
    log_line = "Total correct {}/{}".format(correct, n)
    log_line = "0 : {}, 1 : {}".format(len(predict_label[0]),
                                       len(predict_label[1]))
    log.append(log_line)
    log.append("=" * 20)

    for line in log:
        print(line)
        experiment_log.write(line)

    return correct / n, len(trees) / eval_time
def evaluate(filter_net,
             net,
             split,
             CNN,
             config,
             precision_f1,
             verbose=False,
             tst_json=[],
             out_file='',
             annid2catid={}):

    box_usage = config['box_usage']
    model = config['model']
    perturbation = config['perturbation']

    eval_start = time.time()
    n = correct = 0.0
    trees, boxes, iou, gold = split

    indexes = range(len(trees))
    if verbose:
        pbar = tqdm(indexes)
    else:
        pbar = indexes

    preds = []
    net.eval()

    all_supporting = []

    for j in pbar:
        tree = trees[j]
        filter_pred, gold_predicted = run_filter(filter_net, tree, boxes, CNN,
                                                 gold, j, precision_f1,
                                                 annid2catid)

        if gold_predicted == 0:
            n += 1
            preds.append([[-1, -1, -1, -1]])
            all_supporting.append([])
            continue

        gold_instance = []
        for g in gold[j]:
            if g in set(filter_pred):
                gold_instance.append((filter_pred == g).nonzero()[0][0])

        box_feats, spat_feats = get_box_feats(
            list(np.array(boxes[j])[filter_pred]), CNN)

        if box_usage == 0:
            box_rep = torch.cat((box_feats, spat_feats), 1)
        elif box_usage == 1:
            box_rep = box_feats
        elif box_usage == 2:
            box_rep = spat_feats
        else:
            raise NotImplementedError()
        prediction = net([
            w2i.get(w, 0)
            for w in perturb([node.label
                              for node in tree.leaves()], perturbation)
        ], box_rep, tree)
        supporting = []
        _, pred = torch.max(prediction.data, 1)
        hit = (1.0 if pred[0][0] in set(gold_instance) else 0.0)
        correct += hit
        n += 1
        if len(tst_json) > 0:
            pred_np = prediction.cpu().data.numpy()
            preds.append(
                np.array(tst_json[j]['box_names'])[np.argsort(-pred_np)[0]])
        all_supporting.append(supporting)
    eval_time = time.time() - eval_start

    if tst_json != []:
        for ii, inst in enumerate(tst_json):
            tst_json[ii]['predicted_bounding_boxes'] = [
                list(p) for p in preds[ii]
            ]
            tst_json[ii]['context_box'] = all_supporting[ii]
        json.dump(tst_json, open(out_file, 'w'))

    return correct / n, len(trees) / eval_time
Beispiel #14
0
def evaluate(net,
             split,
             CNN,
             config,
             experiment_log,
             box_usage=0,
             verbose=False,
             tst_json=[],
             out_file='',
             precision_k=10):

    box_usage = config['box_usage']
    model = config['model']

    eval_start = time.time()
    n = correct = 0.0
    trees, boxes, iou, gold = split

    indexes = range(len(trees))
    if verbose:
        pbar = tqdm(indexes)
    else:
        pbar = indexes

    precision_count = [0] * precision_k
    precision_hit = [0] * precision_k

    preds = []
    net.eval()
    net.evaluate = True
    stats = {'hit': defaultdict(int), 'cnt': defaultdict(int)}
    all_supporting = []
    for j in pbar:
        tree = trees[j]
        box_feats, spat_feats = get_box_feats(boxes[j], CNN)

        if box_usage == 0:
            box_rep = torch.cat((box_feats, spat_feats), 1)
        elif box_usage == 1:
            box_rep = box_feats
        elif box_usage == 2:
            box_rep = spat_feats
        else:
            raise NotImplementedError()
        if "ground" in config['model'].lower():
            prediction = net(tree, box_rep, tree, decorate=True)
            supporting = []
            loc_count = 0
            for ii, node in enumerate(tree.nonterms()):
                _, phrase_pred = torch.max(node._expr.data, 1)
                del node._expr, node._attn
                if node.label == "loc":
                    loc_count += 1
                if node.label == "loc" and loc_count >= 2:
                    supporting += [phrase_pred[0][0]]
        else:
            prediction, sub, obj, rel = net(
                [w2i.get(node.label, 0) for node in tree.leaves()], box_rep,
                tree)
            try:
                _, obj_pred = torch.max(obj.data, 0)
                supporting = [obj_pred[0][0]]
            except:
                supporting = []

        _, pred = torch.max(prediction.data, 1)

        pred_np = prediction.cpu().data.numpy()
        for idx, k in enumerate(
                np.argsort(-pred_np)[0][:min(precision_k, pred_np.shape[1])]):
            p_hit = (1.0 if k in set(gold[j]) else 0.0)

            if p_hit == 1.0:
                for jj in xrange(idx, min(precision_k, pred_np.shape[1])):
                    precision_hit[jj] += p_hit
                    precision_count[jj] += 1.0
                break
            precision_hit[idx] += p_hit
            precision_count[idx] += 1.0

        hit = (1.0 if pred[0][0] in set(gold[j]) else 0.0)
        correct += hit
        n += 1
        preds.append(int(pred[0][0]))
        all_supporting.append(supporting)
    eval_time = time.time() - eval_start

    print "_" * 10
    for k in xrange(precision_k):
        print "precision@{} for {} instances: {}".format(
            k + 1, precision_count[k], precision_hit[k] / precision_count[k])
    if tst_json != []:
        for ii, inst in enumerate(tst_json):
            tst_json[ii]['predicted_bounding_boxes'] = [
                tst_json[ii]['box_names'][preds[ii]]
            ]
            tst_json[ii]['context_box'] = all_supporting[ii]
        json.dump(tst_json, open(out_file, 'w'))
    return correct / n, len(trees) / eval_time