Beispiel #1
0
    def update_net(self, data_src, data_obj, data_noobj, epoch, iteration,
                   outf, method):
        self.net_discriminator.zero_grad()
        self.net_encoder_sourse.zero_grad()
        # self.net_encoder_target.zero_grad()
        self.net_decoder.zero_grad()

        src_img = data_src.to(self.device)
        obj_img = data_obj.to(self.device)
        noobj_img = data_noobj.to(self.device)

        batch_size = src_img.size(0)
        if self.label_src is None:
            self.label_obj = torch.full((batch_size, ),
                                        self.obj_label,
                                        dtype=torch.long,
                                        device=self.device)
            self.label_src = torch.full((batch_size, ),
                                        self.src_label,
                                        dtype=torch.long,
                                        device=self.device)
            self.label_noobj = torch.full((batch_size, ),
                                          self.noobj_label,
                                          dtype=torch.long,
                                          device=self.device)

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        errD_src, errD_obj, errD_noobj, errD = self.get_discriminator_loss(
            src_img, obj_img, noobj_img)
        self.update_discriminator(errD_src, errD_obj, errD_noobj)

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        errG_obj = self.get_loss_on_label(obj_img, self.label_src,
                                          self.net_encoder_sourse,
                                          self.net_discriminator)
        errG_noobj_1 = self.get_loss_on_label(noobj_img, self.label_obj,
                                              self.net_encoder_sourse,
                                              self.net_discriminator)
        errG_noobj_2 = self.get_loss_on_label(obj_img, self.label_noobj,
                                              self.net_encoder_sourse,
                                              self.net_discriminator)

        ############################
        # (3) Update Src AutoED network:
        ###########################
        src_autoed_loss = self.get_autoed_loss(src_img)
        self.update_generator(errG_obj, errG_noobj_1, errG_noobj_2,
                              src_autoed_loss)

        print('[%d][%d] Loss: %.4f %.4f %.4f | %.4f %.4f %.4f %.4f' %
              (epoch, iteration, errD_src.item(), errD_obj.item(),
               errD_noobj.item(), errG_obj.item(), errG_noobj_1.item(),
               errG_noobj_2.item(), src_autoed_loss.item()))

        if iteration == 0:
            src_mu, src_conv = self.net_encoder_sourse.get_mu_conv(src_img)
            src_reconst = self.net_decoder(src_mu)

            obj_mu, obj_conv = self.net_encoder_sourse.get_mu_conv(obj_img)
            obj_reconst = self.net_decoder(obj_mu)

            noobj_mu, noobj_conv = self.net_encoder_sourse.get_mu_conv(
                noobj_img)
            noobj_reconst = self.net_decoder(noobj_mu)

            size = src_conv[0].shape[-1]
            tool.save_result_imgs(
                outf,
                method,
                src_img=src_img,
                src_feature=None,
                src_reconst=src_reconst,
                src_conv=src_conv[0].view(self.last_conv_channel, 1, size,
                                          size),
                obj_img=obj_img,
                obj_feature=None,
                obj_reconst=obj_reconst,
                obj_conv=obj_conv[0].view(self.last_conv_channel, 1, size,
                                          size),
                noobj_img=noobj_img,
                noobj_feature=None,
                noobj_reconst=noobj_reconst,
                noobj_conv=noobj_conv[0].view(self.last_conv_channel, 1, size,
                                              size))
        tool.save_models(outf,
                         method,
                         src_encoder=self.net_encoder_sourse,
                         obj_encoder=None,
                         decoder=self.net_decoder,
                         discriminator=self.net_discriminator)
    def update_net(self, data_src, data_obj, epoch, iteration, outf, method):
        self.count += 1
        self.net_discriminator.zero_grad()
        self.net_encoder_sourse.zero_grad()
        # self.net_encoder_target.zero_grad()
        self.net_decoder.zero_grad()

        src_img = data_src['image'].to(self.device)
        obj_img = data_obj['image'].to(self.device)

        batch_size = src_img.size(0)
        if self.label_obj is None:
            self.label_obj = torch.full((batch_size, ),
                                        self.obj_label,
                                        device=self.device)
            self.label_src = torch.full((batch_size, ),
                                        self.src_label,
                                        device=self.device)

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        errD = self.get_discriminator_loss(src_img, obj_img,
                                           self.net_discriminator)
        errD.backward()

        if errD > 1:
            self.optimizer_d.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        errG, _ = self.get_loss_on_label(obj_img, self.label_src,
                                         self.net_encoder_sourse,
                                         self.net_discriminator)

        # errG.backward()
        # self.optimizer_et.step()

        ############################
        # (3) Update Src AutoED network:
        ###########################
        src_autoed_loss = self.get_autoed_loss(self.net_encoder_sourse,
                                               src_img)
        autoed_loss = errG * 0.001 + src_autoed_loss * 1
        autoed_loss.backward()
        self.optimizer_autoed.step()

        if iteration == 0:
            src_mu, src_conv = self.net_encoder_sourse.get_mu_conv(src_img)
            src_reconst = self.net_decoder(src_mu)

            obj_mu, obj_conv = self.net_encoder_sourse.get_mu_conv(obj_img)
            obj_reconst = self.decode_obj(obj_mu)  #self.net_decoder(obj_mu)

            size = src_conv[0].shape[-1]
            tool.save_result_imgs(
                outf,
                method,
                src_img=src_img,
                src_feature=None,
                src_reconst=src_reconst,
                src_conv=src_conv[0].view(self.last_conv_channel, 1, size,
                                          size),
                obj_img=obj_img,
                obj_feature=None,
                obj_reconst=obj_reconst,
                obj_conv=obj_conv[0].view(self.last_conv_channel, 1, size,
                                          size))

            tool.save_models(outf,
                             method,
                             src_encoder=self.net_encoder_sourse,
                             obj_encoder=None,
                             decoder=self.net_decoder,
                             discriminator=self.net_discriminator)

        return errD, errG, src_autoed_loss
    def update_net(self, data_src, data_obj, data_noobj, epoch, iteration,
                   outf, method):
        self.count += 1
        self.net_discriminator.zero_grad()
        self.net_discriminator_noobj.zero_grad()
        self.net_encoder_sourse.zero_grad()
        self.net_encoder_target.zero_grad()
        # self.net_sourse_mask.zero_grad()
        self.net_decoder.zero_grad()

        src_img = data_src['image'].to(self.device)
        obj_img = data_obj['image'].to(self.device)
        noobj_img = data_noobj['image'].to(self.device)
        # obj_only_img = data_obj_only.to(self.device)

        # src_img = torch.min(src_img, noobj_img)
        src_label = data_src['label'].to(self.device)

        batch_size = src_img.size(0)
        if self.label_src is None:
            l = np.array([self.src_label, self.obj_label, self.noobj_label])
            l_obj = np.random.choice(l, batch_size, p=[0.2, 0.8, 0])
            l_src = np.random.choice(l, batch_size, p=[0.8, 0.2, 0])
            l_noobj = np.random.choice(l, batch_size, p=[0.0, 0.2, 0.8])
            self.label_obj = torch.full(
                (batch_size, ), self.obj_label, device=self.device
            )  #torch.from_numpy(l_obj).float().to(self.device) #torch.full((batch_size,), self.obj_label, device=self.device)
            self.label_src = torch.full(
                (batch_size, ), self.src_label, device=self.device
            )  #torch.from_numpy(l_src).float().to(self.device) #torch.full((batch_size,), self.src_label, device=self.device)
            self.label_noobj = torch.full(
                (batch_size, ), self.noobj_label, device=self.device
            )  #torch.from_numpy(l_noobj).float().to(self.device) #torch.full((batch_size,), self.noobj_label, device=self.device)

        # self.net_encoder_sourse.eval()
        _, src_conv = self.net_encoder_sourse(src_img)
        # self.net_encoder_sourse.train()
        _, obj_conv = self.net_encoder_sourse(obj_img)
        _, noobj_conv = self.net_encoder_sourse(noobj_img)

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        errG_obj, dis_obj_src = self.get_loss_on_label(obj_conv,
                                                       self.label_src,
                                                       self.net_encoder_target,
                                                       self.net_discriminator)
        errG_src, _ = self.get_loss_on_label(src_conv, self.label_obj,
                                             self.net_encoder_sourse,
                                             self.net_discriminator)

        # obj_f, _ = self.net_encoder_sourse(obj_img)
        # noobj_f, _ = self.net_encoder_sourse(noobj_img)
        errG_obj_2, _ = self.get_loss_on_label(obj_conv, self.label_obj,
                                               self.net_encoder_target,
                                               self.net_discriminator_noobj)
        errG_noobj_2, _ = self.get_loss_on_label(noobj_conv, self.label_noobj,
                                                 self.net_encoder_target,
                                                 self.net_discriminator_noobj)

        ############################
        # (3) Update Src AutoED network:
        ###########################
        src_autoed_loss = self.get_autoed_loss(self.net_encoder_sourse,
                                               src_conv, src_label, noobj_conv)
        self.update_generator(0, 0, errG_obj_2, errG_noobj_2, src_autoed_loss)

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        errD_src, errD_obj, errD2_obj, errD2_noobj = self.get_discriminator_loss(
            src_conv, obj_conv, noobj_conv)
        self.update_discriminator(errD_src, errD_obj, errD2_obj, errD2_noobj)

        print('[%d][%d] Loss: %.4f %.4f %.4f %.4f | %.4f %.4f %.4f %.4f %.4f' %
              (epoch, iteration, errD_src.item(), errD_obj.item(),
               errD2_obj.item(), errD2_noobj.item(), errG_obj.item(),
               errG_src.item(), errG_obj_2.item(), errG_noobj_2.item(),
               src_autoed_loss.item()))

        # if self.count % 50 == 1:
        if iteration == 0:
            self.set_eval_mode()
            _, src_conv = self.net_encoder_sourse(src_img)
            # src_mu, src_mu_conv = self.net_encoder_sourse.get_mu_conv(src_img)
            # print(src_mu_conv.shape)

            _, noobj_conv = self.net_encoder_target(noobj_img)
            # randomized_z = torch.zeros_like(src_z)

            rand_w = 0  #0.3 + np.random.rand()*0.3
            # randomized_z = src_conv * (1-rand_w) + noobj_conv.detach()*rand_w

            # for i in range(len(src_z)):
            #     rand_w = 0.1 + np.random.rand()*0.9
            #     randomized_z[i] = src_z[i] * (1-rand_w) + encoded_noobj[i].detach()*rand_w

            key, src_feature = self.net_encoder_sourse.get_softmax_feature(
                src_img, src_conv)
            # src_reconst = self.net_decoder(key)

            # obj_mu, obj_conv = self.net_encoder_sourse.get_mu_conv(obj_img)
            # # obj_reconst = self.net_decoder(obj_mu)
            # obj_reconst = self.decode_obj(obj_conv)
            obj_z, obj_conv = self.net_encoder_target(obj_img)
            _, obj_feature = self.net_encoder_sourse.get_softmax_feature(
                obj_img, obj_conv)

            # obj_mu, obj_mu_conv = self.net_encoder_target.get_mu_conv(obj_img)
            # obj_reconst = self.net_decoder(obj_z)

            # noobj_mu, noobj_conv = self.net_encoder_sourse.get_mu_conv(noobj_img)
            # # noobj_reconst = self.net_decoder(noobj_mu)
            # noobj_reconst = self.decode_obj(noobj_conv)
            noobj_z, noobj_conv = self.net_encoder_target(noobj_img)
            _, noobj_feature = self.net_encoder_sourse.get_softmax_feature(
                noobj_img, noobj_conv)

            # noobj_mu, noobj_mu_conv = self.net_encoder_target.get_mu_conv(noobj_img)
            # noobj_reconst = self.net_decoder(noobj_z)

            self.set_train_mode()

            size = src_conv[0].shape
            # src_conv = (src_conv + obj_conv.max() - noobj_conv)
            # print(src_feature.shape)
            tool.save_result_imgs(
                outf,
                method,
                src_img=src_img,
                src_feature=src_conv[0].view(-1, 1, 23, 46),
                src_reconst=None,
                src_conv=src_feature.view(-1, 3, src_feature.shape[-2],
                                          src_feature.shape[-1]),
                obj_img=obj_img,
                obj_feature=obj_conv[0].view(-1, 1, 23, 46),
                obj_reconst=None,
                obj_conv=obj_feature.view(-1, 3, obj_feature.shape[-2],
                                          obj_feature.shape[-1]),
                noobj_img=noobj_img,
                noobj_feature=noobj_conv[0].view(-1, 1, 23, 46),
                noobj_reconst=None,
                noobj_conv=noobj_feature.view(-1, 3, noobj_feature.shape[-2],
                                              noobj_feature.shape[-1]))
            tool.save_models(outf,
                             method,
                             src_encoder=self.net_encoder_sourse,
                             obj_encoder=self.net_encoder_target,
                             decoder=self.net_decoder,
                             discriminator=self.net_discriminator,
                             discriminator_noobj=self.net_discriminator_noobj)
            # torch.save(self.net_sourse_mask.state_dict(), 'results/%s/%s/net_mask.pth' % (outf, method))

        # return errD_src, errD_obj, errD_obj, errG_obj, errG_noobj, src_autoed_loss
        return src_autoed_loss.detach().item()
