Beispiel #1
0
    def test_train_two_epoch_late_start(self):
        train.train(self.hparams.training_hparams,
                    self.model,
                    self.train_loader,
                    self.root,
                    callbacks=[self.callback],
                    start_step=Step.from_epoch(0, 5, len(self.train_loader)),
                    end_step=Step.from_epoch(2, 5, len(self.train_loader)))

        self.assertEqual(self.step_counter, 25)
        self.assertEqual(self.ep, 2)
        self.assertEqual(self.it, 5)
        self.assertEqual(self.lr, 0.01)
Beispiel #2
0
def standard_callbacks(training_hparams: hparams.TrainingHparams,
                       train_set_loader: DataLoader,
                       test_set_loader: DataLoader,
                       eval_on_train: bool = False,
                       verbose: bool = True,
                       start_step: Step = None,
                       evaluate_every_epoch: bool = True,
                       weight_save_steps: List[Step] = []):
    start = start_step or Step.zero(train_set_loader.iterations_per_epoch)
    end = Step.from_str(training_hparams.training_steps,
                        train_set_loader.iterations_per_epoch)
    test_eval_callback = create_eval_callback('test',
                                              test_set_loader,
                                              verbose=verbose)
    train_eval_callback = create_eval_callback('train',
                                               train_set_loader,
                                               verbose=verbose)

    # Basic checkpointing and state saving at the beginning and end.
    result = [
        run_at_step(start, save_model),
        run_at_step(end, save_model),
        run_at_step(end, save_logger),
        run_every_epoch(checkpointing.save_checkpoint_callback),
    ]

    for s in weight_save_steps:
        result.append(run_at_step(s, save_model))

    # Test every epoch if requested.
    if evaluate_every_epoch:
        result = [run_every_epoch(test_eval_callback)] + result
    elif verbose:
        result.append(run_every_epoch(create_timekeeper_callback()))

    # Ensure that testing occurs at least at the beginning and end of training.
    if start.it != 0 or not evaluate_every_epoch:
        result = [run_at_step(start, test_eval_callback)] + result
    if end.it != 0 or not evaluate_every_epoch:
        result = [run_at_step(end, test_eval_callback)] + result

    # Do the same for the train set if requested.
    if eval_on_train:
        if evaluate_every_epoch:
            result = [run_every_epoch(train_eval_callback)] + result
        if start.it != 0 or not evaluate_every_epoch:
            result = [run_at_step(start, train_eval_callback)] + result
        if end.it != 0 or not evaluate_every_epoch:
            result = [run_at_step(end, train_eval_callback)] + result

    return result
Beispiel #3
0
def restore_checkpoint(output_location, model, optimizer,
                       iterations_per_epoch):
    checkpoint_location = paths.checkpoint(output_location)
    if not get_platform().exists(checkpoint_location):
        return None, None
    checkpoint = get_platform().load_model(checkpoint_location,
                                           map_location=torch.device('cpu'))

    # Handle DataParallel.
    module_in_name = get_platform().is_parallel
    if module_in_name and not all(
            k.startswith('module.') for k in checkpoint['model_state_dict']):
        checkpoint['model_state_dict'] = {
            'module.' + k: v
            for k, v in checkpoint['model_state_dict'].items()
        }
    elif all(k.startswith('module.')
             for k in checkpoint['model_state_dict']) and not module_in_name:
        checkpoint['model_state_dict'] = {
            k[len('module.'):]: v
            for k, v in checkpoint['model_state_dict'].items()
        }

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    step = Step.from_epoch(checkpoint['ep'], checkpoint['it'],
                           iterations_per_epoch)
    logger = MetricLogger.create_from_string(checkpoint['logger'])

    return step, logger
