示例#1
0
 def callback(output_location, step, model, optimizer, logger):
     if get_platform().is_primary_process:
         nonlocal time_of_last_call
         t = 0.0 if time_of_last_call is None else time.time() - time_of_last_call
         print(f'Ep {step.ep}\tIt {step.it}\tTime Elapsed {t:.2f}')
         time_of_last_call = time.time()
     get_platform().barrier()
示例#2
0
 def save(self, location):
     if not get_platform().is_primary_process: return
     if not get_platform().exists(location):
         get_platform().makedirs(location)
     with get_platform().open(paths.logger(location), 'w') as fp:
         fp.write("phase,iteration,loss,accuracy,elapsed_time\n")
         fp.write(str(self))
示例#3
0
    def eval_callback(output_location, step, model, optimizer, logger):
        model.eval()

        with torch.no_grad():
            output = model(None)
            energy_gap = model.loss_criterion(output, None) / len(output)

        # Share the information if distributed.
        if get_platform().is_distributed:
            raise NotImplementedError(
                'distribution evaluation for lossonly callback not implemented'
            )
            torch.distributed.reduce(energy_gap,
                                     0,
                                     op=torch.distributed.ReduceOp.SUM)

        energy_gap = energy_gap.cpu().item()

        if get_platform().is_primary_process:
            for idx, output_entry in enumerate(output):
                logger.add('gap_{}'.format(idx), step, output_entry)

            if verbose:
                nonlocal time_of_last_call
                elapsed = 0 if time_of_last_call is None else time.time(
                ) - time_of_last_call
                print('ep {:03d}\tit {:03d}\tgap {:.5f}\ttime {:.2f}s'.format(
                    step.ep, step.it, energy_gap, elapsed))
                time_of_last_call = time.time()
示例#4
0
def _get_samples(root, y_name, y_num):
    y_dir = os.path.join(root, y_name)
    if not get_platform().isdir(y_dir): return []
    output = [(os.path.join(y_dir, f), y_num)
              for f in get_platform().listdir(y_dir)
              if f.lower().endswith('jpeg')]
    return output
示例#5
0
    def _establish_initial_weights(self):
        location = self.desc.run_path(self.replicate, 0)
        if models.registry.exists(location, self.desc.train_start_step):
            if get_platform().is_primary_process:
                print('Initial weights loaded from {}'.format(
                    paths.model(location, self.desc.train_start_step)))
            return

        new_model = models.registry.get(self.desc.model_hparams,
                                        outputs=self.desc.train_outputs)

        # If there was a pretrained model, retrieve its final weights and adapt them for training.
        if self.desc.pretrain_training_hparams is not None:
            pretrain_loc = self.desc.run_path(self.replicate, 'pretrain')
            if get_platform().is_primary_process:
                print('Initial weights loaded from pretrained checkpoint {}'.
                      format(
                          paths.model(pretrain_loc,
                                      self.desc.pretrain_end_step)))
            old = models.registry.load(pretrain_loc,
                                       self.desc.pretrain_end_step,
                                       self.desc.model_hparams,
                                       self.desc.pretrain_outputs)
            state_dict = {k: v for k, v in old.state_dict().items()}

            # Select a new output layer if number of classes differs.
            if self.desc.train_outputs != self.desc.pretrain_outputs:
                state_dict.update({
                    k: new_model.state_dict()[k]
                    for k in new_model.output_layer_names
                })

            new_model.load_state_dict(state_dict)

        new_model.save(location, self.desc.train_start_step)
示例#6
0
 def download(self):
     if get_platform().is_primary_process:
         with get_platform().open(os.devnull, 'w') as fp:
             sys.stdout = fp
             super(CIFAR10, self).download()
             sys.stdout = sys.__stdout__
     get_platform().barrier()
示例#7
0
def restore_checkpoint(output_location, model, optimizer,
                       iterations_per_epoch):
    checkpoint_location = paths.checkpoint(output_location)
    if not get_platform().exists(checkpoint_location):
        return None, None
    checkpoint = get_platform().load_model(checkpoint_location,
                                           map_location=torch.device('cpu'))

    # Handle DataParallel.
    module_in_name = get_platform().is_parallel
    if module_in_name and not all(
            k.startswith('module.') for k in checkpoint['model_state_dict']):
        checkpoint['model_state_dict'] = {
            'module.' + k: v
            for k, v in checkpoint['model_state_dict'].items()
        }
    elif all(k.startswith('module.')
             for k in checkpoint['model_state_dict']) and not module_in_name:
        checkpoint['model_state_dict'] = {
            k[len('module.'):]: v
            for k, v in checkpoint['model_state_dict'].items()
        }

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    step = Step.from_epoch(checkpoint['ep'], checkpoint['it'],
                           iterations_per_epoch)
    logger = MetricLogger.create_from_string(checkpoint['logger'])

    return step, logger
