def __call__(self, module, module_in): print('Module inputs before') print(module_in) for p_name, a in zip(self.param_names, self.amount): prune.random_unstructured(module, p_name, amount=a) print('Module buffers') print(list(module.named_buffers())) print('Module params') print(list(module.named_parameters())) print('Module inputs') print(module_in)
def load(self, filename): """ Loads a trained model from files. Parameters ---------- filename : str Path to the files. '_settings.json' and '_state_dict.pl' will be added. Returns ------- None """ logger.info("Loading model from %s", filename) # Load settings and create model logger.debug("Loading settings from %s_settings.json", filename) with open(filename + "_settings.json", "r") as f: settings = json.load(f) self._unwrap_settings(settings) self._create_model() # Load scaling try: self.x_scaling_means = np.load(filename + "_x_means.npy") self.x_scaling_stds = np.load(filename + "_x_stds.npy") logger.debug( " Found input scaling information: means %s, stds %s", self.x_scaling_means, self.x_scaling_stds ) except FileNotFoundError: logger.warning("Scaling information not found in %s", filename) self.x_scaling_means = None self.x_scaling_stds = None module = self.model.ll1 print("before pruning") print(list(module.named_parameters())) print(list(module.named_buffers())) prune.random_unstructured(module, name="weight", amount=0.3) print("before pruning") print(list(module.named_parameters())) print(list(module.named_buffers())) print(self.model.state_dict().keys()) # Load state dict logger.debug("Loading state dictionary from %s_state_dict.pt", filename) print(self.model.state_dict().keys()) self.model.load_state_dict(torch.load(filename + "_state_dict.pt", map_location="cpu"))
def random_prune_model(model, ratio): module = model.fc[0] pruned_w = prune.random_unstructured( module, name='weight', amount=ratio ) #prune.custom_from_mask(module, name='weight', mask=mag_w_mask) model.fc[0] = pruned_w return model
def pruner(model, amount, random=False): """ (amount) total amount of desired sparsity """ for name, module in model.named_modules(): # prune declared amount of connections in all 2D-conv & Linear layers if isinstance(module, torch.nn.Conv2d) or isinstance( module, torch.nn.Linear): if random: prune.random_unstructured(module, name='weight', amount=amount) else: prune.l1_unstructured(module, name='weight', amount=amount) #prune.remove(module, 'weight') # make it permanent return model
def prune_darts(model, pruning_percentage): for modules in model.children(): if not isinstance(modules, nn.AdaptiveAvgPool3d): for module in modules: if not isinstance(module, Cell): # print(list(module.named_parameters())) prune.random_unstructured(module, name="weight", amount=pruning_percentage) # print(list(module.named_parameters())) # print("Not Cell") else: # print(module) for cell_module in module.children(): # print("Cell") if isinstance(cell_module, ReLUConvBN): # print("ReLUConvBN") for ReLU_List in cell_module.children(): for ReLU_module in ReLU_List: if isinstance(ReLU_module, nn.Conv2d): # if ReLU_module.kernel_size[0] == 1: # continue prune.l1_unstructured( ReLU_module, name="weight", amount=pruning_percentage) elif isinstance(cell_module, nn.ModuleList): # print("nn.ModuleList") for innerModules in cell_module.children(): for seqItem in innerModules.children(): for layer in seqItem: if isinstance(layer, nn.Conv2d): # if layer.kernel_size[0] == 1: # continue prune.l1_unstructured( layer, name="weight", amount=pruning_percentage)
def prune_random_unstructured(net_creator, imagenet_path, batch_size): for i in range(1, 5): for idx in range(2, 18): net = net_creator() module = net.features[idx].conv amount = i / 10 prune.random_unstructured(module[0][0], name='weight', amount=amount) prune.random_unstructured(module[1][0], name='weight', amount=amount) prune.remove(module[0][0], 'weight') prune.remove(module[1][0], 'weight') time0 = time.time() result = test_imagenet(net, imagenet_path, batch_size) speed = 1000 / (time.time() - time0) with open("log.txt", 'a+') as file: file.write( "method: rand_unstr - module: {} - prune amount: {:.0%} - accuracy: {:.2f} - speed: {:.2f} \n" .format(idx, amount, result, speed))
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42) x_train = torch.FloatTensor(x_train) y_train = torch.FloatTensor(blob_label(y_train, 0, [0])) y_train = torch.FloatTensor(blob_label(y_train, 1, [1, 2, 3])) x_test = torch.FloatTensor(x_test) y_test = torch.FloatTensor(blob_label(y_test, 0, [0])) y_test = torch.FloatTensor(blob_label(y_test, 1, [1, 2, 3])) from torch.nn.utils import prune """ model = Feedforward(1024, 512) # Prune weight prune.random_unstructured(model.fc1, name="weight", amount=0.95) criterion = torch.nn.BCELoss() optimizer = torch.optim.SGD(model.parameters(), lr = 0.01) model.eval() y_pred = model(x_test) before_train = criterion(y_pred.squeeze(), y_test) print('Test loss before training' , before_train.item()) model.train()
import torch.nn.utils.prune as prune import torch.nn.functional as F device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() # 1 input image channel, 6 output channels, 3x3 square conv kernel self.conv1 = nn.Conv2d(1, 6, 3) self.conv2 = nn.Conv2d(6, 16, 3) self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) x = F.max_pool2d(F.relu(self.conv2(x)), 2) x = x.view(-1, int(x.nelement() / x.shape[0])) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model = LeNet().to(device=device) module = model.conv1 print(list(module.named_parameters())) prune.random_unstructured(module, name="weight", amount=0.3) print(list(module.named_parameters()))
def perform_pruning(self): prune.random_unstructured(module=self.fc1, name='weight', amount=0.2)
# NOTE: # prune.random_unstructured # prune.l1_unstructured # prune.random_structured # prune.ln_structured # unstrcutured for weight pruning and structured for channel pruning # Iteration over named_parameters for name, mod in pruned_net.named_modules(): # if hasattr(mod, 'bias'): # prune.random_unstructured(mod, name="bias", amount=0.5) # elif hasattr(mod, 'weight'): # prune.random_unstructured(mod, name="weight", amount=0.5) # else: if isinstance(mod, torch.nn.Conv2d): prune.random_unstructured(mod, name="weight", amount=0.5) elif isinstance(mod, torch.nn.BatchNorm2d): prune.random_unstructured(mod, name="weight", amount=0.5) prune.random_unstructured(mod, name="bias", amount=0.5) elif isinstance(mod, torch.nn.Linear): prune.random_unstructured(mod, name="weight", amount=0.5) prune.random_unstructured(mod, name="bias", amount=0.5) # Apply prune for name, mod in pruned_net.named_modules(): if isinstance(mod, torch.nn.Linear): prune.remove(mod, name='weight') prune.remove(mod, name='bias') elif isinstance(mod, torch.nn.BatchNorm2d): prune.remove(mod, name='weight')
def prune_model(model, args, type): if type == 'train': if args.prune_train > 0: print('Pruning {} %'.format(args.prune_train * 100)) if args.prune == 'global': print('Global Pruning') elif args.prune == 'l1': print('L1 Pruning') elif args.prune == 'random': print('Random Pruning') parameters_to_prune = [] for mod_name, module in list(model.named_modules()): # for name, value in list(module.named_parameters()): if hasattr(module, 'weight') or hasattr(module, 'weight_mask'): print(mod_name) name = 'weight' print('weights before {:.3f}%'.format( float(torch.sum(module.weight == 0)) * 100 / float(module.weight.nelement()))) if args.prune == 'global': parameters_to_prune.append((module, name)) elif args.prune == 'l1': prune.l1_unstructured(module, name=name, amount=args.prune_train) elif args.prune == 'random': prune.random_unstructured(module, name=name, amount=args.prune_train) print('weights after {:.3f}%'.format( float(torch.sum(module.weight == 0)) * 100 / float(module.weight.nelement()))) # if prune.is_pruned(module): # prune.remove(module, 'weight') # print('removed',mod_name) if args.prune == 'global': prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=args.prune_train) elif type == 'eval': if args.prune_eval > 0: print('Pruning {} %'.format(args.prune_eval * 100)) if args.prune == 'global': print('Global Pruning') elif args.prune == 'l1': print('L1 Pruning') elif args.prune == 'random': print('Random Pruning') parameters_to_prune = [] for mod_name, module in list(model.named_modules()): # for name, value in list(module.named_parameters()): if hasattr(module, 'weight') or hasattr(module, 'weight_mask'): print(mod_name) name = 'weight' print('weights before {:.3f}%'.format( float(torch.sum(module.weight == 0)) * 100 / float(module.weight.nelement()))) if args.prune == 'global': parameters_to_prune.append((module, name)) elif args.prune == 'l1': prune.l1_unstructured(module, name=name, amount=args.prune_eval) elif args.prune == 'random': prune.random_unstructured(module, name=name, amount=args.prune_eval) print('weights after {:.3f}%'.format( float(torch.sum(module.weight == 0)) * 100 / float(module.weight.nelement()))) # if prune.is_pruned(module): # prune.remove(module, 'weight') # print('removed',mod_name) if args.prune == 'global': prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=args.prune_eval) countZeroWeights(model)
def train_step(self, samples, raise_oom=False): # prune accordingly def to_prune_or_not_to_prune(): if self.get_num_updates() >= self.prune_start_step and \ (self.get_num_updates() - self.prune_start_step) % self.pruning_interval == 0 and \ abs(self.sparsity - self.target_sparsity) > 5e-3 and \ self.sparsity < self.target_sparsity: return True return False def get_pruning_amount(train_step): n = self.num_pruning_steps dt = self.pruning_interval t_0 = self.prune_start_step s_i = 0. s_f = self.target_sparsity s_t = s_f + (s_i - s_f) * (1 - (train_step - t_0) / (n * dt))**3 return s_t if to_prune_or_not_to_prune(): # Determine how much to prune target_sparsity = get_pruning_amount(self.get_num_updates()) pruning_amount = (target_sparsity - self.sparsity) / (1. - self.sparsity) if pruning_amount > 0: for module, _ in self.get_modules_to_prune(): if self.prune_type == 'magnitude': prune.l1_unstructured(module, name="weight", amount=pruning_amount) else: prune.random_unstructured(module, name="weight", amount=pruning_amount) """ prune.global_unstructured( self.get_modules_to_prune(), pruning_method=prune.L1Unstructured if self.prune_type == 'magnitude' \ else prune.RandomUnstructured, amount=pruning_amount) """ logger.info( "NOTE: Weights pruned, type: {}, amount: {}, sparsity: {}". format(self.prune_type, pruning_amount, self.sparsity)) """Do forward, backward and parameter update.""" if self._dummy_batch == "DUMMY": self._dummy_batch = samples[0] self._set_seed() self.model.train() self.criterion.train() self.zero_grad() metrics.log_start_time("train_wall", priority=800, round=0) # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: is_dummy_batch = False def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if (self.data_parallel_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): # forward and backward loss, sample_size_i, logging_output = self.task.train_step( sample=sample, model=self.model, criterion=self.criterion, optimizer=self.optimizer, update_num=self.get_num_updates(), ignore_grad=is_dummy_batch) del loss logging_outputs.append(logging_output) sample_size += sample_size_i # emptying the CUDA cache after the first step can # reduce the chance of OOM if self.cuda and self.get_num_updates() == 0: torch.cuda.empty_cache() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if raise_oom: raise e logger.warning( "attempting to recover from OOM in forward/backward pass" ) ooms += 1 self.zero_grad() if self.cuda: torch.cuda.empty_cache() if self.args.distributed_world_size == 1: return None else: raise e if self.tpu and i < len(samples) - 1: # tpu-comment: every XLA operation before marking step is # appended to the IR graph, and processing too many batches # before marking step can lead to OOM errors. # To handle gradient accumulation use case, we explicitly # mark step here for every forward pass without a backward pass import torch_xla.core.xla_model as xm xm.mark_step() if is_dummy_batch: if torch.is_tensor(sample_size): sample_size.zero_() else: sample_size *= 0. if torch.is_tensor(sample_size): sample_size = sample_size.float() else: sample_size = float(sample_size) # gather logging outputs from all replicas if self._sync_stats(): train_time = self._local_cumulative_training_time() logging_outputs, ( sample_size, ooms, total_train_time) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch, ) self._cumulative_training_time = total_train_time / self.data_parallel_world_size overflow = False try: if self.tpu and self.data_parallel_world_size > 1: import torch_xla.core.xla_model as xm gradients = xm._fetch_gradients(self.optimizer.optimizer) xm.all_reduce('sum', gradients, scale=1.0 / self.data_parallel_world_size) with torch.autograd.profiler.record_function("multiply-grads"): # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get # (sum_of_gradients / sample_size). if not self.args.use_bmuf: self.optimizer.multiply_grads( self.data_parallel_world_size / sample_size) elif sample_size > 0: # BMUF needs to check sample size num = self.data_parallel_world_size if self._sync_stats( ) else 1 self.optimizer.multiply_grads(num / sample_size) with torch.autograd.profiler.record_function("clip-grads"): # clip grads grad_norm = self.clip_grad_norm(self.args.clip_norm) # check that grad norms are consistent across workers if (not self.args.use_bmuf and self.args.distributed_wrapper != 'SlowMo' and not self.tpu): self._check_grad_norms(grad_norm) with torch.autograd.profiler.record_function("optimizer"): # take an optimization step self.optimizer.step() except FloatingPointError: # re-run the forward and backward pass with hooks attached to print # out where it fails with NanDetector(self.model): self.task.train_step(sample, self.model, self.criterion, self.optimizer, self.get_num_updates(), ignore_grad=False) raise except OverflowError as e: overflow = True logger.info("NOTE: overflow detected, " + str(e)) grad_norm = torch.tensor(0.).cuda() self.zero_grad() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) logger.error("OOM during optimization, irrecoverable") raise e # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step if hasattr(self.model, 'perform_additional_optimizer_actions'): if hasattr(self.optimizer, 'fp32_params'): self.model.perform_additional_optimizer_actions( self.optimizer.optimizer, self.optimizer.fp32_params) else: self.model.perform_additional_optimizer_actions( self.optimizer.optimizer) if not overflow or self.args.distributed_wrapper == 'SlowMo': self.set_num_updates(self.get_num_updates() + 1) if self.tpu: # mark step on TPUs import torch_xla.core.xla_model as xm xm.mark_step() # only log stats every log_interval steps # this causes wps to be misreported when log_interval > 1 logging_output = {} if self.get_num_updates() % self.args.log_interval == 0: logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) # log whenever there's an XLA compilation, since these # slow down training and may indicate opportunities for # optimization self._check_xla_compilation() else: # log stats logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) # clear CUDA cache to reduce memory fragmentation if (self.cuda and self.args.empty_cache_freq > 0 and ( (self.get_num_updates() + self.args.empty_cache_freq - 1) % self.args.empty_cache_freq) == 0): torch.cuda.empty_cache() if self.args.fp16: metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0) # Log pruning stuff metrics.log_scalar("sparsity", self.sparsity, priority=0, round=3) metrics.log_stop_time("train_wall") return logging_output
# %% # w2 = snim0.l2.weight.detach().cpu().numpy() w2 = cmod.readout.weight.detach().cpu().numpy() plt.imshow(w2) # plt.plot(w2) # plt.plot(np.sum(w2, axis=0)) plt.figure() f=plt.plot(w2) #%% import torch.nn.utils.prune as prune a = prune.random_unstructured(sgqm0.linear1, name="weight", amount=0.3) sgqm1 = deepcopy(sgqm0) w = a.weight.detach().cpu().numpy() nfilt = w.shape[0] sx,sy = U.get_subplot_dims(nfilt) plt.figure(figsize=(10,10)) for cc in range(nfilt): plt.subplot(sx,sy,cc+1) wtmp = np.reshape(w[cc,:], (gd.num_lags, gd.NX*gd.NY)) bestlag = np.argmax(np.std(wtmp, axis=1)) # bestlag = 5 plt.imshow(np.reshape(wtmp[bestlag,:], (gd.NY, gd.NX))) # %% cnim plot
def main(args): ### config global noise_multiplier dataset = args.dataset num_discriminators = args.num_discriminators noise_multiplier = args.noise_multiplier z_dim = args.z_dim model_dim = args.model_dim batchsize = args.batchsize L_gp = args.L_gp L_epsilon = args.L_epsilon critic_iters = args.critic_iters latent_type = args.latent_type load_dir = args.load_dir save_dir = args.save_dir if_dp = (args.dp > 0.) gen_arch = args.gen_arch num_gpus = args.num_gpus ### CUDA use_cuda = torch.cuda.is_available() devices = [ torch.device("cuda:%d" % i if use_cuda else "cpu") for i in range(num_gpus) ] device0 = devices[0] if use_cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') ### Random seed random.seed(args.random_seed) np.random.seed(args.random_seed) torch.manual_seed(args.random_seed) ### Fix noise for visualization if latent_type == 'normal': fix_noise = torch.randn(10, z_dim) elif latent_type == 'bernoulli': p = 0.5 bernoulli = torch.distributions.Bernoulli(torch.tensor([p])) fix_noise = bernoulli.sample((10, z_dim)).view(10, z_dim) else: raise NotImplementedError ### Set up models print('gen_arch:' + gen_arch) if dataset == 'mnist': netG = GeneratorDCGAN(z_dim=z_dim, model_dim=model_dim, num_classes=10) elif dataset == 'cifar_10': netG = GeneratorDCGAN_cifar(z_dim=z_dim, model_dim=model_dim, num_classes=10) netGS = copy.deepcopy(netG) ##prune if dataset == 'mnist': prune.random_unstructured(netG.fc, name="weight", amount=0.5) prune.random_unstructured(netG.deconv1, name="weight", amount=0.5) prune.random_unstructured(netG.deconv2, name="weight", amount=0.5) prune.random_unstructured(netG.deconv3, name="weight", amount=0.5) prune.random_unstructured(netGS.fc, name="weight", amount=0.5) prune.random_unstructured(netGS.deconv1, name="weight", amount=0.5) prune.random_unstructured(netGS.deconv2, name="weight", amount=0.5) prune.random_unstructured(netGS.deconv3, name="weight", amount=0.5) elif dataset == 'cifar_10': prune.random_unstructured(netG.fc, name="weight", amount=0.5) prune.random_unstructured(netG.deconv1, name="weight", amount=0.5) prune.random_unstructured(netG.deconv2, name="weight", amount=0.5) prune.random_unstructured(netG.deconv3, name="weight", amount=0.5) prune.random_unstructured(netG.deconv4, name="weight", amount=0.5) prune.random_unstructured(netGS.fc, name="weight", amount=0.5) prune.random_unstructured(netGS.deconv1, name="weight", amount=0.5) prune.random_unstructured(netGS.deconv2, name="weight", amount=0.5) prune.random_unstructured(netGS.deconv3, name="weight", amount=0.5) prune.random_unstructured(netGS.deconv4, name="weight", amount=0.5) netD_list = [] for i in range(num_discriminators): if dataset == 'mnist': netD = DiscriminatorDCGAN() elif dataset == 'cifar_10': netD = DiscriminatorDCGAN_cifar() netD_list.append(netD) ### Load pre-trained discriminators print("load pre-training...") if load_dir is not None: for netD_id in range(num_discriminators): print('Load NetD ', str(netD_id)) network_path = os.path.join(load_dir, 'netD_%d' % netD_id, 'netD.pth') netD = netD_list[netD_id] netD.load_state_dict(torch.load(network_path)) netG = netG.to(device0) for netD_id, netD in enumerate(netD_list): device = devices[get_device_id(netD_id, num_discriminators, num_gpus)] netD.to(device) ### Set up optimizers optimizerD_list = [] for i in range(num_discriminators): netD = netD_list[i] optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.99)) optimizerD_list.append(optimizerD) optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.99)) ### Data loaders if dataset == 'mnist' or dataset == 'fashionmnist': transform_train = transforms.Compose([ transforms.CenterCrop((28, 28)), transforms.ToTensor(), #transforms.Grayscale(), ]) elif dataset == 'cifar_100' or dataset == 'cifar_10': transform_train = transforms.Compose([ transforms.Grayscale(), transforms.ToTensor(), ]) if dataset == 'mnist': dataloader = datasets.MNIST trainset = dataloader(root=os.path.join(DATA_ROOT, 'MNIST'), train=True, download=True, transform=transform_train) IMG_DIM = 784 NUM_CLASSES = 10 elif dataset == 'fashionmnist': dataloader = datasets.FashionMNIST trainset = dataloader(root=os.path.join(DATA_ROOT, 'FashionMNIST'), train=True, download=True, transform=transform_train) elif dataset == 'cifar_100': dataloader = datasets.CIFAR100 trainset = dataloader(root=os.path.join(DATA_ROOT, 'CIFAR100'), train=True, download=True, transform=transform_train) IMG_DIM = 3072 NUM_CLASSES = 100 elif dataset == 'cifar_10': IMG_DIM = 1024 NUM_CLASSES = 10 dataloader = datasets.CIFAR10 trainset = dataloader(root=os.path.join(DATA_ROOT, 'CIFAR10'), train=True, download=True, transform=transform_train) else: raise NotImplementedError ###fix sub-training set (fix to 10000 training samples) if args.update_train_dataset: indices_full = np.arange(len(trainset)) np.random.shuffle(indices_full) indices_10000 = indices_full[:10000] np.savetxt('index_10000.txt', indices_10000, fmt='%i') indices = np.loadtxt('index_10000.txt', dtype=np.int_) trainset = torch.utils.data.Subset(trainset, indices) print('creat indices file') indices_full = np.arange(len(trainset)) np.random.shuffle(indices_full) #indices_full.dump(os.path.join(save_dir, 'indices.npy')) trainset_size = int(len(trainset) / num_discriminators) print('Size of the dataset: ', trainset_size) input_pipelines = [] for i in range(num_discriminators): start = i * trainset_size end = (i + 1) * trainset_size indices = indices_full[start:end] trainloader = DataLoader(trainset, batch_size=args.batchsize, drop_last=False, num_workers=args.num_workers, sampler=SubsetRandomSampler(indices)) #input_data = inf_train_gen(trainloader) input_pipelines.append(trainloader) if if_dp: ### Register hook global dynamic_hook_function for netD in netD_list: netD.conv1.register_backward_hook(master_hook_adder) prg_bar = tqdm(range(args.iterations + 1)) for iters in prg_bar: ######################### ### Update D network ######################### netD_id = np.random.randint(num_discriminators, size=1)[0] device = devices[get_device_id(netD_id, num_discriminators, num_gpus)] netD = netD_list[netD_id] optimizerD = optimizerD_list[netD_id] input_data = input_pipelines[netD_id] for p in netD.parameters(): p.requires_grad = True for iter_d in range(critic_iters): real_data, real_y = next(iter(input_data)) real_data = real_data.view(-1, IMG_DIM) real_data = real_data.to(device) real_y = real_y.to(device) real_data_v = autograd.Variable(real_data) ### train with real dynamic_hook_function = dummy_hook netD.zero_grad() D_real_score = netD(real_data_v, real_y) D_real = -D_real_score.mean() ### train with fake batchsize = real_data.shape[0] if latent_type == 'normal': noise = torch.randn(batchsize, z_dim).to(device0) elif latent_type == 'bernoulli': noise = bernoulli.sample( (batchsize, z_dim)).view(batchsize, z_dim).to(device0) else: raise NotImplementedError noisev = autograd.Variable(noise) fake = autograd.Variable(netG(noisev, real_y.to(device0)).data) inputv = fake.to(device) D_fake = netD(inputv, real_y.to(device)) D_fake = D_fake.mean() ### train with gradient penalty gradient_penalty = netD.calc_gradient_penalty( real_data_v.data, fake.data, real_y, L_gp, device) D_cost = D_fake + D_real + gradient_penalty ### train with epsilon penalty logit_cost = L_epsilon * torch.pow(D_real_score, 2).mean() D_cost += logit_cost ### update D_cost.backward() Wasserstein_D = -D_real - D_fake optimizerD.step() del real_data, real_y, fake, noise, inputv, D_real, D_fake, logit_cost, gradient_penalty torch.cuda.empty_cache() ############################ # Update G network ########################### if if_dp: ### Sanitize the gradients passed to the Generator dynamic_hook_function = dp_conv_hook else: ### Only modify the gradient norm, without adding noise dynamic_hook_function = modify_gradnorm_conv_hook for p in netD.parameters(): p.requires_grad = False netG.zero_grad() ### train with sanitized discriminator output if latent_type == 'normal': noise = torch.randn(batchsize, z_dim).to(device0) elif latent_type == 'bernoulli': noise = bernoulli.sample( (batchsize, z_dim)).view(batchsize, z_dim).to(device0) else: raise NotImplementedError label = torch.randint(0, NUM_CLASSES, [batchsize]).to(device0) noisev = autograd.Variable(noise) fake = netG(noisev, label) #summary(netG, input_data=[noisev,label]) fake = fake.to(device) label = label.to(device) G = netD(fake, label) G = -G.mean() ### update G.backward() G_cost = G optimizerG.step() ### update the exponential moving average exp_mov_avg(netGS, netG, alpha=0.999, global_step=iters) ############################ ### Results visualization ############################ prg_bar.set_description( 'iter:{}, G_cost:{:.2f}, D_cost:{:.2f}, Wasserstein:{:.2f}'.format( iters, G_cost.cpu().data, D_cost.cpu().data, Wasserstein_D.cpu().data)) if iters % args.vis_step == 0: if dataset == 'mnist': generate_image_mnist(iters, netGS, fix_noise, save_dir, device0) elif dataset == 'cifar_100': generate_image_cifar100(iters, netGS, fix_noise, save_dir, device0) elif dataset == 'cifar_10': generate_image_cifar10(iters, netGS, fix_noise, save_dir, device0) if iters in [1000, 5000, 10000, 20000]: ### save model ##prune if dataset == 'mnist': prune.remove(netGS.fc, name="weight") prune.remove(netGS.deconv1, name="weight") prune.remove( netGS.deconv2, name="weight", ) prune.remove(netGS.deconv3, name="weight") elif dataset == 'cifar_10': prune.remove(netGS.fc, name="weight") prune.remove(netGS.deconv1, name="weight") prune.remove(netGS.deconv2, name="weight") prune.remove(netGS.deconv3, name="weight") prune.remove(netGS.deconv4, name="weight") ### save model torch.save(netGS.state_dict(), os.path.join(save_dir, 'netGS_%d.pth' % iters)) torch.save(netD.state_dict(), os.path.join(save_dir, 'netD_%d.pth' % iters)) ##prune if dataset == 'mnist': prune.random_unstructured(netGS.fc, name="weight", amount=0.5) prune.random_unstructured(netGS.deconv1, name="weight", amount=0.5) prune.random_unstructured(netGS.deconv2, name="weight", amount=0.5) prune.random_unstructured(netGS.deconv3, name="weight", amount=0.5) elif dataset == 'cifar_10': prune.random_unstructured(netGS.fc, name="weight", amount=0.5) prune.random_unstructured(netGS.deconv1, name="weight", amount=0.5) prune.random_unstructured(netGS.deconv2, name="weight", amount=0.5) prune.random_unstructured(netGS.deconv3, name="weight", amount=0.5) prune.random_unstructured(netGS.deconv4, name="weight", amount=0.5) del label, fake, noisev, noise, G, G_cost, D_cost torch.cuda.empty_cache()
def train( self, method, x, y, x0=None, x1=None, x_val=None, y_val=None, alpha=1.0, optimizer="amsgrad", n_epochs=50, batch_size=128, initial_lr=0.001, final_lr=0.0001, nesterov_momentum=None, validation_split=0.25, early_stopping=True, scale_inputs=True, prune_network = False, limit_samplesize=None, memmap=False, verbose="some", scale_parameters=False, n_workers=8, clip_gradient=None, early_stopping_patience=None, ): """ Trains the network. Parameters ---------- method : str The inference method used for training. Allowed values are 'alice', 'alices', 'carl', 'cascal', 'rascal', and 'rolr'. x : ndarray or str Observations, or filename of a pickled numpy array. y : ndarray or str Class labels (0 = numeerator, 1 = denominator), or filename of a pickled numpy array. alpha : float, optional Default value: 1. optimizer : {"adam", "amsgrad", "sgd"}, optional Optimization algorithm. Default value: "amsgrad". n_epochs : int, optional Number of epochs. Default value: 50. batch_size : int, optional Batch size. Default value: 128. initial_lr : float, optional Learning rate during the first epoch, after which it exponentially decays to final_lr. Default value: 0.001. final_lr : float, optional Learning rate during the last epoch. Default value: 0.0001. nesterov_momentum : float or None, optional If trainer is "sgd", sets the Nesterov momentum. Default value: None. validation_split : float or None, optional Fraction of samples used for validation and early stopping (if early_stopping is True). If None, the entire sample is used for training and early stopping is deactivated. Default value: 0.25. early_stopping : bool, optional Activates early stopping based on the validation loss (only if validation_split is not None). Default value: True. scale_inputs : bool, optional Scale the observables to zero mean and unit variance. Default value: True. memmap : bool, optional. If True, training files larger than 1 GB will not be loaded into memory at once. Default value: False. verbose : {"all", "many", "some", "few", "none}, optional Determines verbosity of training. Default value: "some". Returns ------- None """ logger.info("Starting training") logger.info(" Method: %s", method) logger.info(" Batch size: %s", batch_size) logger.info(" Optimizer: %s", optimizer) logger.info(" Epochs: %s", n_epochs) logger.info(" Learning rate: %s initially, decaying to %s", initial_lr, final_lr) if optimizer == "sgd": logger.info(" Nesterov momentum: %s", nesterov_momentum) logger.info(" Validation split: %s", validation_split) logger.info(" Early stopping: %s", early_stopping) logger.info(" Scale inputs: %s", scale_inputs) if limit_samplesize is None: logger.info(" Samples: all") else: logger.info(" Samples: %s", limit_samplesize) # Load training data logger.info("Loading training data") memmap_threshold = 1.0 if memmap else None x = load_and_check(x, memmap_files_larger_than_gb=memmap_threshold) y = load_and_check(y, memmap_files_larger_than_gb=memmap_threshold) x0 = load_and_check(x0, memmap_files_larger_than_gb=memmap_threshold) x1 = load_and_check(x1, memmap_files_larger_than_gb=memmap_threshold) # Infer dimensions of problem n_samples = x.shape[0] n_observables = x.shape[1] logger.info("Found %s samples with %s observables", n_samples, n_observables) external_validation = x_val is not None and y_val is not None if external_validation: x_val = load_and_check(x_val, memmap_files_larger_than_gb=memmap_threshold) y_val = load_and_check(y_val, memmap_files_larger_than_gb=memmap_threshold) logger.info("Found %s separate validation samples", x_val.shape[0]) assert x_val.shape[1] == n_observables # Scale features if scale_inputs: self.initialize_input_transform(x, overwrite=False) x = self._transform_inputs(x) if external_validation: x_val = self._transform_inputs(x_val) else: self.initialize_input_transform(x, False, overwrite=False) # Features if self.features is not None: x = x[:, self.features] logger.info("Only using %s of %s observables", x.shape[1], n_observables) n_observables = x.shape[1] if external_validation: x_val = x_val[:, self.features] # Check consistency of input with model if self.n_observables is None: self.n_observables = n_observables if n_observables != self.n_observables: raise RuntimeError( "Number of observables does not match model: {} vs {}".format(n_observables, self.n_observables) ) # Data data = self._package_training_data(method, x, y) if external_validation: data_val = self._package_training_data(method, x_val, y_val) else: data_val = None # Create model if self.model is None: logger.info("Creating model") self._create_model() if prune_network: module = self.model.ll1 print("before pruning") print(list(module.named_parameters())) print(list(module.named_buffers())) prune.random_unstructured(module, name="weight", amount=0.3) print("before pruning") print(list(module.named_parameters())) print(list(module.named_buffers())) print(self.model.state_dict().keys()) # Losses w = len(x0)/len(x1) logger.info("Passing weight %s to the loss function to account for imbalanced dataset: ", w) loss_functions, loss_labels, loss_weights = get_loss(method + "2", alpha, w) # Optimizer opt, opt_kwargs = get_optimizer(optimizer, nesterov_momentum) # Train model logger.info("Training model") trainer = RatioTrainer(self.model, n_workers=n_workers) result = trainer.train( data=data, data_val=data_val, loss_functions=loss_functions, loss_weights=loss_weights, loss_labels=loss_labels, epochs=n_epochs, batch_size=batch_size, optimizer=opt, optimizer_kwargs=opt_kwargs, initial_lr=initial_lr, final_lr=final_lr, validation_split=validation_split, early_stopping=early_stopping, verbose=verbose, clip_gradient=clip_gradient, early_stopping_patience=early_stopping_patience, ) return result