Beispiel #4
0
 def branch_function(self,
                     retrain_d: hparams.DatasetHparams,
                     retrain_t: hparams.TrainingHparams,
                     start_at_step_zero: bool = False,
                     transfer_learn: bool = False):
     # Get the mask and model.
     if transfer_learn:
         m = models.registry.load(self.level_root,
                                  self.lottery_desc.train_end_step,
                                  self.lottery_desc.model_hparams)
     else:
         m = models.registry.load(self.level_root,
                                  self.lottery_desc.train_start_step,
                                  self.lottery_desc.model_hparams)
     m = PrunedModel(m, Mask.load(self.level_root))
     start_step = Step.from_iteration(
         0 if start_at_step_zero else
         self.lottery_desc.train_start_step.iteration,
         datasets.registry.iterations_per_epoch(retrain_d))
     train.standard_train(m,
                          self.branch_root,
                          retrain_d,
                          retrain_t,
                          start_step=start_step,
                          verbose=self.verbose)
Beispiel #5
0
def standard_train(model: Model,
                   output_location: str,
                   dataset_hparams: hparams.DatasetHparams,
                   training_hparams: hparams.TrainingHparams,
                   start_step: Step = None,
                   verbose: bool = True,
                   evaluate_every_epoch: bool = True):
    """Train using the standard callbacks according to the provided hparams."""

    # If the model file for the end of training already exists in this location, do not train.
    iterations_per_epoch = datasets.registry.iterations_per_epoch(
        dataset_hparams)
    train_end_step = Step.from_str(training_hparams.training_steps,
                                   iterations_per_epoch)
    if (models.registry.exists(output_location, train_end_step)
            and get_platform().exists(paths.logger(output_location))):
        return

    train_loader = datasets.registry.get(dataset_hparams, train=True)
    test_loader = datasets.registry.get(dataset_hparams, train=False)
    callbacks = standard_callbacks.standard_callbacks(
        training_hparams,
        train_loader,
        test_loader,
        start_step=start_step,
        verbose=verbose,
        evaluate_every_epoch=evaluate_every_epoch)
    train(training_hparams,
          model,
          train_loader,
          output_location,
          callbacks,
          start_step=start_step)
Beispiel #6
0
    def test_train_zero_steps_late_start(self):
        before = TestTrain.get_state(self.model)

        train.train(self.hparams.training_hparams,
                    self.model,
                    self.train_loader,
                    self.root,
                    callbacks=[self.callback],
                    start_step=Step.from_epoch(0, 5, len(self.train_loader)),
                    end_step=Step.from_epoch(0, 5, len(self.train_loader)))

        after = TestTrain.get_state(self.model)
        for k in before:
            self.assertTrue(np.array_equal(before[k], after[k]))
        self.assertEqual(self.step_counter, 0)
        self.assertEqual(self.ep, 0)
        self.assertEqual(self.it, 0)
Beispiel #7
0
def get_lr_schedule(training_hparams: TrainingHparams, optimizer: torch.optim.Optimizer, iterations_per_epoch: int):
    lambdas = [lambda it: 1.0]

    # Drop the learning rate according to gamma at the specified milestones.
    if bool(training_hparams.gamma) != bool(training_hparams.milestone_steps):
        raise ValueError('milestones and gamma hyperparameters must both be set or not at all.')
    if training_hparams.milestone_steps:
        milestones = [Step.from_str(x, iterations_per_epoch).iteration
                      for x in training_hparams.milestone_steps.split(',')]
        lambdas.append(lambda it: training_hparams.gamma ** bisect.bisect(milestones, it))

    # Add linear learning rate warmup if specified.
    if training_hparams.warmup_steps:
        warmup_iters = Step.from_str(training_hparams.warmup_steps, iterations_per_epoch).iteration
        if warmup_iters != 0:
            lambdas.append(lambda it: min(1.0, it / warmup_iters))

    # Combine the lambdas.
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lambda it: np.product([l(it) for l in lambdas]))
Beispiel #8
0
    def test_train_more_than_two_epochs(self):
        train.train(self.hparams.training_hparams,
                    self.model,
                    self.train_loader,
                    self.root,
                    callbacks=[self.callback],
                    end_step=Step.from_epoch(2, 1, len(self.train_loader)))

        self.assertEqual(self.step_counter, 26)
        self.assertEqual(self.ep, 2)
        self.assertEqual(self.it, 1)
        self.assertEqual(self.lr, 0.01)