示例#8
0
def save_checkpoint_callback(output_location, step, model, optimizer, logger):
    if get_platform().is_primary_process:
        get_platform().save_model({
            'ep': step.ep,
            'it': step.it,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'logger': str(logger),
        }, paths.checkpoint(output_location))
    get_platform().barrier()
示例#9
0
    def download(self):
        if self._check_integrity(): return

        if get_platform().is_primary_process:
            if not get_platform().exists(self.root):
                temporary_root = tempfile.mkdtemp()
                torchvision.datasets.utils.download_and_extract_archive(
                    self.url,
                    temporary_root,
                    filename=self.filename,
                    md5=self.tgz_md5)
                get_platform().copytree(temporary_root, self.root)
示例#10
0
    def save(self, output_location):
        if not get_platform().is_primary_process: return
        if not get_platform().exists(output_location):
            get_platform().makedirs(output_location)

        fields_dict = {f.name: getattr(self, f.name) for f in fields(self)}
        hparams_strs = [
            fields_dict[k].display for k in sorted(fields_dict)
            if isinstance(fields_dict[k], Hparams)
        ]
        with get_platform().open(paths.hparams(output_location), 'w') as fp:
            fp.write('\n'.join(hparams_strs))
示例#11
0
def _get_samples(root, y_name, y_num):
    y_dir = os.path.join(root, y_name)
    if not get_platform().isdir(y_dir): return []

    output = []

    for f in get_platform().listdir(y_dir):
        if get_platform().isdir(os.path.join(y_dir, f)):
            output += _get_samples(y_dir, f, y_num)
        elif f.lower().endswith('jpeg'):
            output.append((os.path.join(y_dir, f), y_num))

    return output
示例#12
0
    def __init__(self, dataset: Dataset, batch_size: int, num_workers: int, pin_memory: bool = True, force_sequential: bool = False):
        if get_platform().is_distributed and not force_sequential:
            self._sampler = DistributedShuffleSampler(dataset)
        else:
            self._sampler = ShuffleSampler(len(dataset))

        self._iterations_per_epoch = np.ceil(len(dataset) / batch_size).astype(int)

        if get_platform().is_distributed and not force_sequential:
            batch_size //= get_platform().world_size
            num_workers //= get_platform().world_size

        super(DataLoader, self).__init__(
            dataset, batch_size, sampler=self._sampler, num_workers=num_workers,
            pin_memory=pin_memory and get_platform().torch_device.type == 'cuda' and not force_sequential)
示例#13
0
    def eval_callback(output_location, step, model, optimizer, logger):
        example_count = torch.tensor(0.0).to(get_platform().torch_device)
        total_loss = torch.tensor(0.0).to(get_platform().torch_device)
        total_correct = torch.tensor(0.0).to(get_platform().torch_device)

        def correct(labels, outputs):
            return torch.sum(torch.eq(labels, output.argmax(dim=1)))

        model.eval()

        with torch.no_grad():
            for examples, labels, _, _ in loader:
                examples = examples.to(get_platform().torch_device)
                labels = labels.squeeze().to(get_platform().torch_device)
                output = model(examples)

                labels_size = torch.tensor(len(labels),
                                           device=get_platform().torch_device)
                example_count += labels_size
                total_loss += model.loss_criterion(output,
                                                   labels) * labels_size
                #print(labels, output)
                total_correct += correct(labels, output)
            print(total_correct)
        # Share the information if distributed.
        if get_platform().is_distributed:
            torch.distributed.reduce(total_loss,
                                     0,
                                     op=torch.distributed.ReduceOp.SUM)
            torch.distributed.reduce(total_correct,
                                     0,
                                     op=torch.distributed.ReduceOp.SUM)
            torch.distributed.reduce(example_count,
                                     0,
                                     op=torch.distributed.ReduceOp.SUM)

        total_loss = total_loss.cpu().item()
        total_correct = total_correct.cpu().item()
        example_count = example_count.cpu().item()

        if get_platform().is_primary_process:
            logger.add('{}_loss'.format(eval_name), step,
                       total_loss / example_count)
            logger.add('{}_accuracy'.format(eval_name), step,
                       total_correct / example_count)
            logger.add('{}_examples'.format(eval_name), step, example_count)

            if verbose:
                nonlocal time_of_last_call
                elapsed = 0 if time_of_last_call is None else time.time(
                ) - time_of_last_call
                print(
                    '{}\tep {:03d}\tit {:03d}\tloss {:.3f}\tacc {:.2f}%\tex {:d}\ttime {:.2f}s'
                    .format(eval_name, step.ep, step.it,
                            total_loss / example_count,
                            100 * total_correct / example_count,
                            int(example_count), elapsed))
                time_of_last_call = time.time()
