Exemplo n.º 1
0
class ModelStage(ModelBase):
    """Train with pixel loss"""
    def __init__(self, opt, stage0=False, stage1=False, stage2=False):
        super(ModelStage, self).__init__(opt)
        # ------------------------------------
        # define network
        # ------------------------------------
        self.stage0 = stage0
        self.stage1 = stage1
        self.stage2 = stage2
        self.netG = define_G(opt, self.stage0, self.stage1,
                             self.stage2).to(self.device)
        self.netG = DataParallel(self.netG)

    """
    # ----------------------------------------
    # Preparation before training with data
    # Save model during training
    # ----------------------------------------
    """

    # ----------------------------------------
    # initialize training
    # ----------------------------------------
    def init_train(self):
        self.opt_train = self.opt['train']  # training option
        self.load()  # load model
        self.netG.train()  # set training mode,for BN
        self.define_loss()  # define loss
        self.define_optimizer()  # define optimizer
        self.define_scheduler()  # define scheduler
        self.log_dict = OrderedDict()  # log

    # ----------------------------------------
    # load pre-trained G model
    # ----------------------------------------
    def load(self):

        if self.stage0:
            load_path_G = self.opt['path']['pretrained_netG0']
        elif self.stage1:
            load_path_G = self.opt['path']['pretrained_netG1']
        elif self.stage2:
            load_path_G = self.opt['path']['pretrained_netG2']
        if load_path_G is not None:
            print('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG)

    # ----------------------------------------
    # save model
    # ----------------------------------------
    def save(self, iter_label):
        if self.stage0:
            self.save_network(self.save_dir, self.netG, 'G0', iter_label)
        elif self.stage1:
            self.save_network(self.save_dir, self.netG, 'G1', iter_label)
        elif self.stage2:
            self.save_network(self.save_dir, self.netG, 'G2', iter_label)

    # ----------------------------------------
    # define loss
    # ----------------------------------------
    def define_loss(self):
        G_lossfn_type = self.opt_train['G_lossfn_type']
        if G_lossfn_type == 'l1':
            self.G_lossfn = nn.L1Loss().to(self.device)
        elif G_lossfn_type == 'l2':
            self.G_lossfn = nn.MSELoss().to(self.device)
        elif G_lossfn_type == 'l2sum':
            self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device)
        elif G_lossfn_type == 'ssim':
            self.G_lossfn = SSIMLoss().to(self.device)
        else:
            raise NotImplementedError(
                'Loss type [{:s}] is not found.'.format(G_lossfn_type))
        self.G_lossfn_weight = self.opt_train['G_lossfn_weight']

    # ----------------------------------------
    # define optimizer
    # ----------------------------------------
    def define_optimizer(self):
        G_optim_params = []
        for k, v in self.netG.named_parameters():
            if v.requires_grad:
                G_optim_params.append(v)
            else:
                print('Params [{:s}] will not optimize.'.format(k))
        self.G_optimizer = Adam(G_optim_params,
                                lr=self.opt_train['G_optimizer_lr'],
                                weight_decay=0)

    # ----------------------------------------
    # define scheduler, only "MultiStepLR"
    # ----------------------------------------
    def define_scheduler(self):
        self.schedulers.append(
            lr_scheduler.MultiStepLR(self.G_optimizer,
                                     self.opt_train['G_scheduler_milestones'],
                                     self.opt_train['G_scheduler_gamma']))

    """
    # ----------------------------------------
    # Optimization during training with data
    # Testing/evaluation
    # ----------------------------------------
    """

    # ----------------------------------------
    # feed L/H data
    # ----------------------------------------
    def feed_data(self, data):
        if self.stage0:
            Ls = data['ls']
            self.Ls = util.tos(*Ls, device=self.device)
            Hs = data['hs']
            self.Hs = util.tos(*Hs, device=self.device)
        if self.stage1:
            self.L0 = data['L0'].to(self.device)
            self.H = data['H'].to(self.device)
        elif self.stage2:
            Ls = data['L']
            self.Ls = util.tos(*Ls, device=self.device)
            self.H = data['H'].to(self.device)  #hide for test

    # ----------------------------------------
    # update parameters and get loss
    # ----------------------------------------
    def optimize_parameters(self, current_step):

        self.G_optimizer.zero_grad()

        if self.stage0:
            self.Es = self.netG(self.Ls)
            _loss = []
            for (Es_i, Hs_i) in zip(self.Es, self.Hs):
                _loss += [self.G_lossfn(Es_i, Hs_i)]
            G_loss = sum(_loss) * self.G_lossfn_weight

        if self.stage1:
            self.E = self.netG(self.L0)
            G_loss = self.G_lossfn_weight * self.G_lossfn(self.E, self.H)

        if self.stage2:
            self.E = self.netG(self.Ls)
            G_loss = self.G_lossfn_weight * self.G_lossfn(self.E, self.H)

        G_loss.backward()

        # ------------------------------------
        # clip_grad
        # ------------------------------------
        # `clip_grad_norm` helps prevent the exploding gradient problem.
        G_optimizer_clipgrad = self.opt_train[
            'G_optimizer_clipgrad'] if self.opt_train[
                'G_optimizer_clipgrad'] else 0
        if G_optimizer_clipgrad > 0:
            torch.nn.utils.clip_grad_norm_(
                self.parameters(),
                max_norm=self.opt_train['G_optimizer_clipgrad'],
                norm_type=2)

        self.G_optimizer.step()

        # ------------------------------------
        # regularizer
        # ------------------------------------
        G_regularizer_orthstep = self.opt_train[
            'G_regularizer_orthstep'] if self.opt_train[
                'G_regularizer_orthstep'] else 0
        if G_regularizer_orthstep > 0 and current_step % G_regularizer_orthstep == 0 and current_step % \
                self.opt['train']['checkpoint_save'] != 0:
            self.netG.apply(regularizer_orth)
        G_regularizer_clipstep = self.opt_train[
            'G_regularizer_clipstep'] if self.opt_train[
                'G_regularizer_clipstep'] else 0
        if G_regularizer_clipstep > 0 and current_step % G_regularizer_clipstep == 0 and current_step % \
                self.opt['train']['checkpoint_save'] != 0:
            self.netG.apply(regularizer_clip)

        # self.log_dict['G_loss'] = G_loss.item()/self.E.size()[0]  # if `reduction='sum'`
        self.log_dict['G_loss'] = G_loss.item()

    # ----------------------------------------
    # test / inference
    # ----------------------------------------
    def test(self):
        self.netG.eval()
        if self.stage0:
            with torch.no_grad():
                self.Es = self.netG(self.Ls)
        elif self.stage1:
            with torch.no_grad():
                self.E = self.netG(self.L0)
        elif self.stage2:
            with torch.no_grad():
                self.E = self.netG(self.Ls)
        self.netG.train()

    # ----------------------------------------
    # get log_dict
    # ----------------------------------------
    def current_log(self):
        return self.log_dict

    # ----------------------------------------
    # get L, E, H image
    # ----------------------------------------
    def current_visuals(self):
        out_dict = OrderedDict()
        if self.stage0:
            out_dict['L'] = self.Ls[0].detach()[0].float().cpu()
            out_dict['Es0'] = self.Es[0].detach()[0].float().cpu()
            out_dict['Hs0'] = self.Hs[0].detach()[0].float().cpu()
        elif self.stage1:
            out_dict['L'] = self.L0.detach()[0].float().cpu()
            out_dict['E'] = self.E.detach()[0].float().cpu()
            out_dict['H'] = self.H.detach()[0].float().cpu()  #hide for test

        elif self.stage2:
            out_dict['L'] = self.Ls[0].detach()[0].float().cpu()
            out_dict['E'] = self.E.detach()[0].float().cpu()
            out_dict['H'] = self.H.detach()[0].float().cpu()  #hide for test
        return out_dict

    """
    # ----------------------------------------
    # Information of netG
    # ----------------------------------------
    """

    # ----------------------------------------
    # print network
    # ----------------------------------------
    def print_network(self):
        msg = self.describe_network(self.netG)
        print(msg)

    # ----------------------------------------
    # print params
    # ----------------------------------------
    def print_params(self):
        msg = self.describe_params(self.netG)
        print(msg)

    # ----------------------------------------
    # network information
    # ----------------------------------------
    def info_network(self):
        msg = self.describe_network(self.netG)
        return msg

    # ----------------------------------------
    # params information
    # ----------------------------------------
    def info_params(self):
        msg = self.describe_params(self.netG)
        return msg