Beispiel #9
0
    def test_train_in_full_later_start(self):
        train.train(self.hparams.training_hparams,
                    self.model,
                    self.train_loader,
                    self.root,
                    callbacks=[self.callback],
                    start_step=Step.from_epoch(1, 5, len(self.train_loader)))

        self.assertEqual(self.step_counter, 20)
        self.assertEqual(self.ep, 3)
        self.assertEqual(self.it, 0)
        self.assertEqual(self.lr, 0.01)
Beispiel #10
0
    def test_end_to_end(self):
        init_loc = paths.model(self.root, Step.zero(len(self.train_loader)))
        end_loc = paths.model(self.root,
                              Step.from_epoch(3, 0, len(self.train_loader)))

        init_state = TestStandardCallbacks.get_state(self.model)

        train.train(self.hparams.training_hparams,
                    self.model,
                    self.train_loader,
                    self.root,
                    callbacks=self.callbacks,
                    start_step=Step.from_epoch(0, 0, len(self.train_loader)),
                    end_step=Step.from_epoch(3, 0, len(self.train_loader)))

        end_state = TestStandardCallbacks.get_state(self.model)

        # Check that final state has been saved.
        self.assertTrue(os.path.exists(init_loc))
        self.assertTrue(os.path.exists(end_loc))

        # Check that the checkpoint file still exists.
        self.assertTrue(os.path.exists(paths.checkpoint(self.root)))

        # Check that the initial and final states match those that were saved.
        self.model.load_state_dict(torch.load(init_loc))
        saved_state = TestStandardCallbacks.get_state(self.model)
        self.assertStateEqual(init_state, saved_state)

        self.model.load_state_dict(torch.load(end_loc))
        saved_state = TestStandardCallbacks.get_state(self.model)
        self.assertStateEqual(end_state, saved_state)

        # Check that the logger has the right number of entries.
        self.assertTrue(os.path.exists(paths.logger(self.root)))
        logger = MetricLogger.create_from_file(self.root)
        self.assertEqual(len(logger.get_data('train_loss')), 4)
        self.assertEqual(len(logger.get_data('test_loss')), 4)
        self.assertEqual(len(logger.get_data('train_accuracy')), 4)
        self.assertEqual(len(logger.get_data('test_accuracy')), 4)
Beispiel #11
0
    def test_train_one_epoch(self):
        train.train(self.hparams.training_hparams,
                    self.model,
                    self.train_loader,
                    self.root,
                    callbacks=[self.callback],
                    end_step=Step.from_epoch(1, 0, len(self.train_loader)))

        self.assertEqual(self.step_counter,
                         13)  # Same as len(self.train_loader) + 1
        self.assertEqual(self.ep, 1)
        self.assertEqual(self.it, 0)
        self.assertEqual(self.lr, 0.1)
Beispiel #12
0
    def test_constructor(self):
        with self.assertRaises(ValueError):
            Step(-1, 20)

        with self.assertRaises(ValueError):
            Step(1, 0)

        with self.assertRaises(ValueError):
            Step(1, -5)

        self.assertStepEquals(Step(0, 1), 0, 0, 0)
        self.assertStepEquals(Step(0, 100), 0, 0, 0)
        self.assertStepEquals(Step(10, 100), 10, 0, 10)
        self.assertStepEquals(Step(110, 100), 110, 1, 10)
        self.assertStepEquals(Step(11010, 100), 11010, 110, 10)
Beispiel #13
0
    def create_from_args(cls, args: argparse.Namespace) -> 'LotteryDesc':
        # Get the main arguments.
        dataset_hparams = hparams.DatasetHparams.create_from_args(args)
        model_hparams = hparams.ModelHparams.create_from_args(args)
        training_hparams = hparams.TrainingHparams.create_from_args(args)
        pruning_hparams = pruning.registry.get_pruning_hparams(args.pruning_strategy).create_from_args(args)

        # Create the desc.
        desc = cls(model_hparams, dataset_hparams, training_hparams, pruning_hparams)

        # Handle pretraining.
        if args.pretrain and not Step.str_is_zero(args.pretrain_training_steps):
            desc.pretrain_dataset_hparams = hparams.DatasetHparams.create_from_args(args, prefix='pretrain')
            desc.pretrain_dataset_hparams._name = 'Pretraining ' + desc.pretrain_dataset_hparams._name
            desc.pretrain_training_hparams = hparams.TrainingHparams.create_from_args(args, prefix='pretrain')
            desc.pretrain_training_hparams._name = 'Pretraining ' + desc.pretrain_training_hparams._name
        elif 'rewinding_steps' in args and args.rewinding_steps and not Step.str_is_zero(args.rewinding_steps):
            desc.pretrain_dataset_hparams = copy.deepcopy(dataset_hparams)
            desc.pretrain_dataset_hparams._name = 'Pretraining ' + desc.pretrain_dataset_hparams._name
            desc.pretrain_training_hparams = copy.deepcopy(training_hparams)
            desc.pretrain_training_hparams._name = 'Pretraining ' + desc.pretrain_training_hparams._name
            desc.pretrain_training_hparams.training_steps = args.rewinding_steps

        return desc
