def plot_fc(sequential): """ Plot the weights of all linear layers from 'sequential' as rectangles on a figure and maps values to colors. Use normalization to avoid clipping and to center all values at zero. """ assert isinstance( sequential, nn.Sequential), f"'sequential' has invalid type {type(sequential)}" linear_layers = [ layer for layer in sequential if isinstance(layer, nn.Linear) ] fig, ax_list = plt.subplots(len(linear_layers), 1, figsize=(12, 4 * len(linear_layers)), constrained_layout=False) for ax, layer in zip(ax_list, linear_layers): weights = layer.weight.data.clone().numpy() weight_norm = plotter_evaluation.get_norm_for_tensor(weights) if prune.is_pruned( layer): # mark masked weights with NAN to highlight them later pruning_mask = layer.weight_mask.numpy() weights[np.where(pruning_mask == 0)] = np.nan ax.imshow(weights, norm=weight_norm, cmap=get_cmap(), interpolation='none') cax, hax = generate_axes_for_colorbar_and_histogram(ax, 0.2, 0.4, 0.6) fig.colorbar(ax.images[0], cax=cax) cax.yaxis.set_ticks_position('left') generate_histogram_on_ax(hax, weights) return fig
def mask_merger(model): """remove mask but let weights stay pruned""" for name, module in model.named_modules(): if is_pruned(module) == False: continue if isinstance(module, torch.nn.Conv2d) or isinstance( module, torch.nn.Linear): remove(module, name='weight')
def init_prune_model(model): if prune.is_pruned(model): remove_pruning(model) for layer in model.children(): if isinstance(layer, nn.Linear): prune.identity(layer, name='weight') if isinstance(layer, nn.Conv2d): prune.identity(layer, name='weight')
def rand_initialize_weights(self): for layer in self.modules(): if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): if prune.is_pruned(self): torch.nn.init.xavier_normal_(layer.weight_orig) else: torch.nn.init.xavier_normal_(layer.weight) if layer.bias is not None: torch.nn.init.constant_(layer.bias, 0)
def setup_graph(self): # initialize the masks if self._mask_init_method == 'random': for i, mask in enumerate(self.masks): size = mask.size() self.masks[i] = torch.tensor(get_mask_random(size, self._default_sparsity), dtype=mask.dtype).to(device) # initialize masked weight for i, layer in enumerate(self.layers): if prune.is_pruned(layer): prune.remove(layer, 'weight') prune.custom_from_mask(layer, 'weight', self.masks[i])
def plot_kernels(conv_2d, num_cols=8): """ Plot the weights of all kernels from 'conv_2d' as rectangles on a new figure and map values to colors. Create one normalized image per channel and kernel to avoid clipping and to center all values at zero. Show a colorbar with integrated histogram on the right. """ assert isinstance(conv_2d, nn.Conv2d), f"'conv_2d' has invalid type {type(conv_2d)}" weights = conv_2d.weight.data.clone().numpy( ) # shape is [kernels, channels, height, width] if prune.is_pruned( conv_2d): # mark masked weights with NAN to highlight them later weights[np.where(conv_2d.weight_mask.numpy() == 0)] = np.nan if ( weights.shape[0] * weights.shape[1] ) > 512: # restrict number of images to 512, do not plot partial kernels last_kernel = ceil(512 / weights.shape[1]) weights = weights[:last_kernel] warnings.warn( f"Too many kernels to plot, only plot the first {last_kernel} kernels." ) weight_norm = plotter_evaluation.get_norm_for_tensor(weights) num_cols, num_rows = plotter_evaluation.get_row_and_col_num( weights.shape, num_cols) fig = plt.figure(figsize=(num_cols + 1, num_rows), constrained_layout=False) gs = fig.add_gridspec( num_rows, num_cols + 1, wspace=0.1, hspace=0.4) # extra column for colorbar and histogram for kernel_counter, kernel in enumerate(weights[:]): for channel_counter, channel in enumerate(kernel[:]): ax_counter = kernel_counter * kernel.shape[0] + channel_counter ax = fig.add_subplot(gs[ax_counter // num_cols, ax_counter % num_cols]) ax.imshow(channel, cmap=get_cmap(), norm=weight_norm) ax.set_title( f"K{kernel_counter + 1}.{channel_counter + 1}").set_position( [.5, 0.95]) ax.axis('off') tmp_ax = fig.add_subplot( gs[:, -1]) # empty column to stick colorbar and histogram on tmp_ax.axis('off') cax, hax = generate_axes_for_colorbar_and_histogram( tmp_ax, "40%", "60%", 0) fig.colorbar(fig.axes[0].images[0], cax=cax) cax.yaxis.set_ticks_position('left') generate_histogram_on_ax(hax, weights) return fig
def rand_initialize_weights(self): for layer in self.modules(): if isinstance(layer, nn.Conv2d): if prune.is_pruned(self): torch.nn.init.kaiming_normal_(layer.weight_orig, mode='fan_out', nonlinearity='relu') else: torch.nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu') if layer.bias is not None: torch.nn.init.constant_(layer.bias, 0) elif isinstance(layer, nn.Linear): if prune.is_pruned(self): torch.nn.init.normal_(layer.weight_orig, 0, 0.01) else: torch.nn.init.normal_(layer.weight, 0, 0.01) torch.nn.init.constant_(layer.bias, 0) elif isinstance(layer, nn.BatchNorm2d): torch.nn.init.constant_(layer.weight, 1) torch.nn.init.constant_(layer.bias, 0)
def update_layer_mask(self, layer, layer_mask, noise_std=1e-5): layer_weight = layer.weight layer_grad = self.grad_dict[layer] # Remove weight smaller than adaptive threshold layer_mask_dropped, drop_n = sparsify_weight(layer_weight.abs(), layer_mask, self._drop_fraction) # Grow weight whose gradient larger than adaptive threshold score_grow = layer_grad * (~layer_mask_dropped) layer_mask, new_mask = self.grow_maximum(score_grow, layer_mask_dropped, int(drop_n*self._grow_fraction)) # update the weight if prune.is_pruned(layer): prune.remove(layer, 'weight') prune.custom_from_mask(layer, 'weight', layer_mask) return layer_mask, new_mask
def update_layer_mask(self, layer, layer_mask, noise_std=1e-5): score = self.grad_dict[layer].abs() + noise_std # Add noise for slight bit of randomness. score_drop = score * layer_mask layer_mask_dropped, n_prune = self.drop_minimum(score_drop, layer_mask) # Randomly revive n_prune many connections from non-existing connections. score_grow = score * (~layer_mask_dropped) layer_mask, new_mask = self.grow_maximum(score_grow, layer_mask_dropped, n_prune) # update the weight if prune.is_pruned(layer): prune.remove(layer, 'weight') prune.custom_from_mask(layer, 'weight', layer_mask) return layer_mask, new_mask
def rand_initialize_weights(self): for layer in self.modules(): if isinstance(layer, nn.Linear): if prune.is_pruned(self): torch.nn.init.xavier_normal_(layer.weight_orig) else: torch.nn.init.xavier_normal_(layer.weight) if layer.bias is not None: torch.nn.init.constant_(layer.bias, 0) elif isinstance(layer, nn.LSTM) or isinstance(layer, nn.GRU): for name, param in layer.named_parameters(): if 'weight_ih' in name: torch.nn.init.xavier_normal_(param.data) elif 'weight_hh' in name: torch.nn.init.xavier_normal_(param.data) elif 'bias' in name: param.data.fill_(0)
def update_layer_mask(self, layer, layer_mask, noise_std=1e-5): layer_weight = layer.weight # Add noise for slight bit of randomness and drop masked_weights = layer_mask * layer_weight score_drop = masked_weights.abs() + self._random_normal(layer_weight.size(), std=noise_std) layer_mask_dropped, n_prune = self.drop_minimum(score_drop, layer_mask) # Randomly revive n_prune many connections from non-existing connections. score_grow = self.grad_dict[layer].abs() * (~layer_mask_dropped) + self._random_uniform(layer_weight.size()) * noise_std layer_mask, new_mask = self.grow_maximum(score_grow, layer_mask_dropped, n_prune) # update the weight if prune.is_pruned(layer): prune.remove(layer, 'weight') prune.custom_from_mask(layer, 'weight', layer_mask) return layer_mask, new_mask
def compute_gradients(self, criterion, inputs, labels): """Wraps the compute gradient of passed optimizer.""" # remove the prune and forward for layer in self.layers: if prune.is_pruned(layer): prune.remove(layer, 'weight') # forward outputs = self.model(inputs) loss = criterion(outputs, labels) loss = loss.sum() # backward and calculate all gradients loss.backward() # get gradient gradient = [] for layer in self.layers: gradient.append(layer.weight.grad.data) self.grad_dict = dict(zip(self.layers, gradient))
def train(args, train_dataset, model, tokenizer): """ Train the model """ if args.local_rank in [-1, 0]: tb_writer = SummaryWriter() args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler( train_dataset) if args.local_rank == -1 else DistributedSampler( train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) # Check if saved optimizer or scheduler states exist if os.path.isfile(os.path.join( args.model_name_or_path, "optimizer.pt")) and os.path.isfile( os.path.join(args.model_name_or_path, "scheduler.pt")): # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): # set global_step to global_step of last saved checkpoint from model path try: global_step = int( args.model_name_or_path.split("-")[-1].split("/")[0]) except ValueError: global_step = 0 epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() train_iterator = trange( epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0], ) set_seed(args) # Added here for reproductibility for _ in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) # if args.prune_train > 0: # print('Pruning {} %'.format(args.prune_train*100)) # if args.prune == 'global': print('Global Pruning') # elif args.prune == 'l1': print('L1 Pruning') # elif args.prune == 'random': print('Random Pruning') # parameters_to_prune = [] # for mod_name, module in list(model.named_modules()): # for name, value in list(module.named_parameters()): # if name in ['weight']: # print(mod_name, name) # print('weights before {:.3f}%'.format(float(torch.sum(module.weight == 0)) * 100 / float(module.weight.nelement()))) # if prune.is_pruned(module): # prune.remove(module, 'weight') # print('removed',mod_name) # if args.prune == 'global': parameters_to_prune.append((module, 'weight')) # elif args.prune == 'l1': module = prune.l1_unstructured(module, name='weight', amount=args.prune_train) # elif args.prune == 'random': module = prune.random_unstructured(module, name='weight', amount=args.prune_train) # print('weights after {:.3f}%'.format(float(torch.sum(module.weight == 0)) * 100 / float(module.weight.nelement()))) # if args.prune == 'global': prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=args.prune_train) # for mod_name, module in list(model.named_modules()): # for name, value in list(module.named_parameters()): # print(mod_name, name) prune_model(model, args, 'train') for step, batch in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue model.train() batch = tuple(t.to(args.device) for t in batch) inputs = { "input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3] } if args.model_type != "distilbert": inputs["token_type_ids"] = ( batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids outputs = model(**inputs) loss = outputs[ 0] # model outputs are always tuple in transformers (see doc) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps len(epoch_iterator) <= args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator)): if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} if ( args.local_rank == -1 and args.evaluate_during_training ): # Only evaluate when single GPU otherwise metrics may not average well results = evaluate(args, model, tokenizer) for key, value in results.items(): eval_key = "eval_{}".format(key) logs[eval_key] = value loss_scalar = (tr_loss - logging_loss) / args.logging_steps learning_rate_scalar = scheduler.get_lr()[0] logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar logging_loss = tr_loss for key, value in logs.items(): tb_writer.add_scalar(key, value, global_step) print(json.dumps({**logs, **{"step": global_step}})) if args.local_rank in [ -1, 0 ] and args.save_steps > 0 and global_step % args.save_steps == 0: countZeroWeights(model) for mod_name, module in list(model.named_modules()): for name, value in list(module.named_parameters()): if prune.is_pruned(module): prune.remove(module, 'weight') print('removed', mod_name) countZeroWeights(model) # Save model checkpoint output_dir = os.path.join( args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank in [-1, 0]: tb_writer.close() return global_step, tr_loss / global_step