Exemplo n.º 2
0
class ModelPlain4(ModelBase):
    """Train with pixel loss"""
    def __init__(self, opt):
        super(ModelPlain4, self).__init__(opt)
        # ------------------------------------
        # define network
        # ------------------------------------
        self.netG = define_G(opt).to(self.device)
        self.netG = DataParallel(self.netG)

    """
    # ----------------------------------------
    # Preparation before training with data
    # Save model during training
    # ----------------------------------------
    """

    # ----------------------------------------
    # initialize training
    # ----------------------------------------
    def init_train(self):
        self.opt_train = self.opt['train']    # training option
        self.load()                           # load model
        self.netG.train()                     # set training mode,for BN
        self.define_loss()                    # define loss
        self.define_optimizer()               # define optimizer
        self.define_scheduler()               # define scheduler
        self.log_dict = OrderedDict()         # log

    # ----------------------------------------
    # load pre-trained G model
    # ----------------------------------------
    def load(self):
        load_path_G = self.opt['path']['pretrained_netG']
        if load_path_G is not None:
            print('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG)

    # ----------------------------------------
    # save model
    # ----------------------------------------
    def save(self, iter_label):
        self.save_network(self.save_dir, self.netG, 'G', iter_label)

    # ----------------------------------------
    # define loss
    # ----------------------------------------
    def define_loss(self):
        G_lossfn_type = self.opt_train['G_lossfn_type']
        if G_lossfn_type == 'l1':
            self.G_lossfn = nn.L1Loss().to(self.device)
        elif G_lossfn_type == 'l2':
            self.G_lossfn = nn.MSELoss().to(self.device)
        elif G_lossfn_type == 'l2sum':
            self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device)
        elif G_lossfn_type == 'ssim':
            self.G_lossfn = SSIMLoss().to(self.device)
        else:
            raise NotImplementedError('Loss type [{:s}] is not found.'.format(G_lossfn_type))
        self.G_lossfn_weight = self.opt_train['G_lossfn_weight']

    # ----------------------------------------
    # define optimizer
    # ----------------------------------------
    def define_optimizer(self):
        G_optim_params = []
        for k, v in self.netG.named_parameters():
            if v.requires_grad:
                G_optim_params.append(v)
            else:
                print('Params [{:s}] will not optimize.'.format(k))
        self.G_optimizer = Adam(G_optim_params, lr=self.opt_train['G_optimizer_lr'], weight_decay=0)

    # ----------------------------------------
    # define scheduler, only "MultiStepLR"
    # ----------------------------------------
    def define_scheduler(self):
        self.schedulers.append(lr_scheduler.MultiStepLR(self.G_optimizer,
                                                        self.opt_train['G_scheduler_milestones'],
                                                        self.opt_train['G_scheduler_gamma']
                                                        ))
    """
    # ----------------------------------------
    # Optimization during training with data
    # Testing/evaluation
    # ----------------------------------------
    """

    # ----------------------------------------
    # feed L/H data
    # ----------------------------------------
    def feed_data(self, data, need_H=True):
        self.L = data['L'].to(self.device)  # low-quality image
        self.k = data['k'].to(self.device)  # blur kernel
        self.sf = np.int(data['sf'][0,...].squeeze().cpu().numpy()) # scale factor
        self.sigma = data['sigma'].to(self.device)  # noise level
        if need_H:
            self.H = data['H'].to(self.device)  # H

    # ----------------------------------------
    # update parameters and get loss
    # ----------------------------------------
    def optimize_parameters(self, current_step):
        self.G_optimizer.zero_grad()
        self.E = self.netG(self.L, self.C)
        G_loss = self.G_lossfn_weight * self.G_lossfn(self.E, self.H)
        G_loss.backward()

        # ------------------------------------
        # clip_grad
        # ------------------------------------
        # `clip_grad_norm` helps prevent the exploding gradient problem.
        G_optimizer_clipgrad = self.opt_train['G_optimizer_clipgrad'] if self.opt_train['G_optimizer_clipgrad'] else 0
        if G_optimizer_clipgrad > 0:
            torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=self.opt_train['G_optimizer_clipgrad'], norm_type=2)

        self.G_optimizer.step()

        # ------------------------------------
        # regularizer
        # ------------------------------------
        G_regularizer_orthstep = self.opt_train['G_regularizer_orthstep'] if self.opt_train['G_regularizer_orthstep'] else 0
        if G_regularizer_orthstep > 0 and current_step % G_regularizer_orthstep == 0 and current_step % self.opt['train']['checkpoint_save'] != 0:
            self.netG.apply(regularizer_orth)
        G_regularizer_clipstep = self.opt_train['G_regularizer_clipstep'] if self.opt_train['G_regularizer_clipstep'] else 0
        if G_regularizer_clipstep > 0 and current_step % G_regularizer_clipstep == 0 and current_step % self.opt['train']['checkpoint_save'] != 0:
            self.netG.apply(regularizer_clip)

        self.log_dict['G_loss'] = G_loss.item()  #/self.E.size()[0]

    # ----------------------------------------
    # test / inference
    # ----------------------------------------
    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.E = self.netG(self.L, self.k, self.sf, self.sigma)
        self.netG.train()

    # ----------------------------------------
    # get log_dict
    # ----------------------------------------
    def current_log(self):
        return self.log_dict

    # ----------------------------------------
    # get L, E, H image
    # ----------------------------------------
    def current_visuals(self, need_H=True):
        out_dict = OrderedDict()
        out_dict['L'] = self.L.detach()[0].float().cpu()
        out_dict['E'] = self.E.detach()[0].float().cpu()
        if need_H:
            out_dict['H'] = self.H.detach()[0].float().cpu()
        return out_dict

    # ----------------------------------------
    # get L, E, H batch images
    # ----------------------------------------
    def current_results(self, need_H=True):
        out_dict = OrderedDict()
        out_dict['L'] = self.L.detach().float().cpu()
        out_dict['E'] = self.E.detach().float().cpu()
        if need_H:
            out_dict['H'] = self.H.detach().float().cpu()
        return out_dict

    """
    # ----------------------------------------
    # Information of netG
    # ----------------------------------------
    """

    # ----------------------------------------
    # print network
    # ----------------------------------------
    def print_network(self):
        msg = self.describe_network(self.netG)
        print(msg)

    # ----------------------------------------
    # print params
    # ----------------------------------------
    def print_params(self):
        msg = self.describe_params(self.netG)
        print(msg)

    # ----------------------------------------
    # network information
    # ----------------------------------------
    def info_network(self):
        msg = self.describe_network(self.netG)
        return msg

    # ----------------------------------------
    # params information
    # ----------------------------------------
    def info_params(self):
        msg = self.describe_params(self.netG)
        return msg