Beispiel #14
0
    def run(self):
        if self.verbose and get_platform().is_primary_process:
            print('='*82 + f'\nTraining a Model (Replicate {self.replicate})\n' + '-'*82)
            print(self.desc.display)
            print(f'Output Location: {self.desc.run_path(self.replicate)}' + '\n' + '='*82 + '\n')
        self.desc.save(self.desc.run_path(self.replicate))

        #TODO: make mask and model init paths configurable
        init_path = os.path.join(get_platform().root, 'resnet18_lth')
        model = models.registry.load(init_path, Step.from_str('2ep218it', 1000),
                                    self.desc.model_hparams, self.desc.train_outputs)
        pruned_model = PrunedModel(model, Mask.load(init_path))

        train.standard_train(
            #models.registry.get(self.desc.model_hparams), self.desc.run_path(self.replicate),
            pruned_model, self.desc.run_path(self.replicate),
            self.desc.dataset_hparams, self.desc.training_hparams, evaluate_every_epoch=self.evaluate_every_epoch)
Beispiel #15
0
    def test_train_two_steps(self):
        before = TestTrain.get_state(self.model)

        train.train(self.hparams.training_hparams,
                    self.model,
                    self.train_loader,
                    self.root,
                    callbacks=[self.callback],
                    end_step=Step.from_iteration(2, len(self.train_loader)))

        after = TestTrain.get_state(self.model)
        for k in before:
            with self.subTest(k=k):
                self.assertFalse(np.array_equal(before[k], after[k]), k)

        self.assertEqual(self.step_counter, 3)
        self.assertEqual(self.ep, 0)
        self.assertEqual(self.it, 2)
        self.assertEqual(self.lr, 0.02)
Beispiel #16
0
    def test_save_load_exists(self):
        hp = registry.get_default_hparams('cifar_resnet_20')
        model = registry.get(hp.model_hparams)
        step = Step.from_iteration(27, 17)
        model_location = paths.model(self.root, step)
        model_state = TestSaveLoadExists.get_state(model)

        self.assertFalse(registry.exists(self.root, step))
        self.assertFalse(os.path.exists(model_location))

        # Test saving.
        model.save(self.root, step)
        self.assertTrue(registry.exists(self.root, step))
        self.assertTrue(os.path.exists(model_location))

        # Test loading.
        model = registry.get(hp.model_hparams)
        model.load_state_dict(torch.load(model_location))
        self.assertStateEqual(model_state, TestSaveLoadExists.get_state(model))

        model = registry.load(self.root, step, hp.model_hparams)
        self.assertStateEqual(model_state, TestSaveLoadExists.get_state(model))
Beispiel #17
0
 def str_to_step(self, s: str, pretrain: bool = False) -> Step:
     dataset_hparams = self.pretrain_dataset_hparams if pretrain else self.dataset_hparams
     iterations_per_epoch = datasets_registry.iterations_per_epoch(
         dataset_hparams)
     return Step.from_str(s, iterations_per_epoch)
Beispiel #18
0
 def test_overwrite(self):
     logger = TestMetricLogger.create_logger()
     logger.add('train_accuracy', Step.from_iteration(0, 400), 1.0)
     self.assertEqual(logger.get_data('train_accuracy'), [(0, 1.0), (1, 0.6)])