示例#14
0
 def get_test_set():
     test_set = torchvision.datasets.MNIST(train=False,
                                           root=os.path.join(
                                               get_platform().dataset_root,
                                               'mnist'),
                                           download=True)
     return Dataset(test_set.data, test_set.targets)
示例#15
0
    def get_score(trained_model: models.base.Model,
                  current_mask: Mask,
                  prunable_tensors: set,
                  training_hparams: hparams.TrainingHparams,
                  dataset_hparams: hparams.DatasetHparams,
                  data_order_seed: int = None):
        pruned_model = PrunedModel(trained_model, current_mask).to(device=get_platform().torch_device)
        pruned_model._clear_grad()
        # pruned_model._enable_mask_gradient()

        # Calculate the gradient
        train.accumulate_gradient(
            training_hparams, pruned_model,
            dataset_hparams, data_order_seed, verbose=False
        )

        # Calculate the score
        scores = dict()
        for name, param in pruned_model.model.named_parameters():
            if hasattr(pruned_model, PrunedModel.to_mask_name(name)) and name in prunable_tensors:
                scores[name] = (param.grad * param).abs().clone().cpu().detach().numpy()

        score_vector = np.concatenate([v.reshape(-1) for k, v in scores.items()])
        norm = np.sum(score_vector)
        for k in scores.keys():
            scores[k] /= norm

        # Clean up
        pruned_model._clear_grad()
        # model._disable_mask_gradient()

        return scores
示例#16
0
    def run_path(self, replicate, experiment='main'):

        if not isinstance(replicate, int) or replicate <= 0:
            raise ValueError('Bad replicate: {}'.format(replicate))

        return os.path.join(get_platform().root, self.hashname,
                            f'replicate_{replicate}', experiment)
示例#17
0
def standard_train(model: Model,
                   output_location: str,
                   dataset_hparams: hparams.DatasetHparams,
                   training_hparams: hparams.TrainingHparams,
                   start_step: Step = None,
                   verbose: bool = True,
                   evaluate_every_epoch: bool = True):
    """Train using the standard callbacks according to the provided hparams."""

    # If the model file for the end of training already exists in this location, do not train.
    iterations_per_epoch = datasets.registry.iterations_per_epoch(
        dataset_hparams)
    train_end_step = Step.from_str(training_hparams.training_steps,
                                   iterations_per_epoch)
    if (models.registry.exists(output_location, train_end_step)
            and get_platform().exists(paths.logger(output_location))):
        return

    train_loader = datasets.registry.get(dataset_hparams, train=True)
    test_loader = datasets.registry.get(dataset_hparams, train=False)
    callbacks = standard_callbacks.standard_callbacks(
        training_hparams,
        train_loader,
        test_loader,
        start_step=start_step,
        verbose=verbose,
        evaluate_every_epoch=evaluate_every_epoch)
    train(training_hparams,
          model,
          train_loader,
          output_location,
          callbacks,
          start_step=start_step)
示例#18
0
def load_from_file(file_location: str,
                   model_hparams: ModelHparams,
                   outputs=None):
    state_dict = get_platform().load_model(file_location)
    model = get(model_hparams, outputs)
    model.load_state_dict(state_dict)
    return model
示例#19
0
    def __init__(self, loc: str, image_transforms, enumerate_examples=False):
        if loc.endswith('train'):
            super(Dataset, self).__init__(loc, image_transforms,
                                          enumerate_examples)
            return

        # Test and validation sets have an annotation file.
        annotations_file = os.path.join(
            loc, f'{os.path.basename(loc)}_annotations.txt')
        with get_platform().open(annotations_file) as fp:
            annotations = fp.read().split('\n')
        annotations = [(annotation.split()[0], annotation.split()[1])
                       for annotation in annotations if annotation.strip()]

        classes = sorted(list(set([c for _, c in annotations])))
        labels_dict = {c: i for i, c in enumerate(classes)}
        examples, labels = zip(*[(os.path.join(loc, 'images', f),
                                  labels_dict[c]) for f, c in annotations])

        super(imagenet.Dataset,
              self).__init__(np.array(examples),
                             np.array(labels),
                             image_transforms,
                             [self._normalization_transform()],
                             enumerate_examples=enumerate_examples)
