예제 #1
0
class Ganomaly(object):
    """GANomaly Class
    """
    @staticmethod
    def name():
        """Return name of the class.
        """
        return 'Ganomaly'

    def __init__(self, opt, dataloader=None):
        super(Ganomaly, self).__init__()
        ##
        # Initalize variables.
        self.opt = opt
        self.visualizer = Visualizer(opt)
        self.dataloader = dataloader
        self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train')
        self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test')
        self.device = torch.device(
            "cuda:0" if self.opt.device != 'cpu' else "cpu")

        # -- Discriminator attributes.
        self.out_d_real = None
        self.feat_real = None
        self.err_d_real = None
        self.fake = None
        self.latent_i = None
        self.latent_o = None
        self.out_d_fake = None
        self.feat_fake = None
        self.err_d_fake = None
        self.err_d = None

        # -- Generator attributes.
        self.out_g = None
        self.err_g_bce = None
        self.err_g_l1l = None
        self.err_g_enc = None
        self.err_g = None

        # -- Misc attributes
        self.epoch = 0
        self.times = []
        self.total_steps = 0

        ##
        # Create and initialize networks.
        self.netg = NetG(self.opt).to(self.device)
        self.netd = NetD(self.opt).to(self.device)
        self.netg.apply(weights_init)
        self.netd.apply(weights_init)

        ##
        if self.opt.resume != '':
            print("\nLoading pre-trained networks.")
            self.opt.iter = torch.load(
                os.path.join(self.opt.resume, 'netG.pth'))['epoch']
            self.netg.load_state_dict(
                torch.load(os.path.join(self.opt.resume,
                                        'netG.pth'))['state_dict'])
            self.netd.load_state_dict(
                torch.load(os.path.join(self.opt.resume,
                                        'netD.pth'))['state_dict'])
            print("\tDone.\n")

        # print(self.netg)
        # print(self.netd)

        ##
        # Loss Functions
        self.bce_criterion = nn.BCELoss()
        self.l1l_criterion = nn.L1Loss()
        self.l2l_criterion = l2_loss

        ##
        # Initialize input tensors.
        self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize,
                                       self.opt.isize),
                                 dtype=torch.float32,
                                 device=self.device)
        self.label = torch.empty(size=(self.opt.batchsize, ),
                                 dtype=torch.float32,
                                 device=self.device)
        self.gt = torch.empty(size=(opt.batchsize, ),
                              dtype=torch.long,
                              device=self.device)
        self.fixed_input = torch.empty(size=(self.opt.batchsize, 3,
                                             self.opt.isize, self.opt.isize),
                                       dtype=torch.float32,
                                       device=self.device)
        self.real_label = 1
        self.fake_label = 0

        ##
        # Setup optimizer
        if self.opt.isTrain:
            self.netg.train()
            self.netd.train()
            self.optimizer_d = optim.Adam(self.netd.parameters(),
                                          lr=self.opt.lr,
                                          betas=(self.opt.beta1, 0.999))
            self.optimizer_g = optim.Adam(self.netg.parameters(),
                                          lr=self.opt.lr,
                                          betas=(self.opt.beta1, 0.999))

    ##
    def set_input(self, input):
        """ Set input and ground truth

        Args:
            input (FloatTensor): Input data for batch i.
        """
        with torch.no_grad():
            self.input.resize_(input[0].size()).copy_(input[0])
            self.gt.resize_(input[1].size()).copy_(input[1])

        # Copy the first batch as the fixed input.
        if self.total_steps == self.opt.batchsize:
            with torch.no_grad():
                self.fixed_input.resize_(input[0].size()).copy_(input[0])

    ##
    def update_netd(self):
        """
        Update D network: Ladv = |f(real) - f(fake)|_2
        """
        ##
        # Feature Matching.
        self.netd.zero_grad()
        # --
        # Train with real
        with torch.no_grad():
            self.label.resize_(self.opt.batchsize).fill_(self.real_label)
        self.out_d_real, self.feat_real = self.netd(self.input)
        # --
        # Train with fake
        with torch.no_grad():
            self.label.resize_(self.opt.batchsize).fill_(self.fake_label)
        self.fake, self.latent_i, self.latent_o = self.netg(self.input)
        self.out_d_fake, self.feat_fake = self.netd(self.fake.detach())
        # --
        self.err_d = l2_loss(self.feat_real, self.feat_fake)
        self.err_d_real = self.err_d
        self.err_d_fake = self.err_d
        self.err_d.backward()
        self.optimizer_d.step()

    ##
    def reinitialize_netd(self):
        """ Initialize the weights of netD
        """
        self.netd.apply(weights_init)
        print('Reloading d net')

    ##
    def update_netg(self):
        """
        # ============================================================ #
        # (2) Update G network: log(D(G(x)))  + ||G(x) - x||           #
        # ============================================================ #

        """
        self.netg.zero_grad()
        with torch.no_grad():
            self.label.resize_(self.opt.batchsize).fill_(self.real_label)
        self.out_g, _ = self.netd(self.fake)

        self.err_g_bce = self.bce_criterion(self.out_g, self.label)
        self.err_g_l1l = self.l1l_criterion(
            self.fake, self.input)  # constrain x' to look like x
        self.err_g_enc = self.l2l_criterion(self.latent_o, self.latent_i)
        self.err_g = self.err_g_bce * self.opt.w_bce + self.err_g_l1l * self.opt.w_rec + self.err_g_enc * self.opt.w_enc

        self.err_g.backward(retain_graph=True)
        self.optimizer_g.step()

    ##
    def optimize(self):
        """ Optimize netD and netG  networks.
        """

        self.update_netd()
        self.update_netg()

        # If D loss is zero, then re-initialize netD
        if self.err_d_real.item() < 1e-5 or self.err_d_fake.item() < 1e-5:
            self.reinitialize_netd()

    ##
    def get_errors(self):
        """ Get netD and netG errors.

        Returns:
            [OrderedDict]: Dictionary containing errors.
        """

        errors = OrderedDict([('err_d', self.err_d.item()),
                              ('err_g', self.err_g.item()),
                              ('err_d_real', self.err_d_real.item()),
                              ('err_d_fake', self.err_d_fake.item()),
                              ('err_g_bce', self.err_g_bce.item()),
                              ('err_g_l1l', self.err_g_l1l.item()),
                              ('err_g_enc', self.err_g_enc.item())])

        return errors

    ##
    def get_current_images(self):
        """ Returns current images.

        Returns:
            [reals, fakes, fixed]
        """

        reals = self.input.data
        fakes = self.fake.data
        fixed = self.netg(self.fixed_input)[0].data

        return reals, fakes, fixed

    ##
    def save_weights(self, epoch):
        """Save netG and netD weights for the current epoch.

        Args:
            epoch ([int]): Current epoch number.
        """

        weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train',
                                  'weights')
        if not os.path.exists(weight_dir):
            os.makedirs(weight_dir)

        torch.save({
            'epoch': epoch + 1,
            'state_dict': self.netg.state_dict()
        }, '%s/netG.pth' % (weight_dir))
        torch.save({
            'epoch': epoch + 1,
            'state_dict': self.netd.state_dict()
        }, '%s/netD.pth' % (weight_dir))

    ##
    def train_epoch(self):
        """ Train the model for one epoch.
        """

        self.netg.train()
        epoch_iter = 0
        for data in tqdm(self.dataloader['train'],
                         leave=False,
                         total=len(self.dataloader['train'])):
            self.total_steps += self.opt.batchsize
            epoch_iter += self.opt.batchsize

            self.set_input(data)
            self.optimize()

            if self.total_steps % self.opt.print_freq == 0:
                errors = self.get_errors()
                if self.opt.display:
                    counter_ratio = float(epoch_iter) / len(
                        self.dataloader['train'].dataset)
                    self.visualizer.plot_current_errors(
                        self.epoch, counter_ratio, errors)

            if self.total_steps % self.opt.save_image_freq == 0:
                reals, fakes, fixed = self.get_current_images()
                self.visualizer.save_current_images(self.epoch, reals, fakes,
                                                    fixed)
                if self.opt.display:
                    self.visualizer.display_current_images(reals, fakes, fixed)

        print(">> Training model %s. Epoch %d/%d" %
              (self.name(), self.epoch + 1, self.opt.niter))
        # self.visualizer.print_current_errors(self.epoch, errors)

    ##
    def train(self):
        """ Train the model
        """

        ##
        # TRAIN
        self.total_steps = 0
        best_auc = 0

        # Train for niter epochs.
        print(">> Training model %s." % self.name())
        for self.epoch in range(self.opt.iter, self.opt.niter):
            # Train for one epoch
            self.train_epoch()
            res = self.test()
            if res['AUC'] > best_auc:
                best_auc = res['AUC']
                self.save_weights(self.epoch)
            self.visualizer.print_current_performance(res, best_auc)
        print(">> Training model %s.[Done]" % self.name())

    ##
    def test(self):
        """ Test GANomaly model.

        Args:
            dataloader ([type]): Dataloader for the test set

        Raises:
            IOError: Model weights not found.
        """
        with torch.no_grad():
            # Load the weights of netg and netd.
            if self.opt.load_weights:
                path = "./output/{}/{}/train/weights/netG.pth".format(
                    self.name().lower(), self.opt.dataset)
                pretrained_dict = torch.load(path)['state_dict']

                try:
                    self.netg.load_state_dict(pretrained_dict)
                except IOError:
                    raise IOError("netG weights not found")
                print('   Loaded weights.')

            self.opt.phase = 'test'

            # Create big error tensor for the test set.
            self.an_scores = torch.zeros(size=(len(
                self.dataloader['test'].dataset), ),
                                         dtype=torch.float32,
                                         device=self.device)
            self.gt_labels = torch.zeros(size=(len(
                self.dataloader['test'].dataset), ),
                                         dtype=torch.long,
                                         device=self.device)
            self.latent_i = torch.zeros(size=(len(
                self.dataloader['test'].dataset), self.opt.nz),
                                        dtype=torch.float32,
                                        device=self.device)
            self.latent_o = torch.zeros(size=(len(
                self.dataloader['test'].dataset), self.opt.nz),
                                        dtype=torch.float32,
                                        device=self.device)

            # print("   Testing model %s." % self.name())
            self.times = []
            self.total_steps = 0
            epoch_iter = 0
            for i, data in enumerate(self.dataloader['test'], 0):
                self.total_steps += self.opt.batchsize
                epoch_iter += self.opt.batchsize
                time_i = time.time()
                self.set_input(data)
                self.fake, latent_i, latent_o = self.netg(self.input)

                error = torch.mean(torch.pow((latent_i - latent_o), 2), dim=1)
                time_o = time.time()

                self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize +
                               error.size(0)] = error.reshape(error.size(0))
                self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize +
                               error.size(0)] = self.gt.reshape(error.size(0))
                self.latent_i[i * self.opt.batchsize:i * self.opt.batchsize +
                              error.size(0), :] = latent_i.reshape(
                                  error.size(0), self.opt.nz)
                self.latent_o[i * self.opt.batchsize:i * self.opt.batchsize +
                              error.size(0), :] = latent_o.reshape(
                                  error.size(0), self.opt.nz)

                self.times.append(time_o - time_i)

                # Save test images.
                if self.opt.save_test_images:
                    dst = os.path.join(self.opt.outf, self.opt.name, 'test',
                                       'images')
                    if not os.path.isdir(dst):
                        os.makedirs(dst)
                    real, fake, _ = self.get_current_images()
                    vutils.save_image(real,
                                      '%s/real_%03d.eps' % (dst, i + 1),
                                      normalize=True)
                    vutils.save_image(fake,
                                      '%s/fake_%03d.eps' % (dst, i + 1),
                                      normalize=True)

            # Measure inference time.
            self.times = np.array(self.times)
            self.times = np.mean(self.times[:100] * 1000)

            # Scale error vector between [0, 1]
            self.an_scores = (self.an_scores - torch.min(self.an_scores)) / (
                torch.max(self.an_scores) - torch.min(self.an_scores))
            # auc, eer = roc(self.gt_labels, self.an_scores)
            auc = evaluate(self.gt_labels,
                           self.an_scores,
                           metric=self.opt.metric)
            performance = OrderedDict([('Avg Run Time (ms/batch)', self.times),
                                       ('AUC', auc)])

            if self.opt.display_id > 0 and self.opt.phase == 'test':
                counter_ratio = float(epoch_iter) / len(
                    self.dataloader['test'].dataset)
                self.visualizer.plot_performance(self.epoch, counter_ratio,
                                                 performance)
            return performance