Beispiel #19
0
 def create_logger():
     logger = MetricLogger()
     logger.add('train_accuracy', Step.from_iteration(0, 400), 0.5)
     logger.add('train_accuracy', Step.from_iteration(1, 400), 0.6)
     logger.add('test_accuracy',  Step.from_iteration(0, 400), 0.4)
     return logger
Beispiel #20
0
 def to_step(self, s):
     return Step.from_str(
         s,
         datasets.registry.iterations_per_epoch(self.desc.dataset_hparams))
Beispiel #21
0
    def test_create_restore_delete(self):
        # Create the hyperparameters and objects to save.
        hp = models.registry.get_default_hparams('cifar_resnet_20')
        model = models.registry.get(hp.model_hparams)
        optimizer = optimizers.get_optimizer(hp.training_hparams, model)
        dataloader = datasets.registry.get(hp.dataset_hparams)
        step = Step.from_epoch(13, 27, 400)

        # Run one step of SGD.
        examples, labels = next(iter(dataloader))
        optimizer.zero_grad()
        model.train()
        model.loss_criterion(model(examples), labels).backward()
        optimizer.step()

        # Create a fake logger.
        logger = MetricLogger()
        logger.add('test_accuracy', Step.from_epoch(0, 0, 400), 0.1)
        logger.add('test_accuracy', Step.from_epoch(10, 0, 400), 0.5)
        logger.add('test_accuracy', Step.from_epoch(100, 0, 400), 0.8)

        # Save a checkpoint.
        checkpointing.save_checkpoint_callback(self.root, step, model,
                                               optimizer, logger)
        self.assertTrue(os.path.exists(paths.checkpoint(self.root)))

        # Create new models.
        model2 = models.registry.get(hp.model_hparams)
        optimizer2 = optimizers.get_optimizer(hp.training_hparams, model)

        # Ensure the new model has different weights.
        sd1, sd2 = model.state_dict(), model2.state_dict()
        for k in model.prunable_layer_names:
            self.assertFalse(np.array_equal(sd1[k].numpy(), sd2[k].numpy()))

        self.assertIn('momentum_buffer',
                      optimizer.state[optimizer.param_groups[0]['params'][0]])
        self.assertNotIn(
            'momentum_buffer',
            optimizer2.state[optimizer.param_groups[0]['params'][0]])

        # Restore the checkpointt.
        step2, logger2 = checkpointing.restore_checkpoint(
            self.root, model2, optimizer2, 400)

        self.assertTrue(os.path.exists(paths.checkpoint(self.root)))
        self.assertEqual(step, step2)
        self.assertEqual(str(logger), str(logger2))

        # Ensure the new model is now the same.
        sd1, sd2 = model.state_dict(), model2.state_dict()
        self.assertEqual(set(sd1.keys()), set(sd2.keys()))
        for k in sd1:
            self.assertTrue(np.array_equal(sd1[k].numpy(), sd2[k].numpy()))

        # Ensure the new optimizer is now the same.
        mom1 = optimizer.state[optimizer.param_groups[0]['params']
                               [0]]['momentum_buffer']
        mom2 = optimizer2.state[optimizer.param_groups[0]['params']
                                [0]]['momentum_buffer']
        self.assertTrue(np.array_equal(mom1.numpy(), mom2.numpy()))
Beispiel #22
0
 def test_from_iteration(self):
     self.assertStepEquals(Step.from_iteration(0, 1), 0, 0, 0)
     self.assertStepEquals(Step.from_iteration(0, 100), 0, 0, 0)
     self.assertStepEquals(Step.from_iteration(10, 100), 10, 0, 10)
     self.assertStepEquals(Step.from_iteration(110, 100), 110, 1, 10)
     self.assertStepEquals(Step.from_iteration(11010, 100), 11010, 110, 10)
