def callback(output_location, step, model, optimizer, logger): if get_platform().is_primary_process: nonlocal time_of_last_call t = 0.0 if time_of_last_call is None else time.time() - time_of_last_call print(f'Ep {step.ep}\tIt {step.it}\tTime Elapsed {t:.2f}') time_of_last_call = time.time() get_platform().barrier()
def save(self, location): if not get_platform().is_primary_process: return if not get_platform().exists(location): get_platform().makedirs(location) with get_platform().open(paths.logger(location), 'w') as fp: fp.write("phase,iteration,loss,accuracy,elapsed_time\n") fp.write(str(self))
def eval_callback(output_location, step, model, optimizer, logger): model.eval() with torch.no_grad(): output = model(None) energy_gap = model.loss_criterion(output, None) / len(output) # Share the information if distributed. if get_platform().is_distributed: raise NotImplementedError( 'distribution evaluation for lossonly callback not implemented' ) torch.distributed.reduce(energy_gap, 0, op=torch.distributed.ReduceOp.SUM) energy_gap = energy_gap.cpu().item() if get_platform().is_primary_process: for idx, output_entry in enumerate(output): logger.add('gap_{}'.format(idx), step, output_entry) if verbose: nonlocal time_of_last_call elapsed = 0 if time_of_last_call is None else time.time( ) - time_of_last_call print('ep {:03d}\tit {:03d}\tgap {:.5f}\ttime {:.2f}s'.format( step.ep, step.it, energy_gap, elapsed)) time_of_last_call = time.time()
def _get_samples(root, y_name, y_num): y_dir = os.path.join(root, y_name) if not get_platform().isdir(y_dir): return [] output = [(os.path.join(y_dir, f), y_num) for f in get_platform().listdir(y_dir) if f.lower().endswith('jpeg')] return output
def _establish_initial_weights(self): location = self.desc.run_path(self.replicate, 0) if models.registry.exists(location, self.desc.train_start_step): if get_platform().is_primary_process: print('Initial weights loaded from {}'.format( paths.model(location, self.desc.train_start_step))) return new_model = models.registry.get(self.desc.model_hparams, outputs=self.desc.train_outputs) # If there was a pretrained model, retrieve its final weights and adapt them for training. if self.desc.pretrain_training_hparams is not None: pretrain_loc = self.desc.run_path(self.replicate, 'pretrain') if get_platform().is_primary_process: print('Initial weights loaded from pretrained checkpoint {}'. format( paths.model(pretrain_loc, self.desc.pretrain_end_step))) old = models.registry.load(pretrain_loc, self.desc.pretrain_end_step, self.desc.model_hparams, self.desc.pretrain_outputs) state_dict = {k: v for k, v in old.state_dict().items()} # Select a new output layer if number of classes differs. if self.desc.train_outputs != self.desc.pretrain_outputs: state_dict.update({ k: new_model.state_dict()[k] for k in new_model.output_layer_names }) new_model.load_state_dict(state_dict) new_model.save(location, self.desc.train_start_step)
def download(self): if get_platform().is_primary_process: with get_platform().open(os.devnull, 'w') as fp: sys.stdout = fp super(CIFAR10, self).download() sys.stdout = sys.__stdout__ get_platform().barrier()
def restore_checkpoint(output_location, model, optimizer, iterations_per_epoch): checkpoint_location = paths.checkpoint(output_location) if not get_platform().exists(checkpoint_location): return None, None checkpoint = get_platform().load_model(checkpoint_location, map_location=torch.device('cpu')) # Handle DataParallel. module_in_name = get_platform().is_parallel if module_in_name and not all( k.startswith('module.') for k in checkpoint['model_state_dict']): checkpoint['model_state_dict'] = { 'module.' + k: v for k, v in checkpoint['model_state_dict'].items() } elif all(k.startswith('module.') for k in checkpoint['model_state_dict']) and not module_in_name: checkpoint['model_state_dict'] = { k[len('module.'):]: v for k, v in checkpoint['model_state_dict'].items() } model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) step = Step.from_epoch(checkpoint['ep'], checkpoint['it'], iterations_per_epoch) logger = MetricLogger.create_from_string(checkpoint['logger']) return step, logger
def save_checkpoint_callback(output_location, step, model, optimizer, logger): if get_platform().is_primary_process: get_platform().save_model({ 'ep': step.ep, 'it': step.it, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'logger': str(logger), }, paths.checkpoint(output_location)) get_platform().barrier()
def download(self): if self._check_integrity(): return if get_platform().is_primary_process: if not get_platform().exists(self.root): temporary_root = tempfile.mkdtemp() torchvision.datasets.utils.download_and_extract_archive( self.url, temporary_root, filename=self.filename, md5=self.tgz_md5) get_platform().copytree(temporary_root, self.root)
def save(self, output_location): if not get_platform().is_primary_process: return if not get_platform().exists(output_location): get_platform().makedirs(output_location) fields_dict = {f.name: getattr(self, f.name) for f in fields(self)} hparams_strs = [ fields_dict[k].display for k in sorted(fields_dict) if isinstance(fields_dict[k], Hparams) ] with get_platform().open(paths.hparams(output_location), 'w') as fp: fp.write('\n'.join(hparams_strs))
def _get_samples(root, y_name, y_num): y_dir = os.path.join(root, y_name) if not get_platform().isdir(y_dir): return [] output = [] for f in get_platform().listdir(y_dir): if get_platform().isdir(os.path.join(y_dir, f)): output += _get_samples(y_dir, f, y_num) elif f.lower().endswith('jpeg'): output.append((os.path.join(y_dir, f), y_num)) return output
def __init__(self, dataset: Dataset, batch_size: int, num_workers: int, pin_memory: bool = True, force_sequential: bool = False): if get_platform().is_distributed and not force_sequential: self._sampler = DistributedShuffleSampler(dataset) else: self._sampler = ShuffleSampler(len(dataset)) self._iterations_per_epoch = np.ceil(len(dataset) / batch_size).astype(int) if get_platform().is_distributed and not force_sequential: batch_size //= get_platform().world_size num_workers //= get_platform().world_size super(DataLoader, self).__init__( dataset, batch_size, sampler=self._sampler, num_workers=num_workers, pin_memory=pin_memory and get_platform().torch_device.type == 'cuda' and not force_sequential)
def eval_callback(output_location, step, model, optimizer, logger): example_count = torch.tensor(0.0).to(get_platform().torch_device) total_loss = torch.tensor(0.0).to(get_platform().torch_device) total_correct = torch.tensor(0.0).to(get_platform().torch_device) def correct(labels, outputs): return torch.sum(torch.eq(labels, output.argmax(dim=1))) model.eval() with torch.no_grad(): for examples, labels, _, _ in loader: examples = examples.to(get_platform().torch_device) labels = labels.squeeze().to(get_platform().torch_device) output = model(examples) labels_size = torch.tensor(len(labels), device=get_platform().torch_device) example_count += labels_size total_loss += model.loss_criterion(output, labels) * labels_size #print(labels, output) total_correct += correct(labels, output) print(total_correct) # Share the information if distributed. if get_platform().is_distributed: torch.distributed.reduce(total_loss, 0, op=torch.distributed.ReduceOp.SUM) torch.distributed.reduce(total_correct, 0, op=torch.distributed.ReduceOp.SUM) torch.distributed.reduce(example_count, 0, op=torch.distributed.ReduceOp.SUM) total_loss = total_loss.cpu().item() total_correct = total_correct.cpu().item() example_count = example_count.cpu().item() if get_platform().is_primary_process: logger.add('{}_loss'.format(eval_name), step, total_loss / example_count) logger.add('{}_accuracy'.format(eval_name), step, total_correct / example_count) logger.add('{}_examples'.format(eval_name), step, example_count) if verbose: nonlocal time_of_last_call elapsed = 0 if time_of_last_call is None else time.time( ) - time_of_last_call print( '{}\tep {:03d}\tit {:03d}\tloss {:.3f}\tacc {:.2f}%\tex {:d}\ttime {:.2f}s' .format(eval_name, step.ep, step.it, total_loss / example_count, 100 * total_correct / example_count, int(example_count), elapsed)) time_of_last_call = time.time()
def get_test_set(): test_set = torchvision.datasets.MNIST(train=False, root=os.path.join( get_platform().dataset_root, 'mnist'), download=True) return Dataset(test_set.data, test_set.targets)
def get_score(trained_model: models.base.Model, current_mask: Mask, prunable_tensors: set, training_hparams: hparams.TrainingHparams, dataset_hparams: hparams.DatasetHparams, data_order_seed: int = None): pruned_model = PrunedModel(trained_model, current_mask).to(device=get_platform().torch_device) pruned_model._clear_grad() # pruned_model._enable_mask_gradient() # Calculate the gradient train.accumulate_gradient( training_hparams, pruned_model, dataset_hparams, data_order_seed, verbose=False ) # Calculate the score scores = dict() for name, param in pruned_model.model.named_parameters(): if hasattr(pruned_model, PrunedModel.to_mask_name(name)) and name in prunable_tensors: scores[name] = (param.grad * param).abs().clone().cpu().detach().numpy() score_vector = np.concatenate([v.reshape(-1) for k, v in scores.items()]) norm = np.sum(score_vector) for k in scores.keys(): scores[k] /= norm # Clean up pruned_model._clear_grad() # model._disable_mask_gradient() return scores
def run_path(self, replicate, experiment='main'): if not isinstance(replicate, int) or replicate <= 0: raise ValueError('Bad replicate: {}'.format(replicate)) return os.path.join(get_platform().root, self.hashname, f'replicate_{replicate}', experiment)
def standard_train(model: Model, output_location: str, dataset_hparams: hparams.DatasetHparams, training_hparams: hparams.TrainingHparams, start_step: Step = None, verbose: bool = True, evaluate_every_epoch: bool = True): """Train using the standard callbacks according to the provided hparams.""" # If the model file for the end of training already exists in this location, do not train. iterations_per_epoch = datasets.registry.iterations_per_epoch( dataset_hparams) train_end_step = Step.from_str(training_hparams.training_steps, iterations_per_epoch) if (models.registry.exists(output_location, train_end_step) and get_platform().exists(paths.logger(output_location))): return train_loader = datasets.registry.get(dataset_hparams, train=True) test_loader = datasets.registry.get(dataset_hparams, train=False) callbacks = standard_callbacks.standard_callbacks( training_hparams, train_loader, test_loader, start_step=start_step, verbose=verbose, evaluate_every_epoch=evaluate_every_epoch) train(training_hparams, model, train_loader, output_location, callbacks, start_step=start_step)
def load_from_file(file_location: str, model_hparams: ModelHparams, outputs=None): state_dict = get_platform().load_model(file_location) model = get(model_hparams, outputs) model.load_state_dict(state_dict) return model
def __init__(self, loc: str, image_transforms, enumerate_examples=False): if loc.endswith('train'): super(Dataset, self).__init__(loc, image_transforms, enumerate_examples) return # Test and validation sets have an annotation file. annotations_file = os.path.join( loc, f'{os.path.basename(loc)}_annotations.txt') with get_platform().open(annotations_file) as fp: annotations = fp.read().split('\n') annotations = [(annotation.split()[0], annotation.split()[1]) for annotation in annotations if annotation.strip()] classes = sorted(list(set([c for _, c in annotations]))) labels_dict = {c: i for i, c in enumerate(classes)} examples, labels = zip(*[(os.path.join(loc, 'images', f), labels_dict[c]) for f, c in annotations]) super(imagenet.Dataset, self).__init__(np.array(examples), np.array(labels), image_transforms, [self._normalization_transform()], enumerate_examples=enumerate_examples)
def run(self): if self.verbose and get_platform().is_primary_process: print('='*82 + f'\nTraining a Model (Replicate {self.replicate})\n' + '-'*82) print(self.desc.display) print(f'Output Location: {self.desc.run_path(self.replicate)}' + '\n' + '='*82 + '\n') self.desc.save(self.desc.run_path(self.replicate)) #TODO: make mask and model init paths configurable init_path = os.path.join(get_platform().root, 'resnet18_lth') model = models.registry.load(init_path, Step.from_str('2ep218it', 1000), self.desc.model_hparams, self.desc.train_outputs) pruned_model = PrunedModel(model, Mask.load(init_path)) train.standard_train( #models.registry.get(self.desc.model_hparams), self.desc.run_path(self.replicate), pruned_model, self.desc.run_path(self.replicate), self.desc.dataset_hparams, self.desc.training_hparams, evaluate_every_epoch=self.evaluate_every_epoch)
def run_path(self, replicate: int, pruning_level: Union[str, int], experiment: str = 'main'): """The location where any run is stored.""" if not isinstance(replicate, int) or replicate <= 0: raise ValueError('Bad replicate: {}'.format(replicate)) return os.path.join(get_platform().root, self.hashname, f'replicate_{replicate}', f'level_{pruning_level}', experiment)
def get_sparsity_ratio(self): if not get_platform().is_primary_process: return # Get sparsity ratio total_weights = np.sum([v.size for v in self.numpy().values()]).item() total_unpruned = np.sum([np.sum(v) for v in self.numpy().values()]).item() return float(total_unpruned) / total_weights
def load(output_location, suffix=''): if not Mask.exists(output_location, suffix): error_output_suffix = ' with suffix {}'.format( suffix) if suffix != '' else '' raise ValueError('Mask not found at {}{}'.format( output_location, error_output_suffix)) return Mask(get_platform().load_model( paths.mask(output_location, suffix)))
def get_train_set(use_augmentation): # No augmentation for MNIST. train_set = torchvision.datasets.MNIST(train=True, root=os.path.join( get_platform().dataset_root, 'mnist'), download=True) return Dataset(train_set.data, train_set.targets)
def get_test_set(enumerate_examples=False): test_set = CIFAR10(train=False, root=os.path.join(get_platform().dataset_root, 'cifar10'), download=True) return Dataset(test_set.data, np.array(test_set.targets), enumerate_examples=enumerate_examples)
def _pretrain(self): location = self.desc.run_path(self.replicate, 'pretrain') if models.registry.exists(location, self.desc.pretrain_end_step): return if self.verbose and get_platform().is_primary_process: print('-'*82 + '\nPretraining\n' + '-'*82) model = models.registry.get(self.desc.model_hparams, outputs=self.desc.pretrain_outputs) train.standard_train(model, location, self.desc.pretrain_dataset_hparams, self.desc.pretrain_training_hparams, verbose=self.verbose, evaluate_every_epoch=self.evaluate_every_epoch, weight_save_steps=self.weight_save_steps)
def load(save_location: str, save_step: Step, model_hparams: ModelHparams, outputs=None): state_dict = get_platform().load_model( paths.model(save_location, save_step)) model = get(model_hparams, outputs) model.load_state_dict(state_dict) return model
def get(dataset_hparams: DatasetHparams, train: bool = True): """Get the train or test set corresponding to the hyperparameters.""" seed = dataset_hparams.transformation_seed or 0 # Get the dataset itself. if dataset_hparams.dataset_name in registered_datasets: if train: use_augmentation = not dataset_hparams.do_not_augment dataset = registered_datasets[dataset_hparams.dataset_name].Dataset.get_train_set(use_augmentation) else: dataset = registered_datasets[dataset_hparams.dataset_name].Dataset.get_test_set() else: raise ValueError('No such dataset: {}'.format(dataset_hparams.dataset_name)) # Transform the dataset. randomize = False if dataset_hparams.label_randomization_targets is not None: rand_targets = dataset_hparams.label_randomization_targets.split(',') rand_targets = [elem.strip() for elem in rand_targets] if train: randomize = 'train' in rand_targets else: randomize = 'test' in rand_targets if randomize and dataset_hparams.label_randomization_type is not None: _type = dataset_hparams.label_randomization_type if _type == 'shuffle': dataset.shuffle_labels(seed=seed) elif _type == 'corrupt': dataset.corrupt_labels(seed=seed, corrupt_prob=dataset_hparams.corruption_probability) elif _type == 'fraction': dataset.randomize_labels(seed=seed, fraction=dataset_hparams.random_labels_fraction) else: raise ValueError(f"'{dataset_hparams.label_randomization_type}' is not a valid randomization type.") if train and dataset_hparams.subsample_fraction is not None: dataset.subsample(seed=seed, fraction=dataset_hparams.subsample_fraction) if train and dataset_hparams.blur_factor is not None: if not isinstance(dataset, base.ImageDataset): raise ValueError('Can blur images.') else: dataset.blur(seed=seed, blur_factor=dataset_hparams.blur_factor) if dataset_hparams.unsupervised_labels is not None: if dataset_hparams.unsupervised_labels != 'rotation': raise ValueError('Unknown unsupervised labels: {}'.format(dataset_hparams.unsupervised_labels)) elif not isinstance(dataset, base.ImageDataset): raise ValueError('Can only do unsupervised rotation to images.') else: dataset.unsupervised_rotation(seed=seed) # Create the loader. return registered_datasets[dataset_hparams.dataset_name].DataLoader( dataset, batch_size=dataset_hparams.batch_size, num_workers=get_platform().num_workers)
def get(dataset_hparams: DatasetHparams, train: bool = True, force_sequential: bool = False, mask: np.ndarray = None, enumerate_examples: bool = False): """Get the train or test set corresponding to the hyperparameters.""" seed = dataset_hparams.transformation_seed or 0 # Get the dataset itself. if dataset_hparams.dataset_name in registered_datasets: use_augmentation = train and not dataset_hparams.do_not_augment if train: dataset = registered_datasets[ dataset_hparams.dataset_name].Dataset.get_train_set( use_augmentation, enumerate_examples) else: dataset = registered_datasets[ dataset_hparams.dataset_name].Dataset.get_test_set( enumerate_examples) else: raise ValueError('No such dataset: {}'.format( dataset_hparams.dataset_name)) # Transform the dataset. if train and dataset_hparams.random_labels_fraction is not None: dataset.randomize_labels( seed=seed, fraction=dataset_hparams.random_labels_fraction) if train and dataset_hparams.subsample_fraction is not None: dataset.subsample(seed=seed, fraction=dataset_hparams.subsample_fraction) if train and dataset_hparams.blur_factor is not None: if not isinstance(dataset, base.ImageDataset): raise ValueError('Can blur images.') else: dataset.blur(seed=seed, blur_factor=dataset_hparams.blur_factor) if dataset_hparams.unsupervised_labels is not None: if dataset_hparams.unsupervised_labels != 'rotation': raise ValueError('Unknown unsupervised labels: {}'.format( dataset_hparams.unsupervised_labels)) elif not isinstance(dataset, base.ImageDataset): raise ValueError('Can only do unsupervised rotation to images.') else: dataset.unsupervised_rotation(seed=seed) if mask is not None: dataset.filter(mask) # Create the loader. return registered_datasets[dataset_hparams.dataset_name].DataLoader( dataset, batch_size=dataset_hparams.batch_size, num_workers=get_platform().num_workers, force_sequential=force_sequential)
def _estimate_theta(self, model): rps = [] diffs = [] def correct(labels, output): return torch.eq(labels, output.argmax(dim=1)).cpu() model.eval() with torch.no_grad(): for examples, labels, idx, diff in self._train_loader: examples = examples.to(get_platform().torch_device) labels = labels.squeeze().to(get_platform().torch_device) output = model(examples.float()) rps.extend(correct(labels, output)) diffs.extend(diff) return scoring.calculate_theta(diffs, rps, 1000)