Exemplo n.º 1
0
def main():
    ## Parse command line args
    args = parse_eval_args()
    from pprint import PrettyPrinter
    PrettyPrinter(indent=4).pprint(vars(args))
    print()

    ## Use CUDA if available
    print('=> CUDA availability / use: "{}" / "{}"'.format(
        str(torch.cuda.is_available()), str(args.CUDA)))
    args.CUDA = args.CUDA and torch.cuda.is_available()
    device = torch.device('cuda' if (
        torch.cuda.is_available() and args.CUDA) else 'cpu')

    ## Dataset + Dataloader ## NOTE Update DATASET here
    from arg_utils import construct_hyperpartisan_flair_dataset, \
                        construct_propaganda_flair_dataset
    eval_dataset, input_shape = construct_eval_dataset(
        construct_propaganda_flair_dataset, args)

    from datasets import CachedDataset
    ## NOTE Use Cached dataset ?? (useful for ensemble runs, but not with TensorDataset)
    eval_dataloader = torch.utils.data.DataLoader(
        # CachedDataset(eval_dataset),
        eval_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.dataloader_workers,
        pin_memory=args.CUDA)

    ## Construct Model                                  ## NOTE Update MODEL here
    from nn_architectures import construct_cnn_bertha_von_suttner, \
                                 construct_hierarch_att_net, \
                                 construct_lstm, construct_AttnBiLSTM
    model_constructor = construct_AttnBiLSTM

    ## Load models from checkpoints
    models = list()
    for m_path in args.model_path:
        model = model_constructor(input_shape[-1])
        load_checkpoint(m_path, model)
        models.append(model)

    ## Model Summary
    from torchsummary import summary
    print('\n ** Model Summary ** ')
    print(models[0], end='\n\n')

    ## Evaluate documents
    predictions = evaluate_ensemble(models, eval_dataloader, device=device)

    ## Write predictions to file
    # write_hyperpartisan_predictions(predictions, eval_dataset, args.output_dir)
    write_propaganda_predictions(
        predictions,
        eval_dataset if args.tensor_dataset is None else
        construct_base_propaganda_dataset(args.input_dir, None),
        ## PropagandaDataset must always be provided (even if a tensor-dataset is provided), to properly write predictions to output file
        args.output_dir)
Exemplo n.º 2
0
def construct_model(model_constructor, input_shape, args):
    ## Construct model
    model = model_constructor(input_shape[-1], args)
    if torch.cuda.is_available() and args.CUDA:
        print('=> Moving model to CUDA')
        model.cuda()

    ## Construct optimizer
    from optimizers import RAdam, AdaBound
    # optimizer = RAdam(model.parameters())   ## NOTE Experimenting with RAdam and AdaBound
    # optimizer = AdaBound(model.parameters())   ## NOTE Experimenting with RAdam and AdaBound
    optimizer = torch.optim.Adam(model.parameters())

    ## Construct scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.2,
        patience=args.reduce_lr_patience,
        verbose=True)

    ## Optionally, resume training from checkpoint
    if args.resume is not None and os.path.isfile(args.resume):
        print('\n=> ** Resuming training from checkpoint **')
        load_checkpoint(args.resume, model, optimizer, scheduler)

    ## Optionally, freeze specific layers/sub-modules
    if args.freeze is not None and len(args.freeze) > 0:
        print('=> Freezing layers/sub-modules with indices: {}'.format(
            args.freeze))
        freeze_layers(model, args.freeze)
    else:
        print('=> All layers unfrozen for training')

    ## Model summary #1
    from nn_utils import count_parameters
    print('Model has {} trainable parameters'.format(count_parameters(model)))
    print(model)

    ## Model Summary #2
    # from torchsummary import summary
    # print('\nModel Summary:')
    # summary(model, input_shape, device='cuda' if args.CUDA else 'cpu')

    ## Loss criterion: Binary Cross Entropy
    loss_criterion = torch.nn.BCELoss()

    return model, optimizer, scheduler, loss_criterion
    def test_checkpointing_model(self):
        key = jax.random.PRNGKey(42)

        model, _ = _make_deterministic_model()
        input_shape = (2, 224, 224, 3)
        key, subkey = jax.random.split(key)
        params = _init_model(subkey, model, input_shape=input_shape)
        checkpoint_path = self._save_temp_checkpoint(params)

        key, subkey = jax.random.split(key)
        new_params = _init_model(subkey, model, input_shape=input_shape)
        restored_params = checkpoint_utils.load_checkpoint(
            new_params, checkpoint_path)
        restored_leaves = jax.tree_util.tree_leaves(restored_params)
        leaves = jax.tree_util.tree_leaves(params)
        for arr, restored_arr in zip(leaves, restored_leaves):
            self.assertAllClose(arr, restored_arr)

        key, subkey = jax.random.split(key)
        inputs = jax.random.normal(subkey, input_shape, jnp.float32)
        _, out = model.apply({"params": params}, inputs, train=False)
        _, new_out = model.apply({"params": new_params}, inputs, train=False)
        _, restored_out = model.apply({"params": restored_params},
                                      inputs,
                                      train=False)
        self.assertNotAllClose(out["pre_logits"], new_out["pre_logits"])
        self.assertAllClose(out["pre_logits"], restored_out["pre_logits"])