예제 #2
0
class Ganomaly(BaseModel):
    """GANomaly Class
    """
    @property
    def name(self):
        return 'Ganomaly'

    def __init__(self, opt, dataloader):
        super(Ganomaly, self).__init__(opt, dataloader)

        # -- Misc attributes
        self.epoch = 0
        self.times = []
        self.total_steps = 0

        ##
        # Create and initialize networks.
        self.netg = NetG(self.opt).to(self.device)
        self.netd = NetD(self.opt).to(self.device)
        self.netg.apply(weights_init)
        self.netd.apply(weights_init)
        if self.opt.classifier:
            self.netc_i = NetC(self.opt).to(self.device)
            self.netc_o = NetC(self.opt).to(self.device)
            self.netc_i.apply(weights_init)
            self.netc_o.apply(weights_init)

        ##
        if self.opt.resume != '':
            print("\nLoading pre-trained networks.")
            self.opt.iter = torch.load(
                os.path.join(self.opt.resume, 'netG.pth'))['epoch']
            self.netg.load_state_dict(
                torch.load(os.path.join(self.opt.resume,
                                        'netG.pth'))['state_dict'])
            self.netd.load_state_dict(
                torch.load(os.path.join(self.opt.resume,
                                        'netD.pth'))['state_dict'])
            print("\tDone.\n")
        if self.opt.z_resume != '':
            print("\nLoading pre-trained z_networks.")
            self.opt.iter = torch.load(
                os.path.join(self.opt.z_resume, 'netC_i.pth'))['epoch']
            self.netc_i.load_state_dict(
                torch.load(os.path.join(self.opt.z_resume,
                                        'netC_i.pth'))['state_dict'])
            self.netc_o.load_state_dict(
                torch.load(os.path.join(self.opt.z_resume,
                                        'netC_o.pth'))['state_dict'])
            print("\tDone.\n")

        self.l_adv = l2_loss
        self.l_con = nn.L1Loss()
        self.l_enc = l2_loss
        self.l_bce = nn.BCELoss()

        ##
        # Initialize input tensors.
        self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize,
                                       self.opt.isize),
                                 dtype=torch.float32,
                                 device=self.device)
        self.label = torch.empty(size=(self.opt.batchsize, ),
                                 dtype=torch.float32,
                                 device=self.device)
        self.gt = torch.empty(size=(opt.batchsize, ),
                              dtype=torch.long,
                              device=self.device)
        self.fixed_input = torch.empty(size=(self.opt.batchsize, 3,
                                             self.opt.isize, self.opt.isize),
                                       dtype=torch.float32,
                                       device=self.device)
        self.real_label = torch.ones(size=(self.opt.batchsize, ),
                                     dtype=torch.float32,
                                     device=self.device)
        self.fake_label = torch.zeros(size=(self.opt.batchsize, ),
                                      dtype=torch.float32,
                                      device=self.device)
        ##
        # Initialize input tensors for classifier
        self.i_input = torch.empty(size=(self.opt.batchsize, 1, self.sqrtnz,
                                         self.sqrtnz),
                                   dtype=torch.float32,
                                   device=self.device)
        self.o_input = torch.empty(size=(self.opt.batchsize, 1,
                                         int(self.opt.nz**0.5),
                                         int(self.opt.nz**0.5)),
                                   dtype=torch.float32,
                                   device=self.device)
        self.i_gt = torch.empty(size=(opt.batchsize, ),
                                dtype=torch.long,
                                device=self.device)
        self.o_gt = torch.empty(size=(opt.batchsize, ),
                                dtype=torch.long,
                                device=self.device)
        self.i_label = torch.empty(size=(self.opt.batchsize, ),
                                   dtype=torch.float32,
                                   device=self.device)
        self.o_label = torch.empty(size=(self.opt.batchsize, ),
                                   dtype=torch.float32,
                                   device=self.device)
        self.i_real_label = torch.zeros(size=(self.opt.batchsize, ),
                                        dtype=torch.float32,
                                        device=self.device)
        self.o_real_label = torch.zeros(size=(self.opt.batchsize, ),
                                        dtype=torch.float32,
                                        device=self.device)

        ##
        # Setup optimizer
        if self.opt.isTrain:
            self.netg.train()
            self.netd.train()
            self.optimizer_d = optim.Adam(self.netd.parameters(),
                                          lr=self.opt.lr,
                                          betas=(self.opt.beta1, 0.999))
            self.optimizer_g = optim.Adam(self.netg.parameters(),
                                          lr=self.opt.lr,
                                          betas=(self.opt.beta1, 0.999))
            if self.opt.classifier:
                self.netc_i.train()
                self.netc_o.train()
                self.optimizer_i = optim.Adam(self.netc_i.parameters(),
                                              lr=self.opt.lr,
                                              betas=(self.opt.beta1, 0.999))
                self.optimizer_o = optim.Adam(self.netc_o.parameters(),
                                              lr=self.opt.lr,
                                              betas=(self.opt.beta1, 0.999))

    ##
    def forward_g(self):
        """ Forward propagate through netG
        """
        self.fake, self.latent_i, self.latent_o = self.netg(self.input)

    ##
    def forward_d(self):
        """ Forward propagate through netD
        """
        self.pred_real, self.feat_real = self.netd(self.input)
        self.pred_fake, self.feat_fake = self.netd(self.fake.detach())

    ##
    def backward_g(self):
        """ Backpropagate through netG
        """
        self.err_g_adv = self.l_adv(
            self.netd(self.input)[1],
            self.netd(self.fake)[1])
        self.err_g_con = self.l_con(self.fake, self.input)
        self.err_g_enc = self.l_enc(self.latent_o, self.latent_i)
        self.err_g = self.err_g_adv * self.opt.w_adv + \
                     self.err_g_con * self.opt.w_con + \
                     self.err_g_enc * self.opt.w_enc
        self.err_g.backward(retain_graph=True)

    ##
    def backward_d(self):
        """ Backpropagate through netD
        """
        # Real - Fake Loss
        self.err_d_real = self.l_bce(self.pred_real, self.real_label)
        self.err_d_fake = self.l_bce(self.pred_fake, self.fake_label)

        # NetD Loss & Backward-Pass
        self.err_d = (self.err_d_real + self.err_d_fake) * 0.5
        self.err_d.backward()

    ##
    def reinit_d(self):
        """ Re-initialize the weights of netD
        """
        self.netd.apply(weights_init)
        if (self.opt.strengthen != 1): print('   Reloading net d')

    def optimize_params(self):
        """ Forwardpass, Loss Computation and Backwardpass.
        """
        # Forward-pass
        self.forward_g()
        self.forward_d()

        # Backward-pass
        # netg
        self.optimizer_g.zero_grad()
        self.backward_g()
        self.optimizer_g.step()

        # netd
        self.optimizer_d.zero_grad()
        self.backward_d()
        self.optimizer_d.step()
        if self.err_d.item() < 1e-5: self.reinit_d()

    ##
    def save_weights_z(self, epoch):
        """Save netG and netD weights for the current epoch.

        Args:
            epoch ([int]): Current epoch number.
        """

        weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train',
                                  'weights')
        if not os.path.exists(weight_dir): os.makedirs(weight_dir)

        torch.save({
            'epoch': epoch + 1,
            'state_dict': self.netC_i.state_dict()
        }, '%s/netC_i.pth' % (weight_dir))
        torch.save({
            'epoch': epoch + 1,
            'state_dict': self.netC_o.state_dict()
        }, '%s/netC_o.pth' % (weight_dir))

    def forward_i(self):
        """ Forward propagate through netC_i
        """
        self.pred_abn_i = self.netc_i(self.i_input)

    def forward_o(self):
        """ Forward propagate through netC_o
        """
        self.pred_abn_o = self.netc_o(self.o_input)

    def backward_i(self):
        """ Backpropagate through netC_i
        """
        # Real - Fake Loss
        self.err_i = self.l_bce(self.pred_abn_i, self.i_real_label)

        # NetD Loss & Backward-Pass
        self.err_i.backward()

    def backward_o(self):
        """ Backpropagate through netC_o
        """
        # Real - Fake Loss
        self.err_o = self.l_bce(self.pred_abn_o, self.o_real_label)

        # NetD Loss & Backward-Pass
        self.err_o.backward()

    def reinit_i(self):
        """ Re-initialize the weights of netC_i
        """
        self.netc_i.apply(weights_init)
        if (self.opt.strengthen != 1): print('   Reloading net i')

    def reinit_o(self):
        """ Re-initialize the weights of netC_o
        """
        self.netc_o.apply(weights_init)
        if (self.opt.strengthen != 1): print('   Reloading net o')

    def z_optimize_params(self, net):
        """ Forwardpass, Loss Computation and Backwardpass.
        """
        if net == 'i':
            # Forward-pass
            self.forward_i()

            # Backward-pass
            # netc_i
            self.optimizer_i.zero_grad()
            self.backward_i()
            self.optimizer_i.step()

            if self.err_i.item() < 1e-5: self.reinit_i()
        if net == 'o':
            # Forward-pass
            self.forward_o()

            # Backward-pass
            # netc_o
            self.optimizer_o.zero_grad()
            self.backward_o()
            self.optimizer_o.step()

            if self.err_o.item() < 1e-5: self.reinit_o()