Beispiel #4
0
    def update_net(self, data_src, data_obj, data_noobj, epoch, iteration,
                   outf, method):
        self.count += 1
        self.net_discriminator.zero_grad()
        self.net_discriminator_noobj.zero_grad()
        self.net_encoder_sourse.zero_grad()
        self.net_encoder_target.zero_grad()
        # self.net_sourse_mask.zero_grad()
        self.net_decoder.zero_grad()

        src_img = data_src['image'].to(self.device)
        obj_img = data_obj['image'].to(self.device)
        noobj_img = data_noobj['image'].to(self.device)
        # obj_only_img = data_obj_only.to(self.device)

        src_label = data_src['label'].to(self.device)

        batch_size = src_img.size(0)
        if self.label_src is None:
            # self.label_obj = torch.full((batch_size,), self.obj_label, device=self.device)
            self.label_src = torch.full((batch_size, ),
                                        self.src_label,
                                        device=self.device)
            self.label_noobj = torch.full((batch_size, ),
                                          self.noobj_label,
                                          device=self.device)

        # self.net_encoder_sourse.eval()
        src_z, src_conv = self.net_encoder_sourse(src_img)
        # self.net_encoder_sourse.train()
        obj_z, obj_conv = self.net_encoder_sourse(obj_img)
        noobj_z, noobj_conv = self.net_encoder_sourse(noobj_img)

        _, dis_obj_src = self.get_loss_on_label(obj_conv, self.label_src,
                                                self.net_discriminator_noobj)
        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        errG_obj, _ = self.get_loss_on_label(obj_conv, self.label_src,
                                             self.net_discriminator)
        # errG_src, _ = self.get_loss_on_label(src_conv, self.label_obj, self.net_discriminator)
        self.obj_mean.append(torch.mean(dis_obj_src).detach().item())

        ############################
        # (3) Update Src AutoED network:
        ###########################
        src_autoed_loss = self.get_autoed_loss(self.net_encoder_sourse, src_z,
                                               src_label, noobj_z)
        self.update_generator(errG_obj, src_autoed_loss)

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        errD_src, errD2_obj, errD2_noobj = self.get_discriminator_loss(
            src_conv, obj_conv, noobj_conv, dis_obj_src)

        # if np.mean(self.obj_mean) - self.pre_updated_mean > 0.02:
        #     print('update')
        self.update_discriminator(errD_src, errD2_obj, errD2_noobj)
        # self.pre_updated_mean = np.mean(self.obj_mean)

        print('[%d][%d] Loss: %.4f %.4f %.4f | %.4f %.4f' %
              (epoch, iteration, errD_src.item(), errD2_obj.item(),
               errD2_noobj.item(), errG_obj.item(), src_autoed_loss.item()))
        print(self.pre_updated_mean, np.mean(self.obj_mean))

        # if self.count % 50 == 1:
        if iteration == 0:
            self.set_eval_mode()
            # src_mu, src_conv = self.net_encoder_sourse.get_mu_conv(src_img)
            # src_reconst = self.net_decoder(src_mu)
            # src_reconst = self.decode_obj(src_conv)
            src_z, src_conv = self.net_encoder_sourse(src_img)
            src_mu, src_mu_conv = self.net_encoder_sourse.get_mu_conv(src_img)

            encoded_noobj, conv = self.net_encoder_target(noobj_img)
            randomized_z = torch.zeros_like(src_z)

            rand_w = 0.5 + np.random.rand() * 0.3
            randomized_z = src_z * (1 -
                                    rand_w) + encoded_noobj.detach() * rand_w

            # for i in range(len(src_z)):
            #     rand_w = 0.1 + np.random.rand()*0.9
            #     randomized_z[i] = src_z[i] * (1-rand_w) + encoded_noobj[i].detach()*rand_w

            src_reconst = self.net_decoder(randomized_z)

            # obj_mu, obj_conv = self.net_encoder_sourse.get_mu_conv(obj_img)
            # # obj_reconst = self.net_decoder(obj_mu)
            # obj_reconst = self.decode_obj(obj_conv)
            obj_z, obj_conv = self.net_encoder_target(obj_img)
            obj_mu, obj_mu_conv = self.net_encoder_target.get_mu_conv(obj_img)
            obj_reconst = self.net_decoder(obj_mu)

            # noobj_mu, noobj_conv = self.net_encoder_sourse.get_mu_conv(noobj_img)
            # # noobj_reconst = self.net_decoder(noobj_mu)
            # noobj_reconst = self.decode_obj(noobj_conv)
            noobj_z, noobj_conv = self.net_encoder_target(noobj_img)
            noobj_mu, noobj_mu_conv = self.net_encoder_target.get_mu_conv(
                noobj_img)
            noobj_reconst = self.net_decoder(noobj_z)

            self.set_train_mode()

            size = src_conv[0].shape
            # src_conv = (src_conv + obj_conv.max() - noobj_conv)
            tool.save_result_imgs(
                outf,
                method,
                src_img=src_img,
                src_feature=(randomized_z).view(-1, 1, size[1], size[2]),
                src_reconst=src_reconst,
                src_conv=src_z.view(-1, 1, size[1], size[2]),
                obj_img=obj_img,
                obj_feature=(obj_mu).view(-1, 1, size[1], size[2]),
                obj_reconst=obj_reconst,
                obj_conv=obj_z.view(-1, 1, size[1], size[2]),
                noobj_img=noobj_img,
                noobj_feature=noobj_mu_conv,
                noobj_reconst=noobj_reconst,
                noobj_conv=noobj_conv)
            tool.save_models(outf,
                             method,
                             src_encoder=self.net_encoder_sourse,
                             obj_encoder=self.net_encoder_target,
                             decoder=self.net_decoder,
                             discriminator=self.net_discriminator,
                             discriminator_noobj=None)
            # torch.save(self.net_sourse_mask.state_dict(), 'results/%s/%s/net_mask.pth' % (outf, method))

        # return errD_src, errD_obj, errD_obj, errG_obj, errG_noobj, src_autoed_loss
        return src_autoed_loss.detach().item()