示例#1
0
def plot_fc(sequential):
    """ Plot the weights of all linear layers from 'sequential' as rectangles on a figure and maps values to colors.
    Use normalization to avoid clipping and to center all values at zero. """
    assert isinstance(
        sequential,
        nn.Sequential), f"'sequential' has invalid type {type(sequential)}"
    linear_layers = [
        layer for layer in sequential if isinstance(layer, nn.Linear)
    ]

    fig, ax_list = plt.subplots(len(linear_layers),
                                1,
                                figsize=(12, 4 * len(linear_layers)),
                                constrained_layout=False)
    for ax, layer in zip(ax_list, linear_layers):
        weights = layer.weight.data.clone().numpy()
        weight_norm = plotter_evaluation.get_norm_for_tensor(weights)
        if prune.is_pruned(
                layer):  # mark masked weights with NAN to highlight them later
            pruning_mask = layer.weight_mask.numpy()
            weights[np.where(pruning_mask == 0)] = np.nan
        ax.imshow(weights,
                  norm=weight_norm,
                  cmap=get_cmap(),
                  interpolation='none')

        cax, hax = generate_axes_for_colorbar_and_histogram(ax, 0.2, 0.4, 0.6)
        fig.colorbar(ax.images[0], cax=cax)
        cax.yaxis.set_ticks_position('left')
        generate_histogram_on_ax(hax, weights)

    return fig
示例#2
0
def mask_merger(model):
    """remove mask but let weights stay pruned"""
    for name, module in model.named_modules():
        if is_pruned(module) == False:
            continue
        if isinstance(module, torch.nn.Conv2d) or isinstance(
                module, torch.nn.Linear):
            remove(module, name='weight')
示例#3
0
def init_prune_model(model):
    if prune.is_pruned(model):
        remove_pruning(model)
    for layer in model.children():
        if isinstance(layer, nn.Linear):
            prune.identity(layer, name='weight')
        if isinstance(layer, nn.Conv2d):
            prune.identity(layer, name='weight')
示例#4
0
    def rand_initialize_weights(self):
        for layer in self.modules():
            if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
                if prune.is_pruned(self):
                    torch.nn.init.xavier_normal_(layer.weight_orig)
                else:
                    torch.nn.init.xavier_normal_(layer.weight)

                if layer.bias is not None:
                    torch.nn.init.constant_(layer.bias, 0)
示例#5
0
    def setup_graph(self):
        # initialize the masks
        if self._mask_init_method == 'random':
            for i, mask in enumerate(self.masks):
                size = mask.size()
                self.masks[i] = torch.tensor(get_mask_random(size, self._default_sparsity), dtype=mask.dtype).to(device)

        # initialize masked weight
        for i, layer in enumerate(self.layers):
            if prune.is_pruned(layer):
                prune.remove(layer, 'weight')
            prune.custom_from_mask(layer, 'weight', self.masks[i])
示例#6
0
def plot_kernels(conv_2d, num_cols=8):
    """ Plot the weights of all kernels from 'conv_2d' as rectangles on a new figure and map values to colors.
    Create one normalized image per channel and kernel to avoid clipping and to center all values at zero.
    Show a colorbar with integrated histogram on the right. """
    assert isinstance(conv_2d,
                      nn.Conv2d), f"'conv_2d' has invalid type {type(conv_2d)}"

    weights = conv_2d.weight.data.clone().numpy(
    )  # shape is [kernels, channels, height, width]
    if prune.is_pruned(
            conv_2d):  # mark masked weights with NAN to highlight them later
        weights[np.where(conv_2d.weight_mask.numpy() == 0)] = np.nan
    if (
            weights.shape[0] * weights.shape[1]
    ) > 512:  # restrict number of images to 512, do not plot partial kernels
        last_kernel = ceil(512 / weights.shape[1])
        weights = weights[:last_kernel]
        warnings.warn(
            f"Too many kernels to plot, only plot the first {last_kernel} kernels."
        )

    weight_norm = plotter_evaluation.get_norm_for_tensor(weights)
    num_cols, num_rows = plotter_evaluation.get_row_and_col_num(
        weights.shape, num_cols)

    fig = plt.figure(figsize=(num_cols + 1, num_rows),
                     constrained_layout=False)
    gs = fig.add_gridspec(
        num_rows, num_cols + 1, wspace=0.1,
        hspace=0.4)  # extra column for colorbar and histogram
    for kernel_counter, kernel in enumerate(weights[:]):
        for channel_counter, channel in enumerate(kernel[:]):
            ax_counter = kernel_counter * kernel.shape[0] + channel_counter
            ax = fig.add_subplot(gs[ax_counter // num_cols,
                                    ax_counter % num_cols])
            ax.imshow(channel, cmap=get_cmap(), norm=weight_norm)
            ax.set_title(
                f"K{kernel_counter + 1}.{channel_counter + 1}").set_position(
                    [.5, 0.95])
            ax.axis('off')

    tmp_ax = fig.add_subplot(
        gs[:, -1])  # empty column to stick colorbar and histogram on
    tmp_ax.axis('off')
    cax, hax = generate_axes_for_colorbar_and_histogram(
        tmp_ax, "40%", "60%", 0)

    fig.colorbar(fig.axes[0].images[0], cax=cax)
    cax.yaxis.set_ticks_position('left')
    generate_histogram_on_ax(hax, weights)

    return fig