示例#20
0
    def run(self):
        if self.verbose and get_platform().is_primary_process:
            print('='*82 + f'\nTraining a Model (Replicate {self.replicate})\n' + '-'*82)
            print(self.desc.display)
            print(f'Output Location: {self.desc.run_path(self.replicate)}' + '\n' + '='*82 + '\n')
        self.desc.save(self.desc.run_path(self.replicate))

        #TODO: make mask and model init paths configurable
        init_path = os.path.join(get_platform().root, 'resnet18_lth')
        model = models.registry.load(init_path, Step.from_str('2ep218it', 1000),
                                    self.desc.model_hparams, self.desc.train_outputs)
        pruned_model = PrunedModel(model, Mask.load(init_path))

        train.standard_train(
            #models.registry.get(self.desc.model_hparams), self.desc.run_path(self.replicate),
            pruned_model, self.desc.run_path(self.replicate),
            self.desc.dataset_hparams, self.desc.training_hparams, evaluate_every_epoch=self.evaluate_every_epoch)
示例#21
0
    def run_path(self, replicate: int, pruning_level: Union[str, int], experiment: str = 'main'):
        """The location where any run is stored."""

        if not isinstance(replicate, int) or replicate <= 0:
            raise ValueError('Bad replicate: {}'.format(replicate))

        return os.path.join(get_platform().root, self.hashname,
                            f'replicate_{replicate}', f'level_{pruning_level}', experiment)
示例#22
0
    def get_sparsity_ratio(self):
        if not get_platform().is_primary_process: return

        # Get sparsity ratio
        total_weights = np.sum([v.size for v in self.numpy().values()]).item()
        total_unpruned = np.sum([np.sum(v)
                                 for v in self.numpy().values()]).item()
        return float(total_unpruned) / total_weights
示例#23
0
 def load(output_location, suffix=''):
     if not Mask.exists(output_location, suffix):
         error_output_suffix = ' with suffix {}'.format(
             suffix) if suffix != '' else ''
         raise ValueError('Mask not found at {}{}'.format(
             output_location, error_output_suffix))
     return Mask(get_platform().load_model(
         paths.mask(output_location, suffix)))
示例#24
0
 def get_train_set(use_augmentation):
     # No augmentation for MNIST.
     train_set = torchvision.datasets.MNIST(train=True,
                                            root=os.path.join(
                                                get_platform().dataset_root,
                                                'mnist'),
                                            download=True)
     return Dataset(train_set.data, train_set.targets)
示例#25
0
 def get_test_set(enumerate_examples=False):
     test_set = CIFAR10(train=False,
                        root=os.path.join(get_platform().dataset_root,
                                          'cifar10'),
                        download=True)
     return Dataset(test_set.data,
                    np.array(test_set.targets),
                    enumerate_examples=enumerate_examples)
示例#26
0
    def _pretrain(self):
        location = self.desc.run_path(self.replicate, 'pretrain')
        if models.registry.exists(location, self.desc.pretrain_end_step): return

        if self.verbose and get_platform().is_primary_process: print('-'*82 + '\nPretraining\n' + '-'*82)
        model = models.registry.get(self.desc.model_hparams, outputs=self.desc.pretrain_outputs)
        train.standard_train(model, location, self.desc.pretrain_dataset_hparams, self.desc.pretrain_training_hparams,
                             verbose=self.verbose, evaluate_every_epoch=self.evaluate_every_epoch,
                             weight_save_steps=self.weight_save_steps)
示例#27
0
def load(save_location: str,
         save_step: Step,
         model_hparams: ModelHparams,
         outputs=None):
    state_dict = get_platform().load_model(
        paths.model(save_location, save_step))
    model = get(model_hparams, outputs)
    model.load_state_dict(state_dict)
    return model