예제 #3
0
class Ganomaly(BaseModel):
    """GANomaly Class
    """
    @property
    def name(self):
        return 'Ganomaly'

    def __init__(self, opt, dataloader):
        super(Ganomaly, self).__init__(opt, dataloader)

        # -- Misc attributes
        self.epoch = 0
        self.times = []
        self.total_steps = 0

        ##
        # Create and initialize networks.
        self.netg = NetG(self.opt).to(self.device)
        self.netd = NetD(self.opt).to(self.device)
        self.netg.apply(weights_init)
        self.netd.apply(weights_init)

        ##
        if self.opt.resume != '':
            print("\nLoading pre-trained networks.")
            self.opt.iter = torch.load(
                os.path.join(self.opt.resume, 'netG.pth'))['epoch']
            self.netg.load_state_dict(
                torch.load(os.path.join(self.opt.resume,
                                        'netG.pth'))['state_dict'])
            self.netd.load_state_dict(
                torch.load(os.path.join(self.opt.resume,
                                        'netD.pth'))['state_dict'])
            print("\tDone.\n")

        self.l_adv = l2_loss
        self.l_con = nn.L1Loss()
        self.l_enc = l2_loss
        self.l_bce = nn.BCELoss()

        ##
        # Initialize input tensors.
        self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize,
                                       self.opt.isize),
                                 dtype=torch.float32,
                                 device=self.device)
        self.label = torch.empty(size=(self.opt.batchsize, ),
                                 dtype=torch.float32,
                                 device=self.device)
        self.gt = torch.empty(size=(opt.batchsize, ),
                              dtype=torch.long,
                              device=self.device)
        self.fixed_input = torch.empty(size=(self.opt.batchsize, 3,
                                             self.opt.isize, self.opt.isize),
                                       dtype=torch.float32,
                                       device=self.device)
        self.real_label = torch.ones(size=(self.opt.batchsize, ),
                                     dtype=torch.float32,
                                     device=self.device)
        self.fake_label = torch.zeros(size=(self.opt.batchsize, ),
                                      dtype=torch.float32,
                                      device=self.device)
        ##
        # Setup optimizer
        if self.opt.isTrain:
            self.netg.train()
            self.netd.train()
            self.optimizer_d = optim.Adam(self.netd.parameters(),
                                          lr=self.opt.lr,
                                          betas=(self.opt.beta1, 0.999))
            self.optimizer_g = optim.Adam(self.netg.parameters(),
                                          lr=self.opt.lr,
                                          betas=(self.opt.beta1, 0.999))

    ##
    def forward_g(self):
        """ Forward propagate through netG
        """
        #self.fake, self.latent_i, self.latent_o = self.netg(self.input)
        self.fake, self.latent_i, self.latent_o = self.netg(self.input)

    ##
    def forward_d(self):
        """ Forward propagate through netD
        """
        self.pred_real, self.feat_real = self.netd(self.input)
        self.pred_fake, self.feat_fake = self.netd(self.fake.detach())

    ##
    def backward_g(self):
        """ Backpropagate through netG
        """
        self.err_g_adv = self.l_adv(
            self.netd(self.input)[1],
            self.netd(self.fake)[1])
        self.err_g_con = self.l_con(self.fake, self.input)
        self.err_g_enc = self.l_enc(self.latent_o, self.latent_i)
        self.err_g = self.err_g_adv * self.opt.w_adv + \
                     self.err_g_con * self.opt.w_con + \
                     self.err_g_enc * self.opt.w_enc
        self.err_g.backward(retain_graph=True)

    ##
    def backward_d(self):
        """ Backpropagate through netD
        """
        # Real - Fake Loss
        self.err_d_real = self.l_bce(self.pred_real, self.real_label)
        self.err_d_fake = self.l_bce(self.pred_fake, self.fake_label)

        # NetD Loss & Backward-Pass
        self.err_d = (self.err_d_real + self.err_d_fake) * 0.5
        self.err_d.backward()

    ##
    def reinit_d(self):
        """ Re-initialize the weights of netD
        """
        self.netd.apply(weights_init)
        print('   Reloading net d')

    def optimize_params(self):
        """ Forwardpass, Loss Computation and Backwardpass.
        """
        # Forward-pass
        self.forward_g()
        self.forward_d()

        # Backward-pass
        # netg
        self.optimizer_g.zero_grad()
        self.backward_g()
        self.optimizer_g.step()

        # netd
        self.optimizer_d.zero_grad()
        self.backward_d()
        self.optimizer_d.step()
        if self.err_d.item() < 1e-5: self.reinit_d()