示例#7
0
 def rand_initialize_weights(self):
     for layer in self.modules():
         if isinstance(layer, nn.Conv2d):
             if prune.is_pruned(self):
                 torch.nn.init.kaiming_normal_(layer.weight_orig,
                                               mode='fan_out',
                                               nonlinearity='relu')
             else:
                 torch.nn.init.kaiming_normal_(layer.weight,
                                               mode='fan_out',
                                               nonlinearity='relu')
             if layer.bias is not None:
                 torch.nn.init.constant_(layer.bias, 0)
         elif isinstance(layer, nn.Linear):
             if prune.is_pruned(self):
                 torch.nn.init.normal_(layer.weight_orig, 0, 0.01)
             else:
                 torch.nn.init.normal_(layer.weight, 0, 0.01)
             torch.nn.init.constant_(layer.bias, 0)
         elif isinstance(layer, nn.BatchNorm2d):
             torch.nn.init.constant_(layer.weight, 1)
             torch.nn.init.constant_(layer.bias, 0)
示例#8
0
    def update_layer_mask(self, layer, layer_mask, noise_std=1e-5):
        layer_weight = layer.weight
        layer_grad = self.grad_dict[layer]

        # Remove weight smaller than adaptive threshold
        layer_mask_dropped, drop_n = sparsify_weight(layer_weight.abs(), layer_mask, self._drop_fraction)

        # Grow weight whose gradient larger than adaptive threshold
        score_grow = layer_grad * (~layer_mask_dropped)
        layer_mask, new_mask = self.grow_maximum(score_grow, layer_mask_dropped, int(drop_n*self._grow_fraction))

        # update the weight
        if prune.is_pruned(layer):
            prune.remove(layer, 'weight')
        prune.custom_from_mask(layer, 'weight', layer_mask)

        return layer_mask, new_mask
示例#9
0
    def update_layer_mask(self, layer, layer_mask, noise_std=1e-5):
        score = self.grad_dict[layer].abs() + noise_std

        # Add noise for slight bit of randomness.
        score_drop = score * layer_mask
        layer_mask_dropped, n_prune = self.drop_minimum(score_drop, layer_mask)

        # Randomly revive n_prune many connections from non-existing connections.
        score_grow = score * (~layer_mask_dropped)
        layer_mask, new_mask = self.grow_maximum(score_grow, layer_mask_dropped, n_prune)

        # update the weight
        if prune.is_pruned(layer):
            prune.remove(layer, 'weight')
        prune.custom_from_mask(layer, 'weight', layer_mask)

        return layer_mask, new_mask
示例#10
0
 def rand_initialize_weights(self):
     for layer in self.modules():
         if isinstance(layer, nn.Linear):
             if prune.is_pruned(self):
                 torch.nn.init.xavier_normal_(layer.weight_orig)
             else:
                 torch.nn.init.xavier_normal_(layer.weight)
             if layer.bias is not None:
                 torch.nn.init.constant_(layer.bias, 0)
         elif isinstance(layer, nn.LSTM) or isinstance(layer, nn.GRU):
             for name, param in layer.named_parameters():
                 if 'weight_ih' in name:
                     torch.nn.init.xavier_normal_(param.data)
                 elif 'weight_hh' in name:
                     torch.nn.init.xavier_normal_(param.data)
                 elif 'bias' in name:
                     param.data.fill_(0)
