Beispiel #1
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    if hasattr(trainer.criterion, 'set_epoch'):
        trainer.criterion.set_epoch(epoch_itr.epoch)
    for samples in progress:
        if hasattr(trainer.criterion, 'set_num_updates'):
            trainer.criterion.set_num_updates(trainer.get_num_updates())

        log_output = trainer.train_step(samples)
        num_updates = trainer.get_num_updates()
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(metrics.get_smoothed_values('train'))
        progress.log(stats, tag='train', step=num_updates)

        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
Beispiel #2
0
def downstream_train_pytorch(args, trainer, task, epoch_itr, train_prefix):
    """Fine-tune PyTorch classifier on downstream training set for one epoch"""
    task.split = 'train'
    num_updates = trainer.get_num_updates()

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    progress = maybe_wrap_neptune_logging(progress, args)

    # Task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    max_update = args.max_update or math.inf
    with metrics.aggregate() as agg:
        for samples in progress:

            # Train for one step
            log_output = trainer.train_step(samples)
            num_updates = trainer.get_num_updates()
            if log_output is None:
                continue

            # log mid-epoch stats
            stats = get_ft_train_stats(agg.get_smoothed_values())
            progress.log(stats, tag=train_prefix, step=num_updates)

            if num_updates >= max_update:
                break

    # log end-of-epoch stats
    stats = get_ft_train_stats(agg.get_smoothed_values())
    try:
        progress.print(stats, tag=train_prefix, step=num_updates, log=False)
    except:
        progress.print(stats, tag=train_prefix, step=num_updates)

    # Reset epoch-level meters
    metrics.reset_meters(train_prefix)
    def test_nested_duplicate_names(self):
        name = str(uuid.uuid4())
        metrics.reset_meters(name)

        with metrics.aggregate(name):
            metrics.log_scalar('loss', 1)
            with metrics.aggregate() as other:
                with metrics.aggregate(name):
                    metrics.log_scalar('loss', 2)
            metrics.log_scalar('loss', 6)

        self.assertEqual(metrics.get_smoothed_values(name)['loss'], 3)
        self.assertEqual(other.get_smoothed_values()['loss'], 2)
    def test_named(self):
        name = str(uuid.uuid4())
        metrics.reset_meters(name)

        with metrics.aggregate(name):
            metrics.log_scalar('loss', 1)

        metrics.log_scalar('loss', 3)

        with metrics.aggregate(name):
            metrics.log_scalar('loss', 2)

        self.assertEqual(metrics.get_smoothed_values(name)['loss'], 1.5)
Beispiel #5
0
def downstream_validate_sklearn(args, task, model, epoch_for_logging,
                                task_name, num_updates, classifier, scaler):
    """Evaluate classifier on downstream validation set"""

    if args.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    # Initialize data iterator
    itr = task.get_batch_iterator(
        dataset=task.dataset('valid'),
        max_tokens=args.max_tokens_sklearn,
        max_sentences=args.max_sentences_sklearn,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            model.max_positions(),
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        seed=args.seed,
        num_shards=1,
        shard_id=0,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_for_logging,
        prefix='sklearn valid on \'{}\''.format(task_name),
        no_progress_bar='simple')
    progress = maybe_wrap_neptune_logging(progress, args)

    # Reset validation meters
    metrics.reset_meters(task_name)

    # Load downstream validation data
    with torch.no_grad():
        model.eval()
        features, targets = load_downstream_data(args, progress, model,
                                                 'valid', scaler, None)

    # Compute class predictions and probabilities
    class_predictions = classifier.predict(features)
    class_probabilities = classifier.predict_proba(features)

    # Compute and log downstream validation stats
    stats = compute_sklearn_stats(targets, class_predictions,
                                  class_probabilities, args.num_classes,
                                  args.eval_metric)
    stats = get_sklearn_stats(stats, num_updates)
    progress.print(stats, tag=task_name + '_sk_valid', step=num_updates)
Beispiel #6
0
def validate(args, trainer, task, epoch_for_logging, valid_name, ckpt_idx):
    """Evaluate the model on the validation set(s) and return the losses."""
    task.split = 'valid'

    if args.fixed_validation_seed is not None:
        # Set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    # Initialize data iterator
    itr = task.get_batch_iterator(
        dataset=task.dataset('valid'),
        max_tokens=args.max_tokens_valid,
        max_sentences=args.max_sentences_valid,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            trainer.get_model().max_positions(),
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        num_workers=args.num_workers,
        epoch=epoch_for_logging,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_for_logging,
        prefix='valid on \'{}\' subset'.format(valid_name),
        no_progress_bar='simple')
    progress = maybe_wrap_neptune_logging(progress, args)

    # Reset validation meters
    metrics.reset_meters(valid_name)

    with metrics.aggregate(valid_name) as agg:
        for sample in progress:
            trainer.valid_step(sample)

    # Log validation stats
    stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
    if args.log_valid_progress:
        valid_progress_prefix = '{}_ckpt{}'.format(valid_name, ckpt_idx)
        progress.print({args.eval_metric: stats[args.eval_metric]},
                       tag=valid_progress_prefix,
                       step=epoch_for_logging)

    # Return validations score
    return stats[args.best_checkpoint_metric], stats[
        args.eval_metric], progress
