def test_make_weights_for_balanced_classes(self): dataset = [('A', 0), ('B', 1), ('C', 0), ('D', 2), ('E', 3), ('F', 0)] result = misc.make_weights_for_balanced_classes(dataset) self.assertEqual(result.sum(), 1) self.assertEqual(result[0], result[2]) self.assertEqual(result[1], result[3]) self.assertEqual(3 * result[0], result[1])
out_splits = [] uda_splits = [] for env_i, env in enumerate(dataset): uda = [] out, in_ = misc.split_dataset(env, int(len(env) * args.holdout_fraction), misc.seed_hash(args.trial_seed, env_i)) if env_i in args.test_envs: uda, in_ = misc.split_dataset( in_, int(len(in_) * args.uda_holdout_fraction), misc.seed_hash(args.trial_seed, env_i)) if hparams['class_balanced']: in_weights = misc.make_weights_for_balanced_classes(in_) out_weights = misc.make_weights_for_balanced_classes(out) if uda is not None: uda_weights = misc.make_weights_for_balanced_classes(uda) else: in_weights, out_weights, uda_weights = None, None, None in_splits.append((in_, in_weights)) out_splits.append((out, out_weights)) if len(uda): uda_splits.append((uda, uda_weights)) train_loaders = [ InfiniteDataLoader(dataset=env, weights=env_weights, batch_size=hparams['batch_size'], num_workers=dataset.N_WORKERS)