Пример #1
0
def cluster_subheads_eval(args,
                          net,
                          mapping_assignment_dataloader,
                          mapping_test_dataloader,
                          get_data_fn=_clustering_get_data):
    """
    Used by both clustering and segmentation.
    Returns metrics for test set.
    Get result from average accuracy of all sub_heads (mean and std).
    All matches are made from training data.
    Best head metric, which is order selective unlike mean/std, is taken from
    best head determined by training data (but metric computed on test data).

    ^ detail only matters for IID+/semisup where there's a train/test split.

    Option to choose best sub_head either based on loss (set use_head in main
    script), or eval. Former does not use labels for the selection at all and this
    has negligible impact on accuracy metric for our models.
    """

    all_matches, train_accs = _get_assignment_data_matches(
        net, mapping_assignment_dataloader, args, get_data_fn=get_data_fn)

    flat_predss_all, flat_targets_all, = \
        get_data_fn(args, net, mapping_test_dataloader)

    num_samples = flat_targets_all.shape[0]
    reordered_preds = torch.zeros(num_samples,
                                  dtype=flat_predss_all.dtype).cuda(
                                      int(args.gpu))
    for pred_i, target_i in all_matches:
        reordered_preds[flat_predss_all == pred_i] = torch.tensor(
            target_i).cuda(int(args.gpu))
    test_acc, conf_mat = _acc(reordered_preds,
                              flat_targets_all,
                              args.output_k,
                              verbose=0)

    return {
        "test_accs": test_acc,
        "best": test_acc,
        "worst": test_acc,
        "train_accs": list(train_accs),
        "conf_mat": conf_mat
    }
Пример #2
0
def _get_assignment_data_matches(net, mapping_assignment_dataloader, args,
                                 get_data_fn=None,
                                 just_matches=False,
                                 verbose=0):
    """
    Get all best matches per head based on train set i.e. mapping_assign,
    and mapping_assign accs.
    """

    if verbose:
        print("calling cluster eval direct (helper) %s" % datetime.now())
        sys.stdout.flush()

    flat_predss_all, flat_targets_all = \
        get_data_fn(args, net, mapping_assignment_dataloader)

    if verbose:
        print("getting data fn has completed %s" % datetime.now())
        print("flat_targets_all %s, flat_predss_all[0] %s" %
              (list(flat_targets_all.shape), list(flat_predss_all.shape)))
        sys.stdout.flush()

    num_test = flat_targets_all.shape[0]
    if verbose == 2:
        print("num_test: %d" % num_test)
        for c in range(args.output_k):
            print("output_k: %d count: %d" % (c, (flat_targets_all == c).sum()))

    assert (flat_predss_all.shape == flat_targets_all.shape)
    num_samples = flat_targets_all.shape[0]

    if verbose:
        print("starting head %d with eval mode hung, %s" % (0, datetime.now()))
        sys.stdout.flush()

    match = _hungarian_match(flat_predss_all, flat_targets_all,
                             preds_k=args.output_k,
                             targets_k=args.output_k)
    if verbose:
        print("got match %s" % (datetime.now()))
        sys.stdout.flush()

    all_matches = match
    all_accs = []

    if not just_matches:
        # reorder predictions to be same cluster assignments as output_k
        found = torch.zeros(args.output_k)
        reordered_preds = torch.zeros(num_samples,
                                      dtype=flat_predss_all.dtype).cuda(int(args.gpu))

        for pred_i, target_i in match:
            # reordered_preds[flat_predss_all[i] == pred_i] = target_i
            reordered_preds[torch.eq(flat_predss_all, int(pred_i))] = torch.from_numpy(
                np.array(target_i)).cuda(int(args.gpu)).int().item()
            found[pred_i] = 1
            if verbose == 2:
                print((pred_i, target_i))
        assert (found.sum() == args.output_k)  # each output_k must get mapped

        if verbose:
            print("reordered %s" % (datetime.now()))
            sys.stdout.flush()

        acc, _ = _acc(reordered_preds, flat_targets_all, args.output_k, verbose)
        all_accs.append(acc)

    if just_matches:
        return all_matches
    else:
        return all_matches, all_accs