示例#11
0
    def update_layer_mask(self, layer, layer_mask, noise_std=1e-5):
        layer_weight = layer.weight

        # Add noise for slight bit of randomness and drop
        masked_weights = layer_mask * layer_weight
        score_drop = masked_weights.abs() + self._random_normal(layer_weight.size(), std=noise_std)
        layer_mask_dropped, n_prune = self.drop_minimum(score_drop, layer_mask)

        # Randomly revive n_prune many connections from non-existing connections.
        score_grow = self.grad_dict[layer].abs() * (~layer_mask_dropped) + self._random_uniform(layer_weight.size()) * noise_std
        layer_mask, new_mask = self.grow_maximum(score_grow, layer_mask_dropped, n_prune)

        # update the weight
        if prune.is_pruned(layer):
            prune.remove(layer, 'weight')
        prune.custom_from_mask(layer, 'weight', layer_mask)

        return layer_mask, new_mask
示例#12
0
    def compute_gradients(self, criterion, inputs, labels):
        """Wraps the compute gradient of passed optimizer."""
        # remove the prune and forward
        for layer in self.layers:
            if prune.is_pruned(layer):
                prune.remove(layer, 'weight')

        # forward
        outputs = self.model(inputs)
        loss = criterion(outputs, labels)
        loss = loss.sum()

        # backward and calculate all gradients
        loss.backward()

        # get gradient
        gradient = []
        for layer in self.layers:
            gradient.append(layer.weight.grad.data)

        self.grad_dict = dict(zip(self.layers, gradient))
示例#13
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to global_step of last saved checkpoint from model path
        try:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
        except ValueError:
            global_step = 0
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
        disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproductibility

    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])

        # if args.prune_train > 0:
        #     print('Pruning {} %'.format(args.prune_train*100))
        #     if args.prune == 'global': print('Global Pruning')
        #     elif args.prune == 'l1': print('L1 Pruning')
        #     elif args.prune == 'random': print('Random Pruning')
        #     parameters_to_prune = []
        #     for mod_name, module in list(model.named_modules()):
        #         for name, value in list(module.named_parameters()):
        #             if name in ['weight']:
        #                 print(mod_name, name)
        #                 print('weights before {:.3f}%'.format(float(torch.sum(module.weight == 0)) * 100 / float(module.weight.nelement())))
        #                 if prune.is_pruned(module):
        #                     prune.remove(module, 'weight')
        #                     print('removed',mod_name)
        #                 if args.prune == 'global': parameters_to_prune.append((module, 'weight'))
        #                 elif args.prune == 'l1': module = prune.l1_unstructured(module, name='weight', amount=args.prune_train)
        #                 elif args.prune == 'random': module = prune.random_unstructured(module, name='weight', amount=args.prune_train)
        #                 print('weights after {:.3f}%'.format(float(torch.sum(module.weight == 0)) * 100 / float(module.weight.nelement())))
        #     if args.prune == 'global': prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=args.prune_train)

        # for mod_name, module in list(model.named_modules()):
        #     for name, value in list(module.named_parameters()):
        #         print(mod_name, name)

        prune_model(model, args, 'train')

        for step, batch in enumerate(epoch_iterator):
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            if args.model_type != "distilbert":
                inputs["token_type_ids"] = (
                    batch[2]
                    if args.model_type in ["bert", "xlnet", "albert"] else None
                )  # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= args.gradient_accumulation_steps and
                (step + 1) == len(epoch_iterator)):
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}
                    if (
                            args.local_rank == -1
                            and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            eval_key = "eval_{}".format(key)
                            logs[eval_key] = value

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar
                    logs["loss"] = loss_scalar
                    logging_loss = tr_loss

                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    print(json.dumps({**logs, **{"step": global_step}}))

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:

                    countZeroWeights(model)

                    for mod_name, module in list(model.named_modules()):
                        for name, value in list(module.named_parameters()):
                            if prune.is_pruned(module):
                                prune.remove(module, 'weight')
                                print('removed', mod_name)

                    countZeroWeights(model)

                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training

                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step