Exemplo n.º 3
0
class TrainPatch:
    def __init__(self, mode):
        self.config = patch_config_types[mode]()

        self.yolov5 = attempt_load(self.config.weight_file, map_location=device)

        self.img_size = self.config.patch_size

        self.dot_applier = DotApplier(
            self.config.num_of_dots,
            self.img_size,
            self.config.alpha_max,
            self.config.beta_dropoff)

        self.patch_applier = PatchApplier()

        self.non_printability_score = NonPrintabilityScore(
            self.config.print_file,
            self.config.num_of_dots)

        self.eval_type = self.config.eval_data_type
        self.clean_img_dict = np.load('confidences/yolov5/medium/clean_img_conf_lisa_ordered.npy', allow_pickle=True).item()

        self.detections = DetectionsYolov5(
            cls_id=self.config.class_id,
            num_cls=self.config.num_classes,
            config=self.config,
            clean_img_conf=self.clean_img_dict,
            conf_threshold=self.config.conf_threshold)

        self.noise_amount = NoiseAmount(self.config.radius_lower_bound, self.config.radius_upper_bound)

        # self.set_multiple_gpu()
        self.set_to_device()

        if self.config.eval_data_type == 'ordered' or self.config.eval_data_type == 'one':
            split_dataset = SplitDataset(
                img_dir=self.config.img_dir,
                lab_dir=self.config.lab_dir,
                max_lab=self.config.max_labels_per_img,
                img_size=self.img_size,
                transform=transforms.Compose([transforms.Resize((self.img_size, self.img_size)), transforms.ToTensor()]))
        else:
            split_dataset = SplitDataset1(
                img_dir_train_val=self.config.img_dir,
                lab_dir_train_val=self.config.lab_dir,
                img_dir_test=self.config.img_dir_test,
                lab_dir_test=self.config.lab_dir_test,
                max_lab=self.config.max_labels_per_img,
                img_size=self.img_size,
                transform=transforms.Compose(
                    [transforms.Resize((self.img_size, self.img_size)), transforms.ToTensor()]))

        self.train_loader, self.val_loader, self.test_loader = split_dataset(val_split=0.2,
                                                                             test_split=0.2,
                                                                             shuffle_dataset=True,
                                                                             random_seed=seed,
                                                                             batch_size=self.config.batch_size,
                                                                             ordered=True)

        self.train_losses = []
        self.val_losses = []

        self.max_prob_losses = []
        self.cor_det_losses = []
        self.nps_losses = []
        self.noise_losses = []

        self.train_acc = []
        self.val_acc = []
        self.final_epoch_count = self.config.epochs

        my_date = datetime.datetime.now()
        month_name = my_date.strftime("%B")
        if 'SLURM_JOBID' not in os.environ.keys():
            self.current_dir = "experiments/" + month_name + '/' + time.strftime("%d-%m-%Y") + '_' + time.strftime("%H-%M-%S")
        else:
            self.current_dir = "experiments/" + month_name + '/' + time.strftime("%d-%m-%Y") + '_' + os.environ['SLURM_JOBID']
        self.create_folders()
        # self.save_config_details()

        # subprocess.Popen(['tensorboard', '--logdir=' + self.current_dir + '/runs'])
        # self.writer = SummaryWriter(self.current_dir + '/runs')
        self.writer = None

    def set_to_device(self):
        self.dot_applier = self.dot_applier.to(device)
        self.patch_applier = self.patch_applier.to(device)
        self.detections = self.detections.to(device)
        self.non_printability_score = self.non_printability_score.to(device)
        self.noise_amount = self.noise_amount.to(device)

    def set_multiple_gpu(self):
        if torch.cuda.device_count() > 1:
            print("more than 1")
            self.dot_applier = DP(self.dot_applier)
            self.patch_applier = DP(self.patch_applier)
            self.detections = DP(self.detections)

    def create_folders(self):
        Path('/'.join(self.current_dir.split('/')[:2])).mkdir(parents=True, exist_ok=True)
        Path(self.current_dir).mkdir(parents=True, exist_ok=True)
        Path(self.current_dir + '/final_results').mkdir(parents=True, exist_ok=True)
        Path(self.current_dir + '/saved_patches').mkdir(parents=True, exist_ok=True)
        Path(self.current_dir + '/losses').mkdir(parents=True, exist_ok=True)
        Path(self.current_dir + '/testing').mkdir(parents=True, exist_ok=True)

    def train(self):
        epoch_length = len(self.train_loader)
        print(f'One epoch is {epoch_length} batches', flush=True)

        optimizer = Adam([{'params': self.dot_applier.theta, 'lr': self.config.loc_lr},
                          {'params': self.dot_applier.colors, 'lr': self.config.color_lr},
                          {'params': self.dot_applier.radius, 'lr': self.config.radius_lr}],
                         amsgrad=True)
        scheduler = self.config.scheduler_factory(optimizer)
        early_stop = EarlyStopping(delta=1e-3, current_dir=self.current_dir, patience=20)

        clipper = WeightClipper(self.config.radius_lower_bound, self.config.radius_upper_bound)
        adv_patch_cpu = torch.zeros((1, 3, self.img_size, self.img_size), dtype=torch.float32)
        alpha_cpu = torch.zeros((1, 1, self.img_size, self.img_size), dtype=torch.float32)
        for epoch in range(self.config.epochs):
            train_loss = 0.0
            max_prob_loss = 0.0
            cor_det_loss = 0.0
            nps_loss = 0.0
            noise_loss = 0.0

            progress_bar = tqdm(enumerate(self.train_loader), desc=f'Epoch {epoch}', total=epoch_length)
            prog_bar_desc = 'train-loss: {:.6}, ' \
                            'maxprob-loss: {:.6}, ' \
                            'corr det-loss: {:.6}, ' \
                            'nps-loss: {:.6}, ' \
                            'noise-loss: {:.6}'
            for i_batch, (img_batch, lab_batch, img_names) in progress_bar:
                # move tensors to gpu
                img_batch = img_batch.to(device)
                lab_batch = lab_batch.to(device)
                adv_patch = adv_patch_cpu.to(device)
                alpha = alpha_cpu.to(device)

                # forward prop
                adv_patch, alpha = self.dot_applier(adv_patch, alpha)  # put dots on patch

                applied_batch = self.patch_applier(img_batch, adv_patch, alpha)  # apply patch on a batch of images

                if epoch == 0 and i_batch == 0:
                    self.save_initial_patch(adv_patch, alpha)

                output_patch = self.yolov5(applied_batch)[0]  # get yolo output with patch

                max_prob, cor_det = self.detections(lab_batch, output_patch, img_names)

                nps = self.non_printability_score(self.dot_applier.colors)

                noise = self.noise_amount(self.dot_applier.radius)

                loss, loss_arr = self.loss_function(max_prob, cor_det, nps, noise)  # calculate loss

                # save losses
                max_prob_loss += loss_arr[0].item()
                cor_det_loss += loss_arr[1].item()
                nps_loss += loss_arr[2].item()
                noise_loss += loss_arr[3].item()
                train_loss += loss.item()

                # back prop
                optimizer.zero_grad()
                loss.backward()

                # update parameters
                optimizer.step()
                self.dot_applier.apply(clipper)  # clip x,y coordinates

                progress_bar.set_postfix_str(prog_bar_desc.format(train_loss / (i_batch + 1),
                                                                  max_prob_loss / (i_batch + 1),
                                                                  cor_det_loss / (i_batch + 1),
                                                                  nps_loss / (i_batch + 1),
                                                                  noise_loss / (i_batch + 1)))

                if i_batch % 1 == 0 and self.writer is not None:
                    self.write_to_tensorboard(adv_patch, alpha,train_loss, max_prob_loss, cor_det_loss, nps_loss, noise_loss,
                                              epoch_length, epoch, i_batch, optimizer)
                # self.writer.add_image('patch', adv_patch.squeeze(0), epoch_length * epoch + i_batch)
                if i_batch + 1 == epoch_length:
                    self.last_batch_calc(adv_patch, alpha, epoch_length, progress_bar, prog_bar_desc,
                                         train_loss, max_prob_loss, cor_det_loss, nps_loss, noise_loss,
                                         optimizer, epoch, i_batch)

                # self.run_slide_show(adv_patch)

                # clear gpu
                del img_batch, lab_batch, applied_batch, output_patch, max_prob, cor_det, nps, noise, loss
                torch.cuda.empty_cache()

            # check if loss has decreased
            if early_stop(self.val_losses[-1], adv_patch.cpu(), alpha.cpu(), epoch):
                self.final_epoch_count = epoch
                break

            scheduler.step(self.val_losses[-1])

        self.adv_patch = early_stop.best_patch
        self.alpha = early_stop.best_alpha
        print("Training finished")

    def get_image_size(self):
        if type(self.yolov2) == nn.DataParallel:
            img_size = self.yolov2.module.height
        else:
            img_size = self.yolov2.height
        return int(img_size)

    def evaluate_loss(self, loader, adv_patch, alpha):
        total_loss = 0.0
        for img_batch, lab_batch, img_names in loader:
            with torch.no_grad():
                img_batch = img_batch.to(device)
                lab_batch = lab_batch.to(device)

                applied_batch = self.patch_applier(img_batch, adv_patch, alpha)
                output_patch = self.yolov5(applied_batch)[0]
                max_prob, cor_det = self.detections(lab_batch, output_patch, img_names)
                nps = self.non_printability_score(self.dot_applier.colors)
                noise = self.noise_amount(self.dot_applier.radius)
                batch_loss, _ = self.loss_function(max_prob, cor_det, nps, noise)
                total_loss += batch_loss.item()

                del img_batch, lab_batch, applied_batch, output_patch, max_prob, cor_det, nps, noise, batch_loss
                torch.cuda.empty_cache()
        loss = total_loss / len(loader)
        return loss

    def plot_train_val_loss(self):
        epochs = [x + 1 for x in range(len(self.train_losses))]
        plt.plot(epochs, self.train_losses, 'b', label='Training loss')
        plt.plot(epochs, self.val_losses, 'r', label='Validation loss')
        plt.title('Training and validation loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend(loc='upper right')
        plt.savefig(self.current_dir + '/final_results/train_val_loss_plt.png')
        plt.close()

    def plot_separate_loss(self):
        epochs = [x + 1 for x in range(len(self.train_losses))]
        weights = np.array([self.config.max_prob_weight, self.config.pres_det_weight, self.config.nps_weight, self.config.noise_weight])
        number_of_subplots = weights[weights > 0].astype(np.bool).sum()
        fig, axes = plt.subplots(nrows=1, ncols=number_of_subplots, figsize=(5 * number_of_subplots, 3 * number_of_subplots), squeeze=False)
        for idx, (weight, loss, label, color_name) in enumerate(zip([self.config.max_prob_weight, self.config.pres_det_weight, self.config.nps_weight, self.config.noise_weight],
                                                                    [self.max_prob_losses, self.cor_det_losses, self.nps_losses, self.noise_losses],
                                                                    ['Max probability loss', 'Correct detections loss', 'Non printability loss', 'Noise Amount loss'],
                                                                    'brgkyc')):
            if weight > 0:
                axes[0, idx].plot(epochs, loss, c=color_name, label=label)
                axes[0, idx].set_xlabel('Epoch')
                axes[0, idx].set_ylabel('Loss')
                axes[0, idx].legend(loc='upper right')
        fig.tight_layout()
        plt.savefig(self.current_dir + '/final_results/separate_loss_plt.png')
        plt.close()

    def plot_combined(self):
        epochs = [x + 1 for x in range(len(self.train_losses))]
        fig, ax1 = plt.subplots(ncols=1, figsize=(8, 4))
        ax1.plot(epochs, self.max_prob_losses, c='b', label='Max Probability')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.tick_params(axis='y', labelcolor='b')
        ax2 = ax1.twinx()
        ax2.plot(epochs, self.cor_det_losses, c='r', label='Correct Detections')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Loss')
        ax2.tick_params(axis='y', labelcolor='r')
        ax2.legend(loc='upper right')
        fig.tight_layout()
        plt.savefig(self.current_dir + '/final_results/combined_losses.png')
        plt.close()

    def loss_function(self, max_prob, cor_det, nps, noise):
        max_prob_loss = self.config.max_prob_weight * torch.mean(max_prob)
        cor_det_loss = self.config.pres_det_weight * torch.mean(cor_det)
        nps_loss = self.config.nps_weight * nps
        noise_loss = self.config.noise_weight * noise
        return max_prob_loss + cor_det_loss + nps_loss + noise_loss, [max_prob_loss, cor_det_loss, nps_loss, noise_loss]

    def save_final_objects(self):
        # save patch
        transforms.ToPILImage()(self.adv_patch.squeeze(0)).save(
            self.current_dir + '/final_results/final_patch_wo_alpha.png', 'PNG')
        torch.save(self.adv_patch, self.current_dir + '/final_results/final_patch_raw.pt')
        transforms.ToPILImage()(self.alpha.squeeze(0)).save(self.current_dir + '/final_results/alpha.png', 'PNG')
        torch.save(self.alpha, self.current_dir + '/final_results/alpha_raw.pt')
        final_patch = torch.cat([self.adv_patch.squeeze(0), self.alpha.squeeze(0)])
        transforms.ToPILImage()(final_patch.cpu()).save(self.current_dir + '/final_results/final_patch_w_alpha.png',
                                                        'PNG')
        # save losses
        with open(self.current_dir + '/losses/train_losses', 'wb') as fp:
            pickle.dump(self.train_losses, fp)
        with open(self.current_dir + '/losses/val_losses', 'wb') as fp:
            pickle.dump(self.val_losses, fp)
        with open(self.current_dir + '/losses/max_prob_losses', 'wb') as fp:
            pickle.dump(self.max_prob_losses, fp)
        with open(self.current_dir + '/losses/cor_det_losses', 'wb') as fp:
            pickle.dump(self.cor_det_losses, fp)
        with open(self.current_dir + '/losses/nps_losses', 'wb') as fp:
            pickle.dump(self.nps_losses, fp)
        with open(self.current_dir + '/losses/noise_losses', 'wb') as fp:
            pickle.dump(self.noise_losses, fp)

    def save_final_results(self, avg_precision):
        target_noise_ap, target_patch_ap, other_noise_ap, other_patch_ap = avg_precision
        # calculate test loss
        test_loss = self.evaluate_loss(self.test_loader, self.adv_patch.to(device),
                                       self.alpha.to(device))
        print("Test loss: " + str(test_loss))
        self.save_config_details()
        row_to_csv = \
            str(self.train_losses[-1]) + ',' + \
            str(self.val_losses[-1]) + ',' + \
            str(test_loss) + ',' + \
            str(self.max_prob_losses[-1]) + ',' + \
            str(self.cor_det_losses[-1]) + ',' + \
            str(self.nps_losses[-1]) + ',' + \
            str(self.final_epoch_count) + '/' + str(self.config.epochs) + ',' + \
            str(target_noise_ap) + ',' + \
            str(target_patch_ap) + ',' + \
            str(other_noise_ap) + ',' + \
            str(other_patch_ap) + '\n'

        # write results to csv
        with open('experiments/results.csv', 'a') as fd:
            fd.write(row_to_csv)

    def write_to_tensorboard(self, adv_patch, alpha, train_loss, max_prob_loss, cor_det_loss, nps_loss, noise_loss,
                             epoch_length, epoch, i_batch, optimizer):
        iteration = epoch_length * epoch + i_batch
        self.writer.add_scalar('train_loss', train_loss / (i_batch + 1), iteration)
        self.writer.add_scalar('loss/max_prob_loss', max_prob_loss / (i_batch + 1), iteration)
        self.writer.add_scalar('loss/cor_det_loss', cor_det_loss / (i_batch + 1), iteration)
        self.writer.add_scalar('loss/nps_loss', nps_loss / (i_batch + 1), iteration)
        self.writer.add_scalar('loss/noise_loss', noise_loss / (i_batch + 1), iteration)
        self.writer.add_scalar('misc/epoch', epoch, iteration)
        self.writer.add_scalar('misc/loc_learning_rate', optimizer.param_groups[0]["lr"], iteration)
        self.writer.add_scalar('misc/color_learning_rate', optimizer.param_groups[1]["lr"], iteration)
        self.writer.add_scalar('misc/radius_learning_rate', optimizer.param_groups[2]["lr"], iteration)
        self.writer.add_image('patch_rgb', adv_patch.squeeze(0), iteration)
        self.writer.add_image('patch_rgba', torch.cat([adv_patch.squeeze(0), alpha.squeeze(0)]), iteration)

    def last_batch_calc(self, adv_patch, alpha, epoch_length, progress_bar, prog_bar_desc,
                        train_loss, max_prob_loss, cor_det_loss, nps_loss, noise_loss,
                        optimizer, epoch, i_batch):
        # calculate epoch losses
        train_loss /= epoch_length
        max_prob_loss /= epoch_length
        cor_det_loss /= epoch_length
        nps_loss /= epoch_length
        noise_loss /= epoch_length
        self.train_losses.append(train_loss)
        self.max_prob_losses.append(max_prob_loss)
        self.cor_det_losses.append(cor_det_loss)
        self.nps_losses.append(nps_loss)
        self.noise_losses.append(noise_loss)

        # check on validation
        val_loss = self.evaluate_loss(self.val_loader, adv_patch, alpha)
        self.val_losses.append(val_loss)

        prog_bar_desc += ', val-loss: {:.6}, loc-lr: {:.6}, color-lr: {:.6}, radius-lr: {:.6}'
        progress_bar.set_postfix_str(prog_bar_desc.format(train_loss,
                                                          max_prob_loss,
                                                          cor_det_loss,
                                                          nps_loss,
                                                          noise_loss,
                                                          val_loss,
                                                          optimizer.param_groups[0]['lr'],
                                                          optimizer.param_groups[1]['lr'],
                                                          optimizer.param_groups[2]['lr']))
        if self.writer is not None:
            self.writer.add_scalar('loss/val_loss', val_loss, epoch_length * epoch + i_batch)

    def get_clean_image_conf(self):
        clean_img_dict = dict()
        for loader in [self.train_loader, self.val_loader, self.test_loader]:
            for img_batch, lab_batch, img_name in loader:
                img_batch = img_batch.to(device)
                lab_batch = lab_batch.to(device)

                output = self.yolov5(img_batch)[0]
                output = output.transpose(1, 2).contiguous()
                output_objectness, output = output[:, 4, :], output[:, 5:, :]
                batch_idx = torch.index_select(lab_batch, 2, torch.tensor([0], dtype=torch.long).to(device))
                for i in range(batch_idx.size()[0]):
                    ids = np.unique(
                        batch_idx[i][(batch_idx[i] >= 0) & (batch_idx[i] != self.config.class_id)].cpu().numpy().astype(
                            int))
                    if len(ids) == 0:
                        continue
                    clean_img_dict[img_name[i]] = dict()
                    # get relevant classes
                    confs_for_class = output[i, ids, :]
                    confs_if_object = self.config.loss_target(output_objectness[i], confs_for_class)

                    # find the max prob for each related class
                    max_conf, _ = torch.max(confs_if_object, dim=1)
                    for j, label in enumerate(ids):
                        clean_img_dict[img_name[i]][label] = max_conf[j].item()

                del img_batch, lab_batch, output, output_objectness, batch_idx
                torch.cuda.empty_cache()

        print(len(clean_img_dict))
        np.save('confidences/' + self.config.model_name + '/medium/clean_img_conf_lisa_ordered.npy', clean_img_dict)

    # def get_clean_image_conf(self):
    #     clean_img_dict = dict()
    #     for loader in [self.train_loader, self.val_loader, self.test_loader]:
    #         for img_batch, lab_batch, img_name in loader:
    #             img_batch = img_batch.to(device)
    #             lab_batch = lab_batch.to(device)
    #
    #             output = self.yolov2(img_batch)
    #             batch = output.size(0)
    #             h = output.size(2)
    #             w = output.size(3)
    #             output = output.view(batch, self.yolov2.num_anchors, 5 + self.config.num_classes,
    #                                  h * w)  # [batch, 5, 85, 361]
    #             output = output.transpose(1, 2).contiguous()  # [batch, 85, 5, 361]
    #             output = output.view(batch, 5 + self.config.num_classes,
    #                                  self.yolov2.num_anchors * h * w)  # [batch, 85, 1805]
    #             output_objectness = torch.sigmoid(output[:, 4, :])  # [batch, 1805]
    #             output = output[:, 5:5 + self.config.num_classes, :]  # [batch, 80, 1805]
    #             normal_confs = torch.nn.Softmax(dim=1)(output)  # [batch, 80, 1805]
    #             batch_idx = torch.index_select(lab_batch, 2, torch.tensor([0], dtype=torch.long).to(device))
    #             for i in range(batch_idx.size(0)):
    #                 ids = np.unique(
    #                     batch_idx[i][(batch_idx[i] >= 0) & (batch_idx[i] != self.config.class_id)].cpu().numpy().astype(
    #                         int))
    #                 if len(ids) == 0:
    #                     continue
    #                 clean_img_dict[img_name[i]] = dict()
    #                 # get relevant classes
    #                 confs_for_class = normal_confs[i, ids, :]
    #                 confs_if_object = self.config.loss_target(output_objectness[i], confs_for_class)
    #
    #                 # find the max prob for each related class
    #                 max_conf, _ = torch.max(confs_if_object, dim=1)
    #                 for j, label in enumerate(ids):
    #                     clean_img_dict[img_name[i]][label] = max_conf[j].item()
    #
    #             del img_batch, lab_batch, output, output_objectness, normal_confs, batch_idx
    #             torch.cuda.empty_cache()
    #
    #     print(len(clean_img_dict))
    #     np.save('confidences/clean_img_conf_lisa_new_color.npy', clean_img_dict)

    def save_config_details(self):
        # write results to csv
        row_to_csv = self.current_dir.split('/')[-1] + ',' + \
                     self.config.model_name + ',' + \
                     self.config.img_dir.split('/')[-2] + ',' + \
                     str(self.config.loc_lr) + '-' + str(self.config.color_lr) + '-' + str(self.config.radius_lr) + ',' + \
                     str(self.config.sched_cooldown) + ',' + \
                     str(self.config.sched_patience) + ',' + \
                     str(self.config.loss_mode) + ',' + \
                     str(self.config.conf_threshold) + ',' + \
                     str(self.config.max_prob_weight) + ',' + \
                     str(self.config.pres_det_weight) + ',' + \
                     str(self.config.nps_weight) + ',' + \
                     str(self.config.num_of_dots) + ',' + \
                     str(None) + ',' + \
                     str(self.config.alpha_max) + ',' + \
                     str(self.config.beta_dropoff) + ','
        with open('experiments/results.csv', 'a') as fd:
            fd.write(row_to_csv)

    def save_initial_patch(self, adv_patch, alpha):
        transforms.ToPILImage()(adv_patch.cpu().squeeze(0)).save(self.current_dir + '/saved_patches/initial_patch.png')
        transforms.ToPILImage()(alpha.cpu().squeeze(0)).save(self.current_dir + '/saved_patches/initial_alpha.png')

    @staticmethod
    def run_slide_show(adv_patch):
        adv_to_show = adv_patch.detach().cpu()
        adv_to_show = torch.where(adv_to_show == 0, torch.ones_like(adv_to_show), adv_to_show)
        transforms.ToPILImage()(adv_to_show.squeeze(0)).save('current_slide.jpg')
        img = cv2.imread('current_slide.jpg')
        cv2.imshow('slide show', img)
        cv2.waitKey(1)