示例#28
0
def get(dataset_hparams: DatasetHparams, train: bool = True):
    """Get the train or test set corresponding to the hyperparameters."""

    seed = dataset_hparams.transformation_seed or 0

    # Get the dataset itself.
    if dataset_hparams.dataset_name in registered_datasets:
        if train:
            use_augmentation = not dataset_hparams.do_not_augment
            dataset = registered_datasets[dataset_hparams.dataset_name].Dataset.get_train_set(use_augmentation)
        else:
            dataset = registered_datasets[dataset_hparams.dataset_name].Dataset.get_test_set()
    else:
        raise ValueError('No such dataset: {}'.format(dataset_hparams.dataset_name))

    # Transform the dataset.
    randomize = False
    if dataset_hparams.label_randomization_targets is not None:
        rand_targets = dataset_hparams.label_randomization_targets.split(',')
        rand_targets = [elem.strip() for elem in rand_targets]
        if train:
            randomize = 'train' in rand_targets
        else:
            randomize = 'test' in rand_targets
    
    if randomize and dataset_hparams.label_randomization_type is not None:
        _type = dataset_hparams.label_randomization_type
        if _type == 'shuffle':
            dataset.shuffle_labels(seed=seed)
        elif _type == 'corrupt':
            dataset.corrupt_labels(seed=seed, corrupt_prob=dataset_hparams.corruption_probability)
        elif _type == 'fraction':
            dataset.randomize_labels(seed=seed, fraction=dataset_hparams.random_labels_fraction)
        else:
            raise ValueError(f"'{dataset_hparams.label_randomization_type}' is not a valid randomization type.")

    if train and dataset_hparams.subsample_fraction is not None:
        dataset.subsample(seed=seed, fraction=dataset_hparams.subsample_fraction)

    if train and dataset_hparams.blur_factor is not None:
        if not isinstance(dataset, base.ImageDataset):
            raise ValueError('Can blur images.')
        else:
            dataset.blur(seed=seed, blur_factor=dataset_hparams.blur_factor)

    if dataset_hparams.unsupervised_labels is not None:
        if dataset_hparams.unsupervised_labels != 'rotation':
            raise ValueError('Unknown unsupervised labels: {}'.format(dataset_hparams.unsupervised_labels))
        elif not isinstance(dataset, base.ImageDataset):
            raise ValueError('Can only do unsupervised rotation to images.')
        else:
            dataset.unsupervised_rotation(seed=seed)

    # Create the loader.
    return registered_datasets[dataset_hparams.dataset_name].DataLoader(
        dataset, batch_size=dataset_hparams.batch_size, num_workers=get_platform().num_workers)
示例#29
0
def get(dataset_hparams: DatasetHparams,
        train: bool = True,
        force_sequential: bool = False,
        mask: np.ndarray = None,
        enumerate_examples: bool = False):
    """Get the train or test set corresponding to the hyperparameters."""

    seed = dataset_hparams.transformation_seed or 0

    # Get the dataset itself.
    if dataset_hparams.dataset_name in registered_datasets:
        use_augmentation = train and not dataset_hparams.do_not_augment
        if train:
            dataset = registered_datasets[
                dataset_hparams.dataset_name].Dataset.get_train_set(
                    use_augmentation, enumerate_examples)
        else:
            dataset = registered_datasets[
                dataset_hparams.dataset_name].Dataset.get_test_set(
                    enumerate_examples)
    else:
        raise ValueError('No such dataset: {}'.format(
            dataset_hparams.dataset_name))

    # Transform the dataset.
    if train and dataset_hparams.random_labels_fraction is not None:
        dataset.randomize_labels(
            seed=seed, fraction=dataset_hparams.random_labels_fraction)

    if train and dataset_hparams.subsample_fraction is not None:
        dataset.subsample(seed=seed,
                          fraction=dataset_hparams.subsample_fraction)

    if train and dataset_hparams.blur_factor is not None:
        if not isinstance(dataset, base.ImageDataset):
            raise ValueError('Can blur images.')
        else:
            dataset.blur(seed=seed, blur_factor=dataset_hparams.blur_factor)

    if dataset_hparams.unsupervised_labels is not None:
        if dataset_hparams.unsupervised_labels != 'rotation':
            raise ValueError('Unknown unsupervised labels: {}'.format(
                dataset_hparams.unsupervised_labels))
        elif not isinstance(dataset, base.ImageDataset):
            raise ValueError('Can only do unsupervised rotation to images.')
        else:
            dataset.unsupervised_rotation(seed=seed)

    if mask is not None: dataset.filter(mask)

    # Create the loader.
    return registered_datasets[dataset_hparams.dataset_name].DataLoader(
        dataset,
        batch_size=dataset_hparams.batch_size,
        num_workers=get_platform().num_workers,
        force_sequential=force_sequential)
示例#30
0
    def _estimate_theta(self, model):
        rps = []
        diffs = []

        def correct(labels, output):
            return torch.eq(labels, output.argmax(dim=1)).cpu()

        model.eval()
        with torch.no_grad():
            for examples, labels, idx, diff in self._train_loader:
                examples = examples.to(get_platform().torch_device)

                labels = labels.squeeze().to(get_platform().torch_device)
                output = model(examples.float())

                rps.extend(correct(labels, output))
                diffs.extend(diff)

        return scoring.calculate_theta(diffs, rps, 1000)