def main(): ## Parse command line args args = parse_eval_args() from pprint import PrettyPrinter PrettyPrinter(indent=4).pprint(vars(args)) print() ## Use CUDA if available print('=> CUDA availability / use: "{}" / "{}"'.format( str(torch.cuda.is_available()), str(args.CUDA))) args.CUDA = args.CUDA and torch.cuda.is_available() device = torch.device('cuda' if ( torch.cuda.is_available() and args.CUDA) else 'cpu') ## Dataset + Dataloader ## NOTE Update DATASET here from arg_utils import construct_hyperpartisan_flair_dataset, \ construct_propaganda_flair_dataset eval_dataset, input_shape = construct_eval_dataset( construct_propaganda_flair_dataset, args) from datasets import CachedDataset ## NOTE Use Cached dataset ?? (useful for ensemble runs, but not with TensorDataset) eval_dataloader = torch.utils.data.DataLoader( # CachedDataset(eval_dataset), eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.dataloader_workers, pin_memory=args.CUDA) ## Construct Model ## NOTE Update MODEL here from nn_architectures import construct_cnn_bertha_von_suttner, \ construct_hierarch_att_net, \ construct_lstm, construct_AttnBiLSTM model_constructor = construct_AttnBiLSTM ## Load models from checkpoints models = list() for m_path in args.model_path: model = model_constructor(input_shape[-1]) load_checkpoint(m_path, model) models.append(model) ## Model Summary from torchsummary import summary print('\n ** Model Summary ** ') print(models[0], end='\n\n') ## Evaluate documents predictions = evaluate_ensemble(models, eval_dataloader, device=device) ## Write predictions to file # write_hyperpartisan_predictions(predictions, eval_dataset, args.output_dir) write_propaganda_predictions( predictions, eval_dataset if args.tensor_dataset is None else construct_base_propaganda_dataset(args.input_dir, None), ## PropagandaDataset must always be provided (even if a tensor-dataset is provided), to properly write predictions to output file args.output_dir)
def construct_model(model_constructor, input_shape, args): ## Construct model model = model_constructor(input_shape[-1], args) if torch.cuda.is_available() and args.CUDA: print('=> Moving model to CUDA') model.cuda() ## Construct optimizer from optimizers import RAdam, AdaBound # optimizer = RAdam(model.parameters()) ## NOTE Experimenting with RAdam and AdaBound # optimizer = AdaBound(model.parameters()) ## NOTE Experimenting with RAdam and AdaBound optimizer = torch.optim.Adam(model.parameters()) ## Construct scheduler scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.2, patience=args.reduce_lr_patience, verbose=True) ## Optionally, resume training from checkpoint if args.resume is not None and os.path.isfile(args.resume): print('\n=> ** Resuming training from checkpoint **') load_checkpoint(args.resume, model, optimizer, scheduler) ## Optionally, freeze specific layers/sub-modules if args.freeze is not None and len(args.freeze) > 0: print('=> Freezing layers/sub-modules with indices: {}'.format( args.freeze)) freeze_layers(model, args.freeze) else: print('=> All layers unfrozen for training') ## Model summary #1 from nn_utils import count_parameters print('Model has {} trainable parameters'.format(count_parameters(model))) print(model) ## Model Summary #2 # from torchsummary import summary # print('\nModel Summary:') # summary(model, input_shape, device='cuda' if args.CUDA else 'cpu') ## Loss criterion: Binary Cross Entropy loss_criterion = torch.nn.BCELoss() return model, optimizer, scheduler, loss_criterion
def test_checkpointing_model(self): key = jax.random.PRNGKey(42) model, _ = _make_deterministic_model() input_shape = (2, 224, 224, 3) key, subkey = jax.random.split(key) params = _init_model(subkey, model, input_shape=input_shape) checkpoint_path = self._save_temp_checkpoint(params) key, subkey = jax.random.split(key) new_params = _init_model(subkey, model, input_shape=input_shape) restored_params = checkpoint_utils.load_checkpoint( new_params, checkpoint_path) restored_leaves = jax.tree_util.tree_leaves(restored_params) leaves = jax.tree_util.tree_leaves(params) for arr, restored_arr in zip(leaves, restored_leaves): self.assertAllClose(arr, restored_arr) key, subkey = jax.random.split(key) inputs = jax.random.normal(subkey, input_shape, jnp.float32) _, out = model.apply({"params": params}, inputs, train=False) _, new_out = model.apply({"params": new_params}, inputs, train=False) _, restored_out = model.apply({"params": restored_params}, inputs, train=False) self.assertNotAllClose(out["pre_logits"], new_out["pre_logits"]) self.assertAllClose(out["pre_logits"], restored_out["pre_logits"])
def test_sngp_script(self, dataset_name, classifier, representation_size, correct_train_loss, correct_val_loss, correct_fewshot_acc_sum, simulate_failure): data_dir = self.data_dir config = test_utils.get_config(dataset_name=dataset_name, classifier=classifier, representation_size=representation_size, use_sngp=True, use_gp_layer=True) output_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) config.dataset_dir = data_dir num_examples = config.batch_size * config.total_steps if not simulate_failure: # Check for any errors. with tfds.testing.mock_data(num_examples=num_examples, data_dir=data_dir): train_loss, val_loss, fewshot_results = sngp.main( config, output_dir) else: # Check for the ability to restart from a previous checkpoint (after # failure, etc.). output_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) # NOTE: Use this flag to simulate failing at a certain step. config.testing_failure_step = config.total_steps - 1 config.checkpoint_steps = config.testing_failure_step config.keep_checkpoint_steps = config.checkpoint_steps with tfds.testing.mock_data(num_examples=num_examples, data_dir=data_dir): sngp.main(config, output_dir) checkpoint_path = os.path.join(output_dir, 'checkpoint.npz') self.assertTrue(os.path.exists(checkpoint_path)) checkpoint = checkpoint_utils.load_checkpoint( None, checkpoint_path) self.assertEqual(int(checkpoint['opt']['state']['step']), config.testing_failure_step) # This should resume from the failed step. del config.testing_failure_step with tfds.testing.mock_data(num_examples=num_examples, data_dir=data_dir): train_loss, val_loss, fewshot_results = sngp.main( config, output_dir) # Check for reproducibility. fewshot_acc_sum = sum(jax.tree_util.tree_flatten(fewshot_results)[0]) logging.info('(train_loss, val_loss, fewshot_acc_sum) = %s, %s, %s', train_loss, val_loss['val'], fewshot_acc_sum) # TODO(dusenberrymw): Determine why the SNGP script is non-deterministic. self.assertAllClose(train_loss, correct_train_loss, atol=0.025, rtol=0.3) self.assertAllClose(val_loss['val'], correct_val_loss, atol=0.02, rtol=0.3)
def test_checkpointing(self): key = jax.random.PRNGKey(42) key, subkey = jax.random.split(key) tree = _make_pytree(subkey) checkpoint_path = self._save_temp_checkpoint(tree) key, subkey = jax.random.split(key) new_tree = _make_pytree(subkey) leaves = jax.tree_util.tree_leaves(tree) new_leaves = jax.tree_util.tree_leaves(new_tree) for arr, new_arr in zip(leaves, new_leaves): self.assertNotAllClose(arr, new_arr) restored_tree = checkpoint_utils.load_checkpoint( new_tree, checkpoint_path) restored_leaves = jax.tree_util.tree_leaves(restored_tree) for arr, restored_arr in zip(leaves, restored_leaves): self.assertAllClose(arr, restored_arr)
def load_checkpoints(config): """Load the checkpoints for each ensemble members.""" if not (config.model_init and isinstance(config.model_init, (tuple, list))): raise ValueError( ('deep_ensemble.py expects a list/tuple of ckpts to load; ' f'got instead config.model_init={config.model_init}.')) load_fn = lambda p: checkpoint_utils.load_checkpoint({}, p)['opt']['target' ] params = {} ensemble_size = len(config.model_init) for model_idx, path in enumerate(config.model_init, start=1): prefix = f'[{model_idx}/{ensemble_size}]' logging_msg = f'{prefix} Start to load checkpoint: {path}.' logging.info(logging_msg) params[path] = load_fn(path) logging_msg = f'{prefix} Finish to load checkpoint: {path}.' logging.info(logging_msg) return params
def main(args, init_distributed=False): assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) # set random seed np.random.seed(args.seed) torch.manual_seed(args.seed) if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) print(args, flush=True) # Setup task, e.g., translation, language modeling, etc. task = None if args.task == 'bert': task = tasks.LanguageModelingTask.setup_task(args) elif args.task == 'mnist': task = tasks.MNISTTask.setup_task(args) assert task != None # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=False, epoch=0) # Build model model = task.build_model(args) print('| num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # Build controller controller = Controller(args, task, model) print('| training on {} GPUs'.format(args.distributed_world_size)) print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, controller) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = controller.get_lr() train_meter = StopwatchMeter() train_meter.start() while (lr > args.min_lr and (epoch_itr.epoch < max_epoch or (epoch_itr.epoch == max_epoch and epoch_itr._next_epoch_itr is not None)) and controller.get_num_updates() < max_update): # train for one epoch train(args, controller, task, epoch_itr) # #revise-task 6 # debug valid_losses = [None] # only use first validation loss to update the learning rate lr = controller.lr_step(epoch_itr.epoch, valid_losses[0]) # save checkpoint if epoch_itr.epoch % args.save_interval == 0: checkpoint_utils.save_checkpoint(args, controller, epoch_itr, valid_losses[0]) reload_dataset = ':' in getattr(args, 'data', '') # sharded data: get train iterator for next epoch epoch_itr = controller.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset) train_meter.stop() print('| done training in {:.1f} seconds'.format(train_meter.sum))
def main(config, output_dir): seed = config.get('seed', 0) rng = jax.random.PRNGKey(seed) tf.random.set_seed(seed) if config.get('data_dir'): logging.info('data_dir=%s', config.data_dir) logging.info('Output dir: %s', output_dir) save_checkpoint_path = None if config.get('checkpoint_steps'): gfile.makedirs(output_dir) save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz') # Create an asynchronous multi-metric writer. writer = metric_writers.create_default_writer( output_dir, just_logging=jax.process_index() > 0) # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() def write_note(note): if jax.host_id() == 0: logging.info('NOTE: %s', note) write_note('Initializing...') # Verify settings to make sure no checkpoints are accidentally missed. if config.get('keep_checkpoint_steps'): assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.' assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, ( f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be' f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`') batch_size = config.batch_size batch_size_eval = config.get('batch_size_eval', batch_size) if (batch_size % jax.device_count() != 0 or batch_size_eval % jax.device_count() != 0): raise ValueError( f'Batch sizes ({batch_size} and {batch_size_eval}) must ' f'be divisible by device number ({jax.device_count()})') local_batch_size = batch_size // jax.host_count() local_batch_size_eval = batch_size_eval // jax.host_count() logging.info( 'Global batch size %d on %d hosts results in %d local batch size. ' 'With %d devices per host (%d devices total), that\'s a %d per-device ' 'batch size.', batch_size, jax.host_count(), local_batch_size, jax.local_device_count(), jax.device_count(), local_batch_size // jax.local_device_count()) write_note('Initializing train dataset...') rng, train_ds_rng = jax.random.split(rng) train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index()) train_ds = input_utils.get_data( dataset=config.dataset, split=config.train_split, rng=train_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preprocess_spec.parse( spec=config.pp_train, available_ops=preprocess_utils.all_ops()), shuffle_buffer_size=config.shuffle_buffer_size, prefetch_size=config.get('prefetch_to_host', 2), data_dir=config.get('data_dir')) # Start prefetching already. train_iter = input_utils.start_input_pipeline( train_ds, config.get('prefetch_to_device', 1)) write_note('Initializing val dataset(s)...') def _get_val_split(dataset, split, pp_eval, data_dir=None): # We do ceil rounding such that we include the last incomplete batch. nval_img = input_utils.get_num_examples( dataset, split=split, process_batch_size=local_batch_size_eval, drop_remainder=False, data_dir=data_dir) val_steps = int(np.ceil(nval_img / batch_size_eval)) logging.info('Running validation for %d steps for %s, %s', val_steps, dataset, split) if isinstance(pp_eval, str): pp_eval = preprocess_spec.parse( spec=pp_eval, available_ops=preprocess_utils.all_ops()) val_ds = input_utils.get_data(dataset=dataset, split=split, rng=None, process_batch_size=local_batch_size_eval, preprocess_fn=pp_eval, cache=config.get('val_cache', 'batched'), num_epochs=1, repeat_after_batching=True, shuffle=False, prefetch_size=config.get( 'prefetch_to_host', 2), drop_remainder=False, data_dir=data_dir) return val_ds val_ds_splits = { 'val': _get_val_split(config.dataset, split=config.val_split, pp_eval=config.pp_eval, data_dir=config.get('data_dir')) } if config.get('test_split'): val_ds_splits.update({ 'test': _get_val_split(config.dataset, split=config.test_split, pp_eval=config.pp_eval, data_dir=config.get('data_dir')) }) if config.get('eval_on_cifar_10h'): cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn( config.get('data_dir', None)) preprocess_fn = preprocess_spec.parse( spec=config.pp_eval_cifar_10h, available_ops=preprocess_utils.all_ops()) pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex)) val_ds_splits['cifar_10h'] = _get_val_split( 'cifar10', split=config.get('cifar_10h_split') or 'test', pp_eval=pp_eval, data_dir=config.get('data_dir')) elif config.get('eval_on_imagenet_real'): imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn( ) preprocess_fn = preprocess_spec.parse( spec=config.pp_eval_imagenet_real, available_ops=preprocess_utils.all_ops()) pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex)) val_ds_splits['imagenet_real'] = _get_val_split( 'imagenet2012_real', split=config.get('imagenet_real_split') or 'validation', pp_eval=pp_eval, data_dir=config.get('data_dir')) ood_ds = {} if config.get('ood_datasets') and config.get('ood_methods'): if config.get( 'ood_methods'): # config.ood_methods is not a empty list logging.info('loading OOD dataset = %s', config.get('ood_datasets')) ood_ds, ood_ds_names = ood_utils.load_ood_datasets( config.dataset, config.ood_datasets, config.ood_split, config.pp_eval, config.pp_eval_ood, config.ood_methods, config.train_split, config.get('data_dir'), _get_val_split, ) ntrain_img = input_utils.get_num_examples( config.dataset, split=config.train_split, process_batch_size=local_batch_size, data_dir=config.get('data_dir')) steps_per_epoch = int(ntrain_img / batch_size) if config.get('num_epochs'): total_steps = int(config.num_epochs * steps_per_epoch) assert not config.get( 'total_steps'), 'Set either num_epochs or total_steps' else: total_steps = config.total_steps logging.info('Total train data points: %d', ntrain_img) logging.info( 'Running for %d steps, that means %f epochs and %d steps per epoch', total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch) write_note('Initializing model...') logging.info('config.model = %s', config.get('model')) model = ub.models.bit_resnet(num_classes=config.num_classes, **config.get('model', {})) # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @partial(jax.jit, backend='cpu') def init(rng): image_size = tuple(train_ds.element_spec['image'].shape[2:]) logging.info('image_size = %s', image_size) dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32) params = flax.core.unfreeze(model.init(rng, dummy_input, train=False))['params'] # Set bias in the head to a low value, such that loss is small initially. params['head']['bias'] = jnp.full_like(params['head']['bias'], config.get('init_head_bias', 0)) # init head kernel to all zeros for fine-tuning if config.get('model_init'): params['head']['kernel'] = jnp.full_like(params['head']['kernel'], 0) return params rng, rng_init = jax.random.split(rng) params_cpu = init(rng_init) if jax.host_id() == 0: num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @partial(jax.pmap, axis_name='batch') def evaluation_fn(params, images, labels, mask): # Ignore the entries with all zero labels for evaluation. mask *= labels.max(axis=1) logits, out = model.apply({'params': flax.core.freeze(params)}, images, train=False) losses = getattr(train_utils, config.get('loss', 'sigmoid_xent'))(logits=logits, labels=labels, reduction=False) loss = jax.lax.psum(losses * mask, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch') n = jax.lax.psum(mask, axis_name='batch') metric_args = jax.lax.all_gather( [logits, labels, out['pre_logits'], mask], axis_name='batch') return ncorrect, loss, n, metric_args @partial(jax.pmap, axis_name='batch') def cifar_10h_evaluation_fn(params, images, labels, mask): logits, out = model.apply({'params': flax.core.freeze(params)}, images, train=False) losses = getattr(train_utils, config.get('loss', 'softmax_xent'))(logits=logits, labels=labels, reduction=False) loss = jax.lax.psum(losses, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)] top1_correct = jnp.take_along_axis(one_hot_labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct, axis_name='batch') n = jax.lax.psum(one_hot_labels, axis_name='batch') metric_args = jax.lax.all_gather( [logits, labels, out['pre_logits'], mask], axis_name='batch') return ncorrect, loss, n, metric_args # Setup function for computing representation. @partial(jax.pmap, axis_name='batch') def representation_fn(params, images, labels, mask): _, outputs = model.apply({'params': flax.core.freeze(params)}, images, train=False) representation = outputs[config.fewshot.representation_layer] representation = jax.lax.all_gather(representation, 'batch') labels = jax.lax.all_gather(labels, 'batch') mask = jax.lax.all_gather(mask, 'batch') return representation, labels, mask # Load the optimizer from flax. opt_name = config.get('optim_name') write_note(f'Initializing {opt_name} optimizer...') opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {})) # We jit this, such that the arrays that are created are created on the same # device as the input is, in this case the CPU. Else they'd be on device[0]. opt_cpu = jax.jit(opt_def.create)(params_cpu) @partial(jax.pmap, axis_name='batch', donate_argnums=(0, )) def update_fn(opt, lr, images, labels, rng): """Update step.""" measurements = {} # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index('batch')) def loss_fn(params, images, labels): logits, _ = model.apply({'params': flax.core.freeze(params)}, images, train=True, rngs={'dropout': rng_model_local}) accuracy = jnp.mean( jnp.equal(jnp.argmax(logits, axis=-1), jnp.argmax(labels, axis=-1))) return getattr(train_utils, config.get('loss', 'sigmoid_xent'))(logits=logits, labels=labels), accuracy grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (l, train_accuracy), g = grad_fn(opt.target, images, labels) l, g = jax.lax.pmean((l, g), axis_name='batch') measurements['accuracy'] = train_accuracy # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if config.get('grad_clip_norm'): g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) g = jax.tree_util.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) decay_rules = config.get('weight_decay', []) or [] if isinstance(decay_rules, numbers.Number): decay_rules = [('.*kernel.*', decay_rules)] sched_m = lr / config.lr.base if config.get( 'weight_decay_decouple') else lr def decay_fn(v, wd): return (1.0 - sched_m * wd) * v opt = opt.replace(target=train_utils.tree_map_with_regex( decay_fn, opt.target, decay_rules)) params, _ = jax.tree_flatten(opt.target) measurements['l2_params'] = jnp.sqrt( sum([jnp.vdot(p, p) for p in params])) return opt, l, rng, measurements # Other things besides optimizer state to be stored. rng, rng_loop = jax.random.split(rng, 2) rngs_loop = flax_utils.replicate(rng_loop) checkpoint_extra = dict(accum_train_time=0.0, rngs_loop=rngs_loop) # Decide how to initialize training. The order is important. # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. # 3. Initialize model from something, e,g, start a fine-tuning job. # 4. Train from scratch. resume_checkpoint_path = None if save_checkpoint_path and gfile.exists(save_checkpoint_path): resume_checkpoint_path = save_checkpoint_path elif config.get('resume'): resume_checkpoint_path = config.resume if resume_checkpoint_path: write_note('Resume training from checkpoint...') checkpoint_tree = {'opt': opt_cpu, 'extra': checkpoint_extra} checkpoint = checkpoint_utils.load_checkpoint(checkpoint_tree, resume_checkpoint_path) opt_cpu, checkpoint_extra = checkpoint['opt'], checkpoint['extra'] rngs_loop = checkpoint_extra['rngs_loop'] elif config.get('model_init'): write_note(f'Initialize model from {config.model_init}...') reinit_params = config.get('model_reinit_params', ('head/kernel', 'head/bias')) logging.info('Reinitializing these parameters: %s', reinit_params) # We only support "no head" fine-tuning for now. loaded_params = checkpoint_utils.load_checkpoint( tree=None, path=config.model_init) loaded = checkpoint_utils.restore_from_pretrained_params( params_cpu, loaded_params, model_representation_size=None, model_classifier=None, reinit_params=reinit_params) opt_cpu = opt_cpu.replace(target=loaded) if jax.host_id() == 0: logging.info('Restored parameter overview:') parameter_overview.log_parameter_overview(loaded) write_note('Kicking off misc stuff...') first_step = int(opt_cpu.state.step) # Might be a DeviceArray type. if first_step == 0 and jax.host_id() == 0: writer.write_hparams(dict(config)) chrono = train_utils.Chrono(first_step, total_steps, batch_size, checkpoint_extra['accum_train_time']) # Note: switch to ProfileAllHosts() if you need to profile all hosts. # (Xprof data become much larger and take longer to load for analysis) profiler = periodic_actions.Profile( # Create profile after every restart to analyze pre-emption related # problems and assure we get similar performance in every run. logdir=output_dir, first_profile=first_step + 10) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = train_utils.create_learning_rate_schedule(total_steps, **config.get('lr', {})) # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be # necessary for TPUs. lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)), config.get('prefetch_to_device', 1)) write_note(f'Replicating...\n{chrono.note}') opt_repl = flax_utils.replicate(opt_cpu) write_note(f'Initializing few-shotters...\n{chrono.note}') fewshotter = None if 'fewshot' in config and fewshot is not None: fewshotter = fewshot.FewShotEvaluator( representation_fn, config.fewshot, config.fewshot.get('batch_size') or batch_size_eval) checkpoint_writer = None # Note: we return the train loss, val loss, and fewshot best l2s for use in # reproducibility unit tests. train_loss = -jnp.inf val_loss = {val_name: -jnp.inf for val_name, _ in val_ds_splits.items()} fewshot_results = {'dummy': {(0, 1): -jnp.inf}} write_note(f'First step compilations...\n{chrono.note}') logging.info('first_step = %s', first_step) # Advance the iterators if we are restarting from an earlier checkpoint. # TODO(dusenberrymw): Look into checkpointing dataset state instead. if first_step > 0: write_note('Advancing iterators after resuming from a checkpoint...') lr_iter = itertools.islice(lr_iter, first_step, None) train_iter = itertools.islice(train_iter, first_step, None) # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. for step, train_batch, lr_repl in zip( range(first_step + 1, total_steps + 1), train_iter, lr_iter): with jax.profiler.TraceContext('train_step', step_num=step, _r=1): opt_repl, loss_value, rngs_loop, extra_measurements = update_fn( opt_repl, lr_repl, train_batch['image'], train_batch['labels'], rng=rngs_loop) if jax.host_id() == 0: profiler(step) # Checkpoint saving if train_utils.itstime(step, config.get('checkpoint_steps'), total_steps, process=0): write_note('Checkpointing...') chrono.pause() train_utils.checkpointing_timeout( checkpoint_writer, config.get('checkpoint_timeout', 1)) checkpoint_extra['accum_train_time'] = chrono.accum_train_time checkpoint_extra['rngs_loop'] = rngs_loop # We need to transfer the weights over now or else we risk keeping them # alive while they'll be updated in a future step, creating hard to debug # memory errors (see b/160593526). Also, takes device 0's params only. opt_cpu = jax.tree_util.tree_map(lambda x: np.array(x[0]), opt_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None if train_utils.itstime(step, config.get('keep_checkpoint_steps'), total_steps): write_note('Keeping a checkpoint copy...') copy_step = step # Checkpoint should be a nested dictionary or FLAX datataclasses from # `flax.struct`. Both can be present in a checkpoint. checkpoint = {'opt': opt_cpu, 'extra': checkpoint_extra} checkpoint_writer = pool.apply_async( checkpoint_utils.save_checkpoint, (checkpoint, save_checkpoint_path, copy_step)) chrono.resume() # Report training progress if train_utils.itstime(step, config.log_training_steps, total_steps, process=0): write_note('Reporting training progress...') train_accuracy = extra_measurements['accuracy'] train_accuracy = jnp.mean(train_accuracy) train_loss = loss_value[ 0] # Keep to return for reproducibility tests. timing_measurements, note = chrono.tick(step) write_note(note) train_measurements = {} train_measurements.update({ 'learning_rate': lr_repl[0], 'training_loss': train_loss, 'training_accuracy': train_accuracy, }) train_measurements.update( flax.jax_utils.unreplicate(extra_measurements)) train_measurements.update(timing_measurements) writer.write_scalars(step, train_measurements) # Report validation performance if train_utils.itstime(step, config.log_eval_steps, total_steps): write_note('Evaluating on the validation set...') chrono.pause() for val_name, val_ds in val_ds_splits.items(): # Sets up evaluation metrics. ece_num_bins = config.get('ece_num_bins', 15) auc_num_bins = config.get('auc_num_bins', 1000) ece = rm.metrics.ExpectedCalibrationError( num_bins=ece_num_bins) calib_auc = rm.metrics.CalibrationAUC( correct_pred_as_pos_label=False) oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.005, num_bins=auc_num_bins) oc_auc_1 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.01, num_bins=auc_num_bins) oc_auc_2 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.02, num_bins=auc_num_bins) oc_auc_5 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.05, num_bins=auc_num_bins) label_diversity = tf.keras.metrics.Mean() sample_diversity = tf.keras.metrics.Mean() ged = tf.keras.metrics.Mean() # Runs evaluation loop. val_iter = input_utils.start_input_pipeline( val_ds, config.get('prefetch_to_device', 1)) ncorrect, loss, nseen = 0, 0, 0 for batch in val_iter: if val_name == 'cifar_10h': batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( cifar_10h_evaluation_fn(opt_repl.target, batch['image'], batch['labels'], batch['mask'])) else: batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( evaluation_fn(opt_repl.target, batch['image'], batch['labels'], batch['mask'])) # All results are a replicated array shaped as follows: # (local_devices, per_device_batch_size, elem_shape...) # with each local device's entry being identical as they got psum'd. # So let's just take the first one to the host as numpy. ncorrect += np.sum(np.array(batch_ncorrect[0])) loss += np.sum(np.array(batch_losses[0])) nseen += np.sum(np.array(batch_n[0])) if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent': # Here we parse batch_metric_args to compute uncertainty metrics. # (e.g., ECE or Calibration AUC). logits, labels, _, masks = batch_metric_args masks = np.array(masks[0], dtype=np.bool) logits = np.array(logits[0]) probs = jax.nn.softmax(logits) # From one-hot to integer labels, as required by ECE. int_labels = np.argmax(np.array(labels[0]), axis=-1) int_preds = np.argmax(logits, axis=-1) confidence = np.max(probs, axis=-1) for p, c, l, d, m, label in zip( probs, confidence, int_labels, int_preds, masks, labels[0]): ece.add_batch(p[m, :], label=l[m]) calib_auc.add_batch(d[m], label=l[m], confidence=c[m]) # TODO(jereliu): Extend to support soft multi-class probabilities. oc_auc_0_5.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_1.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_2.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_5.add_batch(d[m], label=l[m], custom_binning_score=c[m]) if val_name == 'cifar_10h' or val_name == 'imagenet_real': batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance( label[m], p[m, :], config.num_classes) label_diversity.update_state( batch_label_diversity) sample_diversity.update_state( batch_sample_diversity) ged.update_state(batch_ged) val_loss[ val_name] = loss / nseen # Keep for reproducibility tests. val_measurements = { f'{val_name}_prec@1': ncorrect / nseen, f'{val_name}_loss': val_loss[val_name], } if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent': val_measurements[f'{val_name}_ece'] = ece.result()['ece'] val_measurements[ f'{val_name}_calib_auc'] = calib_auc.result( )['calibration_auc'] val_measurements[ f'{val_name}_oc_auc_0.5%'] = oc_auc_0_5.result( )['collaborative_auc'] val_measurements[ f'{val_name}_oc_auc_1%'] = oc_auc_1.result( )['collaborative_auc'] val_measurements[ f'{val_name}_oc_auc_2%'] = oc_auc_2.result( )['collaborative_auc'] val_measurements[ f'{val_name}_oc_auc_5%'] = oc_auc_5.result( )['collaborative_auc'] writer.write_scalars(step, val_measurements) if val_name == 'cifar_10h' or val_name == 'imagenet_real': cifar_10h_measurements = { f'{val_name}_label_diversity': label_diversity.result(), f'{val_name}_sample_diversity': sample_diversity.result(), f'{val_name}_ged': ged.result(), } writer.write_scalars(step, cifar_10h_measurements) # OOD eval # Entries in the ood_ds dict include: # (ind_dataset, ood_dataset1, ood_dataset2, ...). # OOD metrics are computed using ind_dataset paired with each of the # ood_dataset. When Mahalanobis distance method is applied, train_ind_ds # is also included in the ood_ds. if ood_ds and config.ood_methods: ood_measurements = ood_utils.eval_ood_metrics( ood_ds, ood_ds_names, config.ood_methods, evaluation_fn, opt_repl.target, n_prefetch=config.get('prefetch_to_device', 1)) writer.write_scalars(step, ood_measurements) chrono.resume() if 'fewshot' in config and fewshotter is not None: # Compute few-shot on-the-fly evaluation. if train_utils.itstime(step, config.fewshot.log_steps, total_steps): chrono.pause() write_note(f'Few-shot evaluation...\n{chrono.note}') # Keep `results` to return for reproducibility tests. fewshot_results, best_l2 = fewshotter.run_all( opt_repl.target, config.fewshot.datasets) # TODO(dusenberrymw): Remove this once fewshot.py is updated. def make_writer_measure_fn(step): def writer_measure(name, value): writer.write_scalars(step, {name: value}) return writer_measure fewshotter.walk_results(make_writer_measure_fn(step), fewshot_results, best_l2) chrono.resume() # End of step. if config.get('testing_failure_step'): # Break early to simulate infra failures in test cases. if config.testing_failure_step == step: break write_note(f'Done!\n{chrono.note}') pool.close() pool.join() writer.close() # Return final training loss, validation loss, and fewshot results for # reproducibility test cases. return train_loss, val_loss, fewshot_results
def main(config, output_dir): # Note: switch to ProfileAllHosts() if you need to profile all hosts. # (Xprof data become much larger and take longer to load for analysis) profiler = periodic_actions.Profile( # Create profile after every restart to analyze pre-emption related # problems and assure we get similar performance in every run. logdir=output_dir, first_profile=10) logging.info(config) acquisition_method = config.get('acquisition_method') # Create an asynchronous multi-metric writer. writer = metric_writers.create_default_writer( output_dir, just_logging=jax.process_index() > 0) writer.write_hparams(dict(config)) # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() def write_note(note): if jax.process_index() == 0: logging.info('NOTE: %s', note) write_note(f'Initializing for {acquisition_method}') # Download dataset data_builder = tfds.builder(config.dataset) data_builder.download_and_prepare() seed = config.get('seed', 0) rng = jax.random.PRNGKey(seed) batch_size = config.batch_size batch_size_eval = config.get('batch_size_eval', batch_size) local_batch_size = batch_size // jax.process_count() local_batch_size_eval = batch_size_eval // jax.process_count() val_ds = input_utils.get_data( dataset=config.dataset, split=config.val_split, rng=None, process_batch_size=local_batch_size_eval, preprocess_fn=preprocess_spec.parse( spec=config.pp_eval, available_ops=preprocess_utils.all_ops()), shuffle=False, prefetch_size=config.get('prefetch_to_host', 2), num_epochs=1, # Only repeat once. ) test_ds = input_utils.get_data( dataset=config.dataset, split=config.test_split, rng=None, process_batch_size=local_batch_size_eval, preprocess_fn=preprocess_spec.parse( spec=config.pp_eval, available_ops=preprocess_utils.all_ops()), shuffle=False, prefetch_size=config.get('prefetch_to_host', 2), num_epochs=1, # Only repeat once. ) # Init model if config.model_type == 'deterministic': model_utils = deterministic_utils reinit_params = config.get('model_reinit_params', ('head/kernel', 'head/bias')) model = ub.models.vision_transformer(num_classes=config.num_classes, **config.get('model', {})) elif config.model_type == 'batchensemble': model_utils = batchensemble_utils reinit_params = ('batchensemble_head/bias', 'batchensemble_head/kernel', 'batchensemble_head/fast_weight_alpha', 'batchensemble_head/fast_weight_gamma') model = ub.models.PatchTransformerBE(num_classes=config.num_classes, **config.model) else: raise ValueError('Expect config.model_type to be "deterministic" or' f'"batchensemble", but received {config.model_type}.') init = model_utils.create_init(model, config, test_ds) rng, rng_init = jax.random.split(rng) params_cpu = init(rng_init) if jax.process_index() == 0: num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) # Load the optimizer from flax. opt_name = config.get('optim_name') opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {})) # We jit this, such that the arrays that are created on the same # device as the input is, in this case the CPU. Else they'd be on device[0]. opt_cpu = jax.jit(opt_def.create)(params_cpu) loaded_params = checkpoint_utils.load_checkpoint(tree=None, path=config.model_init) loaded = checkpoint_utils.restore_from_pretrained_params( params_cpu, loaded_params, config.model.representation_size, config.model.classifier, reinit_params, ) opt_cpu = opt_cpu.replace(target=loaded) # TODO(joost,andreas): This shouldn't be needed but opt_cpu is being # donated otherwise. Ensure opt_cpu is really on the cpu this way. opt_cpu = jax.device_get(opt_cpu) update_fn = model_utils.create_update_fn(model, config) evaluation_fn = model_utils.create_evaluation_fn(model, config) # NOTE: We need this because we need an Id field of type int. # TODO(andreas): Rename to IdSubsetDatasetBuilder? pool_subset_data_builder = al_utils.SubsetDatasetBuilder(data_builder, subset_ids=None) rng, pool_ds_rng = jax.random.split(rng) # NOTE: below line is necessary on multi host setup # pool_ds_rng = jax.random.fold_in(pool_ds_rng, jax.process_index()) pool_train_ds = input_utils.get_data( dataset=pool_subset_data_builder, split=config.train_split, rng=pool_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preprocess_spec.parse( spec=config.pp_eval, available_ops=preprocess_utils.all_ops()), shuffle=False, drop_remainder=False, prefetch_size=config.get('prefetch_to_host', 2), num_epochs=1, # Don't repeat ) # Potentially acquire an initial training set. initial_training_set_size = config.get('initial_training_set_size', 10) if initial_training_set_size > 0: current_opt_repl = flax_utils.replicate(opt_cpu) pool_ids, _, _, pool_masks = get_ids_logits_masks( model=model, opt_repl=current_opt_repl, ds=pool_train_ds, config=config) rng, initial_uniform_rng = jax.random.split(rng) pool_scores = get_uniform_scores(pool_masks, initial_uniform_rng) initial_training_set_batch_ids, _ = select_acquisition_batch_indices( acquisition_batch_size=initial_training_set_size, scores=pool_scores, ids=pool_ids, ignored_ids=set(), ) else: initial_training_set_batch_ids = [] # NOTE: if we could `enumerate` before `filter` in `create_dataset` of CLU # then this dataset creation could be simplified. # https://github.com/google/CommonLoopUtils/blob/main/clu/deterministic_data.py#L340 # CLU is explicitly not accepting outside contributions at the moment. train_subset_data_builder = al_utils.SubsetDatasetBuilder( data_builder, subset_ids=set(initial_training_set_batch_ids)) test_accuracies = [] training_sizes = [] rng, rng_loop = jax.random.split(rng) rngs_loop = flax_utils.replicate(rng_loop) if config.model_type == 'batchensemble': rngs_loop = {'dropout': rngs_loop} # TODO(joost,andreas): double check if below is still necessary # (train_split is independent of this) # NOTE: train_ds_rng is re-used for all train_ds creations rng, train_ds_rng = jax.random.split(rng) measurements = {} accumulated_steps = 0 while True: current_train_ds_length = len(train_subset_data_builder.subset_ids) if current_train_ds_length >= config.get('max_training_set_size', 150): break write_note(f'Training set size: {current_train_ds_length}') current_opt_repl = flax_utils.replicate(opt_cpu) # Only fine-tune if there is anything to fine-tune with. if current_train_ds_length > 0: # Repeat dataset to have oversampled epochs and bootstrap more batches number_of_batches = current_train_ds_length / config.batch_size num_repeats = math.ceil(config.total_steps / number_of_batches) write_note(f'Repeating dataset {num_repeats} times') # We repeat the dataset several times, such that we can obtain batches # of size batch_size, even at start of training. These batches will be # effectively 'bootstrap' sampled, meaning they are sampled with # replacement from the original training set. repeated_train_ds = input_utils.get_data( dataset=train_subset_data_builder, split=config.train_split, rng=train_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preprocess_spec.parse( spec=config.pp_train, available_ops=preprocess_utils.all_ops()), shuffle_buffer_size=config.shuffle_buffer_size, prefetch_size=config.get('prefetch_to_host', 2), # TODO(joost,andreas): double check if below leads to bootstrap # sampling. num_epochs=num_repeats, ) # We use this dataset to evaluate how well we perform on the training set. # We need this to evaluate if we fit well within max_steps budget. train_eval_ds = input_utils.get_data( dataset=train_subset_data_builder, split=config.train_split, rng=train_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preprocess_spec.parse( spec=config.pp_eval, available_ops=preprocess_utils.all_ops()), shuffle=False, drop_remainder=False, prefetch_size=config.get('prefetch_to_host', 2), num_epochs=1, ) # NOTE: warmup and decay are not a good fit for the small training set # lr_fn = train_utils.create_learning_rate_schedule(config.total_steps, # **config.get('lr', {}) # ) lr_fn = lambda x: config.lr.base early_stopping_patience = config.get('early_stopping_patience', 15) current_opt_repl, rngs_loop, measurements = finetune( update_fn=update_fn, opt_repl=current_opt_repl, lr_fn=lr_fn, ds=repeated_train_ds, rngs_loop=rngs_loop, total_steps=config.total_steps, train_eval_ds=train_eval_ds, val_ds=val_ds, evaluation_fn=evaluation_fn, early_stopping_patience=early_stopping_patience, profiler=profiler) train_val_accuracies = measurements.pop('train_val_accuracies') current_steps = 0 for step, train_acc, val_acc in train_val_accuracies: writer.write_scalars(accumulated_steps + step, { 'train_accuracy': train_acc, 'val_accuracy': val_acc }) current_steps = step accumulated_steps += current_steps + 10 test_accuracy = get_accuracy(evaluation_fn=evaluation_fn, opt_repl=current_opt_repl, ds=test_ds) write_note(f'Accuracy at {current_train_ds_length}: {test_accuracy}') test_accuracies.append(test_accuracy) training_sizes.append(current_train_ds_length) pool_ids, pool_outputs, _, pool_masks = get_ids_logits_masks( model=model, opt_repl=current_opt_repl, ds=pool_train_ds, use_pre_logits=acquisition_method == 'density', config=config) if acquisition_method == 'uniform': rng_loop, rng_acq = jax.random.split(rng_loop, 2) pool_scores = get_uniform_scores(pool_masks, rng_acq) elif acquisition_method == 'entropy': pool_scores = get_entropy_scores(pool_outputs, pool_masks) elif acquisition_method == 'margin': pool_scores = get_margin_scores(pool_outputs, pool_masks) elif acquisition_method == 'density': if current_train_ds_length > 0: pool_scores = get_density_scores(model=model, opt_repl=current_opt_repl, train_ds=train_eval_ds, pool_pre_logits=pool_outputs, pool_masks=pool_masks, config=config) else: rng_loop, rng_acq = jax.random.split(rng_loop, 2) pool_scores = get_uniform_scores(pool_masks, rng_acq) else: raise ValueError('Acquisition method not found.') acquisition_batch_ids, _ = select_acquisition_batch_indices( acquisition_batch_size=config.get('acquisition_batch_size', 10), scores=pool_scores, ids=pool_ids, ignored_ids=train_subset_data_builder.subset_ids) train_subset_data_builder.subset_ids.update(acquisition_batch_ids) measurements.update({'test_accuracy': test_accuracy}) writer.write_scalars(current_train_ds_length, measurements) write_note(f'Final acquired training ids: ' f'{train_subset_data_builder.subset_ids}' f'Accuracies: {test_accuracies}') pool.close() pool.join() writer.close() # TODO(joost,andreas): save the final checkpoint return (train_subset_data_builder.subset_ids, test_accuracies)
def main(): ## Parse command line args args = parse_train_args() ## Use CUDA if available print('=> CUDA availability / use: "{}" / "{}"'.format( str(torch.cuda.is_available()), str(args.CUDA))) args.CUDA = args.CUDA and torch.cuda.is_available() device = torch.device('cuda' if ( torch.cuda.is_available() and args.CUDA) else 'cpu') ## Load embeddings from arg_utils import load_embeddings embeddings = load_embeddings(args) ## Construct Datasets (Train/Validation/Test) from arg_utils import construct_hyperpartisan_flair_dataset, \ construct_hyperpartisan_flair_and_features_dataset, \ construct_propaganda_flair_dataset train_dataset, val_dataset, test_dataset, input_shape = construct_datasets( construct_propaganda_flair_dataset, embeddings, args, ) ## Construct Dataloaders train_dataloader, val_dataloader, test_dataloader = construct_dataloaders( train_dataset, val_dataset, test_dataset, args) ## Construct model + optimizer + scheduler from nn_architectures import construct_hierarch_att_net, \ construct_cnn_bertha_von_suttner, \ construct_HAN_with_features, \ construct_lstm model, optimizer, scheduler, loss_criterion = construct_model( construct_lstm, ## NOTE change model here input_shape, args) ## Train model if train_dataloader is not None: random.seed(args.seed) best_path, _ = train_pytorch( model, optimizer, loss_criterion, train_dataloader, args=args, val_loader=test_dataloader if val_dataloader is None else val_dataloader, device=device, scheduler=scheduler) checkpoint(model, optimizer, scheduler, args.epochs, args.checkpoint_dir, name='final.' + args.name) ## Test model # Load best model's checkpoint for testing (if available) if best_path: load_checkpoint(best_path, model) test_model(model, test_dataloader, device=device)
def main(): ## Parse command line args args = parse_train_args() assert args.k_fold is not None, 'Use "--k-fold <N>" for specifying the number of folds to use' ## Use CUDA if available print('=> CUDA availability / use: "{}" / "{}"'.format( str(torch.cuda.is_available()), str(args.CUDA))) args.CUDA = args.CUDA and torch.cuda.is_available() device = torch.device('cuda' if ( torch.cuda.is_available() and args.CUDA) else 'cpu') ## Extract data *Xs, Y = extract_data(args) Xs = tuple(Xs) input_shape = Xs[0].shape[1:] ## Optionally, undersample majority class if args.undersampling: balanced_indices = balanced_sampling(Y) Y = Y[balanced_indices] Xs = tuple(x[balanced_indices] for x in Xs) ## Construct TensorDataset main_tensor_dataset = TensorDataset(*(Xs + (Y, ))) ## Model constructor ## NOTE Change MODEL here from nn_architectures import construct_lstm, construct_AttnBiLSTM model_constructor = construct_AttnBiLSTM ## k-fold split of Train/Test stats = np.zeros((args.k_fold, 4)) kfold = StratifiedKFold(n_splits=args.k_fold) for i, (train_indices, test_indices) in enumerate(kfold.split(Xs[0].numpy(), Y.numpy())): print('K-Fold: [{:02}/{:02}]'.format(i + 1, args.k_fold)) train_dataset, test_dataset = Subset(main_tensor_dataset, train_indices), Subset( main_tensor_dataset, test_indices) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=torch.cuda.is_available()) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=torch.cuda.is_available()) ## Construct model + optimizer + scheduler args.name = args.name + '.k{}'.format(i) model, optimizer, scheduler, loss_criterion = construct_model( model_constructor, input_shape, args) ## Train model best_model, _ = train_pytorch(model, optimizer, loss_criterion, train_loader, args=args, val_loader=test_loader, device=device, scheduler=scheduler) ## Load best model if best_model: load_checkpoint(best_model, model) stats[i] = test_model(model, test_loader, device) ## Final stats print('** Final Statistics **') print('Mean:\t', np.mean(stats, axis=0)) print('STD: \t', np.std(stats, axis=0))