示例#1
0
    model_config = deepcopy(ECG_CRNN_CONFIG)
    model_config.cnn.name = config.cnn_name
    model_config.rnn.name = config.rnn_name

    model = ECG_CRNN_CINC2020(
        classes=classes,
        n_leads=config.n_leads,
        input_len=config.input_len,
        config=model_config,
    )

    if torch.cuda.device_count() > 1:
        model = DP(model)
        # model = DDP(model)

    model.to(device=device)
    model.__DEBUG__ = False

    try:
        train(
            model=model,
            model_config=model_config,
            config=config,
            device=device,
            logger=logger,
            debug=config.debug,
        )
    except KeyboardInterrupt:
        torch.save(
            {
                "model_state_dict": model.state_dict(),
示例#2
0
class TrainPatch:
    def __init__(self, mode):
        self.config = patch_config_types[mode]()

        self.yolov5 = attempt_load(self.config.weight_file, map_location=device)

        self.img_size = self.config.patch_size

        self.dot_applier = DotApplier(
            self.config.num_of_dots,
            self.img_size,
            self.config.alpha_max,
            self.config.beta_dropoff)

        self.patch_applier = PatchApplier()

        self.non_printability_score = NonPrintabilityScore(
            self.config.print_file,
            self.config.num_of_dots)

        self.eval_type = self.config.eval_data_type
        self.clean_img_dict = np.load('confidences/yolov5/medium/clean_img_conf_lisa_ordered.npy', allow_pickle=True).item()

        self.detections = DetectionsYolov5(
            cls_id=self.config.class_id,
            num_cls=self.config.num_classes,
            config=self.config,
            clean_img_conf=self.clean_img_dict,
            conf_threshold=self.config.conf_threshold)

        self.noise_amount = NoiseAmount(self.config.radius_lower_bound, self.config.radius_upper_bound)

        # self.set_multiple_gpu()
        self.set_to_device()

        if self.config.eval_data_type == 'ordered' or self.config.eval_data_type == 'one':
            split_dataset = SplitDataset(
                img_dir=self.config.img_dir,
                lab_dir=self.config.lab_dir,
                max_lab=self.config.max_labels_per_img,
                img_size=self.img_size,
                transform=transforms.Compose([transforms.Resize((self.img_size, self.img_size)), transforms.ToTensor()]))
        else:
            split_dataset = SplitDataset1(
                img_dir_train_val=self.config.img_dir,
                lab_dir_train_val=self.config.lab_dir,
                img_dir_test=self.config.img_dir_test,
                lab_dir_test=self.config.lab_dir_test,
                max_lab=self.config.max_labels_per_img,
                img_size=self.img_size,
                transform=transforms.Compose(
                    [transforms.Resize((self.img_size, self.img_size)), transforms.ToTensor()]))

        self.train_loader, self.val_loader, self.test_loader = split_dataset(val_split=0.2,
                                                                             test_split=0.2,
                                                                             shuffle_dataset=True,
                                                                             random_seed=seed,
                                                                             batch_size=self.config.batch_size,
                                                                             ordered=True)

        self.train_losses = []
        self.val_losses = []

        self.max_prob_losses = []
        self.cor_det_losses = []
        self.nps_losses = []
        self.noise_losses = []

        self.train_acc = []
        self.val_acc = []
        self.final_epoch_count = self.config.epochs

        my_date = datetime.datetime.now()
        month_name = my_date.strftime("%B")
        if 'SLURM_JOBID' not in os.environ.keys():
            self.current_dir = "experiments/" + month_name + '/' + time.strftime("%d-%m-%Y") + '_' + time.strftime("%H-%M-%S")
        else:
            self.current_dir = "experiments/" + month_name + '/' + time.strftime("%d-%m-%Y") + '_' + os.environ['SLURM_JOBID']
        self.create_folders()
        # self.save_config_details()

        # subprocess.Popen(['tensorboard', '--logdir=' + self.current_dir + '/runs'])
        # self.writer = SummaryWriter(self.current_dir + '/runs')
        self.writer = None

    def set_to_device(self):
        self.dot_applier = self.dot_applier.to(device)
        self.patch_applier = self.patch_applier.to(device)
        self.detections = self.detections.to(device)
        self.non_printability_score = self.non_printability_score.to(device)
        self.noise_amount = self.noise_amount.to(device)

    def set_multiple_gpu(self):
        if torch.cuda.device_count() > 1:
            print("more than 1")
            self.dot_applier = DP(self.dot_applier)
            self.patch_applier = DP(self.patch_applier)
            self.detections = DP(self.detections)

    def create_folders(self):
        Path('/'.join(self.current_dir.split('/')[:2])).mkdir(parents=True, exist_ok=True)
        Path(self.current_dir).mkdir(parents=True, exist_ok=True)
        Path(self.current_dir + '/final_results').mkdir(parents=True, exist_ok=True)
        Path(self.current_dir + '/saved_patches').mkdir(parents=True, exist_ok=True)
        Path(self.current_dir + '/losses').mkdir(parents=True, exist_ok=True)
        Path(self.current_dir + '/testing').mkdir(parents=True, exist_ok=True)

    def train(self):
        epoch_length = len(self.train_loader)
        print(f'One epoch is {epoch_length} batches', flush=True)

        optimizer = Adam([{'params': self.dot_applier.theta, 'lr': self.config.loc_lr},
                          {'params': self.dot_applier.colors, 'lr': self.config.color_lr},
                          {'params': self.dot_applier.radius, 'lr': self.config.radius_lr}],
                         amsgrad=True)
        scheduler = self.config.scheduler_factory(optimizer)
        early_stop = EarlyStopping(delta=1e-3, current_dir=self.current_dir, patience=20)

        clipper = WeightClipper(self.config.radius_lower_bound, self.config.radius_upper_bound)
        adv_patch_cpu = torch.zeros((1, 3, self.img_size, self.img_size), dtype=torch.float32)
        alpha_cpu = torch.zeros((1, 1, self.img_size, self.img_size), dtype=torch.float32)
        for epoch in range(self.config.epochs):
            train_loss = 0.0
            max_prob_loss = 0.0
            cor_det_loss = 0.0
            nps_loss = 0.0
            noise_loss = 0.0

            progress_bar = tqdm(enumerate(self.train_loader), desc=f'Epoch {epoch}', total=epoch_length)
            prog_bar_desc = 'train-loss: {:.6}, ' \
                            'maxprob-loss: {:.6}, ' \
                            'corr det-loss: {:.6}, ' \
                            'nps-loss: {:.6}, ' \
                            'noise-loss: {:.6}'
            for i_batch, (img_batch, lab_batch, img_names) in progress_bar:
                # move tensors to gpu
                img_batch = img_batch.to(device)
                lab_batch = lab_batch.to(device)
                adv_patch = adv_patch_cpu.to(device)
                alpha = alpha_cpu.to(device)

                # forward prop
                adv_patch, alpha = self.dot_applier(adv_patch, alpha)  # put dots on patch

                applied_batch = self.patch_applier(img_batch, adv_patch, alpha)  # apply patch on a batch of images

                if epoch == 0 and i_batch == 0:
                    self.save_initial_patch(adv_patch, alpha)

                output_patch = self.yolov5(applied_batch)[0]  # get yolo output with patch

                max_prob, cor_det = self.detections(lab_batch, output_patch, img_names)

                nps = self.non_printability_score(self.dot_applier.colors)

                noise = self.noise_amount(self.dot_applier.radius)

                loss, loss_arr = self.loss_function(max_prob, cor_det, nps, noise)  # calculate loss

                # save losses
                max_prob_loss += loss_arr[0].item()
                cor_det_loss += loss_arr[1].item()
                nps_loss += loss_arr[2].item()
                noise_loss += loss_arr[3].item()
                train_loss += loss.item()

                # back prop
                optimizer.zero_grad()
                loss.backward()

                # update parameters
                optimizer.step()
                self.dot_applier.apply(clipper)  # clip x,y coordinates

                progress_bar.set_postfix_str(prog_bar_desc.format(train_loss / (i_batch + 1),
                                                                  max_prob_loss / (i_batch + 1),
                                                                  cor_det_loss / (i_batch + 1),
                                                                  nps_loss / (i_batch + 1),
                                                                  noise_loss / (i_batch + 1)))

                if i_batch % 1 == 0 and self.writer is not None:
                    self.write_to_tensorboard(adv_patch, alpha,train_loss, max_prob_loss, cor_det_loss, nps_loss, noise_loss,
                                              epoch_length, epoch, i_batch, optimizer)
                # self.writer.add_image('patch', adv_patch.squeeze(0), epoch_length * epoch + i_batch)
                if i_batch + 1 == epoch_length:
                    self.last_batch_calc(adv_patch, alpha, epoch_length, progress_bar, prog_bar_desc,
                                         train_loss, max_prob_loss, cor_det_loss, nps_loss, noise_loss,
                                         optimizer, epoch, i_batch)

                # self.run_slide_show(adv_patch)

                # clear gpu
                del img_batch, lab_batch, applied_batch, output_patch, max_prob, cor_det, nps, noise, loss
                torch.cuda.empty_cache()

            # check if loss has decreased
            if early_stop(self.val_losses[-1], adv_patch.cpu(), alpha.cpu(), epoch):
                self.final_epoch_count = epoch
                break

            scheduler.step(self.val_losses[-1])

        self.adv_patch = early_stop.best_patch
        self.alpha = early_stop.best_alpha
        print("Training finished")

    def get_image_size(self):
        if type(self.yolov2) == nn.DataParallel:
            img_size = self.yolov2.module.height
        else:
            img_size = self.yolov2.height
        return int(img_size)

    def evaluate_loss(self, loader, adv_patch, alpha):
        total_loss = 0.0
        for img_batch, lab_batch, img_names in loader:
            with torch.no_grad():
                img_batch = img_batch.to(device)
                lab_batch = lab_batch.to(device)

                applied_batch = self.patch_applier(img_batch, adv_patch, alpha)
                output_patch = self.yolov5(applied_batch)[0]
                max_prob, cor_det = self.detections(lab_batch, output_patch, img_names)
                nps = self.non_printability_score(self.dot_applier.colors)
                noise = self.noise_amount(self.dot_applier.radius)
                batch_loss, _ = self.loss_function(max_prob, cor_det, nps, noise)
                total_loss += batch_loss.item()

                del img_batch, lab_batch, applied_batch, output_patch, max_prob, cor_det, nps, noise, batch_loss
                torch.cuda.empty_cache()
        loss = total_loss / len(loader)
        return loss

    def plot_train_val_loss(self):
        epochs = [x + 1 for x in range(len(self.train_losses))]
        plt.plot(epochs, self.train_losses, 'b', label='Training loss')
        plt.plot(epochs, self.val_losses, 'r', label='Validation loss')
        plt.title('Training and validation loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend(loc='upper right')
        plt.savefig(self.current_dir + '/final_results/train_val_loss_plt.png')
        plt.close()

    def plot_separate_loss(self):
        epochs = [x + 1 for x in range(len(self.train_losses))]
        weights = np.array([self.config.max_prob_weight, self.config.pres_det_weight, self.config.nps_weight, self.config.noise_weight])
        number_of_subplots = weights[weights > 0].astype(np.bool).sum()
        fig, axes = plt.subplots(nrows=1, ncols=number_of_subplots, figsize=(5 * number_of_subplots, 3 * number_of_subplots), squeeze=False)
        for idx, (weight, loss, label, color_name) in enumerate(zip([self.config.max_prob_weight, self.config.pres_det_weight, self.config.nps_weight, self.config.noise_weight],
                                                                    [self.max_prob_losses, self.cor_det_losses, self.nps_losses, self.noise_losses],
                                                                    ['Max probability loss', 'Correct detections loss', 'Non printability loss', 'Noise Amount loss'],
                                                                    'brgkyc')):
            if weight > 0:
                axes[0, idx].plot(epochs, loss, c=color_name, label=label)
                axes[0, idx].set_xlabel('Epoch')
                axes[0, idx].set_ylabel('Loss')
                axes[0, idx].legend(loc='upper right')
        fig.tight_layout()
        plt.savefig(self.current_dir + '/final_results/separate_loss_plt.png')
        plt.close()

    def plot_combined(self):
        epochs = [x + 1 for x in range(len(self.train_losses))]
        fig, ax1 = plt.subplots(ncols=1, figsize=(8, 4))
        ax1.plot(epochs, self.max_prob_losses, c='b', label='Max Probability')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.tick_params(axis='y', labelcolor='b')
        ax2 = ax1.twinx()
        ax2.plot(epochs, self.cor_det_losses, c='r', label='Correct Detections')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Loss')
        ax2.tick_params(axis='y', labelcolor='r')
        ax2.legend(loc='upper right')
        fig.tight_layout()
        plt.savefig(self.current_dir + '/final_results/combined_losses.png')
        plt.close()

    def loss_function(self, max_prob, cor_det, nps, noise):
        max_prob_loss = self.config.max_prob_weight * torch.mean(max_prob)
        cor_det_loss = self.config.pres_det_weight * torch.mean(cor_det)
        nps_loss = self.config.nps_weight * nps
        noise_loss = self.config.noise_weight * noise
        return max_prob_loss + cor_det_loss + nps_loss + noise_loss, [max_prob_loss, cor_det_loss, nps_loss, noise_loss]

    def save_final_objects(self):
        # save patch
        transforms.ToPILImage()(self.adv_patch.squeeze(0)).save(
            self.current_dir + '/final_results/final_patch_wo_alpha.png', 'PNG')
        torch.save(self.adv_patch, self.current_dir + '/final_results/final_patch_raw.pt')
        transforms.ToPILImage()(self.alpha.squeeze(0)).save(self.current_dir + '/final_results/alpha.png', 'PNG')
        torch.save(self.alpha, self.current_dir + '/final_results/alpha_raw.pt')
        final_patch = torch.cat([self.adv_patch.squeeze(0), self.alpha.squeeze(0)])
        transforms.ToPILImage()(final_patch.cpu()).save(self.current_dir + '/final_results/final_patch_w_alpha.png',
                                                        'PNG')
        # save losses
        with open(self.current_dir + '/losses/train_losses', 'wb') as fp:
            pickle.dump(self.train_losses, fp)
        with open(self.current_dir + '/losses/val_losses', 'wb') as fp:
            pickle.dump(self.val_losses, fp)
        with open(self.current_dir + '/losses/max_prob_losses', 'wb') as fp:
            pickle.dump(self.max_prob_losses, fp)
        with open(self.current_dir + '/losses/cor_det_losses', 'wb') as fp:
            pickle.dump(self.cor_det_losses, fp)
        with open(self.current_dir + '/losses/nps_losses', 'wb') as fp:
            pickle.dump(self.nps_losses, fp)
        with open(self.current_dir + '/losses/noise_losses', 'wb') as fp:
            pickle.dump(self.noise_losses, fp)

    def save_final_results(self, avg_precision):
        target_noise_ap, target_patch_ap, other_noise_ap, other_patch_ap = avg_precision
        # calculate test loss
        test_loss = self.evaluate_loss(self.test_loader, self.adv_patch.to(device),
                                       self.alpha.to(device))
        print("Test loss: " + str(test_loss))
        self.save_config_details()
        row_to_csv = \
            str(self.train_losses[-1]) + ',' + \
            str(self.val_losses[-1]) + ',' + \
            str(test_loss) + ',' + \
            str(self.max_prob_losses[-1]) + ',' + \
            str(self.cor_det_losses[-1]) + ',' + \
            str(self.nps_losses[-1]) + ',' + \
            str(self.final_epoch_count) + '/' + str(self.config.epochs) + ',' + \
            str(target_noise_ap) + ',' + \
            str(target_patch_ap) + ',' + \
            str(other_noise_ap) + ',' + \
            str(other_patch_ap) + '\n'

        # write results to csv
        with open('experiments/results.csv', 'a') as fd:
            fd.write(row_to_csv)

    def write_to_tensorboard(self, adv_patch, alpha, train_loss, max_prob_loss, cor_det_loss, nps_loss, noise_loss,
                             epoch_length, epoch, i_batch, optimizer):
        iteration = epoch_length * epoch + i_batch
        self.writer.add_scalar('train_loss', train_loss / (i_batch + 1), iteration)
        self.writer.add_scalar('loss/max_prob_loss', max_prob_loss / (i_batch + 1), iteration)
        self.writer.add_scalar('loss/cor_det_loss', cor_det_loss / (i_batch + 1), iteration)
        self.writer.add_scalar('loss/nps_loss', nps_loss / (i_batch + 1), iteration)
        self.writer.add_scalar('loss/noise_loss', noise_loss / (i_batch + 1), iteration)
        self.writer.add_scalar('misc/epoch', epoch, iteration)
        self.writer.add_scalar('misc/loc_learning_rate', optimizer.param_groups[0]["lr"], iteration)
        self.writer.add_scalar('misc/color_learning_rate', optimizer.param_groups[1]["lr"], iteration)
        self.writer.add_scalar('misc/radius_learning_rate', optimizer.param_groups[2]["lr"], iteration)
        self.writer.add_image('patch_rgb', adv_patch.squeeze(0), iteration)
        self.writer.add_image('patch_rgba', torch.cat([adv_patch.squeeze(0), alpha.squeeze(0)]), iteration)

    def last_batch_calc(self, adv_patch, alpha, epoch_length, progress_bar, prog_bar_desc,
                        train_loss, max_prob_loss, cor_det_loss, nps_loss, noise_loss,
                        optimizer, epoch, i_batch):
        # calculate epoch losses
        train_loss /= epoch_length
        max_prob_loss /= epoch_length
        cor_det_loss /= epoch_length
        nps_loss /= epoch_length
        noise_loss /= epoch_length
        self.train_losses.append(train_loss)
        self.max_prob_losses.append(max_prob_loss)
        self.cor_det_losses.append(cor_det_loss)
        self.nps_losses.append(nps_loss)
        self.noise_losses.append(noise_loss)

        # check on validation
        val_loss = self.evaluate_loss(self.val_loader, adv_patch, alpha)
        self.val_losses.append(val_loss)

        prog_bar_desc += ', val-loss: {:.6}, loc-lr: {:.6}, color-lr: {:.6}, radius-lr: {:.6}'
        progress_bar.set_postfix_str(prog_bar_desc.format(train_loss,
                                                          max_prob_loss,
                                                          cor_det_loss,
                                                          nps_loss,
                                                          noise_loss,
                                                          val_loss,
                                                          optimizer.param_groups[0]['lr'],
                                                          optimizer.param_groups[1]['lr'],
                                                          optimizer.param_groups[2]['lr']))
        if self.writer is not None:
            self.writer.add_scalar('loss/val_loss', val_loss, epoch_length * epoch + i_batch)

    def get_clean_image_conf(self):
        clean_img_dict = dict()
        for loader in [self.train_loader, self.val_loader, self.test_loader]:
            for img_batch, lab_batch, img_name in loader:
                img_batch = img_batch.to(device)
                lab_batch = lab_batch.to(device)

                output = self.yolov5(img_batch)[0]
                output = output.transpose(1, 2).contiguous()
                output_objectness, output = output[:, 4, :], output[:, 5:, :]
                batch_idx = torch.index_select(lab_batch, 2, torch.tensor([0], dtype=torch.long).to(device))
                for i in range(batch_idx.size()[0]):
                    ids = np.unique(
                        batch_idx[i][(batch_idx[i] >= 0) & (batch_idx[i] != self.config.class_id)].cpu().numpy().astype(
                            int))
                    if len(ids) == 0:
                        continue
                    clean_img_dict[img_name[i]] = dict()
                    # get relevant classes
                    confs_for_class = output[i, ids, :]
                    confs_if_object = self.config.loss_target(output_objectness[i], confs_for_class)

                    # find the max prob for each related class
                    max_conf, _ = torch.max(confs_if_object, dim=1)
                    for j, label in enumerate(ids):
                        clean_img_dict[img_name[i]][label] = max_conf[j].item()

                del img_batch, lab_batch, output, output_objectness, batch_idx
                torch.cuda.empty_cache()

        print(len(clean_img_dict))
        np.save('confidences/' + self.config.model_name + '/medium/clean_img_conf_lisa_ordered.npy', clean_img_dict)

    # def get_clean_image_conf(self):
    #     clean_img_dict = dict()
    #     for loader in [self.train_loader, self.val_loader, self.test_loader]:
    #         for img_batch, lab_batch, img_name in loader:
    #             img_batch = img_batch.to(device)
    #             lab_batch = lab_batch.to(device)
    #
    #             output = self.yolov2(img_batch)
    #             batch = output.size(0)
    #             h = output.size(2)
    #             w = output.size(3)
    #             output = output.view(batch, self.yolov2.num_anchors, 5 + self.config.num_classes,
    #                                  h * w)  # [batch, 5, 85, 361]
    #             output = output.transpose(1, 2).contiguous()  # [batch, 85, 5, 361]
    #             output = output.view(batch, 5 + self.config.num_classes,
    #                                  self.yolov2.num_anchors * h * w)  # [batch, 85, 1805]
    #             output_objectness = torch.sigmoid(output[:, 4, :])  # [batch, 1805]
    #             output = output[:, 5:5 + self.config.num_classes, :]  # [batch, 80, 1805]
    #             normal_confs = torch.nn.Softmax(dim=1)(output)  # [batch, 80, 1805]
    #             batch_idx = torch.index_select(lab_batch, 2, torch.tensor([0], dtype=torch.long).to(device))
    #             for i in range(batch_idx.size(0)):
    #                 ids = np.unique(
    #                     batch_idx[i][(batch_idx[i] >= 0) & (batch_idx[i] != self.config.class_id)].cpu().numpy().astype(
    #                         int))
    #                 if len(ids) == 0:
    #                     continue
    #                 clean_img_dict[img_name[i]] = dict()
    #                 # get relevant classes
    #                 confs_for_class = normal_confs[i, ids, :]
    #                 confs_if_object = self.config.loss_target(output_objectness[i], confs_for_class)
    #
    #                 # find the max prob for each related class
    #                 max_conf, _ = torch.max(confs_if_object, dim=1)
    #                 for j, label in enumerate(ids):
    #                     clean_img_dict[img_name[i]][label] = max_conf[j].item()
    #
    #             del img_batch, lab_batch, output, output_objectness, normal_confs, batch_idx
    #             torch.cuda.empty_cache()
    #
    #     print(len(clean_img_dict))
    #     np.save('confidences/clean_img_conf_lisa_new_color.npy', clean_img_dict)

    def save_config_details(self):
        # write results to csv
        row_to_csv = self.current_dir.split('/')[-1] + ',' + \
                     self.config.model_name + ',' + \
                     self.config.img_dir.split('/')[-2] + ',' + \
                     str(self.config.loc_lr) + '-' + str(self.config.color_lr) + '-' + str(self.config.radius_lr) + ',' + \
                     str(self.config.sched_cooldown) + ',' + \
                     str(self.config.sched_patience) + ',' + \
                     str(self.config.loss_mode) + ',' + \
                     str(self.config.conf_threshold) + ',' + \
                     str(self.config.max_prob_weight) + ',' + \
                     str(self.config.pres_det_weight) + ',' + \
                     str(self.config.nps_weight) + ',' + \
                     str(self.config.num_of_dots) + ',' + \
                     str(None) + ',' + \
                     str(self.config.alpha_max) + ',' + \
                     str(self.config.beta_dropoff) + ','
        with open('experiments/results.csv', 'a') as fd:
            fd.write(row_to_csv)

    def save_initial_patch(self, adv_patch, alpha):
        transforms.ToPILImage()(adv_patch.cpu().squeeze(0)).save(self.current_dir + '/saved_patches/initial_patch.png')
        transforms.ToPILImage()(alpha.cpu().squeeze(0)).save(self.current_dir + '/saved_patches/initial_alpha.png')

    @staticmethod
    def run_slide_show(adv_patch):
        adv_to_show = adv_patch.detach().cpu()
        adv_to_show = torch.where(adv_to_show == 0, torch.ones_like(adv_to_show), adv_to_show)
        transforms.ToPILImage()(adv_to_show.squeeze(0)).save('current_slide.jpg')
        img = cv2.imread('current_slide.jpg')
        cv2.imshow('slide show', img)
        cv2.waitKey(1)
class Model:
    """
    This class handles basic methods for handling the model:
    1. Fit the model
    2. Make predictions
    3. Make inference predictions
    3. Save
    4. Load weights
    5. Restore the model
    6. Restore the model with averaged weights
    """
    def __init__(self, hparams, gpu=None, inference=False):

        self.hparams = hparams
        self.gpu = gpu
        self.inference = inference

        self.start_training = time()

        # ininialize model architecture
        self.__setup_model(inference=inference, gpu=gpu)
        self.postprocessing = Post_Processing()

        # define model parameters
        self.__setup_model_hparams()

        # declare preprocessing object
        self.__seed_everything(42)

    def fit(self, train, valid, pretrain):

        # setup train and val dataloaders
        train_loader = DataLoader(
            train,
            batch_size=self.hparams['batch_size'],
            shuffle=True,
            num_workers=self.hparams['num_workers'],
        )
        valid_loader = DataLoader(
            valid,
            batch_size=self.hparams['batch_size'],
            shuffle=False,
            num_workers=self.hparams['num_workers'],
        )

        adv_loader = DataLoader(pretrain,
                                batch_size=self.hparams['batch_size'],
                                shuffle=True,
                                num_workers=0)

        # tensorboard
        writer = SummaryWriter(
            f"runs/{self.hparams['model_name']}_{self.start_training}")

        print('Start training the model')
        for epoch in range(self.hparams['n_epochs']):

            # training mode
            self.model.train()
            avg_loss = 0.0
            avg_adv_loss = 0.0

            for X_batch, y_batch, X_batch_adv, y_batch_adv in tqdm(
                    train_loader):

                sample = np.round(np.random.uniform(size=X_batch.shape[0]), 2)
                X_batch_adv_train_val, _, _, _ = next(iter(adv_loader))
                X_batch_adv_train_val = X_batch_adv_train_val[:X_batch.
                                                              shape[0]]
                X_batch_adv[sample >= 0.5] = X_batch_adv_train_val[
                    sample >= 0.5]
                y_batch_adv[sample >= 0.5] = 1
                y_batch_adv[sample < 0.5] = 0

                # push the data into the GPU
                X_batch = X_batch.float().to(self.device)
                y_batch = y_batch.float().to(self.device)
                X_batch_adv = X_batch_adv.float().to(self.device)
                y_batch_adv = y_batch_adv.float().to(self.device)

                # clean gradients from the previous step
                self.optimizer.zero_grad()

                # get model predictions
                pred, pred_adv = self.model(X_batch, X_batch_adv, train=True)

                # process main loss
                pred = pred.reshape(-1)
                y_batch = y_batch.reshape(-1)
                train_loss = self.loss(pred, y_batch)

                # process loss_2
                pred_adv = pred_adv.reshape(-1)
                y_batch_adv = y_batch_adv.reshape(-1)
                adv_loss = self.loss_adv(pred_adv, y_batch_adv)

                # calc loss
                avg_loss += train_loss.item() / len(train_loader)
                avg_adv_loss += adv_loss.item() / len(train_loader)

                train_loss = train_loss + self.hparams['model'][
                    'alpha'] * adv_loss

                # remove data from GPU
                y_batch = y_batch.float().cpu().detach().numpy()
                pred = pred.float().cpu().detach().numpy()
                X_batch = X_batch.float().cpu().detach().numpy()
                X_batch_adv = X_batch_adv.float().cpu().detach().numpy()
                y_batch_adv = y_batch_adv.cpu().detach().numpy()
                pred_adv = pred_adv.cpu().detach().numpy()

                # gradient clipping
                if self.apply_clipping:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
                    torch.nn.utils.clip_grad_value_(self.model.parameters(),
                                                    0.5)

                # backprop
                train_loss.backward()

                # iptimizer step
                self.optimizer.step()

                y_batch = self.postprocessing.run(y_batch)
                pred = self.postprocessing.run(pred)

                # calculate a step for metrics
                self.metric.calc_running_score(labels=y_batch, outputs=pred)

            # calc train metrics
            metric_train = self.metric.compute()

            # evaluate the model
            print('Model evaluation')

            # val mode
            self.model.eval()
            self.optimizer.zero_grad()
            avg_val_loss = 0.0

            with torch.no_grad():

                for X_batch, y_batch, _, _ in tqdm(valid_loader):

                    # push the data into the GPU
                    X_batch = X_batch.float().to(self.device)
                    y_batch = y_batch.float().to(self.device)

                    # get predictions
                    pred = self.model(X_batch)

                    pred = pred.reshape(-1)
                    y_batch = y_batch.reshape(-1)
                    avg_val_loss += self.loss(
                        pred, y_batch).item() / len(valid_loader)

                    # remove data from GPU
                    X_batch = X_batch.float().cpu().detach().numpy()
                    pred = pred.float().cpu().detach().numpy()
                    y_batch = y_batch.float().cpu().detach().numpy()

                    y_batch = self.postprocessing.run(y_batch)
                    pred = self.postprocessing.run(pred)

                    # calculate a step for metrics
                    self.metric.calc_running_score(labels=y_batch,
                                                   outputs=pred)

            # calc val metrics
            metric_val = self.metric.compute()

            # early stopping for scheduler
            if self.hparams['scheduler_name'] == 'ReduceLROnPlateau':
                self.scheduler.step(metric_val)
            else:
                self.scheduler.step()

            es_result = self.early_stopping(score=metric_val,
                                            model=self.model,
                                            threshold=None)

            # print statistics
            if self.hparams['verbose_train']:
                print(
                    '| Epoch: ',
                    epoch + 1,
                    '| Train_loss: ',
                    avg_loss,
                    '| Val_loss: ',
                    avg_val_loss,
                    '| Adv_loss: ',
                    avg_adv_loss,
                    '| Metric_train: ',
                    metric_train,
                    '| Metric_val: ',
                    metric_val,
                    '| Current LR: ',
                    self.__get_lr(self.optimizer),
                )

            # add data to tensorboard
            writer.add_scalars(
                'Loss',
                {
                    'Train_loss': avg_loss,
                    'Val_loss': avg_val_loss
                },
                epoch,
            )
            writer.add_scalars('Metric', {
                'Metric_train': metric_train,
                'Metric_val': metric_val
            }, epoch)

            # early stopping procesudre
            if es_result == 2:
                print("Early Stopping")
                print(
                    f'global best val_loss model score {self.early_stopping.best_score}'
                )
                break
            elif es_result == 1:
                print(f'save global val_loss model score {metric_val}')

        writer.close()

        # load the best model trained so fat
        self.model = self.early_stopping.load_best_weights()

        return self.start_training

    def predict(self, X_test):
        """
        This function makes:
        1. batch-wise predictions
        2. calculation of the metric for each sample
        3. calculation of the metric for the entire dataset

        Parameters
        ----------
        X_test

        Returns
        -------

        """

        # evaluate the model
        self.model.eval()

        test_loader = torch.utils.data.DataLoader(
            X_test,
            batch_size=self.hparams['batch_size'],
            shuffle=False,
            num_workers=0)

        self.metric.reset()

        print('Getting predictions')
        with torch.no_grad():
            for i, (X_batch, y_batch, _, _) in enumerate(tqdm(test_loader)):
                X_batch = X_batch.float().to(self.device)
                y_batch = y_batch.float().to(self.device)

                pred = self.model(X_batch)

                pred = pred.reshape(-1)
                y_batch = y_batch.reshape(-1)

                pred = pred.cpu().detach().numpy()
                X_batch = X_batch.cpu().detach().numpy()
                y_batch = y_batch.cpu().detach().numpy()

                y_batch = self.postprocessing.run(y_batch)
                pred = self.postprocessing.run(pred)

                self.metric.calc_running_score(labels=y_batch, outputs=pred)

        fold_score = self.metric.compute()

        return fold_score

    def save(self, model_path):

        print('Saving the model')

        # states (weights + optimizers)
        if self.gpu != None:
            if len(self.gpu) > 1:
                torch.save(self.model.module.state_dict(), model_path + '.pt')
            else:
                torch.save(self.model.state_dict(), model_path + '.pt')
        else:
            torch.save(self.model.state_dict(), model_path)

        # hparams
        with open(f"{model_path}_hparams.yml", 'w') as file:
            yaml.dump(self.hparams, file)

        return True

    def load(self, model_name):
        self.model.load_state_dict(
            torch.load(model_name + '.pt', map_location=self.device))
        self.model.eval()
        return True

    @classmethod
    def restore(cls, model_name: str, gpu: list, inference: bool):

        if gpu is not None:
            assert all([isinstance(i, int)
                        for i in gpu]), "All gpu indexes should be integer"

        # load hparams
        hparams = yaml.load(open(model_name + "_hparams.yml"),
                            Loader=yaml.FullLoader)

        # construct class
        self = cls(hparams, gpu=gpu, inference=inference)

        # load weights + optimizer state
        self.load(model_name=model_name)

        return self

    ################## Utils #####################

    def __get_lr(self, optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

    def __setup_model(self, inference, gpu):

        # TODO: re-write to pure DDP
        if inference or gpu is None:
            self.device = torch.device('cpu')
            self.model = EfficientNet.from_pretrained(
                self.hparams['model']['pre_trained_model'],
                num_classes=self.hparams['model']['n_classes'])
            self.model.build_adv_model()
            self.model = self.model.to(self.device)
        else:
            if torch.cuda.device_count() > 1:
                if len(gpu) > 1:
                    print("Number of GPUs will be used: ", len(gpu))
                    self.device = torch.device(f"cuda:{gpu[0]}" if torch.cuda.
                                               is_available() else "cpu")
                    self.model = EfficientNet.from_pretrained(
                        self.hparams['model']['pre_trained_model'],
                        num_classes=self.hparams['model']['n_classes'],
                    )
                    self.model.build_adv_model()
                    self.model = self.model.to(self.device)
                    self.model = DP(self.model,
                                    device_ids=gpu,
                                    output_device=gpu[0])
                else:
                    print("Only one GPU will be used")
                    self.device = torch.device(f"cuda:{gpu[0]}" if torch.cuda.
                                               is_available() else "cpu")
                    self.model = EfficientNet.from_pretrained(
                        self.hparams['model']['pre_trained_model'],
                        num_classes=self.hparams['model']['n_classes'],
                    )
                    self.model.build_adv_model()
                    self.model = self.model.to(self.device)
            else:
                self.device = torch.device(
                    f"cuda:{gpu[0]}" if torch.cuda.is_available() else "cpu")
                self.model = EfficientNet.from_pretrained(
                    self.hparams['model']['pre_trained_model'],
                    num_classes=self.hparams['model']['n_classes'],
                )
                self.model.build_adv_model()
                self.model = self.model.to(self.device)
                print('Only one GPU is available')

        print('Cuda available: ', torch.cuda.is_available())

        return True

    def __setup_model_hparams(self):

        # 1. define losses
        self.loss = nn.L1Loss()
        self.loss_adv = nn.BCELoss()

        # 2. define model metric
        self.metric = Kappa()

        # 3. define optimizer
        self.optimizer = eval(f"torch.optim.{self.hparams['optimizer_name']}")(
            params=self.model.parameters(),
            **self.hparams['optimizer_hparams'])

        # 4. define scheduler
        self.scheduler = eval(
            f"torch.optim.lr_scheduler.{self.hparams['scheduler_name']}")(
                optimizer=self.optimizer, **self.hparams['scheduler_hparams'])

        # 5. define early stopping
        self.early_stopping = EarlyStopping(
            checkpoint_path=self.hparams['checkpoint_path'] +
            f'/checkpoint_{self.start_training}' + '.pt',
            patience=self.hparams['patience'],
            delta=self.hparams['min_delta'],
            is_maximize=True,
        )

        # 6. set gradient clipping
        self.apply_clipping = self.hparams['clipping']  # clipping of gradients

        # 7. Set scaler for optimizer
        self.scaler = torch.cuda.amp.GradScaler()

        return True

    def __seed_everything(self, seed):
        np.random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        torch.manual_seed(seed)
        random.seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
示例#4
0
class ELM_Classifier_finetune:
    def __init__(self, args) -> None:
        """Use ELM with fintuned language model for sentiment classification

        Args:
            args (dict): contain all the arguments needed.
                - model_name(str): the name of the transformer model
                - bsz(int): batch size
                - epoch: epochs to train
                - type(str): fintuned type
                  - base: train only ELM
                  - finetune_elm: train transformers with ELM directly
                  - finetune_classifier: train transformers with classifier
                  - finetune_classifier_elm: train transformers with classifier,
                    and use elm replace the classifier
                  - finetune_classifier_beta: train transformers with classifier,
                    and use pinv to calculate beta in classifier
                - learning_rate(float): learning_rate for finetuning
        """
        # load configuration
        self.model_name = args.get('model_name', 'bert-base-uncased')
        self.bsz = args.get('batch_size', 10)
        self.epoch = args.get('epoch_num', 2)
        self.learning_rate = args.get('learning_rate', 0.001)
        self.training_type = args.get('training_type', 'base')
        self.debug = args.get('debug', True)
        self.eval_epoch = args.get('eval_epoch', 1)
        self.lr_decay = args.get('learning_rate_decay', 0.99)
        if torch.cuda.is_available():
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
        self.device = device
        self.n_gpu = torch.cuda.device_count()

        # load pretrained model
        if (self.model_name == 'bert-base-uncased') or \
                (self.model_name == 'distilbert-base-uncased') or \
                (self.model_name == 'albert-base-v2'):
            self.pretrained_model = AutoModel.from_pretrained(self.model_name)
            self.pretrained_tokenizer = AutoTokenizer.from_pretrained(
                self.model_name)
            input_shape = 768
            output_shape = 256
        elif (self.model_name == 'prajjwal1/bert-tiny'):
            self.pretrained_model = AutoModel.from_pretrained(self.model_name)
            self.pretrained_tokenizer = AutoTokenizer.from_pretrained(
                self.model_name, model_max_length=512)
            input_shape = 128
            output_shape = 64
        elif self.model_name == 'voidful/albert_chinese_xxlarge':
            self.pretrained_model = AlbertForMaskedLM.from_pretrained(
                self.model_name)
            self.pretrained_tokenizer = BertTokenizer.from_pretrained(
                self.model_name)
            input_shape = 768
            output_shape = 256
        else:
            raise TypeError("Unsupported model name")
        self.pretrained_model.to(device)
        device_ids = None
        if self.n_gpu > 1:
            device_ids = range(torch.cuda.device_count())
            self.pretrained_model = DP(self.pretrained_model,
                                       device_ids=device_ids)

        # load specific model
        if (self.training_type == 'finetune_classifier') or \
            (self.training_type == 'finetune_classifier_elm'):
            self.classifier = torch.nn.Sequential(
                torch.nn.Linear(input_shape, 2))
            self.loss_func = torch.nn.CrossEntropyLoss()
            self.classifier.to(device)
            if self.n_gpu > 1:
                self.classifier = DP(self.classifier, device_ids=device_ids)
        if (self.training_type == 'base') or \
            (self.training_type =='finetune_classifier_elm'):
            self.elm = classic_ELM(input_shape, output_shape)
        if (self.training_type == 'finetune_classifier_linear'):
            self.elm = classic_ELM(None, None)
            self.classifier = torch.nn.Sequential(
                OrderedDict([
                    ('w', torch.nn.Linear(input_shape, output_shape)),
                    ('act', torch.nn.Sigmoid()),
                    ('beta', torch.nn.Linear(output_shape, 2)),
                ]))
            self.loss_func = torch.nn.CrossEntropyLoss()
            self.classifier.to(device)
            if self.n_gpu > 1:
                self.classifier = DP(self.classifier, device_ids=device_ids)

        # load processor, trainer, evaluator, inferer.
        processors = {
            'base': self.__processor_base__,
            'finetune_classifier': self.__processor_base__,
            'finetune_classifier_elm': self.__processor_base__,
            'finetune_classifier_linear': self.__processor_base__,
        }
        trainers = {
            'base':
            self.__train_base__,
            'finetune_classifier':
            self.__train_finetune_classifier__,
            'finetune_classifier_elm':
            self.__train_finetune_classifier_elm__,
            'finetune_classifier_linear':
            self.__train_finetune_classifier_linear__,
        }
        evaluators = {
            'base': self.__eval_base__,
            'finetune_classifier': self.__eval_finetune_classifier__,
            'finetune_classifier_elm': self.__eval_base__,
            'finetune_classifier_linear':
            self.__eval_finetune_classifier_linear__,
        }
        inferers = {
            'base': self.__infer_base__,
            'finetune_classifier': self.__infer_finetune_classifier__,
            'finetune_classifier_elm': self.__infer_finetune_classifier_elm__,
            'finetune_classifier_linear': self.__infer_base__
        }
        self.processor = processors[self.training_type]
        self.trainer = trainers[self.training_type]
        self.evaluator = evaluators[self.training_type]
        self.inferer = inferers[self.training_type]

    def preprocess(self, *list_arg, **dict_arg):
        """
        Unified preprocess
        """
        print('Preprocessing......')
        return self.processor(*list_arg, **dict_arg)

    def train(self, *list_arg, **dict_arg):
        """
        Unified train
        """
        print('Training......')
        acc = self.trainer(*list_arg, **dict_arg)
        print('Best Accuracy:', acc)
        return acc

    def eval(self, *list_arg, **dict_arg):
        """
        Unified evalate
        """
        print('Evaluating......')
        return self.evaluator(*list_arg, **dict_arg)

    def infer(self, *list_arg, **dict_arg):
        """
        Unified inference
        """
        print('Infering......')
        return self.inferer(*list_arg, **dict_arg)

    def __train_base__(self, train_dataset, test_dataset, do_eval=True):
        # prepare to train
        self.pretrained_model.eval()
        batch_num = math.ceil(len(train_dataset.labels) / self.bsz)
        test_loader = DataLoader(train_dataset,
                                 batch_size=self.bsz,
                                 shuffle=True)
        collect_out = []
        collect_label = []

        # collect outputs and train elm
        print('collecting outputs......')
        pbar = tqdm(range(batch_num))
        for batch in test_loader:
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            labels = batch['labels'].to(self.device)
            with torch.no_grad():
                outputs = self.pretrained_model(input_ids,
                                                attention_mask=attention_mask)
                pooler = outputs[1]
                collect_out.append(pooler.cpu().numpy())
                collect_label.append(labels.cpu().numpy())
            pbar.update()
        pbar.close()

        # train elm
        print('Train ELM......')
        collect_out = np.array(collect_out)
        collect_label = np.array(collect_label)
        num, bsz, hidden_dim = collect_out.shape
        collect_out = collect_out.reshape(num * bsz, hidden_dim)
        collect_label = collect_label.reshape(num * bsz)
        self.elm.train(collect_out, collect_label)

        # evaluate
        acc = 0
        if do_eval:
            acc = self.eval(test_dataset)
        return acc

    def __train_finetune_classifier__(self,
                                      train_dataset,
                                      test_dataset,
                                      do_eval=True):
        # set train mode
        self.pretrained_model.train()
        self.classifier.train()

        # prepare optimizer
        batch_num = math.ceil(len(train_dataset.labels) / self.bsz)
        train_loader = DataLoader(train_dataset,
                                  batch_size=self.bsz,
                                  shuffle=True)
        params = [{
            'params': self.pretrained_model.parameters()
        }, {
            'params': self.classifier.parameters()
        }]
        optimizer = AdamW(params, lr=self.learning_rate)
        scheduler = ExponentialLR(optimizer, self.lr_decay)

        # train
        best_acc = 0
        epochs = self.epoch if do_eval else 1
        for epoch in range(epochs):
            pbar = tqdm(range(batch_num))
            losses = []
            for batch in train_loader:
                optimizer.zero_grad()
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['labels'].to(self.device)
                outputs = self.pretrained_model(input_ids,
                                                attention_mask=attention_mask)
                pooler = outputs[1]
                outputs = self.classifier(pooler)
                loss = self.loss_func(outputs, labels)
                if self.n_gpu > 1:
                    loss = loss.mean()
                loss.backward()
                optimizer.step()
                pbar.update()
                losses.append(loss.data.cpu())
                descrip = 'Train epoch:%3d   Loss:%6.3f' % (epoch,
                                                            loss.data.cpu())
                if not do_eval:
                    descrip = 'Loss:%6.3f' % loss.data.cpu()
                pbar.set_description(descrip)
            scheduler.step()
            # set average epoch description
            avg_loss = torch.mean(torch.tensor(losses))
            final_descrip = 'Epoch:%2d  Average Loss:%6.3f' % (epoch, avg_loss)
            if not do_eval:
                final_descrip = 'Average Loss:%6.3f' % avg_loss
            pbar.set_description(final_descrip)
            pbar.close()
            # eval for epochs
            if (epoch % self.eval_epoch == 0) and do_eval:
                test_acc = self.eval(test_dataset)
                best_acc = max(test_acc, best_acc)
                self.pretrained_model.train()
                self.classifier.train()
        return best_acc

    def __train_finetune_classifier_elm__(self,
                                          train_dataset,
                                          test_dataset,
                                          do_eval=True):
        best_acc = 0
        for epoch in range(self.epoch):
            print('Epoch %d' % epoch)
            self.__train_finetune_classifier__(train_dataset,
                                               test_dataset,
                                               do_eval=False)
            self.__train_base__(train_dataset, test_dataset, do_eval=False)
            if do_eval and (epoch % self.eval_epoch == 0):
                acc = self.eval(test_dataset)
                best_acc = max(best_acc, acc)
        return best_acc

    def __train_finetune_classifier_linear__(self,
                                             train_dataset,
                                             test_dataset,
                                             do_eval=True):
        best_acc = 0
        batch_num = math.ceil(len(train_dataset.labels) / self.bsz)
        for epoch in range(self.epoch):
            # train classifier
            print('Epoch %d' % epoch)
            self.__train_finetune_classifier__(train_dataset,
                                               test_dataset,
                                               do_eval=False)

            # calculate last layer with model_output
            print('collecting outputs......')
            collect_out = []
            collect_label = []
            self.pretrained_model.eval()
            self.classifier.eval()
            test_loader = DataLoader(train_dataset,
                                     batch_size=self.bsz,
                                     shuffle=True)
            pbar = tqdm(range(batch_num))
            for batch in test_loader:
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['labels'].to(self.device)
                with torch.no_grad():
                    outputs = self.pretrained_model(
                        input_ids, attention_mask=attention_mask)
                    pooler = outputs[1]
                    linear = self.classifier.w(pooler)
                    linear = self.classifier.act(linear)
                    collect_out.append(linear.cpu().numpy())
                    collect_label.append(labels.cpu().numpy())
                pbar.update()
            pbar.close()

            print('Train ELM......')
            collect_out = np.array(collect_out)
            collect_label = np.array(collect_label)
            num, bsz, hidden_dim = collect_out.shape
            collect_out = collect_out.reshape(num * bsz, hidden_dim)
            collect_label = collect_label.reshape(num * bsz)
            self.elm.train(collect_out, collect_label)

            if do_eval and (epoch % self.eval_epoch == 0):
                acc = self.eval(test_dataset)
                best_acc = max(best_acc, acc)
        return best_acc

    def __eval_base__(self, test_dataset):
        # prepare eval
        self.pretrained_model.eval()
        batch_num = math.ceil(len(test_dataset.labels) / self.bsz)
        test_loader = DataLoader(test_dataset,
                                 batch_size=self.bsz,
                                 shuffle=True)
        pbar = tqdm(range(batch_num))

        # collect tensors
        collect_out = []
        collect_label = []
        for batch in test_loader:
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            labels = batch['labels'].to(self.device)
            with torch.no_grad():
                outputs = self.pretrained_model(input_ids,
                                                attention_mask=attention_mask)
                pooler = outputs[1]
                collect_out.append(pooler.cpu().numpy())
                collect_label.append(labels.cpu().numpy())
            pbar.update()
        pbar.close()

        # evaluate
        collect_out = np.array(collect_out)
        collect_label = np.array(collect_label)
        num, bsz, hidden_dim = collect_out.shape
        collect_out = collect_out.reshape(num * bsz, hidden_dim)
        collect_label = collect_label.reshape(num * bsz)
        pred_labels = self.elm.infer(collect_out) > 0.5
        acc = pred_labels == collect_label
        acc = np.sum(acc) / len(collect_out)
        print('Total accuracy: ', acc)
        return acc

    def __eval_finetune_classifier__(self, test_dataset):
        # set eval mode
        self.pretrained_model.eval()
        self.classifier.eval()

        # prepare eval
        batch_num = math.ceil(len(test_dataset.labels) / self.bsz)
        test_loader = DataLoader(test_dataset,
                                 batch_size=self.bsz,
                                 shuffle=True)
        pbar = tqdm(range(batch_num))
        acc_list = []
        for batch in test_loader:
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            labels = batch['labels'].to(self.device)
            with torch.no_grad():
                outputs = self.pretrained_model(input_ids,
                                                attention_mask=attention_mask)
                pooler = outputs[1]
                outputs = self.classifier(pooler)
                output_label = torch.argmax(outputs, axis=1)
            acc = output_label == labels
            acc = acc.float()
            acc = torch.sum(acc) / labels.size(0)
            acc_list.append(acc.cpu())
            pbar.update()
            descrip = 'Current Accuracy:%6.3f' % acc
            pbar.set_description(descrip)
        pbar.close()
        t_acc = np.array(acc_list).mean()
        print('Total accuracy: ', t_acc)
        return t_acc

    def __eval_finetune_classifier_linear__(self, test_dataset):
        # prepare eval
        self.pretrained_model.eval()
        batch_num = math.ceil(len(test_dataset.labels) / self.bsz)
        test_loader = DataLoader(test_dataset,
                                 batch_size=self.bsz,
                                 shuffle=True)
        pbar = tqdm(range(batch_num))

        # collect tensors
        collect_out = []
        collect_label = []
        for batch in test_loader:
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            labels = batch['labels'].to(self.device)
            with torch.no_grad():
                outputs = self.pretrained_model(input_ids,
                                                attention_mask=attention_mask)
                pooler = outputs[1]
                linear = self.classifier.w(pooler)
                linear = self.classifier.act(linear)
                collect_out.append(linear.cpu().numpy())
                collect_label.append(labels.cpu().numpy())
            pbar.update()
        pbar.close()

        # evaluate
        collect_out = np.array(collect_out)
        collect_label = np.array(collect_label)
        num, bsz, hidden_dim = collect_out.shape
        collect_out = collect_out.reshape(num * bsz, hidden_dim)
        collect_label = collect_label.reshape(num * bsz)
        pred_labels = self.elm.infer(collect_out) > 0.5
        acc = pred_labels == collect_label
        acc = np.sum(acc) / len(collect_out)
        print('Total accuracy: ', acc)
        return acc

    def __infer_base__(self, texts):
        collect_out = []
        for data in tqdm(texts):
            data = list(data)
            inputs = self.pretrained_tokenizer(
                data,
                truncation=True,
                padding=True,
                return_tensors='pt',
            )
            outputs = self.pretrained_model(**inputs)
            collect_out.append(outputs['pooler_output'].detach().numpy())
        collect_out = np.array(collect_out)
        label = self.elm.infer(collect_out) > 0.5
        return label

    def __infer_finetune_classifier__(self, texts):
        raise NotImplementedError

    def __infer_finetune_classifier_elm__(self, texts):
        raise NotImplementedError

    def __processor_base__(self, train_text, train_label, test_text,
                           test_label):
        """packaging dataset use torch.Dataset

        Args:
            train_text (numpy.ndarray): (trainset_num,)
            train_label (numpy.ndarray): (trainset_num,)
            test_text (numpy.ndarray): (testset_num,)
            test_label (numpy.ndarray): (testset_num,)

        Returns:
            train_text (numpy.ndarray): (batch_num, batch_size)
            train_label (numpy.ndarray): (batch_num, batch_size)
            test_text (numpy.ndarray): (batch_num, batch_size)
            test_label (numpy.ndarray): (batch_num, batch_size)
        """

        # use only first 50 sentences
        if self.debug:
            train_text = train_text[:50]
            train_label = train_label[:50]
            test_text = test_text[:50]
            test_label = test_label[:50]

        train_text = list(train_text)
        test_text = list(test_text)
        train_encodings = self.pretrained_tokenizer(train_text,
                                                    truncation=True,
                                                    padding=True)
        test_encodings = self.pretrained_tokenizer(test_text,
                                                   truncation=True,
                                                   padding=True)
        train_dataset = IMDbDataset(train_encodings, train_label)
        test_dataset = IMDbDataset(test_encodings, test_label)

        return train_dataset, test_dataset
示例#5
0
class SRFlowModel(BaseModel):
    def __init__(self, opt, step):
        super(SRFlowModel, self).__init__(opt)
        self.opt = opt

        self.heats = opt['val']['heats']
        self.n_sample = opt['val']['n_sample']
        self.hr_size = opt_get(opt,
                               ['datasets', 'train', 'center_crop_hr_size'])
        self.hr_size = 160 if self.hr_size is None else self.hr_size
        self.lr_size = self.hr_size // opt['scale']

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_Flow(opt, step).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()

        if opt_get(opt, ['path', 'resume_state'], 1) is not None:
            self.load()
        else:
            print(
                "WARNING: skipping initial loading, due to resume_state None")

        if self.is_train:
            self.netG.train()

            self.init_optimizer_and_scheduler(train_opt, opt, step)
            self.log_dict = OrderedDict()

    def to(self, device):
        self.device = device
        self.netG.to(device)

    def init_optimizer_and_scheduler(self, train_opt, opt, step):
        # set RRDB training  false
        self.netG.module.set_rrdb_training(False)
        # if opt_get(opt, ['network_G', 'train_RRDB'], False) and step >= opt_get(opt, ['network_G', 'train_RRDB_delay'], 0.5)*opt_get(opt, ['train', 'niter'], 200000):
        #     self.netG.module.set_rrdb_training(True)

        # optimizers
        self.optimizers = []
        wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
        optim_params_RRDB = []
        optim_params_other = []
        for k, v in self.netG.named_parameters(
        ):  # can optimize for a part of the model
            print(k, v.requires_grad)
            if v.requires_grad:
                if '.RRDB.' in k:
                    optim_params_RRDB.append(v)
                    print('opt', k)
                else:
                    optim_params_other.append(v)
                # if self.rank <= 0:
                #     logger.warning('Params [{:s}] will not optimize.'.format(k))

        # print('rrdb params', len(optim_params_RRDB))

        self.optimizer_G = torch.optim.Adam(
            [{
                "params": optim_params_other,
                "lr": train_opt['lr_G'],
                'beta1': train_opt['beta1'],
                'beta2': train_opt['beta2'],
                'weight_decay': wd_G
            }, {
                "params": optim_params_RRDB,
                "lr": train_opt.get('lr_RRDB', train_opt['lr_G']),
                'beta1': train_opt['beta1'],
                'beta2': train_opt['beta2'],
                'weight_decay': wd_G
            }], )

        self.optimizers.append(self.optimizer_G)
        # schedulers
        if train_opt['lr_scheme'] == 'MultiStepLR':
            for optimizer in self.optimizers:
                self.schedulers.append(
                    lr_scheduler.MultiStepLR_Restart(
                        optimizer,
                        train_opt['lr_steps'],
                        restarts=train_opt['restarts'],
                        weights=train_opt['restart_weights'],
                        gamma=train_opt['lr_gamma'],
                        clear_state=train_opt['clear_state'],
                        lr_steps_invese=train_opt.get('lr_steps_inverse', [])))
        elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
            for optimizer in self.optimizers:
                self.schedulers.append(
                    lr_scheduler.CosineAnnealingLR_Restart(
                        optimizer,
                        train_opt['T_period'],
                        eta_min=train_opt['eta_min'],
                        restarts=train_opt['restarts'],
                        weights=train_opt['restart_weights']))
        else:
            raise NotImplementedError(
                'MultiStepLR learning rate scheme is enough.')

    def add_optimizer_and_scheduler_RRDB(self, train_opt):
        # optimizers
        assert len(self.optimizers) == 1, self.optimizers
        # assert len(self.optimizer_G.param_groups[1]['params']) == 0, self.optimizer_G.param_groups[1]
        assert len(self.optimizer_G.param_groups[1]['params']) == 0
        for k, v in self.netG.named_parameters(
        ):  # can optimize for a part of the model
            if v.requires_grad:
                if '.RRDB.' in k:
                    self.optimizer_G.param_groups[1]['params'].append(v)
        assert len(self.optimizer_G.param_groups[1]['params']) > 0

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.real_H = data['GT'].to(self.device)  # GT

    def optimize_parameters(self, step):

        train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
        # if train_RRDB_delay is not None and step > int(train_RRDB_delay * self.opt['train']['niter']) \
        #         and not self.netG.module.RRDB_training:
        if train_RRDB_delay is not None and step > int(
                train_RRDB_delay * self.opt['train']['niter']):
            if self.netG.module.set_rrdb_training(True):
                self.add_optimizer_and_scheduler_RRDB(self.opt['train'])
                # if step % 100 == 0:
                print("set RRDB trainable")

        # self.print_rrdb_state()

        # add GT noise
        add_gt_noise = opt_get(self.opt, ['train', 'add_gt_noise'], True)

        self.netG.train()
        self.log_dict = OrderedDict()
        self.optimizer_G.zero_grad()

        losses = {}
        weight_fl = opt_get(self.opt, ['train', 'weight_fl'])
        weight_fl = 1 if weight_fl is None else weight_fl
        if weight_fl > 0:
            z, nll, y_logits = self.netG(gt=self.real_H,
                                         lr=self.var_L,
                                         reverse=False,
                                         add_gt_noise=add_gt_noise)
            nll_loss = torch.mean(nll)
            losses['nll_loss'] = nll_loss * weight_fl

        weight_l1 = opt_get(self.opt, ['train', 'weight_l1']) or 0
        if weight_l1 > 0:
            z = self.get_z(heat=0,
                           seed=None,
                           batch_size=self.var_L.shape[0],
                           lr_shape=self.var_L.shape)
            sr, logdet = self.netG(lr=self.var_L,
                                   z=z,
                                   eps_std=0,
                                   reverse=True,
                                   reverse_with_grad=True)
            l1_loss = (sr - self.real_H).abs().mean()
            losses['l1_loss'] = l1_loss * weight_l1

        total_loss = sum(losses.values())
        total_loss.backward()
        self.optimizer_G.step()

        mean = total_loss.item()
        return mean

    def print_rrdb_state(self):
        for name, param in self.netG.module.named_parameters():
            if "RRDB.conv_first.weight" in name:
                print(name, param.requires_grad, param.data.abs().sum())
        print('params',
              [len(p['params']) for p in self.optimizer_G.param_groups])

    def test(self):
        self.netG.eval()
        self.fake_H = {}
        for heat in self.heats:
            for i in range(self.n_sample):
                z = self.get_z(heat,
                               seed=None,
                               batch_size=self.var_L.shape[0],
                               lr_shape=self.var_L.shape)
                with torch.no_grad():
                    self.fake_H[(heat, i)], logdet = self.netG(lr=self.var_L,
                                                               z=z,
                                                               eps_std=heat,
                                                               reverse=True)
        with torch.no_grad():
            _, nll, _ = self.netG(gt=self.real_H, lr=self.var_L, reverse=False)
        self.netG.train()
        return nll.mean().item()

    def get_encode_nll(self, lq, gt):
        self.netG.eval()
        with torch.no_grad():
            _, nll, _ = self.netG(gt=gt, lr=lq, reverse=False)
        self.netG.train()
        return nll.mean().item()

    def get_sr(self, lq, heat=None, seed=None, z=None, epses=None):
        return self.get_sr_with_z(lq, heat, seed, z, epses)[0]

    def get_encode_z(self, lq, gt, epses=None, add_gt_noise=True):
        self.netG.eval()
        with torch.no_grad():
            z, _, _ = self.netG(gt=gt,
                                lr=lq,
                                reverse=False,
                                epses=epses,
                                add_gt_noise=add_gt_noise)
        self.netG.train()
        return z

    def get_encode_z_and_nll(self, lq, gt, epses=None, add_gt_noise=True):
        self.netG.eval()
        with torch.no_grad():
            z, nll, _ = self.netG(gt=gt,
                                  lr=lq,
                                  reverse=False,
                                  epses=epses,
                                  add_gt_noise=add_gt_noise)
        self.netG.train()
        return z, nll

    def get_sr_with_z(self, lq, heat=None, seed=None, z=None, epses=None):
        self.netG.eval()

        z = self.get_z(heat, seed, batch_size=lq.shape[0],
                       lr_shape=lq.shape) if z is None and epses is None else z

        with torch.no_grad():
            sr, logdet = self.netG(lr=lq,
                                   z=z,
                                   eps_std=heat,
                                   reverse=True,
                                   epses=epses)
        self.netG.train()
        return sr, z

    def get_z(self, heat, seed=None, batch_size=1, lr_shape=None):
        if seed: torch.manual_seed(seed)
        if opt_get(self.opt, ['network_G', 'flow', 'split', 'enable']):
            C = self.netG.module.flowUpsamplerNet.C
            H = int(self.opt['scale'] * lr_shape[2] //
                    self.netG.module.flowUpsamplerNet.scaleH)
            W = int(self.opt['scale'] * lr_shape[3] //
                    self.netG.module.flowUpsamplerNet.scaleW)
            z = torch.normal(mean=0, std=heat,
                             size=(batch_size, C, H,
                                   W)) if heat > 0 else torch.zeros(
                                       (batch_size, C, H, W))
        else:
            L = opt_get(self.opt, ['network_G', 'flow', 'L']) or 3
            fac = 2**(L - 3)
            z_size = int(self.lr_size // (2**(L - 3)))
            z = torch.normal(mean=0,
                             std=heat,
                             size=(batch_size, 3 * 8 * 8 * fac * fac, z_size,
                                   z_size))
        return z

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        for heat in self.heats:
            for i in range(self.n_sample):
                out_dict[('SR', heat,
                          i)] = self.fake_H[(heat,
                                             i)].detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        _, get_resume_model_path = get_resume_paths(self.opt)
        if get_resume_model_path is not None:
            self.load_network(get_resume_model_path,
                              self.netG,
                              strict=True,
                              submodule=None)
            return

        load_path_G = self.opt['path']['pretrain_model_G']
        load_submodule = self.opt['path'][
            'load_submodule'] if 'load_submodule' in self.opt['path'].keys(
            ) else 'RRDB'
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G,
                              self.netG,
                              self.opt['path'].get('strict_load', True),
                              submodule=load_submodule)

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
示例#6
0
class MGANTrainer:
    def __init__(self, args, task, saver, logger, vocab):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.pretrain = False
        self.saver = saver
        self.logger = logger
        self._model = MGANModel.build_model(args, task, pretrain=self.pretrain)
        self.model = DataParallel(self._model)
        self.model = self.model.to(device)
        self.opt = ClippedAdam(self.model.parameters(), lr=1e-3)
        self.opt.set_clip(clip_value=5.0)
        self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.opt,
                                                                   gamma=0.5)
        self.saver.load("mgan", self.model.module)
        self.step = 0
        self.vocab = vocab
        self.critic_lag_max = 50
        self.critic_lag = self.critic_lag_max

        self.args = args
        self.task = task

    def run(self, epoch, samples):
        self.model.train()
        num_rollouts = 1 if self.pretrain else self.args.num_rollouts
        self.lr_scheduler.step(epoch)
        self.rollout_discriminator(num_rollouts, samples)
        self.rollout_generator(num_rollouts, samples)
        self.rollout_critic(num_rollouts, samples)
        self.saver.checkpoint("mgan", self.model.module)
        self.step += 1

    def rollout_discriminator(self, num_rollouts, samples):
        masked, unmasked, lengths, mask = samples
        real, fake = AverageMeter(), AverageMeter()
        batch_size, seq_len = samples[0].size()

        self.opt.zero_grad()
        pbar = _tqdm(num_rollouts, 'discriminator-rollout')

        for rollout in pbar:
            real_loss = self.model(masked,
                                   lengths,
                                   mask,
                                   unmasked,
                                   tag="d-step",
                                   real=True)

            real_loss = real_loss.sum() / batch_size

            with torch.no_grad():
                net_output = self.model(masked,
                                        lengths,
                                        mask,
                                        unmasked,
                                        tag="g-step")
                generated = net_output[1]

            fake_loss = self.model(masked,
                                   lengths,
                                   mask,
                                   generated,
                                   tag="d-step",
                                   real=False)

            fake_loss = fake_loss.sum() / batch_size

            loss = (real_loss + fake_loss) / 2
            loss.backward()

            real.update(real_loss.item())
            fake.update(fake_loss.item())

        self.opt.step()
        self.logger.log("discriminator/real", self.step, real.avg)
        self.logger.log("discriminator/fake", self.step, fake.avg)
        self.logger.log("discriminator", self.step, real.avg + fake.avg)

    def rollout_critic(self, num_rollouts, samples):
        masked, unmasked, lengths, mask = samples
        batch_size, seq_len = samples[0].size()
        meter = AverageMeter()
        self.opt.zero_grad()
        pbar = _tqdm(num_rollouts, 'critic-rollout')
        for rollout in pbar:
            loss = self.model(masked, lengths, mask, unmasked, tag="c-step")
            loss = loss.sum() / batch_size
            loss.backward()
            meter.update(loss.item())

        self.opt.step()
        self.logger.log("critic/loss", self.step, meter.avg)

    def rollout_generator(self, num_rollouts, samples):
        masked, unmasked, lengths, mask = samples
        batch_size, seq_len = samples[0].size()
        meter = AverageMeter()
        ppl_meter = defaultdict(lambda: AverageMeter())
        self.opt.zero_grad()
        pbar = _tqdm(num_rollouts, 'generator-rollout')

        for rollout in pbar:
            loss, generated, ppl = self.model(masked,
                                              lengths,
                                              mask,
                                              unmasked,
                                              tag="g-step")
            loss = loss.sum() / batch_size
            loss.backward()
            meter.update(-1 * loss.item())
            # for key in ppl:
            #     ppl[key] = ppl[key].sum() / batch_size
            #     ppl_meter[key].update(ppl[key].item())
        self.opt.step()
        self.logger.log("generator/advantage", self.step, meter.avg)
        # for key in ppl_meter:
        #     self.logger.log("ppl/{}".format(key), ppl_meter[key].avg)

        self.debug('train', samples, generated)

    def debug(self, key, samples, generated):
        masked, unmasked, lengths, mask = samples
        tag = 'generated/{}'.format(key)
        logger = lambda s: self.logger.log(tag, s)
        pretty_print(logger,
                     self.vocab,
                     masked,
                     unmasked,
                     generated,
                     truncate=10)

    def validate_dataset(self, loader):
        self.model.eval()
        _meters = 'generator dfake dreal critic ppl_sampled ppl_truths'
        _n_meters = len(_meters.split())
        Meters = namedtuple('Meters', _meters)
        meters_list = [AverageMeter() for i in range(_n_meters)]
        meters = Meters(*meters_list)
        for sample_batch in loader:
            self._validate(meters, sample_batch)
            for key, value in meters._asdict().items():
                pass
                # print(key, value.avg)

    @property
    def umodel(self):
        if isinstance(self.model, DataParallel):
            return self.model.module
        return self.model

    def aggregate(self, batch_size):
        return lambda tensor: tensor.sum() / batch_size

    def _validate(self, meters, samples):
        with torch.no_grad():
            masked, unmasked, lengths, mask = samples
            batch_size, seq_len = samples[0].size()

            agg = self.aggregate(batch_size)

            real_loss = self.model(masked,
                                   lengths,
                                   mask,
                                   unmasked,
                                   tag="d-step",
                                   real=True)

            real_loss = agg(real_loss)

            generator_loss, generated, ppl = self.model(masked,
                                                        lengths,
                                                        mask,
                                                        unmasked,
                                                        tag="g-step",
                                                        ppl=True)

            generator_loss = agg(generator_loss)

            fake_loss = self.model(masked,
                                   lengths,
                                   mask,
                                   generated,
                                   tag="d-step",
                                   real=False)

            fake_loss = agg(fake_loss)

            loss = (real_loss + fake_loss) / 2

            critic_loss = self.model(masked,
                                     lengths,
                                     mask,
                                     unmasked,
                                     tag="c-step")
            critic_loss = agg(fake_loss)

            meters.dreal.update(real_loss.item())
            meters.dfake.update(fake_loss.item())
            meters.generator.update(generator_loss.item())
            meters.critic.update(critic_loss.item())

            self.debug('dev', samples, generated)

            for key in ppl:
                ppl[key] = agg(ppl[key])

            meters.ppl_sampled.update(ppl['sampled'].item())
            meters.ppl_truths.update(ppl['ground-truth'].item())
            self.debug('dev', samples, generated)