Exemplo n.º 4
0
    def test_sngp_script(self, dataset_name, classifier, representation_size,
                         correct_train_loss, correct_val_loss,
                         correct_fewshot_acc_sum, simulate_failure):
        data_dir = self.data_dir
        config = test_utils.get_config(dataset_name=dataset_name,
                                       classifier=classifier,
                                       representation_size=representation_size,
                                       use_sngp=True,
                                       use_gp_layer=True)
        output_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
        config.dataset_dir = data_dir
        num_examples = config.batch_size * config.total_steps

        if not simulate_failure:
            # Check for any errors.
            with tfds.testing.mock_data(num_examples=num_examples,
                                        data_dir=data_dir):
                train_loss, val_loss, fewshot_results = sngp.main(
                    config, output_dir)
        else:
            # Check for the ability to restart from a previous checkpoint (after
            # failure, etc.).
            output_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
            # NOTE: Use this flag to simulate failing at a certain step.
            config.testing_failure_step = config.total_steps - 1
            config.checkpoint_steps = config.testing_failure_step
            config.keep_checkpoint_steps = config.checkpoint_steps
            with tfds.testing.mock_data(num_examples=num_examples,
                                        data_dir=data_dir):
                sngp.main(config, output_dir)

            checkpoint_path = os.path.join(output_dir, 'checkpoint.npz')
            self.assertTrue(os.path.exists(checkpoint_path))
            checkpoint = checkpoint_utils.load_checkpoint(
                None, checkpoint_path)
            self.assertEqual(int(checkpoint['opt']['state']['step']),
                             config.testing_failure_step)

            # This should resume from the failed step.
            del config.testing_failure_step
            with tfds.testing.mock_data(num_examples=num_examples,
                                        data_dir=data_dir):
                train_loss, val_loss, fewshot_results = sngp.main(
                    config, output_dir)

        # Check for reproducibility.
        fewshot_acc_sum = sum(jax.tree_util.tree_flatten(fewshot_results)[0])
        logging.info('(train_loss, val_loss, fewshot_acc_sum) = %s, %s, %s',
                     train_loss, val_loss['val'], fewshot_acc_sum)
        # TODO(dusenberrymw): Determine why the SNGP script is non-deterministic.
        self.assertAllClose(train_loss,
                            correct_train_loss,
                            atol=0.025,
                            rtol=0.3)
        self.assertAllClose(val_loss['val'],
                            correct_val_loss,
                            atol=0.02,
                            rtol=0.3)
    def test_checkpointing(self):
        key = jax.random.PRNGKey(42)

        key, subkey = jax.random.split(key)
        tree = _make_pytree(subkey)
        checkpoint_path = self._save_temp_checkpoint(tree)

        key, subkey = jax.random.split(key)
        new_tree = _make_pytree(subkey)

        leaves = jax.tree_util.tree_leaves(tree)
        new_leaves = jax.tree_util.tree_leaves(new_tree)
        for arr, new_arr in zip(leaves, new_leaves):
            self.assertNotAllClose(arr, new_arr)

        restored_tree = checkpoint_utils.load_checkpoint(
            new_tree, checkpoint_path)
        restored_leaves = jax.tree_util.tree_leaves(restored_tree)
        for arr, restored_arr in zip(leaves, restored_leaves):
            self.assertAllClose(arr, restored_arr)
Exemplo n.º 6
0
def load_checkpoints(config):
    """Load the checkpoints for each ensemble members."""
    if not (config.model_init and isinstance(config.model_init,
                                             (tuple, list))):
        raise ValueError(
            ('deep_ensemble.py expects a list/tuple of ckpts to load; '
             f'got instead config.model_init={config.model_init}.'))

    load_fn = lambda p: checkpoint_utils.load_checkpoint({}, p)['opt']['target'
                                                                       ]
    params = {}
    ensemble_size = len(config.model_init)
    for model_idx, path in enumerate(config.model_init, start=1):
        prefix = f'[{model_idx}/{ensemble_size}]'
        logging_msg = f'{prefix} Start to load checkpoint: {path}.'
        logging.info(logging_msg)
        params[path] = load_fn(path)
        logging_msg = f'{prefix} Finish to load checkpoint: {path}.'
        logging.info(logging_msg)
    return params
Exemplo n.º 7
0
def main(args, init_distributed=False):
    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)

    #  set random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    print(args, flush=True)

    # Setup task, e.g., translation, language modeling, etc.
    task = None
    if args.task == 'bert':
        task = tasks.LanguageModelingTask.setup_task(args)
    elif args.task == 'mnist':
        task = tasks.MNISTTask.setup_task(args)
    assert task != None

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model
    model = task.build_model(args)

    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build controller
    controller = Controller(args, task, model)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator

    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, controller)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf

    lr = controller.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()

    while (lr > args.min_lr and (epoch_itr.epoch < max_epoch or
                                 (epoch_itr.epoch == max_epoch
                                  and epoch_itr._next_epoch_itr is not None))
           and controller.get_num_updates() < max_update):
        # train for one epoch
        train(args, controller, task, epoch_itr)  # #revise-task 6

        # debug
        valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = controller.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, controller, epoch_itr,
                                             valid_losses[0])

        reload_dataset = ':' in getattr(args, 'data', '')
        # sharded data: get train iterator for next epoch
        epoch_itr = controller.get_train_iterator(epoch_itr.epoch,
                                                  load_dataset=reload_dataset)

    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