Beispiel #23
0
def distill(
    training_hparams: hparams.TrainingHparams,
    distill_hparams: hparams.DistillHparams,
    student: Model,
    teacher: Model,
    train_loader: DataLoader,
    output_location: str,
    callbacks: typing.List[typing.Callable] = [],
    start_step: Step = None,
    end_step: Step = None
):

    """The main training loop for this framework.

    Args:
      * training_hparams: The training hyperparameters whose schema is specified in hparams.py.
      * distll_hparams: The knowledge distillation hyperparameters whose schema is specified in hparams.py.
      * student: The student model to train. Must be a models.base.Model
      * teacher: The teacher model to distill the knowledge. Must be a models.base.Model
      * train_loader: The training data. Must be a datasets.base.DataLoader
      * output_location: The string path where all outputs should be stored.
      * callbacks: A list of functions that are called before each training step and once more
        after the last training step. Each function takes five arguments: the current step,
        the output location, the model, the optimizer, and the logger.
        Callbacks are used for running the test set, saving the logger, saving the state of the
        model, etc. The provide hooks into the training loop for customization so that the
        training loop itself can remain simple.
      * start_step: The step at which the training data and learning rate schedule should begin.
        Defaults to step 0.
      * end_step: The step at which training should cease. Otherwise, training will go for the
        full `training_hparams.training_steps` steps.
    """

    import torch
    import torch.nn as nn
    import torch.nn.functional as F

    # Create the output location if it doesn't already exist.
    if not get_platform().exists(output_location) and get_platform().is_primary_process:
        get_platform().makedirs(output_location)

    # Get the optimizer and learning rate schedule.
    student.to(get_platform().torch_device)
    teacher.to(get_platform().torch_device)
    optimizer = optimizers.get_optimizer(training_hparams, student)
    step_optimizer = optimizer
    lr_schedule = optimizers.get_lr_schedule(training_hparams, optimizer, train_loader.iterations_per_epoch)

    ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
    if distill_hparams.alpha_mse > 0.0:
        mse_loss_fct = nn.MSELoss(reduction='sum')
    if distill_hparams.alpha_cos > 0.0:
        cos_loss_fct = nn.CosineEmbeddingLoss(reduction='mean')

    # Adapt for FP16.
    if training_hparams.apex_fp16:
        if NO_APEX: raise ImportError('Must install nvidia apex to use this model.')
        (student, teacher), step_optimizer = apex.amp.initialize(
            [student, teacher], optimizer, loss_scale='dynamic', verbosity=0
        )

    # Handle parallelism if applicable.
    if get_platform().is_distributed:
        student = DistributedDataParallel(student, device_ids=[get_platform().rank])
        teacher = DistributedDataParallel(teacher, device_ids=[get_platform().rank])
    elif get_platform().is_parallel:
        student = DataParallel(student)
        teacher = DataParallel(teacher)

    # Get the random seed for the data order.
    data_order_seed = training_hparams.data_order_seed

    # Restore the model from a saved checkpoint if the checkpoint exists.
    cp_step, cp_logger = restore_checkpoint(output_location, student, optimizer, train_loader.iterations_per_epoch)
    start_step = cp_step or start_step or Step.zero(train_loader.iterations_per_epoch)
    logger = cp_logger or MetricLogger()
    with warnings.catch_warnings():  # Filter unnecessary warning.
        warnings.filterwarnings("ignore", category=UserWarning)
        for _ in range(start_step.iteration): lr_schedule.step()

    # Determine when to end training.
    end_step = end_step or Step.from_str(training_hparams.training_steps, train_loader.iterations_per_epoch)
    if end_step <= start_step: return

    # The training loop.
    for ep in range(start_step.ep, end_step.ep + 1):

        # Ensure the data order is different for each epoch.
        train_loader.shuffle(None if data_order_seed is None else (data_order_seed + ep))

        for it, (examples, labels) in enumerate(train_loader):

            # Advance the data loader until the start epoch and iteration.
            if ep == start_step.ep and it < start_step.it: continue

            # Run the callbacks.
            step = Step.from_epoch(ep, it, train_loader.iterations_per_epoch)
            for callback in callbacks: callback(output_location, step, student, optimizer, logger)

            # Exit at the end step.
            if ep == end_step.ep and it == end_step.it: return

            # Otherwise, train.
            examples = examples.to(device=get_platform().torch_device)
            labels = labels.to(device=get_platform().torch_device)

            loss = 0.0
            step_optimizer.zero_grad()
            student.train()
            teacher.eval()

            student_outputs = student(examples)
            with torch.no_grad():
                teacher_outputs = teacher(examples)

            s_logits = student_outputs
            t_logits = teacher_outputs

            # KL Divergence loss for the knowledge distillation
            loss_ce = ce_loss_fct(
                F.log_softmax(s_logits / distill_hparams.temperature, dim=-1),
                F.softmax(t_logits / distill_hparams.temperature, dim=-1),
            ) * distill_hparams.temperature**2
            loss += distill_hparams.alpha_ce * loss_ce

            if distill_hparams.alpha_cls > 0.0:
                loss_cls = student.loss_criterion(student_outputs, labels)
                loss += distill_hparams.alpha_cls * loss_cls

            if distill_hparams.alpha_mse > 0.0:
                loss_mse = mse_loss_fct(s_logits, t_logits) / s_logits.size(0)
                loss += distill_hparams.alpha_mse * loss_mse

            if training_hparams.apex_fp16:
                with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            # Step forward. Ignore extraneous warnings that the lr_schedule generates.
            step_optimizer.step()
            with warnings.catch_warnings():  # Filter unnecessary warning.
                warnings.filterwarnings("ignore", category=UserWarning)
                lr_schedule.step()

    get_platform().barrier()