예제 #4
0
class MNIST_UNET(nn.Module):
    def __init__(self, opt, dataloader):
        super(MNIST_UNET, self).__init__()

        self.opt = opt
        #self.visualizer = Visualizer(opt)
        self.dataloader = dataloader
        self.total_steps = len(dataloader)
        self.device = torch.device(
            'cuda:0' if self.opt.device != 'cpu' else 'cpu')

        self.netg = NetG(self.opt).to(self.device)
        self.netd = NetD(self.opt).to(self.device)
        weights_init(self.netg)
        weights_init(self.netd)

        self.l_adv = self.l2_loss
        self.l_con = nn.L1Loss()
        self.l_enc = self.l2_loss
        self.l_bce = nn.BCELoss()

        # Initialize input tensors.
        self.input_imgs = torch.empty(size=(self.opt.batchsize, self.opt.nc,
                                            self.opt.isize, self.opt.isize),
                                      dtype=torch.float32,
                                      device=self.device)
        #self.label = torch.empty(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device)
        #self.gt = torch.empty(size=(self.opt.batchsize, ), dtype=torch.long, device=self.device)

        self.real_label = torch.ones(size=(self.opt.batchsize, ),
                                     dtype=torch.float32,
                                     device=self.device)
        self.fake_label = torch.zeros(size=(self.opt.batchsize, ),
                                      dtype=torch.float32,
                                      device=self.device)

    def train(self):
        """

       Train the model.
        """
        ##
        # TRAIN
        self.netd.train()
        self.netg.train()
        optimizer_g = optim.Adam(self.netg.parameters(),
                                 lr=self.opt.lr,
                                 betas=(self.opt.beta1, 0.999))
        optimizer_d = optim.Adam(self.netd.parameters(),
                                 lr=self.opt.lr,
                                 betas=(self.opt.beta1, 0.999))

        # Train for niter epochs.
        print(">> Train model %s steps." % self.total_steps)
        #if self.opt.resume != '':
        #    netG_weights_path = os.path.join(self.opt.resume, 'netG.pth')
        #    netD_weights_path = os.path.join(self.opt.resume, 'netD.pth')
        #    if os.path.exists(netG_weights_path):

        self.step_reward = []
        for step in tqdm(range(self.total_steps)):
            # Train for one step
            step_iter = 0
            loss_d_step = 0
            loss_g_step = 0
            loss_g_adv_step = 0
            loss_g_con_step = 0
            loss_g_enc_step = 0

            current_dataloader = []
            next_batch = None

            for count, (input_imgs, gt) in enumerate(self.dataloader):
                if count < step:
                    current_dataloader.append(input_imgs)
                elif count == step:
                    next_batch = input_imgs
            self.set_input(next_batch)
            intrinsic_loss = self.calculate_intrinsic_loss()
            self.save_images(self.input_imgs, self.fake_imgs, step)
            print('step: %s, reward: %s.' % (step, intrinsic_loss))
            self.step_reward.append(intrinsic_loss)

            current_dataloader.append(next_batch)

            netG_weights_path = os.path.join(self.opt.resume, 'netG.pth')
            netD_weights_path = os.path.join(self.opt.resume, 'netD.pth')
            if os.path.exists(netG_weights_path) and os.path.exists(
                    netD_weights_path):
                self.netg.load_state_dict(
                    torch.load(netG_weights_path)['state_dict'])
                self.netd.load_state_dict(
                    torch.load(netD_weights_path)['state_dict'])

            for epoch in range(self.opt.niter):
                epoch_iter = 0
                loss_d_epoch = 0
                loss_g_epoch = 0
                loss_g_adv_epoch = 0
                loss_g_con_epoch = 0
                loss_g_enc_epoch = 0

                for input_imgs in current_dataloader:

                    self.set_input(input_imgs)
                    self.fake_imgs, self.latent_i, self.latent_o = self.netg(
                        self.input_imgs)
                    self.pred_real, self.feat_real = self.netd(self.input_imgs)
                    self.pred_fake, self.feat_fake = self.netd(
                        self.fake_imgs.detach())

                    # Update generator
                    optimizer_g.zero_grad()
                    self.err_g_adv = self.l_adv(
                        self.netd(self.input_imgs)[0],
                        self.netd(self.fake_imgs)[0])
                    self.err_g_con = self.l_con(self.input_imgs,
                                                self.fake_imgs)
                    self.err_g_enc = self.l_enc(self.latent_i, self.latent_o)
                    self.err_g = self.err_g_adv * self.opt.w_adv +\
                                self.err_g_con * self.opt.w_con +\
                                self.err_g_enc *self.opt.w_enc
                    self.err_g.backward()
                    optimizer_g.step()

                    # Update discriminator
                    optimizer_d.zero_grad()
                    self.err_d_real = self.l_bce(self.pred_real,
                                                 self.real_label)
                    self.err_d_fake = self.l_bce(self.pred_fake,
                                                 self.fake_label)
                    self.err_d = (self.err_d_real + self.err_d_fake) * 0.5
                    self.err_d.backward()
                    optimizer_d.step()

                    if self.err_d.item() < 1e-5:
                        weights_init(self.netd)
                        print('Reloading netd')

                self.save_weights()
        self.draw_reward()

    def calculate_intrinsic_loss(self):
        """

        :param input_imgs: Current seen batch images
        :return: The calculated intrinsic rewards
        """
        with torch.no_grad():
            print(">>Geting current intrinsic reward.")
            netG_weights_path = os.path.join(self.opt.resume, 'netG.pth')
            if os.path.exists(netG_weights_path):
                pretrained_dict = torch.load(netG_weights_path)['state_dict']
                self.netg.load_state_dict(pretrained_dict)
            else:
                weights_init(self.netg)

            # Creat big error tensor for the current seen batch images.
            self.fake_imgs, self.latent_i, self.latent_o = self.netg(
                self.input_imgs)
            if self.opt.use_con_reward:
                con_reward = self.l_con(self.fake_imgs, self.input_imgs)
            else:
                con_reward = 0
            enc_reward = self.l_enc(self.latent_i, self.latent_o)
            total_reward = enc_reward + con_reward
            return total_reward.to('cpu').numpy().item()

    def l2_loss(self, input, target, size_average=True):
        if size_average:
            return torch.mean(torch.pow((input - target), 2))
        else:
            return torch.pow((input - target), 2)

    def set_input(self, input_imgs):
        # Set input and ground truth
        with torch.no_grad():
            self.input_imgs.resize_(input_imgs.size()).copy_(input_imgs)
            #self.gt.resize(gt.size()).copy_(gt)
            #self.label.resize_(gt.size())

    def save_weights(self):
        weight_dir = os.path.join(self.opt.resume, 'weights')
        if not os.path.exists(weight_dir): os.makedirs(weight_dir)

        torch.save({'state_dict': self.netg.state_dict()},
                   os.path.join(weight_dir, 'netG.pth'))
        torch.save({'state_dict': self.netd.state_dict()},
                   os.path.join(weight_dir, 'netD.pth'))

    def save_images(self, real, fake, step):
        N, C, W, H = real.shape
        stitch_images = np.zeros((C, W * N, 3 * H))
        image_dir = os.path.join(self.opt.resume, 'images')
        if not os.path.exists(image_dir): os.makedirs(image_dir)
        for i in range(N):
            real_img = (real[i, :, :, :] * 255).to('cpu').numpy().astype(
                np.int)
            fake_img = (fake[i, :, :, :] * 255).to('cpu').numpy().astype(
                np.int)
            mask_img = np.abs(real_img - fake_img).astype(np.uint8)
            print(np.min(real_img))
            print(np.max(real_img))
            stitch_images[:, W * i:W * i + W, :H] = real_img.astype(np.uint8)
            stitch_images[:, W * i:W * i + W,
                          H:2 * H] = fake_img.astype(np.uint8)
            print(np.min(fake_img))
            print(np.max(fake_img))
            stitch_images[:, W * i:W * i + W, 2 * H:] = mask_img
        stitch_images = stitch_images.squeeze(0)
        #stitch_images = stitch_images.numpy()
        plt.imsave(os.path.join(image_dir, '%s.png' % (step + 1)),
                   stitch_images,
                   cmap='gray')

    def draw_reward(self):
        plt.figure(figsize=(10, 5))
        plt.plot(range(1, self.total_steps + 1), self.step_reward)
        plt.xlabel("steps")
        plt.ylabel("reward")
        plt.savefig(os.path.join(self.opt.resume, 'images', 'rewards.png'))
        plt.show()