Example #1
0
 def test_make_weights_for_balanced_classes(self):
     dataset = [('A', 0), ('B', 1), ('C', 0), ('D', 2), ('E', 3), ('F', 0)]
     result = misc.make_weights_for_balanced_classes(dataset)
     self.assertEqual(result.sum(), 1)
     self.assertEqual(result[0], result[2])
     self.assertEqual(result[1], result[3])
     self.assertEqual(3 * result[0], result[1])
Example #2
0
    out_splits = []
    uda_splits = []
    for env_i, env in enumerate(dataset):
        uda = []

        out, in_ = misc.split_dataset(env,
                                      int(len(env) * args.holdout_fraction),
                                      misc.seed_hash(args.trial_seed, env_i))

        if env_i in args.test_envs:
            uda, in_ = misc.split_dataset(
                in_, int(len(in_) * args.uda_holdout_fraction),
                misc.seed_hash(args.trial_seed, env_i))

        if hparams['class_balanced']:
            in_weights = misc.make_weights_for_balanced_classes(in_)
            out_weights = misc.make_weights_for_balanced_classes(out)
            if uda is not None:
                uda_weights = misc.make_weights_for_balanced_classes(uda)
        else:
            in_weights, out_weights, uda_weights = None, None, None
        in_splits.append((in_, in_weights))
        out_splits.append((out, out_weights))
        if len(uda):
            uda_splits.append((uda, uda_weights))

    train_loaders = [
        InfiniteDataLoader(dataset=env,
                           weights=env_weights,
                           batch_size=hparams['batch_size'],
                           num_workers=dataset.N_WORKERS)