Beispiel #24
0
 def end_step(self):
     iterations_per_epoch = datasets_registry.iterations_per_epoch(self.dataset_hparams)
     return Step.from_str(self.training_hparams.training_steps, iterations_per_epoch)
Beispiel #25
0
 def str_to_step(self, s: str) -> Step:
     return Step.from_str(s, datasets_registry.iterations_per_epoch(self.dataset_hparams))
Beispiel #26
0
    def test_from_str(self):
        self.assertStepEquals(Step.from_str('0it', 100), 0, 0, 0)
        self.assertStepEquals(Step.from_str('0ep', 100), 0, 0, 0)
        self.assertStepEquals(Step.from_str('0ep0it', 100), 0, 0, 0)

        self.assertStepEquals(Step.from_str('50it', 100), 50, 0, 50)
        self.assertStepEquals(Step.from_str('100it', 100), 100, 1, 0)
        self.assertStepEquals(Step.from_str('2021it', 100), 2021, 20, 21)

        self.assertStepEquals(Step.from_str('5ep', 100), 500, 5, 0)
        self.assertStepEquals(Step.from_str('20ep', 100), 2000, 20, 0)

        self.assertStepEquals(Step.from_str('5ep3it', 100), 503, 5, 3)

        with self.assertRaises(ValueError):
            Step.from_str('', 100)

        with self.assertRaises(ValueError):
            Step.from_str('0', 100)

        with self.assertRaises(ValueError):
            Step.from_str('it', 100)

        with self.assertRaises(ValueError):
            Step.from_str('ep', 100)

        with self.assertRaises(ValueError):
            Step.from_str('ep0', 100)

        with self.assertRaises(ValueError):
            Step.from_str('20it50ep', 100)
Beispiel #27
0
    def test_comparisons(self):
        self.assertLessEqual(Step.from_str('100it', 100),
                             Step.from_str('100it', 100))
        self.assertLess(Step.from_str('100it', 100),
                        Step.from_str('101it', 100))
        self.assertLessEqual(Step.from_str('100it', 100),
                             Step.from_str('101it', 100))

        self.assertGreaterEqual(Step.from_str('100it', 100),
                                Step.from_str('100it', 100))
        self.assertGreater(Step.from_str('102it', 100),
                           Step.from_str('101it', 100))
        self.assertGreaterEqual(Step.from_str('102it', 100),
                                Step.from_str('101it', 100))