def main(config, output_dir):

    seed = config.get('seed', 0)
    rng = jax.random.PRNGKey(seed)
    tf.random.set_seed(seed)

    if config.get('data_dir'):
        logging.info('data_dir=%s', config.data_dir)
    logging.info('Output dir: %s', output_dir)

    save_checkpoint_path = None
    if config.get('checkpoint_steps'):
        gfile.makedirs(output_dir)
        save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz')

    # Create an asynchronous multi-metric writer.
    writer = metric_writers.create_default_writer(
        output_dir, just_logging=jax.process_index() > 0)

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    def write_note(note):
        if jax.host_id() == 0:
            logging.info('NOTE: %s', note)

    write_note('Initializing...')

    # Verify settings to make sure no checkpoints are accidentally missed.
    if config.get('keep_checkpoint_steps'):
        assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.'
        assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, (
            f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be'
            f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`')

    batch_size = config.batch_size
    batch_size_eval = config.get('batch_size_eval', batch_size)
    if (batch_size % jax.device_count() != 0
            or batch_size_eval % jax.device_count() != 0):
        raise ValueError(
            f'Batch sizes ({batch_size} and {batch_size_eval}) must '
            f'be divisible by device number ({jax.device_count()})')

    local_batch_size = batch_size // jax.host_count()
    local_batch_size_eval = batch_size_eval // jax.host_count()
    logging.info(
        'Global batch size %d on %d hosts results in %d local batch size. '
        'With %d devices per host (%d devices total), that\'s a %d per-device '
        'batch size.', batch_size, jax.host_count(), local_batch_size,
        jax.local_device_count(), jax.device_count(),
        local_batch_size // jax.local_device_count())

    write_note('Initializing train dataset...')
    rng, train_ds_rng = jax.random.split(rng)
    train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index())
    train_ds = input_utils.get_data(
        dataset=config.dataset,
        split=config.train_split,
        rng=train_ds_rng,
        process_batch_size=local_batch_size,
        preprocess_fn=preprocess_spec.parse(
            spec=config.pp_train, available_ops=preprocess_utils.all_ops()),
        shuffle_buffer_size=config.shuffle_buffer_size,
        prefetch_size=config.get('prefetch_to_host', 2),
        data_dir=config.get('data_dir'))

    # Start prefetching already.
    train_iter = input_utils.start_input_pipeline(
        train_ds, config.get('prefetch_to_device', 1))

    write_note('Initializing val dataset(s)...')

    def _get_val_split(dataset, split, pp_eval, data_dir=None):
        # We do ceil rounding such that we include the last incomplete batch.
        nval_img = input_utils.get_num_examples(
            dataset,
            split=split,
            process_batch_size=local_batch_size_eval,
            drop_remainder=False,
            data_dir=data_dir)
        val_steps = int(np.ceil(nval_img / batch_size_eval))
        logging.info('Running validation for %d steps for %s, %s', val_steps,
                     dataset, split)

        if isinstance(pp_eval, str):
            pp_eval = preprocess_spec.parse(
                spec=pp_eval, available_ops=preprocess_utils.all_ops())

        val_ds = input_utils.get_data(dataset=dataset,
                                      split=split,
                                      rng=None,
                                      process_batch_size=local_batch_size_eval,
                                      preprocess_fn=pp_eval,
                                      cache=config.get('val_cache', 'batched'),
                                      num_epochs=1,
                                      repeat_after_batching=True,
                                      shuffle=False,
                                      prefetch_size=config.get(
                                          'prefetch_to_host', 2),
                                      drop_remainder=False,
                                      data_dir=data_dir)

        return val_ds

    val_ds_splits = {
        'val':
        _get_val_split(config.dataset,
                       split=config.val_split,
                       pp_eval=config.pp_eval,
                       data_dir=config.get('data_dir'))
    }

    if config.get('test_split'):
        val_ds_splits.update({
            'test':
            _get_val_split(config.dataset,
                           split=config.test_split,
                           pp_eval=config.pp_eval,
                           data_dir=config.get('data_dir'))
        })

    if config.get('eval_on_cifar_10h'):
        cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn(
            config.get('data_dir', None))
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_cifar_10h,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex))
        val_ds_splits['cifar_10h'] = _get_val_split(
            'cifar10',
            split=config.get('cifar_10h_split') or 'test',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))
    elif config.get('eval_on_imagenet_real'):
        imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn(
        )
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_imagenet_real,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex))
        val_ds_splits['imagenet_real'] = _get_val_split(
            'imagenet2012_real',
            split=config.get('imagenet_real_split') or 'validation',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))

    ood_ds = {}
    if config.get('ood_datasets') and config.get('ood_methods'):
        if config.get(
                'ood_methods'):  #  config.ood_methods is not a empty list
            logging.info('loading OOD dataset = %s',
                         config.get('ood_datasets'))
            ood_ds, ood_ds_names = ood_utils.load_ood_datasets(
                config.dataset,
                config.ood_datasets,
                config.ood_split,
                config.pp_eval,
                config.pp_eval_ood,
                config.ood_methods,
                config.train_split,
                config.get('data_dir'),
                _get_val_split,
            )

    ntrain_img = input_utils.get_num_examples(
        config.dataset,
        split=config.train_split,
        process_batch_size=local_batch_size,
        data_dir=config.get('data_dir'))
    steps_per_epoch = int(ntrain_img / batch_size)

    if config.get('num_epochs'):
        total_steps = int(config.num_epochs * steps_per_epoch)
        assert not config.get(
            'total_steps'), 'Set either num_epochs or total_steps'
    else:
        total_steps = config.total_steps

    logging.info('Total train data points: %d', ntrain_img)
    logging.info(
        'Running for %d steps, that means %f epochs and %d steps per epoch',
        total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch)

    write_note('Initializing model...')
    logging.info('config.model = %s', config.get('model'))
    model = ub.models.bit_resnet(num_classes=config.num_classes,
                                 **config.get('model', {}))

    # We want all parameters to be created in host RAM, not on any device, they'll
    # be sent there later as needed, otherwise we already encountered two
    # situations where we allocate them twice.
    @partial(jax.jit, backend='cpu')
    def init(rng):
        image_size = tuple(train_ds.element_spec['image'].shape[2:])
        logging.info('image_size = %s', image_size)
        dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32)
        params = flax.core.unfreeze(model.init(rng, dummy_input,
                                               train=False))['params']

        # Set bias in the head to a low value, such that loss is small initially.
        params['head']['bias'] = jnp.full_like(params['head']['bias'],
                                               config.get('init_head_bias', 0))

        # init head kernel to all zeros for fine-tuning
        if config.get('model_init'):
            params['head']['kernel'] = jnp.full_like(params['head']['kernel'],
                                                     0)

        return params

    rng, rng_init = jax.random.split(rng)
    params_cpu = init(rng_init)

    if jax.host_id() == 0:
        num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
        parameter_overview.log_parameter_overview(params_cpu)
        writer.write_scalars(step=0, scalars={'num_params': num_params})

    @partial(jax.pmap, axis_name='batch')
    def evaluation_fn(params, images, labels, mask):
        # Ignore the entries with all zero labels for evaluation.
        mask *= labels.max(axis=1)
        logits, out = model.apply({'params': flax.core.freeze(params)},
                                  images,
                                  train=False)

        losses = getattr(train_utils,
                         config.get('loss', 'sigmoid_xent'))(logits=logits,
                                                             labels=labels,
                                                             reduction=False)
        loss = jax.lax.psum(losses * mask, axis_name='batch')

        top1_idx = jnp.argmax(logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        top1_correct = jnp.take_along_axis(labels, top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch')
        n = jax.lax.psum(mask, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [logits, labels, out['pre_logits'], mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    @partial(jax.pmap, axis_name='batch')
    def cifar_10h_evaluation_fn(params, images, labels, mask):
        logits, out = model.apply({'params': flax.core.freeze(params)},
                                  images,
                                  train=False)

        losses = getattr(train_utils,
                         config.get('loss', 'softmax_xent'))(logits=logits,
                                                             labels=labels,
                                                             reduction=False)
        loss = jax.lax.psum(losses, axis_name='batch')

        top1_idx = jnp.argmax(logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)]

        top1_correct = jnp.take_along_axis(one_hot_labels,
                                           top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
        n = jax.lax.psum(one_hot_labels, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [logits, labels, out['pre_logits'], mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    # Setup function for computing representation.
    @partial(jax.pmap, axis_name='batch')
    def representation_fn(params, images, labels, mask):
        _, outputs = model.apply({'params': flax.core.freeze(params)},
                                 images,
                                 train=False)
        representation = outputs[config.fewshot.representation_layer]
        representation = jax.lax.all_gather(representation, 'batch')
        labels = jax.lax.all_gather(labels, 'batch')
        mask = jax.lax.all_gather(mask, 'batch')
        return representation, labels, mask

    # Load the optimizer from flax.
    opt_name = config.get('optim_name')
    write_note(f'Initializing {opt_name} optimizer...')
    opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))

    # We jit this, such that the arrays that are created are created on the same
    # device as the input is, in this case the CPU. Else they'd be on device[0].
    opt_cpu = jax.jit(opt_def.create)(params_cpu)

    @partial(jax.pmap, axis_name='batch', donate_argnums=(0, ))
    def update_fn(opt, lr, images, labels, rng):
        """Update step."""

        measurements = {}

        # Get device-specific loss rng.
        rng, rng_model = jax.random.split(rng, 2)
        rng_model_local = jax.random.fold_in(rng_model,
                                             jax.lax.axis_index('batch'))

        def loss_fn(params, images, labels):
            logits, _ = model.apply({'params': flax.core.freeze(params)},
                                    images,
                                    train=True,
                                    rngs={'dropout': rng_model_local})
            accuracy = jnp.mean(
                jnp.equal(jnp.argmax(logits, axis=-1),
                          jnp.argmax(labels, axis=-1)))
            return getattr(train_utils,
                           config.get('loss',
                                      'sigmoid_xent'))(logits=logits,
                                                       labels=labels), accuracy

        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (l, train_accuracy), g = grad_fn(opt.target, images, labels)
        l, g = jax.lax.pmean((l, g), axis_name='batch')
        measurements['accuracy'] = train_accuracy

        # Log the gradient norm only if we need to compute it anyways (clipping)
        # or if we don't use grad_accum_steps, as they interact badly.
        if config.get('grad_accum_steps',
                      1) == 1 or config.get('grad_clip_norm'):
            grads, _ = jax.tree_flatten(g)
            l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
            measurements['l2_grads'] = l2_g

        # Optionally resize the global gradient to a maximum norm. We found this
        # useful in some cases across optimizers, hence it's in the main loop.
        if config.get('grad_clip_norm'):
            g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
            g = jax.tree_util.tree_map(lambda p: g_factor * p, g)
        opt = opt.apply_gradient(g, learning_rate=lr)

        decay_rules = config.get('weight_decay', []) or []
        if isinstance(decay_rules, numbers.Number):
            decay_rules = [('.*kernel.*', decay_rules)]
        sched_m = lr / config.lr.base if config.get(
            'weight_decay_decouple') else lr

        def decay_fn(v, wd):
            return (1.0 - sched_m * wd) * v

        opt = opt.replace(target=train_utils.tree_map_with_regex(
            decay_fn, opt.target, decay_rules))

        params, _ = jax.tree_flatten(opt.target)
        measurements['l2_params'] = jnp.sqrt(
            sum([jnp.vdot(p, p) for p in params]))

        return opt, l, rng, measurements

    # Other things besides optimizer state to be stored.
    rng, rng_loop = jax.random.split(rng, 2)
    rngs_loop = flax_utils.replicate(rng_loop)
    checkpoint_extra = dict(accum_train_time=0.0, rngs_loop=rngs_loop)

    # Decide how to initialize training. The order is important.
    # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job.
    # 2. Resume from a previous checkpoint, e.g. start a cooldown training job.
    # 3. Initialize model from something, e,g, start a fine-tuning job.
    # 4. Train from scratch.
    resume_checkpoint_path = None
    if save_checkpoint_path and gfile.exists(save_checkpoint_path):
        resume_checkpoint_path = save_checkpoint_path
    elif config.get('resume'):
        resume_checkpoint_path = config.resume
    if resume_checkpoint_path:
        write_note('Resume training from checkpoint...')
        checkpoint_tree = {'opt': opt_cpu, 'extra': checkpoint_extra}
        checkpoint = checkpoint_utils.load_checkpoint(checkpoint_tree,
                                                      resume_checkpoint_path)
        opt_cpu, checkpoint_extra = checkpoint['opt'], checkpoint['extra']
        rngs_loop = checkpoint_extra['rngs_loop']
    elif config.get('model_init'):
        write_note(f'Initialize model from {config.model_init}...')
        reinit_params = config.get('model_reinit_params',
                                   ('head/kernel', 'head/bias'))
        logging.info('Reinitializing these parameters: %s', reinit_params)
        # We only support "no head" fine-tuning for now.
        loaded_params = checkpoint_utils.load_checkpoint(
            tree=None, path=config.model_init)
        loaded = checkpoint_utils.restore_from_pretrained_params(
            params_cpu,
            loaded_params,
            model_representation_size=None,
            model_classifier=None,
            reinit_params=reinit_params)
        opt_cpu = opt_cpu.replace(target=loaded)
        if jax.host_id() == 0:
            logging.info('Restored parameter overview:')
            parameter_overview.log_parameter_overview(loaded)

    write_note('Kicking off misc stuff...')
    first_step = int(opt_cpu.state.step)  # Might be a DeviceArray type.
    if first_step == 0 and jax.host_id() == 0:
        writer.write_hparams(dict(config))
    chrono = train_utils.Chrono(first_step, total_steps, batch_size,
                                checkpoint_extra['accum_train_time'])
    # Note: switch to ProfileAllHosts() if you need to profile all hosts.
    # (Xprof data become much larger and take longer to load for analysis)
    profiler = periodic_actions.Profile(
        # Create profile after every restart to analyze pre-emption related
        # problems and assure we get similar performance in every run.
        logdir=output_dir,
        first_profile=first_step + 10)

    # Prepare the learning-rate and pre-fetch it to device to avoid delays.
    lr_fn = train_utils.create_learning_rate_schedule(total_steps,
                                                      **config.get('lr', {}))
    # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be
    # necessary for TPUs.
    lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)),
                                          config.get('prefetch_to_device', 1))

    write_note(f'Replicating...\n{chrono.note}')
    opt_repl = flax_utils.replicate(opt_cpu)

    write_note(f'Initializing few-shotters...\n{chrono.note}')
    fewshotter = None
    if 'fewshot' in config and fewshot is not None:
        fewshotter = fewshot.FewShotEvaluator(
            representation_fn, config.fewshot,
            config.fewshot.get('batch_size') or batch_size_eval)

    checkpoint_writer = None

    # Note: we return the train loss, val loss, and fewshot best l2s for use in
    # reproducibility unit tests.
    train_loss = -jnp.inf
    val_loss = {val_name: -jnp.inf for val_name, _ in val_ds_splits.items()}
    fewshot_results = {'dummy': {(0, 1): -jnp.inf}}

    write_note(f'First step compilations...\n{chrono.note}')
    logging.info('first_step = %s', first_step)
    # Advance the iterators if we are restarting from an earlier checkpoint.
    # TODO(dusenberrymw): Look into checkpointing dataset state instead.
    if first_step > 0:
        write_note('Advancing iterators after resuming from a checkpoint...')
        lr_iter = itertools.islice(lr_iter, first_step, None)
        train_iter = itertools.islice(train_iter, first_step, None)

    # Using a python integer for step here, because opt.state.step is allocated
    # on TPU during replication.
    for step, train_batch, lr_repl in zip(
            range(first_step + 1, total_steps + 1), train_iter, lr_iter):

        with jax.profiler.TraceContext('train_step', step_num=step, _r=1):
            opt_repl, loss_value, rngs_loop, extra_measurements = update_fn(
                opt_repl,
                lr_repl,
                train_batch['image'],
                train_batch['labels'],
                rng=rngs_loop)

        if jax.host_id() == 0:
            profiler(step)

        # Checkpoint saving
        if train_utils.itstime(step,
                               config.get('checkpoint_steps'),
                               total_steps,
                               process=0):
            write_note('Checkpointing...')
            chrono.pause()
            train_utils.checkpointing_timeout(
                checkpoint_writer, config.get('checkpoint_timeout', 1))
            checkpoint_extra['accum_train_time'] = chrono.accum_train_time
            checkpoint_extra['rngs_loop'] = rngs_loop
            # We need to transfer the weights over now or else we risk keeping them
            # alive while they'll be updated in a future step, creating hard to debug
            # memory errors (see b/160593526). Also, takes device 0's params only.
            opt_cpu = jax.tree_util.tree_map(lambda x: np.array(x[0]),
                                             opt_repl)

            # Check whether we want to keep a copy of the current checkpoint.
            copy_step = None
            if train_utils.itstime(step, config.get('keep_checkpoint_steps'),
                                   total_steps):
                write_note('Keeping a checkpoint copy...')
                copy_step = step

            # Checkpoint should be a nested dictionary or FLAX datataclasses from
            # `flax.struct`. Both can be present in a checkpoint.
            checkpoint = {'opt': opt_cpu, 'extra': checkpoint_extra}
            checkpoint_writer = pool.apply_async(
                checkpoint_utils.save_checkpoint,
                (checkpoint, save_checkpoint_path, copy_step))
            chrono.resume()

        # Report training progress
        if train_utils.itstime(step,
                               config.log_training_steps,
                               total_steps,
                               process=0):
            write_note('Reporting training progress...')
            train_accuracy = extra_measurements['accuracy']
            train_accuracy = jnp.mean(train_accuracy)
            train_loss = loss_value[
                0]  # Keep to return for reproducibility tests.
            timing_measurements, note = chrono.tick(step)
            write_note(note)
            train_measurements = {}
            train_measurements.update({
                'learning_rate': lr_repl[0],
                'training_loss': train_loss,
                'training_accuracy': train_accuracy,
            })
            train_measurements.update(
                flax.jax_utils.unreplicate(extra_measurements))
            train_measurements.update(timing_measurements)
            writer.write_scalars(step, train_measurements)

        # Report validation performance
        if train_utils.itstime(step, config.log_eval_steps, total_steps):
            write_note('Evaluating on the validation set...')
            chrono.pause()
            for val_name, val_ds in val_ds_splits.items():
                # Sets up evaluation metrics.
                ece_num_bins = config.get('ece_num_bins', 15)
                auc_num_bins = config.get('auc_num_bins', 1000)
                ece = rm.metrics.ExpectedCalibrationError(
                    num_bins=ece_num_bins)
                calib_auc = rm.metrics.CalibrationAUC(
                    correct_pred_as_pos_label=False)
                oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.005, num_bins=auc_num_bins)
                oc_auc_1 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.01, num_bins=auc_num_bins)
                oc_auc_2 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.02, num_bins=auc_num_bins)
                oc_auc_5 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.05, num_bins=auc_num_bins)
                label_diversity = tf.keras.metrics.Mean()
                sample_diversity = tf.keras.metrics.Mean()
                ged = tf.keras.metrics.Mean()

                # Runs evaluation loop.
                val_iter = input_utils.start_input_pipeline(
                    val_ds, config.get('prefetch_to_device', 1))
                ncorrect, loss, nseen = 0, 0, 0
                for batch in val_iter:
                    if val_name == 'cifar_10h':
                        batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                            cifar_10h_evaluation_fn(opt_repl.target,
                                                    batch['image'],
                                                    batch['labels'],
                                                    batch['mask']))
                    else:
                        batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                            evaluation_fn(opt_repl.target, batch['image'],
                                          batch['labels'], batch['mask']))
                    # All results are a replicated array shaped as follows:
                    # (local_devices, per_device_batch_size, elem_shape...)
                    # with each local device's entry being identical as they got psum'd.
                    # So let's just take the first one to the host as numpy.
                    ncorrect += np.sum(np.array(batch_ncorrect[0]))
                    loss += np.sum(np.array(batch_losses[0]))
                    nseen += np.sum(np.array(batch_n[0]))
                    if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
                        # Here we parse batch_metric_args to compute uncertainty metrics.
                        # (e.g., ECE or Calibration AUC).
                        logits, labels, _, masks = batch_metric_args
                        masks = np.array(masks[0], dtype=np.bool)
                        logits = np.array(logits[0])
                        probs = jax.nn.softmax(logits)
                        # From one-hot to integer labels, as required by ECE.
                        int_labels = np.argmax(np.array(labels[0]), axis=-1)
                        int_preds = np.argmax(logits, axis=-1)
                        confidence = np.max(probs, axis=-1)
                        for p, c, l, d, m, label in zip(
                                probs, confidence, int_labels, int_preds,
                                masks, labels[0]):
                            ece.add_batch(p[m, :], label=l[m])
                            calib_auc.add_batch(d[m],
                                                label=l[m],
                                                confidence=c[m])
                            # TODO(jereliu): Extend to support soft multi-class probabilities.
                            oc_auc_0_5.add_batch(d[m],
                                                 label=l[m],
                                                 custom_binning_score=c[m])
                            oc_auc_1.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])
                            oc_auc_2.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])
                            oc_auc_5.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])

                            if val_name == 'cifar_10h' or val_name == 'imagenet_real':
                                batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance(
                                    label[m], p[m, :], config.num_classes)
                                label_diversity.update_state(
                                    batch_label_diversity)
                                sample_diversity.update_state(
                                    batch_sample_diversity)
                                ged.update_state(batch_ged)

                val_loss[
                    val_name] = loss / nseen  # Keep for reproducibility tests.
                val_measurements = {
                    f'{val_name}_prec@1': ncorrect / nseen,
                    f'{val_name}_loss': val_loss[val_name],
                }
                if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
                    val_measurements[f'{val_name}_ece'] = ece.result()['ece']
                    val_measurements[
                        f'{val_name}_calib_auc'] = calib_auc.result(
                        )['calibration_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_0.5%'] = oc_auc_0_5.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_1%'] = oc_auc_1.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_2%'] = oc_auc_2.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_5%'] = oc_auc_5.result(
                        )['collaborative_auc']
                writer.write_scalars(step, val_measurements)

                if val_name == 'cifar_10h' or val_name == 'imagenet_real':
                    cifar_10h_measurements = {
                        f'{val_name}_label_diversity':
                        label_diversity.result(),
                        f'{val_name}_sample_diversity':
                        sample_diversity.result(),
                        f'{val_name}_ged': ged.result(),
                    }
                    writer.write_scalars(step, cifar_10h_measurements)

            # OOD eval
            # Entries in the ood_ds dict include:
            # (ind_dataset, ood_dataset1, ood_dataset2, ...).
            # OOD metrics are computed using ind_dataset paired with each of the
            # ood_dataset. When Mahalanobis distance method is applied, train_ind_ds
            # is also included in the ood_ds.
            if ood_ds and config.ood_methods:
                ood_measurements = ood_utils.eval_ood_metrics(
                    ood_ds,
                    ood_ds_names,
                    config.ood_methods,
                    evaluation_fn,
                    opt_repl.target,
                    n_prefetch=config.get('prefetch_to_device', 1))
                writer.write_scalars(step, ood_measurements)
            chrono.resume()

        if 'fewshot' in config and fewshotter is not None:
            # Compute few-shot on-the-fly evaluation.
            if train_utils.itstime(step, config.fewshot.log_steps,
                                   total_steps):
                chrono.pause()
                write_note(f'Few-shot evaluation...\n{chrono.note}')
                # Keep `results` to return for reproducibility tests.
                fewshot_results, best_l2 = fewshotter.run_all(
                    opt_repl.target, config.fewshot.datasets)

                # TODO(dusenberrymw): Remove this once fewshot.py is updated.
                def make_writer_measure_fn(step):
                    def writer_measure(name, value):
                        writer.write_scalars(step, {name: value})

                    return writer_measure

                fewshotter.walk_results(make_writer_measure_fn(step),
                                        fewshot_results, best_l2)
                chrono.resume()

        # End of step.
        if config.get('testing_failure_step'):
            # Break early to simulate infra failures in test cases.
            if config.testing_failure_step == step:
                break

    write_note(f'Done!\n{chrono.note}')
    pool.close()
    pool.join()
    writer.close()

    # Return final training loss, validation loss, and fewshot results for
    # reproducibility test cases.
    return train_loss, val_loss, fewshot_results
Exemplo n.º 9
0
def main(config, output_dir):
    # Note: switch to ProfileAllHosts() if you need to profile all hosts.
    # (Xprof data become much larger and take longer to load for analysis)
    profiler = periodic_actions.Profile(
        # Create profile after every restart to analyze pre-emption related
        # problems and assure we get similar performance in every run.
        logdir=output_dir,
        first_profile=10)

    logging.info(config)

    acquisition_method = config.get('acquisition_method')

    # Create an asynchronous multi-metric writer.
    writer = metric_writers.create_default_writer(
        output_dir, just_logging=jax.process_index() > 0)
    writer.write_hparams(dict(config))

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    def write_note(note):
        if jax.process_index() == 0:
            logging.info('NOTE: %s', note)

    write_note(f'Initializing for {acquisition_method}')

    # Download dataset
    data_builder = tfds.builder(config.dataset)
    data_builder.download_and_prepare()

    seed = config.get('seed', 0)
    rng = jax.random.PRNGKey(seed)

    batch_size = config.batch_size
    batch_size_eval = config.get('batch_size_eval', batch_size)

    local_batch_size = batch_size // jax.process_count()
    local_batch_size_eval = batch_size_eval // jax.process_count()

    val_ds = input_utils.get_data(
        dataset=config.dataset,
        split=config.val_split,
        rng=None,
        process_batch_size=local_batch_size_eval,
        preprocess_fn=preprocess_spec.parse(
            spec=config.pp_eval, available_ops=preprocess_utils.all_ops()),
        shuffle=False,
        prefetch_size=config.get('prefetch_to_host', 2),
        num_epochs=1,  # Only repeat once.
    )

    test_ds = input_utils.get_data(
        dataset=config.dataset,
        split=config.test_split,
        rng=None,
        process_batch_size=local_batch_size_eval,
        preprocess_fn=preprocess_spec.parse(
            spec=config.pp_eval, available_ops=preprocess_utils.all_ops()),
        shuffle=False,
        prefetch_size=config.get('prefetch_to_host', 2),
        num_epochs=1,  # Only repeat once.
    )

    # Init model
    if config.model_type == 'deterministic':
        model_utils = deterministic_utils
        reinit_params = config.get('model_reinit_params',
                                   ('head/kernel', 'head/bias'))
        model = ub.models.vision_transformer(num_classes=config.num_classes,
                                             **config.get('model', {}))
    elif config.model_type == 'batchensemble':
        model_utils = batchensemble_utils
        reinit_params = ('batchensemble_head/bias',
                         'batchensemble_head/kernel',
                         'batchensemble_head/fast_weight_alpha',
                         'batchensemble_head/fast_weight_gamma')
        model = ub.models.PatchTransformerBE(num_classes=config.num_classes,
                                             **config.model)
    else:
        raise ValueError('Expect config.model_type to be "deterministic" or'
                         f'"batchensemble", but received {config.model_type}.')

    init = model_utils.create_init(model, config, test_ds)

    rng, rng_init = jax.random.split(rng)
    params_cpu = init(rng_init)

    if jax.process_index() == 0:
        num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
        parameter_overview.log_parameter_overview(params_cpu)
        writer.write_scalars(step=0, scalars={'num_params': num_params})

    # Load the optimizer from flax.
    opt_name = config.get('optim_name')
    opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))

    # We jit this, such that the arrays that are created on the same
    # device as the input is, in this case the CPU. Else they'd be on device[0].
    opt_cpu = jax.jit(opt_def.create)(params_cpu)

    loaded_params = checkpoint_utils.load_checkpoint(tree=None,
                                                     path=config.model_init)
    loaded = checkpoint_utils.restore_from_pretrained_params(
        params_cpu,
        loaded_params,
        config.model.representation_size,
        config.model.classifier,
        reinit_params,
    )

    opt_cpu = opt_cpu.replace(target=loaded)

    # TODO(joost,andreas): This shouldn't be needed but opt_cpu is being
    # donated otherwise. Ensure opt_cpu is really on the cpu this way.
    opt_cpu = jax.device_get(opt_cpu)

    update_fn = model_utils.create_update_fn(model, config)
    evaluation_fn = model_utils.create_evaluation_fn(model, config)

    # NOTE: We need this because we need an Id field of type int.
    # TODO(andreas): Rename to IdSubsetDatasetBuilder?
    pool_subset_data_builder = al_utils.SubsetDatasetBuilder(data_builder,
                                                             subset_ids=None)

    rng, pool_ds_rng = jax.random.split(rng)

    # NOTE: below line is necessary on multi host setup
    # pool_ds_rng = jax.random.fold_in(pool_ds_rng, jax.process_index())

    pool_train_ds = input_utils.get_data(
        dataset=pool_subset_data_builder,
        split=config.train_split,
        rng=pool_ds_rng,
        process_batch_size=local_batch_size,
        preprocess_fn=preprocess_spec.parse(
            spec=config.pp_eval, available_ops=preprocess_utils.all_ops()),
        shuffle=False,
        drop_remainder=False,
        prefetch_size=config.get('prefetch_to_host', 2),
        num_epochs=1,  # Don't repeat
    )

    # Potentially acquire an initial training set.
    initial_training_set_size = config.get('initial_training_set_size', 10)

    if initial_training_set_size > 0:
        current_opt_repl = flax_utils.replicate(opt_cpu)
        pool_ids, _, _, pool_masks = get_ids_logits_masks(
            model=model,
            opt_repl=current_opt_repl,
            ds=pool_train_ds,
            config=config)

        rng, initial_uniform_rng = jax.random.split(rng)
        pool_scores = get_uniform_scores(pool_masks, initial_uniform_rng)

        initial_training_set_batch_ids, _ = select_acquisition_batch_indices(
            acquisition_batch_size=initial_training_set_size,
            scores=pool_scores,
            ids=pool_ids,
            ignored_ids=set(),
        )
    else:
        initial_training_set_batch_ids = []

    # NOTE: if we could `enumerate` before `filter` in `create_dataset` of CLU
    # then this dataset creation could be simplified.
    # https://github.com/google/CommonLoopUtils/blob/main/clu/deterministic_data.py#L340
    # CLU is explicitly not accepting outside contributions at the moment.
    train_subset_data_builder = al_utils.SubsetDatasetBuilder(
        data_builder, subset_ids=set(initial_training_set_batch_ids))

    test_accuracies = []
    training_sizes = []

    rng, rng_loop = jax.random.split(rng)
    rngs_loop = flax_utils.replicate(rng_loop)
    if config.model_type == 'batchensemble':
        rngs_loop = {'dropout': rngs_loop}

    # TODO(joost,andreas): double check if below is still necessary
    # (train_split is independent of this)
    # NOTE: train_ds_rng is re-used for all train_ds creations
    rng, train_ds_rng = jax.random.split(rng)

    measurements = {}
    accumulated_steps = 0
    while True:
        current_train_ds_length = len(train_subset_data_builder.subset_ids)
        if current_train_ds_length >= config.get('max_training_set_size', 150):
            break
        write_note(f'Training set size: {current_train_ds_length}')

        current_opt_repl = flax_utils.replicate(opt_cpu)

        # Only fine-tune if there is anything to fine-tune with.
        if current_train_ds_length > 0:
            # Repeat dataset to have oversampled epochs and bootstrap more batches
            number_of_batches = current_train_ds_length / config.batch_size
            num_repeats = math.ceil(config.total_steps / number_of_batches)
            write_note(f'Repeating dataset {num_repeats} times')

            # We repeat the dataset several times, such that we can obtain batches
            # of size batch_size, even at start of training. These batches will be
            # effectively 'bootstrap' sampled, meaning they are sampled with
            # replacement from the original training set.
            repeated_train_ds = input_utils.get_data(
                dataset=train_subset_data_builder,
                split=config.train_split,
                rng=train_ds_rng,
                process_batch_size=local_batch_size,
                preprocess_fn=preprocess_spec.parse(
                    spec=config.pp_train,
                    available_ops=preprocess_utils.all_ops()),
                shuffle_buffer_size=config.shuffle_buffer_size,
                prefetch_size=config.get('prefetch_to_host', 2),
                # TODO(joost,andreas): double check if below leads to bootstrap
                # sampling.
                num_epochs=num_repeats,
            )

            # We use this dataset to evaluate how well we perform on the training set.
            # We need this to evaluate if we fit well within max_steps budget.
            train_eval_ds = input_utils.get_data(
                dataset=train_subset_data_builder,
                split=config.train_split,
                rng=train_ds_rng,
                process_batch_size=local_batch_size,
                preprocess_fn=preprocess_spec.parse(
                    spec=config.pp_eval,
                    available_ops=preprocess_utils.all_ops()),
                shuffle=False,
                drop_remainder=False,
                prefetch_size=config.get('prefetch_to_host', 2),
                num_epochs=1,
            )

            # NOTE: warmup and decay are not a good fit for the small training set
            # lr_fn = train_utils.create_learning_rate_schedule(config.total_steps,
            #                                                   **config.get('lr', {})
            #                                                   )
            lr_fn = lambda x: config.lr.base

            early_stopping_patience = config.get('early_stopping_patience', 15)
            current_opt_repl, rngs_loop, measurements = finetune(
                update_fn=update_fn,
                opt_repl=current_opt_repl,
                lr_fn=lr_fn,
                ds=repeated_train_ds,
                rngs_loop=rngs_loop,
                total_steps=config.total_steps,
                train_eval_ds=train_eval_ds,
                val_ds=val_ds,
                evaluation_fn=evaluation_fn,
                early_stopping_patience=early_stopping_patience,
                profiler=profiler)
            train_val_accuracies = measurements.pop('train_val_accuracies')
            current_steps = 0
            for step, train_acc, val_acc in train_val_accuracies:
                writer.write_scalars(accumulated_steps + step, {
                    'train_accuracy': train_acc,
                    'val_accuracy': val_acc
                })
                current_steps = step
            accumulated_steps += current_steps + 10

        test_accuracy = get_accuracy(evaluation_fn=evaluation_fn,
                                     opt_repl=current_opt_repl,
                                     ds=test_ds)

        write_note(f'Accuracy at {current_train_ds_length}: {test_accuracy}')

        test_accuracies.append(test_accuracy)
        training_sizes.append(current_train_ds_length)

        pool_ids, pool_outputs, _, pool_masks = get_ids_logits_masks(
            model=model,
            opt_repl=current_opt_repl,
            ds=pool_train_ds,
            use_pre_logits=acquisition_method == 'density',
            config=config)

        if acquisition_method == 'uniform':
            rng_loop, rng_acq = jax.random.split(rng_loop, 2)
            pool_scores = get_uniform_scores(pool_masks, rng_acq)
        elif acquisition_method == 'entropy':
            pool_scores = get_entropy_scores(pool_outputs, pool_masks)
        elif acquisition_method == 'margin':
            pool_scores = get_margin_scores(pool_outputs, pool_masks)
        elif acquisition_method == 'density':
            if current_train_ds_length > 0:
                pool_scores = get_density_scores(model=model,
                                                 opt_repl=current_opt_repl,
                                                 train_ds=train_eval_ds,
                                                 pool_pre_logits=pool_outputs,
                                                 pool_masks=pool_masks,
                                                 config=config)
            else:
                rng_loop, rng_acq = jax.random.split(rng_loop, 2)
                pool_scores = get_uniform_scores(pool_masks, rng_acq)
        else:
            raise ValueError('Acquisition method not found.')

        acquisition_batch_ids, _ = select_acquisition_batch_indices(
            acquisition_batch_size=config.get('acquisition_batch_size', 10),
            scores=pool_scores,
            ids=pool_ids,
            ignored_ids=train_subset_data_builder.subset_ids)

        train_subset_data_builder.subset_ids.update(acquisition_batch_ids)

        measurements.update({'test_accuracy': test_accuracy})
        writer.write_scalars(current_train_ds_length, measurements)

    write_note(f'Final acquired training ids: '
               f'{train_subset_data_builder.subset_ids}'
               f'Accuracies: {test_accuracies}')

    pool.close()
    pool.join()
    writer.close()
    # TODO(joost,andreas): save the final checkpoint
    return (train_subset_data_builder.subset_ids, test_accuracies)
Exemplo n.º 10
0
def main():
    ## Parse command line args
    args = parse_train_args()

    ## Use CUDA if available
    print('=> CUDA availability / use: "{}" / "{}"'.format(
        str(torch.cuda.is_available()), str(args.CUDA)))
    args.CUDA = args.CUDA and torch.cuda.is_available()
    device = torch.device('cuda' if (
        torch.cuda.is_available() and args.CUDA) else 'cpu')

    ## Load embeddings
    from arg_utils import load_embeddings
    embeddings = load_embeddings(args)

    ## Construct Datasets (Train/Validation/Test)
    from arg_utils import construct_hyperpartisan_flair_dataset, \
                          construct_hyperpartisan_flair_and_features_dataset, \
                          construct_propaganda_flair_dataset
    train_dataset, val_dataset, test_dataset, input_shape = construct_datasets(
        construct_propaganda_flair_dataset,
        embeddings,
        args,
    )

    ## Construct Dataloaders
    train_dataloader, val_dataloader, test_dataloader = construct_dataloaders(
        train_dataset, val_dataset, test_dataset, args)

    ## Construct model + optimizer + scheduler
    from nn_architectures import construct_hierarch_att_net, \
                                 construct_cnn_bertha_von_suttner, \
                                 construct_HAN_with_features, \
                                 construct_lstm
    model, optimizer, scheduler, loss_criterion = construct_model(
        construct_lstm,  ## NOTE change model here
        input_shape,
        args)

    ## Train model
    if train_dataloader is not None:
        random.seed(args.seed)
        best_path, _ = train_pytorch(
            model,
            optimizer,
            loss_criterion,
            train_dataloader,
            args=args,
            val_loader=test_dataloader
            if val_dataloader is None else val_dataloader,
            device=device,
            scheduler=scheduler)
        checkpoint(model,
                   optimizer,
                   scheduler,
                   args.epochs,
                   args.checkpoint_dir,
                   name='final.' + args.name)

    ## Test model
    # Load best model's checkpoint for testing (if available)
    if best_path:
        load_checkpoint(best_path, model)
    test_model(model, test_dataloader, device=device)
Exemplo n.º 11
0
def main():
    ## Parse command line args
    args = parse_train_args()
    assert args.k_fold is not None, 'Use "--k-fold <N>" for specifying the number of folds to use'

    ## Use CUDA if available
    print('=> CUDA availability / use: "{}" / "{}"'.format(
        str(torch.cuda.is_available()), str(args.CUDA)))
    args.CUDA = args.CUDA and torch.cuda.is_available()
    device = torch.device('cuda' if (
        torch.cuda.is_available() and args.CUDA) else 'cpu')

    ## Extract data
    *Xs, Y = extract_data(args)
    Xs = tuple(Xs)
    input_shape = Xs[0].shape[1:]

    ## Optionally, undersample majority class
    if args.undersampling:
        balanced_indices = balanced_sampling(Y)
        Y = Y[balanced_indices]
        Xs = tuple(x[balanced_indices] for x in Xs)

    ## Construct TensorDataset
    main_tensor_dataset = TensorDataset(*(Xs + (Y, )))

    ## Model constructor ## NOTE Change MODEL here
    from nn_architectures import construct_lstm, construct_AttnBiLSTM
    model_constructor = construct_AttnBiLSTM

    ## k-fold split of Train/Test
    stats = np.zeros((args.k_fold, 4))
    kfold = StratifiedKFold(n_splits=args.k_fold)
    for i, (train_indices,
            test_indices) in enumerate(kfold.split(Xs[0].numpy(), Y.numpy())):
        print('K-Fold: [{:02}/{:02}]'.format(i + 1, args.k_fold))

        train_dataset, test_dataset = Subset(main_tensor_dataset,
                                             train_indices), Subset(
                                                 main_tensor_dataset,
                                                 test_indices)
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=0,
                                  pin_memory=torch.cuda.is_available())
        test_loader = DataLoader(test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=0,
                                 pin_memory=torch.cuda.is_available())

        ## Construct model + optimizer + scheduler
        args.name = args.name + '.k{}'.format(i)
        model, optimizer, scheduler, loss_criterion = construct_model(
            model_constructor, input_shape, args)

        ## Train model
        best_model, _ = train_pytorch(model,
                                      optimizer,
                                      loss_criterion,
                                      train_loader,
                                      args=args,
                                      val_loader=test_loader,
                                      device=device,
                                      scheduler=scheduler)

        ## Load best model
        if best_model:
            load_checkpoint(best_model, model)

        stats[i] = test_model(model, test_loader, device)

    ## Final stats
    print('** Final Statistics **')
    print('Mean:\t', np.mean(stats, axis=0))
    print('STD: \t', np.std(stats, axis=0))