Beispiel #7
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    task.split = 'train'

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = maybe_wrap_neptune_logging(
        progress_bar.build_progress_bar(
            args,
            itr,
            epoch_itr.epoch,
            no_progress_bar='simple',
        ),
        args=args,
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    with metrics.aggregate() as agg:
        for samples in progress:
            log_output = trainer.train_step(samples)
            num_updates = trainer.get_num_updates()
            if log_output is None:
                continue

            # log mid-epoch stats
            stats = get_training_stats(agg.get_smoothed_values())
            progress.log(stats, tag='train', step=num_updates)

            if num_updates >= max_update:
                break

    # log end-of-epoch stats
    stats = get_training_stats(agg.get_smoothed_values())
    try:
        progress.print(stats, tag='train', step=num_updates, log=False)
    except:
        progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
Beispiel #8
0
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""

    if args.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args,
            itr,
            epoch_itr.epoch,
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple')

        # reset validation meters
        metrics.reset_meters('valid')

        with metrics.aggregate() as agg:
            for sample in progress:
                trainer.valid_step(sample)

        # log validation stats
        stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[args.best_checkpoint_metric])
    return valid_losses
Beispiel #9
0
def downstream_validate_pytorch(args,
                                task,
                                model,
                                criterion,
                                epoch_for_logging,
                                subsets,
                                valid_name,
                                num_updates,
                                global_epoch=None):
    """Evaluate the model on the validation set(s) and return the losses."""
    task.split = 'valid'
    valid_name_ = valid_name if valid_name is not None else 'valid'

    if args.fixed_validation_seed is not None:
        # Set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                model.max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
            epoch=1,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args,
            itr,
            epoch_for_logging,
            prefix='valid on \'{}\' subset'.format(valid_name_),
            no_progress_bar='simple')

        # Add global epoch to beginning of progress bar description
        if global_epoch is not None:
            try:
                progress.wrapped_bar.tqdm.set_description(
                    desc='epoch {:03d} | \'{}\' {}'.format(
                        global_epoch, valid_name_,
                        progress.wrapped_bar.prefix),
                    refresh=True)
            except:
                progress.tqdm.set_description(
                    desc='epoch {:03d} | \'{}\' {}'.format(
                        global_epoch, valid_name_, progress.tqdm.desc),
                    refresh=True)

        progress = maybe_wrap_neptune_logging(progress, args)

        # Reset validation meters
        metrics.reset_meters(valid_name_)

        with metrics.aggregate(valid_name) as agg:
            dummy_batch = "DUMMY"
            for sample in progress:
                dummy_batch = sample if dummy_batch == "DUMMY" else dummy_batch
                valid_step(args, sample, task, model, criterion, dummy_batch,
                           logger)

        # Log validation stats
        stats = get_ft_valid_stats(args, agg.get_smoothed_values(),
                                   num_updates)
        progress.print(stats, tag=valid_name_, step=num_updates)

        valid_losses.append(stats[args.best_checkpoint_metric])
    return valid_losses
Beispiel #10
0
def downstream_train_sklearn(args, task, model, epoch_for_logging, task_name,
                             num_updates):
    """Fine-tune sklearn classifier on downstream training set"""

    if args.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    # Initialize data iterator
    itr = task.get_batch_iterator(
        dataset=task.dataset('train'),
        max_tokens=args.max_tokens_sklearn,
        max_sentences=args.max_sentences_sklearn,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            model.max_positions(),
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        seed=args.seed,
        num_shards=1,
        shard_id=0,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_for_logging,
        prefix='sklearn fine-tune on \'{}\''.format(task_name),
        no_progress_bar='simple')
    progress = maybe_wrap_neptune_logging(progress, args)

    # Reset meters
    metrics.reset_meters(task_name)

    # Load downstream train data
    with torch.no_grad():
        model.eval()
        features, targets, scaler = load_downstream_data(
            args, progress, model, 'train', None, args.scaler_type)

    # Train classifier
    logger.info('fine-tuning LogisticRegression classifier on \'{}\''.format(
        task_name))
    timer_start = timer()
    best_C = LogRegCV(args, features, targets)
    classifier = LogisticRegression(
        multi_class=args.multi_class,
        solver=args.solver,
        C=best_C,
        n_jobs=min(os.cpu_count(), args.num_classes, args.num_workers_sklearn)
        if args.solver != 'liblinear' else None,
        tol=args.tol,
        random_state=args.seed,
        max_iter=args.max_iter,
        verbose=args.verbose).fit(features, targets)
    timer_end = timer()
    logger.info(
        'finished sklearn fine-tuning in {:.2f} seconds'.format(timer_end -
                                                                timer_start))

    # Compute class predictions and probabilities
    class_predictions = classifier.predict(features)
    class_probabilities = classifier.predict_proba(features)

    # Compute and log downstream training stats
    stats = compute_sklearn_stats(targets, class_predictions,
                                  class_probabilities, args.num_classes,
                                  args.eval_metric)
    stats = get_sklearn_stats(stats, num_updates)
    progress.print(stats, tag=task_name + '_sk_train', step=num_updates)

    return classifier, scaler
Beispiel #11
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf

    start_time = time.time()
    step = 0
    for samples in progress:
        log_output = trainer.train_step(samples)
        num_updates = trainer.get_num_updates()
        step += 1
        """
        if step % 10 == 0:
            print(step)

        if step >= 200:
            pr.disable()
            #pr.dump_stats( "torch_profile")

            sys.exit()

        step += 1
        """

        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(metrics.get_smoothed_values('train'))
        progress.log(stats, tag='train', step=num_updates)

        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            print("validate and save_checkpoint")
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    train_epoch_cost = time.time() - start_time

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)
    print("epoch_cost: %.5f s, avg_speed: %.5f steps/s" %
          (train_epoch_cost, float(step) / train_epoch_cost))

    # reset epoch-level meters
    metrics.reset_meters('train')