Beispiel #28
0
def train(training_hparams: hparams.TrainingHparams,
          model: Model,
          train_loader: DataLoader,
          output_location: str,
          callbacks: typing.List[typing.Callable] = [],
          start_step: Step = None,
          end_step: Step = None):
    """The main training loop for this framework.

    Args:
      * training_hparams: The training hyperparameters whose schema is specified in hparams.py.
      * model: The model to train. Must be a models.base.Model
      * train_loader: The training data. Must be a datasets.base.DataLoader
      * output_location: The string path where all outputs should be stored.
      * callbacks: A list of functions that are called before each training step and once more
        after the last training step. Each function takes five arguments: the current step,
        the output location, the model, the optimizer, and the logger.
        Callbacks are used for running the test set, saving the logger, saving the state of the
        model, etc. The provide hooks into the training loop for customization so that the
        training loop itself can remain simple.
      * start_step: The step at which the training data and learning rate schedule should begin.
        Defaults to step 0.
      * end_step: The step at which training should cease. Otherwise, training will go for the
        full `training_hparams.training_steps` steps.
    """

    # Create the output location if it doesn't already exist.
    if not get_platform().exists(output_location) and get_platform(
    ).is_primary_process:
        get_platform().makedirs(output_location)

    # Get the optimizer and learning rate schedule.
    model.to(get_platform().torch_device)
    optimizer = optimizers.get_optimizer(training_hparams, model)
    step_optimizer = optimizer
    lr_schedule = optimizers.get_lr_schedule(training_hparams, optimizer,
                                             train_loader.iterations_per_epoch)

    # Adapt for FP16.
    if training_hparams.apex_fp16:
        if NO_APEX:
            raise ImportError('Must install nvidia apex to use this model.')
        model, step_optimizer = apex.amp.initialize(model,
                                                    optimizer,
                                                    loss_scale='dynamic',
                                                    verbosity=0)

    # Handle parallelism if applicable.
    if get_platform().is_distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[get_platform().rank])
    elif get_platform().is_parallel:
        model = DataParallel(model)

    # Get the random seed for the data order.
    data_order_seed = training_hparams.data_order_seed

    # Restore the model from a saved checkpoint if the checkpoint exists.
    cp_step, cp_logger = restore_checkpoint(output_location, model, optimizer,
                                            train_loader.iterations_per_epoch)
    start_step = cp_step or start_step or Step.zero(
        train_loader.iterations_per_epoch)
    logger = cp_logger or MetricLogger()
    with warnings.catch_warnings():  # Filter unnecessary warning.
        warnings.filterwarnings("ignore", category=UserWarning)
        for _ in range(start_step.iteration):
            lr_schedule.step()

    # Determine when to end training.
    end_step = end_step or Step.from_str(training_hparams.training_steps,
                                         train_loader.iterations_per_epoch)
    if end_step <= start_step: return

    # The training loop.
    for ep in range(start_step.ep, end_step.ep + 1):

        # Ensure the data order is different for each epoch.
        train_loader.shuffle(None if data_order_seed is None else (
            data_order_seed + ep))

        for it, (examples, labels) in enumerate(train_loader):

            # Advance the data loader until the start epoch and iteration.
            if ep == start_step.ep and it < start_step.it: continue

            # Run the callbacks.
            step = Step.from_epoch(ep, it, train_loader.iterations_per_epoch)
            for callback in callbacks:
                callback(output_location, step, model, optimizer, logger)

            # Exit at the end step.
            if ep == end_step.ep and it == end_step.it: return

            # Otherwise, train.
            examples = examples.to(device=get_platform().torch_device)
            labels = labels.to(device=get_platform().torch_device)

            step_optimizer.zero_grad()
            model.train()
            loss = model.loss_criterion(model(examples), labels)
            if training_hparams.apex_fp16:
                with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            # Step forward. Ignore extraneous warnings that the lr_schedule generates.
            step_optimizer.step()
            with warnings.catch_warnings():  # Filter unnecessary warning.
                warnings.filterwarnings("ignore", category=UserWarning)
                lr_schedule.step()

    get_platform().barrier()
Beispiel #29
0
 def test_zero(self):
     self.assertStepEquals(Step.zero(100), 0, 0, 0)
Beispiel #30
0
    def test_equal(self):
        self.assertEqual(Step.from_str('100it', 100),
                         Step.from_str('100it', 100))
        self.assertNotEqual(Step.from_str('101it', 100),
                            Step.from_str('100it', 100))
        self.assertEqual(Step.from_str('1ep', 100),
                         Step.from_str('100it', 100))
        self.assertEqual(Step.from_str('5ep6it', 100),
                         Step.from_str('506it', 100))

        with self.assertRaises(ValueError):
            Step.from_str('100it', 101) == Step.from_str('100it', 100)