def __init__(self, config):
        for k in config.__dict__:
            setattr(self, k, config.__dict__[k])

        self.curProjectPath = os.path.join(config.testRoot,
                                           config.test_version)
        self.reporter = Reporter(
            os.path.join(self.curProjectPath, 'report.log'))
        self.reporter.writeConfig(config)

        self.device = torch.device('cuda:%d' % config.cuda)

        self.testimagepath = os.path.join(self.curProjectPath, 'images')
        if not os.path.exists(self.testimagepath):
            os.makedirs(self.testimagepath)
        # load paramters
        self.n_classes = len(getattr(self, "selected_attrs"))
        self.version = config.test_version

        # self.thres_int      = config.ThresInt
        if config.EnableThresIntSetting:
            # self.thres_int      = config.TestThresInt
            self.thres_int = 1.5
        else:
            self.thres_int = getattr(self, "ThresInt")
        self.reporter.writeInfo("thres_int:" + str(self.thres_int))
        self.batch_size = getattr(self, "test_batch_size")
        self.build_model()
        self.printTable = pt.PrettyTable()
        self.printTable.field_names = getattr(
            self, "selected_simple_attrs") + ["Mean"]

        self.testTable = pt.PrettyTable()
        self.testTable.field_names = [str(i) for i in range(self.batch_size)]
        self.ModelFigureName = config.ModelFigureName
Example #2
0
    def __init__(self, data_loader, config):

        self.report_file = os.path.join(config.log_path, config.version,
                                        config.version + "_report.log")
        self.reporter = Reporter(self.report_file)

        # Data loader
        self.data_loader = data_loader

        # exact model and loss
        self.cGAN = config.cGAN
        self.adv_loss = config.adv_loss

        # Model hyper-parameters
        self.imsize = config.imsize
        self.z_dim = config.z_dim
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.n_classes = config.n_class if config.cGAN else 0
        self.parallel = config.parallel
        self.seed = config.seed
        self.device = torch.device('cuda:%d' % config.cuda)
        self.GPUs = config.GPUs

        self.gen_distribution = config.gen_distribution
        self.gen_bottom_width = config.gen_bottom_width

        self.total_step = config.total_step
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.lr_decay = config.lr_decay
        self.beta1 = config.beta1
        self.beta2 = config.beta2

        self.use_pretrained_model = config.use_pretrained_model
        self.chechpoint_step = config.chechpoint_step
        self.use_pretrained_model = config.use_pretrained_model

        self.dataset = config.dataset
        self.image_path = config.image_path
        self.log_path = config.log_path
        self.model_save_path = config.model_save_path
        self.sample_path = config.sample_path
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.version = config.version
        self.caculate_FID = config.caculate_FID
        self.DStep = config.D_step
        self.GStep = config.G_step

        self.metric_caculation_step = config.metric_caculation_step

        # Path
        self.log_path = os.path.join(config.log_path, self.version)
        self.summary_path = self.log_path
        self.sample_path = os.path.join(config.sample_path, self.version)
        self.model_save_path = os.path.join(config.model_save_path,
                                            self.version)

        self.build_model()
        self.reporter.writeConfig(config)
        self.reporter.writeModel(self.G.__str__())
        self.reporter.writeModel(self.D.__str__())
        if self.caculate_FID:
            z_sampler, c_sampler = Sampler.prepare_z_c(self.batch_size,
                                                       self.z_dim,
                                                       self.n_classes,
                                                       device=self.device)
            gsampler = functools.partial(Sampler.sampleG,
                                         G=self.G,
                                         z_=z_sampler,
                                         c_=c_sampler,
                                         parallel=self.parallel)
            self.get_inception_metrics = FIDCaculator.prepare_inception_metrics(
                config.FID_mean_cov, gsampler, config.metric_images_num)

        self.writer = SummaryWriter(log_dir=self.summary_path)
        z = torch.zeros(1, self.z_dim).to(self.device)
        c = torch.zeros(1).long().to(self.device)
        y = torch.zeros(1, 3, self.imsize, self.imsize).to(self.device)
        vise_graph = make_dot(self.G(z, c),
                              params=dict(self.G.named_parameters()))
        vise_graph.view(self.log_path + "/Generator")
        vise_graph = make_dot(self.D(y, c),
                              params=dict(self.D.named_parameters()))
        vise_graph.view(self.log_path + "/Discriminator",
                        quiet=False,
                        quiet_view=False)
        del z
        del c
        del y

        # Start with trained model
        if self.use_pretrained_model:
            self.load_pretrained_model()
Example #3
0
class Trainer(object):
    def __init__(self, data_loader, config):

        self.report_file = os.path.join(config.log_path, config.version,
                                        config.version + "_report.log")
        self.reporter = Reporter(self.report_file)

        # Data loader
        self.data_loader = data_loader

        # exact model and loss
        self.cGAN = config.cGAN
        self.adv_loss = config.adv_loss

        # Model hyper-parameters
        self.imsize = config.imsize
        self.z_dim = config.z_dim
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.n_classes = config.n_class if config.cGAN else 0
        self.parallel = config.parallel
        self.seed = config.seed
        self.device = torch.device('cuda:%d' % config.cuda)
        self.GPUs = config.GPUs

        self.gen_distribution = config.gen_distribution
        self.gen_bottom_width = config.gen_bottom_width

        self.total_step = config.total_step
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.lr_decay = config.lr_decay
        self.beta1 = config.beta1
        self.beta2 = config.beta2

        self.use_pretrained_model = config.use_pretrained_model
        self.chechpoint_step = config.chechpoint_step
        self.use_pretrained_model = config.use_pretrained_model

        self.dataset = config.dataset
        self.image_path = config.image_path
        self.log_path = config.log_path
        self.model_save_path = config.model_save_path
        self.sample_path = config.sample_path
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.version = config.version
        self.caculate_FID = config.caculate_FID
        self.DStep = config.D_step
        self.GStep = config.G_step

        self.metric_caculation_step = config.metric_caculation_step

        # Path
        self.log_path = os.path.join(config.log_path, self.version)
        self.summary_path = self.log_path
        self.sample_path = os.path.join(config.sample_path, self.version)
        self.model_save_path = os.path.join(config.model_save_path,
                                            self.version)

        self.build_model()
        self.reporter.writeConfig(config)
        self.reporter.writeModel(self.G.__str__())
        self.reporter.writeModel(self.D.__str__())
        if self.caculate_FID:
            z_sampler, c_sampler = Sampler.prepare_z_c(self.batch_size,
                                                       self.z_dim,
                                                       self.n_classes,
                                                       device=self.device)
            gsampler = functools.partial(Sampler.sampleG,
                                         G=self.G,
                                         z_=z_sampler,
                                         c_=c_sampler,
                                         parallel=self.parallel)
            self.get_inception_metrics = FIDCaculator.prepare_inception_metrics(
                config.FID_mean_cov, gsampler, config.metric_images_num)

        self.writer = SummaryWriter(log_dir=self.summary_path)
        z = torch.zeros(1, self.z_dim).to(self.device)
        c = torch.zeros(1).long().to(self.device)
        y = torch.zeros(1, 3, self.imsize, self.imsize).to(self.device)
        vise_graph = make_dot(self.G(z, c),
                              params=dict(self.G.named_parameters()))
        vise_graph.view(self.log_path + "/Generator")
        vise_graph = make_dot(self.D(y, c),
                              params=dict(self.D.named_parameters()))
        vise_graph.view(self.log_path + "/Discriminator",
                        quiet=False,
                        quiet_view=False)
        del z
        del c
        del y

        # Start with trained model
        if self.use_pretrained_model:
            self.load_pretrained_model()

    def build_model(self):
        self.G = ResNetGenerator(self.g_conv_dim,
                                 self.z_dim,
                                 self.gen_bottom_width,
                                 num_classes=self.n_classes).to(self.device)
        self.D = SNResNetProjectionDiscriminator(
            self.d_conv_dim, self.n_classes).to(self.device)
        if self.parallel:
            self.G = nn.DataParallel(self.G, device_ids=self.GPUs)
            self.D = nn.DataParallel(self.D, device_ids=self.GPUs)
        # Loss and optimizer
        self.g_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr,
            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr,
            [self.beta1, self.beta2])

    def train(self):

        # Data iterator
        data_iter = iter(self.data_loader)
        model_save_step = self.model_save_step

        # Fixed input for debugging
        sampleBatch = 10
        fixed_z = torch.randn(self.n_classes * sampleBatch, self.z_dim)
        fixed_z = fixed_z.to(self.device)
        fixed_c = Sampler.sampleFixedLabels(self.n_classes, sampleBatch,
                                            self.device)

        runingZ, runingLabel = Sampler.prepare_z_c(self.batch_size,
                                                   self.z_dim,
                                                   self.n_classes,
                                                   device=self.device)

        # Start with trained model
        if self.use_pretrained_model:
            start = self.chechpoint_step + 1
        else:
            start = 0
        # Start time
        start_time = time.time()
        self.reporter.writeInfo("Start to train the model")
        dstepCounter = 0
        gstepCounter = 0
        for step in range(start, self.total_step):
            # ================== Train D ================== #
            self.D.train()
            self.G.train()
            if dstepCounter < self.DStep:
                try:
                    realImages, realLabel = next(data_iter)
                except:
                    data_iter = iter(self.data_loader)
                    realImages, realLabel = next(data_iter)

                # Compute loss with real images
                realImages = realImages.to(self.device)
                realLabel = realLabel.to(self.device).long()
                d_out_real = self.D(realImages, realLabel)
                d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()

                # apply Gumbel Softmax
                runingZ.sample_()
                runingLabel.sample_()
                fake_images = self.G(runingZ, runingLabel)
                d_out_fake = self.D(fake_images, runingLabel)
                d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()
                # Backward + Optimize
                d_loss = d_loss_real + d_loss_fake
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()
                dstepCounter += 1
            else:
                # ================== Train G and gumbel ================== #
                # Create random noise
                runingZ.sample_()
                runingLabel.sample_()
                fake_images = self.G(runingZ, runingLabel)

                # Compute loss with fake images
                g_out_fake = self.D(fake_images, runingLabel)
                g_loss_fake = -g_out_fake.mean()

                self.reset_grad()
                g_loss_fake.backward()
                self.g_optimizer.step()
                gstepCounter += 1

            if gstepCounter == self.GStep:
                dstepCounter = 0
                gstepCounter = 0

            # Print out log info
            if (step + 1) % self.log_step == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))
                print(
                    "Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, "
                    " d_loss_fake: {:.4f}, g_loss_fake: {:.4f}".format(
                        elapsed, step + 1, self.total_step,
                        (step + 1), self.total_step, d_loss_real.item(),
                        d_loss_fake.item(), g_loss_fake.item()))

                self.writer.add_scalar('log/d_loss_real', d_loss_real.item(),
                                       (step + 1))
                self.writer.add_scalar('log/d_loss_fake', d_loss_fake.item(),
                                       (step + 1))
                self.writer.add_scalar('log/d_loss', d_loss.item(), (step + 1))
                self.writer.add_scalar('log/g_loss_fake', g_loss_fake.item(),
                                       (step + 1))

            if (step + 1) % self.sample_step == 0:
                fake_images = self.G(fixed_z, fixed_c)
                save_image(denorm(fake_images.data),
                           os.path.join(self.sample_path,
                                        '{}_fake.png'.format(step + 1)),
                           nrow=self.n_classes)

            if (step + 1) % model_save_step == 0:

                torch.save(
                    self.G.state_dict(),
                    os.path.join(self.model_save_path,
                                 '{}_G.pth'.format(step + 1)))

                torch.save(
                    self.D.state_dict(),
                    os.path.join(self.model_save_path,
                                 '{}_D.pth'.format(step + 1)))

            if (step + 1
                ) % self.metric_caculation_step == 0 and self.caculate_FID:
                print("start to caculate the FID")
                FID = self.get_inception_metrics()
                print("FID is %.3f" % FID)
                self.writer.add_scalar('metric/FID', FID, (step + 1))
                self.reporter.writeTrainLog(step + 1,
                                            "Current FID is %.4f" % FID)

    def load_pretrained_model(self):
        self.G.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_G.pth'.format(self.chechpoint_step))))
        self.D.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_D.pth'.format(self.chechpoint_step))))
        print('loaded trained models (step: {})..!'.format(
            self.chechpoint_step))

    def reset_grad(self):
        self.d_optimizer.zero_grad()
        self.g_optimizer.zero_grad()
class Tester(object):
    def __init__(self, config):
        for k in config.__dict__:
            setattr(self, k, config.__dict__[k])

        self.curProjectPath = os.path.join(config.testRoot,
                                           config.test_version)
        self.reporter = Reporter(
            os.path.join(self.curProjectPath, 'report.log'))
        self.reporter.writeConfig(config)

        self.device = torch.device('cuda:%d' % config.cuda)

        self.testimagepath = os.path.join(self.curProjectPath, 'images')
        if not os.path.exists(self.testimagepath):
            os.makedirs(self.testimagepath)
        # load paramters
        self.n_classes = len(getattr(self, "selected_attrs"))
        self.version = config.test_version

        # self.thres_int      = config.ThresInt
        if config.EnableThresIntSetting:
            # self.thres_int      = config.TestThresInt
            self.thres_int = 1.5
        else:
            self.thres_int = getattr(self, "ThresInt")
        self.reporter.writeInfo("thres_int:" + str(self.thres_int))
        self.batch_size = getattr(self, "test_batch_size")
        self.build_model()
        self.printTable = pt.PrettyTable()
        self.printTable.field_names = getattr(
            self, "selected_simple_attrs") + ["Mean"]

        self.testTable = pt.PrettyTable()
        self.testTable.field_names = [str(i) for i in range(self.batch_size)]
        self.ModelFigureName = config.ModelFigureName

    def build_model(self):
        device = self.device
        package = __import__("%s" % getattr(self, 'GeneratorScriptName'),
                             fromlist=True)
        genClass = getattr(package, 'Generator')
        self.GlobalG = genClass(getattr(self, 'g_conv_dim'),
                                getattr(self, 'gLayerNum'),
                                getattr(self, 'resNum'), self.n_classes,
                                getattr(self, 'skipNum'),
                                getattr(self, 'skipRatio'),
                                getattr(self, 'GEncActName'),
                                getattr(self, 'GSkipActName'),
                                getattr(self, 'GDecActName'),
                                getattr(self, 'GOutActName'),
                                getattr(self, "LSTUScriptName")).to(device)
        # self.LSTUScriptName).to(self.device)
        model_save_path = getattr(self, "model_save_path")
        chechpoint_step = getattr(self, "test_chechpoint_step")
        ACCModel = getattr(self, "ACCModel")
        self.GlobalG.load_state_dict(
            torch.load(os.path.join(
                model_save_path, 'Epoch{}_LocalG.pth'.format(chechpoint_step)),
                       map_location=device))
        print('loaded trained models (step: {})..!'.format(chechpoint_step))
        self.Classifier = FaceAttributesClassifier(
            attr_dim=self.n_classes).to(device)
        self.Classifier.load_state_dict(
            torch.load(ACCModel, map_location=device))
        print("load classifer model successful!")

    def getNonConflictingLablels(self, orignalLabel, class_n):
        fixeDelta = torch.zeros_like(orignalLabel)
        labelSize = orignalLabel.size()[0]
        outlabel = torch.zeros_like(orignalLabel)
        for index in range(labelSize):
            if orignalLabel[index, class_n] == 0:
                fixeDelta[index, class_n] = 1
                outlabel[index, class_n] = 1
                # if index == 0 and fixedLabel[index1,0,1]==1: # from non-blad to blad, so we need to remove bangs in the same time
                #     fixeDelta[index,index1,:,1] = -1
                if class_n == 0:
                    if orignalLabel[index, 1] == 1:
                        fixeDelta[index, 1] = -1
                if class_n == 1:
                    if orignalLabel[index, 0] == 1:
                        fixeDelta[index, 0] = -1
                if class_n == 2:
                    if orignalLabel[index, 3] == 1:
                        fixeDelta[index, 3] = -1
                    if orignalLabel[index, 4] == 1:
                        fixeDelta[index, 4] = -1
                if class_n == 3:
                    if orignalLabel[index, 2] == 1:
                        fixeDelta[index, 2] = -1
                    if orignalLabel[index, 4] == 1:
                        fixeDelta[index, 4] = -1
                if class_n == 4:
                    if orignalLabel[index, 2] == 1:
                        fixeDelta[index, 2] = -1
                    if orignalLabel[index, 3] == 1:
                        fixeDelta[index, 3] = -1
            else:
                fixeDelta[index, class_n] = -1
                outlabel[index, class_n] = -1
                if class_n == 2:
                    fixeDelta[index,
                              3] = 1  # translate black hair to blond hair
                if class_n == 3:
                    fixeDelta[index,
                              4] = 1  # translate blond hair to Brown Hair
                if class_n == 4:
                    fixeDelta[index, 3] = 1
        return fixeDelta, outlabel

    def test(self):
        # Start time
        start_time = time.time()
        data_loader = getLoader(getattr(self, "image_path"),
                                getattr(self, "attributes_path"),
                                getattr(self, "selected_attrs"),
                                getattr(self, "imCropSize"),
                                getattr(self, "imsize"),
                                self.batch_size,
                                dataset=getattr(self, "dataset"),
                                num_workers=0,
                                toPatch=False,
                                mode='test',
                                microPatchSize=0)

        chechpoint_step = getattr(self, "chechpoint_step")
        device = self.device
        total_images = getattr(self, "total_images")
        save_testimages = getattr(self, "save_testimages")
        imsize = getattr(self, "imsize")

        # Fixed input for debugging
        # BatchText   = BatchTensorClass(30,[10,10])
        data_iter = iter(data_loader)
        # BatchText1   = BatchTensorClass(30,[60,10],color='blue')
        total = total_images
        self.GlobalG.eval()
        self.Classifier.eval()
        acc_cnt_generate = np.zeros([self.n_classes])
        total_psnr = 0.0
        total_ssim = 0.0
        with torch.no_grad():
            for iii in tqdm(range(total // self.batch_size)):
                try:
                    realImages, labelOriginal = next(data_iter)
                except:
                    data_iter = iter(data_loader)
                    realImages, labelOriginal = next(data_iter)
                res = realImages
                realImages = realImages.to(device)
                labelOriginal = labelOriginal.to(device)

                for index in range(self.n_classes + 1):
                    if index < (self.n_classes):
                        templabel, gt_label = self.getNonConflictingLablels(
                            labelOriginal, index)
                        templabel = templabel.to(device) * self.thres_int
                        xFake = self.GlobalG(realImages, templabel)
                        if imsize > 128:
                            xFakelow = F.interpolate(xFake, size=128)
                        else:
                            xFakelow = xFake
                        wocao = self.Classifier(xFakelow)
                        pred_laebl = torch.round(wocao) * 2 - 1
                        # gtpre       = torch.round(self.Classifier(realImages))*2-1
                        # print(self.selected_attrs[index])
                        # self.testTable.add_row(labelOriginal[:,index].cpu().numpy())
                        # self.testTable.add_row(gtpre[:,index].cpu().numpy())
                        # self.testTable.add_row(templabel[:,index].cpu().numpy())
                        # self.testTable.add_row(pred_laebl[:,index].cpu().numpy())
                        # self.testTable.add_row(wocao[:,index].cpu().numpy())
                        # print(self.testTable)
                        # self.testTable.clear_rows()
                        # self.reporter.writeInfo()
                        # gtpre       = torch.round(self.Classifier(realImages))*2-1
                        # print("getpre:")
                        # print(gtpre)
                        # print("prelabel:")
                        # print(labelOriginal)
                        # print("prelabel:")
                        # print(pred_laebl)
                        pre = pred_laebl[:, index]
                        gt = gt_label[:, index].to(device)
                        acc_generate = (pre == gt).cpu().numpy()
                        acc_cnt_generate[index] += np.sum(np.sum(acc_generate,
                                                                 axis=0),
                                                          axis=0)
                        if save_testimages:
                            imgSamples = xFake.cpu()
                            res = torch.cat([res, imgSamples], 0)
                    else:
                        recLabel = torch.zeros_like(labelOriginal)
                        xFake = self.GlobalG(realImages, recLabel)
                        _, meanPSNR = PSNR(xFake.cpu(), realImages.cpu())
                        total_psnr += (meanPSNR *
                                       float(labelOriginal.size()[0]))

                        imgs1_ssim = (realImages + 1) / 2
                        imgs2_ssim = (xFake + 1) / 2
                        per_ssim_value = pytorch_ssim.ssim(
                            imgs1_ssim.cpu(), imgs2_ssim.cpu()).item()
                        total_ssim += per_ssim_value * self.batch_size

                        if save_testimages:
                            imgSamples = xFake.cpu()
                            res = torch.cat([res, imgSamples], 0)

                if save_testimages:
                    print("Save test data")
                    save_image(denorm(res.data),
                               os.path.join(self.testimagepath,
                                            '{}_fake.png'.format(iii + 1)),
                               nrow=self.batch_size)  #,nrow=self.batch_size)
                # return
        tabledata = {"acc": [], 'psnr': 0.0, 'ssim': 0.0, 'step': 0}
        total_ssim = total_ssim / total_images
        tabledata['step'] = chechpoint_step
        tabledata['ssim'] = total_ssim
        total_psnr = total_psnr / float(total)
        acc_gen = acc_cnt_generate / (total)
        acc = ["%.3f" % x for x in acc_gen]
        meanAcc = sum(acc_gen) / 13.0
        acc += ["%.3f" % meanAcc]
        tabledata['acc'] = acc
        self.printTable.add_row(acc)
        print("The %s model logs:" % self.version)
        print(self.printTable)
        self.reporter.writeInfo("\n" + self.printTable.__str__())
        print("The mean PSNR is %.2f" % total_psnr)
        tabledata['psnr'] = total_psnr
        print("The average SSIM is %.4f" % total_ssim)
        self.reporter.writeInfo("The mean PSNR is %.2f" % total_psnr)
        self.reporter.writeInfo("The average SSIM is %.4f" % total_ssim)
        bestResultPath = os.path.join(self.curProjectPath, 'bestResult.json')
        bestResultCsv = os.path.join(self.curProjectPath, 'bestResult.csv')
        ResultCsv = os.path.join(self.curProjectPath,
                                 'step%d_Result.csv' % chechpoint_step)
        attrname = [
            "Bald", "Bangs", "Black", "Blond", "Brown", "Eyebrows", "Glass",
            "Male", "Mouth", "Mustache", "NoBeard", "Pale", "Young", "Average"
        ]
        if os.path.exists(bestResultPath):
            with open(bestResultPath, 'r') as cf:
                score = cf.read()
                score = json.loads(score)
                if score['acc'][13] < acc[13]:
                    with open(bestResultPath, 'w') as cf:
                        scorejson = json.dumps(tabledata)
                        cf.writelines(scorejson)
                    import csv
                    headers = ['Name', 'Arch', 'Score']
                    rows = []
                    for i in range(14):
                        rows.append({
                            "Name": attrname[i],
                            "Arch": self.ModelFigureName,
                            "Score": acc[i],
                        })

                    with open(bestResultCsv, 'w', newline='') as f:
                        f_csv = csv.DictWriter(f, headers)
                        f_csv.writeheader()
                        f_csv.writerows(rows)
        else:
            with open(bestResultPath, 'w') as cf:
                scorejson = json.dumps(tabledata)
                cf.writelines(scorejson)

        import csv
        headers = ['Name', 'Arch', 'Score']
        rows = []
        for i in range(14):
            rows.append({
                "Name": attrname[i],
                "Arch": self.ModelFigureName,
                "Score": acc[i],
            })

        with open(ResultCsv, 'w', newline='') as f:
            f_csv = csv.DictWriter(f, headers)
            f_csv.writeheader()
            f_csv.writerows(rows)

        elapsed = time.time() - start_time
        elapsed = str(datetime.timedelta(seconds=elapsed))
        print("Elapsed [{}]".format(elapsed))
Example #5
0
def main(config):
    ignoreKey = [
        "TestScriptsName", "test_version", "read_node_local", "model_node",
        "test_chechpoint_step", "test_batch_size", "total_images",
        "save_testimages", "EnableThresIntSetting", "ModelFigureName",
        "use_pretrained_model", "chechpoint_step", "train", "testRoot",
        "ACCModel", "cuda", "use_system_cuda_way", "logRootPath"
    ]
    if config.cuda > -1:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(config.cuda)
        config.cuda = 0
    # For fast training
    cudnn.benchmark = True

    if config.mode != "test":

        # load the dataset path
        datapath = "./dataTool/dataPath.json"
        with open(datapath, 'r') as cf:
            datastr = cf.read()
            dataobj = json.loads(datastr)
            config.image_path = dataobj[config.dataset.lower()]
            config.attributes_path = dataobj[config.dataset.lower() + 'att']
            print("Get data path: %s" % config.image_path)

        # create dataloader
        data_loader = getLoader(config.image_path,
                                config.attributes_path,
                                config.selected_attrs,
                                config.imCropSize,
                                config.imsize,
                                config.batch_size,
                                dataset=config.dataset,
                                num_workers=config.num_workers)

        # create training log dirs
        if not os.path.exists(config.logRootPath):
            os.makedirs(config.logRootPath)
        makeFolder(config.logRootPath, config.version)
        currentProjectPath = os.path.join(config.logRootPath, config.version)

        makeFolder(currentProjectPath, "summary")
        config.log_path = os.path.join(currentProjectPath, "summary")

        makeFolder(currentProjectPath, "checkpoint")
        config.model_save_path = os.path.join(currentProjectPath, "checkpoint")

        makeFolder(currentProjectPath, "sample")
        config.sample_path = os.path.join(currentProjectPath, "sample")

        config.reporter = ''
        configjson = json.dumps(config.__dict__)
        with open(
                os.path.join(config.logRootPath, config.version,
                             'config.json'), 'w') as cf:
            cf.writelines(configjson)

        report_file = os.path.join(config.logRootPath, config.version,
                                   config.version + "_report")
        config.reporter = Reporter(report_file)

        moduleName = "TrainingScripts.trainer_" + config.TrainScriptName
        tempstr = "Start to run training script: {}".format(moduleName)
        print(tempstr)
        print("Traning version: %s" % config.version)
        print("Training Script Name: %s" % config.TrainScriptName)
        print("Image Size: %d" % config.imsize)
        print("Image Crop Size: %d" % config.imCropSize)
        print("ThresInt: %d" % config.ThresInt)
        print("D : G = %d : %d" % (config.D_step, config.G_step))

        config.reporter.writeInfo(tempstr)
        package = __import__(moduleName, fromlist=True)
        trainerClass = getattr(package, 'Trainer')
        trainer = trainerClass(data_loader, config)
        trainer.train()
    else:

        # make a dir for saving test results
        makeFolder(config.testRoot, config.test_version)
        moduleName = "TestScripts.tester_" + config.TestScriptsName
        tempstr = "Start to run test script: {}".format(moduleName)
        print(tempstr)

        package = __import__(moduleName, fromlist=True)
        testerClass = getattr(package, 'Tester')
        tester = testerClass(config)
        tester.test()
    def __init__(self, config):
        if config.read_node_local:
            with open('nodesInfo/modelpositions.json', 'r') as cf:
                nodelocaltionstr = cf.read()
                nodelocaltioninf = json.loads(nodelocaltionstr)
                self.model_node = nodelocaltioninf[config.test_version]
                self.model_node = self.model_node['server']
                print("model node %s" % self.model_node)
        else:
            self.model_node = config.model_node
        self.version = config.test_version
        self.reporter = Reporter(
            os.path.join(config.testRoot, self.version, 'report.log'))
        self.reporter.writeInfo("version:" + str(self.version))
        self.logRootPath = config.logRootPath
        self.testimagepath = os.path.join(config.testRoot, self.version,
                                          'images')
        self.total_images = config.total_images
        self.testRoot = config.testRoot
        if not os.path.exists(self.testimagepath):
            os.makedirs(self.testimagepath)

        if self.model_node.lower() != "localhost":
            with open('nodesInfo/nodes.json', 'r') as cf:
                nodestr = cf.read()
                nodeinf = json.loads(nodestr)
            nodeinf = nodeinf[self.model_node.lower()]
            currentProjectPath = os.path.join(self.logRootPath, self.version)
            uploader = fileUploaderClass(nodeinf["ip"], nodeinf["user"],
                                         nodeinf["passwd"])
            remoteFile = nodeinf['basePath'] + self.version + "/config.json"
            localFile = os.path.join(currentProjectPath, "config.json")
            uploader.sshScpGet(remoteFile, localFile)
            print("success get the config file from server %s" % nodeinf['ip'])

        with open(
                os.path.join(config.logRootPath, self.version, 'config.json'),
                'r') as cf:
            cfstr = cf.read()
            configObj = json.loads(cfstr)
        # Model hyper-parameters
        self.g_conv_dim = configObj['g_conv_dim']
        self.reporter.writeInfo("g_conv_dim" + str(self.g_conv_dim))
        self.GEncActName = configObj['GEncActName']
        self.reporter.writeInfo("GEncActName" + self.GEncActName)
        self.GDecActName = configObj['GDecActName']
        self.reporter.writeInfo("GDecActName:" + self.GDecActName)
        self.GSkipActName = configObj['GSkipActName']
        self.reporter.writeInfo("GSkipActName:" + self.GDecActName)
        self.resNum = configObj['resNum']
        self.reporter.writeInfo("resNum:" + str(self.resNum))
        self.skipNum = configObj['skipNum']
        self.reporter.writeInfo("skipNum:" + str(self.skipNum))
        self.gLayerNum = configObj['gLayerNum']
        self.reporter.writeInfo("gLayerNum:" + str(self.gLayerNum))
        self.skipRatio = configObj['skipRatio']
        self.reporter.writeInfo("skipRatio:" + str(self.skipRatio))
        self.GScriptName = configObj['GeneratorScriptName']
        self.reporter.writeInfo("GScriptName:" + str(self.GScriptName))
        self.GOutActName = configObj['GOutActName']
        self.reporter.writeInfo("GOutActName:" + str(self.GOutActName))
        # training information
        self.save_testimages = config.save_testimages

        self.imsize = configObj['imsize']
        self.reporter.writeInfo("imsize:" + str(self.imsize))
        self.imCropSize = configObj['imCropSize']
        self.reporter.writeInfo("imCropSize:" + str(self.imCropSize))
        self.device = torch.device('cuda:%d' % config.cuda)
        # self.selected_attrs = configObj['selected_attrs']
        self.selected_attrs = config.selected_attrs
        self.reporter.writeInfo("selected_attrs:" + str(self.imCropSize))
        self.simple_attrs = configObj['selected_simple_attrs']
        self.n_classes = len(self.selected_attrs)
        self.specifiedImages = config.specifiedImages
        self.SampleImgNum = len(config.specifiedImages)
        # self.thres_int      = config.ThresInt
        if config.EnableThresIntSetting:
            self.thres_int = config.TestThresInt
        else:
            self.thres_int = configObj['ThresInt']
        self.reporter.writeInfo("thres_int:" + str(self.thres_int))
        self.batch_size = config.test_batch_size
        self.ACCModel = config.ACCModel

        # steps
        self.chechpoint_step = config.test_chechpoint_step
        self.reporter.writeInfo("chechpoint_step:" + str(self.chechpoint_step))
        # Path
        self.dataset = config.dataset
        # self.sample_path    = config.sample_path
        self.model_save_path = config.model_save_path
        self.attributes_path = config.attributes_path
        self.image_path = config.image_path
        self.getModel()
        self.build_model()
        self.printTable = pt.PrettyTable()
        self.printTable.field_names = self.simple_attrs + ["Mean"]

        self.testTable = pt.PrettyTable()
        self.testTable.field_names = [str(i) for i in range(self.batch_size)]
        self.ModelFigureName = config.ModelFigureName
class Tester(object):
    def __init__(self, config):
        if config.read_node_local:
            with open('nodesInfo/modelpositions.json', 'r') as cf:
                nodelocaltionstr = cf.read()
                nodelocaltioninf = json.loads(nodelocaltionstr)
                self.model_node = nodelocaltioninf[config.test_version]
                self.model_node = self.model_node['server']
                print("model node %s" % self.model_node)
        else:
            self.model_node = config.model_node
        self.version = config.test_version
        self.reporter = Reporter(
            os.path.join(config.testRoot, self.version, 'report.log'))
        self.reporter.writeInfo("version:" + str(self.version))
        self.logRootPath = config.logRootPath
        self.testimagepath = os.path.join(config.testRoot, self.version,
                                          'images')
        self.total_images = config.total_images
        self.testRoot = config.testRoot
        if not os.path.exists(self.testimagepath):
            os.makedirs(self.testimagepath)

        if self.model_node.lower() != "localhost":
            with open('nodesInfo/nodes.json', 'r') as cf:
                nodestr = cf.read()
                nodeinf = json.loads(nodestr)
            nodeinf = nodeinf[self.model_node.lower()]
            currentProjectPath = os.path.join(self.logRootPath, self.version)
            uploader = fileUploaderClass(nodeinf["ip"], nodeinf["user"],
                                         nodeinf["passwd"])
            remoteFile = nodeinf['basePath'] + self.version + "/config.json"
            localFile = os.path.join(currentProjectPath, "config.json")
            uploader.sshScpGet(remoteFile, localFile)
            print("success get the config file from server %s" % nodeinf['ip'])

        with open(
                os.path.join(config.logRootPath, self.version, 'config.json'),
                'r') as cf:
            cfstr = cf.read()
            configObj = json.loads(cfstr)
        # Model hyper-parameters
        self.g_conv_dim = configObj['g_conv_dim']
        self.reporter.writeInfo("g_conv_dim" + str(self.g_conv_dim))
        self.GEncActName = configObj['GEncActName']
        self.reporter.writeInfo("GEncActName" + self.GEncActName)
        self.GDecActName = configObj['GDecActName']
        self.reporter.writeInfo("GDecActName:" + self.GDecActName)
        self.GSkipActName = configObj['GSkipActName']
        self.reporter.writeInfo("GSkipActName:" + self.GDecActName)
        self.resNum = configObj['resNum']
        self.reporter.writeInfo("resNum:" + str(self.resNum))
        self.skipNum = configObj['skipNum']
        self.reporter.writeInfo("skipNum:" + str(self.skipNum))
        self.gLayerNum = configObj['gLayerNum']
        self.reporter.writeInfo("gLayerNum:" + str(self.gLayerNum))
        self.skipRatio = configObj['skipRatio']
        self.reporter.writeInfo("skipRatio:" + str(self.skipRatio))
        self.GScriptName = configObj['GeneratorScriptName']
        self.reporter.writeInfo("GScriptName:" + str(self.GScriptName))
        self.GOutActName = configObj['GOutActName']
        self.reporter.writeInfo("GOutActName:" + str(self.GOutActName))
        # training information
        self.save_testimages = config.save_testimages

        self.imsize = configObj['imsize']
        self.reporter.writeInfo("imsize:" + str(self.imsize))
        self.imCropSize = configObj['imCropSize']
        self.reporter.writeInfo("imCropSize:" + str(self.imCropSize))
        self.device = torch.device('cuda:%d' % config.cuda)
        # self.selected_attrs = configObj['selected_attrs']
        self.selected_attrs = config.selected_attrs
        self.reporter.writeInfo("selected_attrs:" + str(self.imCropSize))
        self.simple_attrs = configObj['selected_simple_attrs']
        self.n_classes = len(self.selected_attrs)
        self.specifiedImages = config.specifiedImages
        self.SampleImgNum = len(config.specifiedImages)
        # self.thres_int      = config.ThresInt
        if config.EnableThresIntSetting:
            self.thres_int = config.TestThresInt
        else:
            self.thres_int = configObj['ThresInt']
        self.reporter.writeInfo("thres_int:" + str(self.thres_int))
        self.batch_size = config.test_batch_size
        self.ACCModel = config.ACCModel

        # steps
        self.chechpoint_step = config.test_chechpoint_step
        self.reporter.writeInfo("chechpoint_step:" + str(self.chechpoint_step))
        # Path
        self.dataset = config.dataset
        # self.sample_path    = config.sample_path
        self.model_save_path = config.model_save_path
        self.attributes_path = config.attributes_path
        self.image_path = config.image_path
        self.getModel()
        self.build_model()
        self.printTable = pt.PrettyTable()
        self.printTable.field_names = self.simple_attrs + ["Mean"]

        self.testTable = pt.PrettyTable()
        self.testTable.field_names = [str(i) for i in range(self.batch_size)]
        self.ModelFigureName = config.ModelFigureName

    def getModel(self):
        server = self.model_node.lower()
        with open('nodesInfo/nodes.json', 'r') as cf:
            nodestr = cf.read()
            nodeinf = json.loads(nodestr)
        if server == "localhost":
            return
        else:
            nodeinf = nodeinf[server]
            # makeFolder(self.logRootPath, self.version)
            currentProjectPath = os.path.join(self.logRootPath, self.version)
            # makeFolder(currentProjectPath, "checkpoint")

            remoteFile = nodeinf[
                'basePath'] + self.version + "/checkpoint/" + "%d_LocalG.pth" % self.chechpoint_step
            localFile = os.path.join(currentProjectPath, "checkpoint",
                                     "%d_LocalG.pth" % self.chechpoint_step)
            if os.path.exists(localFile):
                print("checkpoint already exists")
            else:
                uploader = fileUploaderClass(nodeinf["ip"], nodeinf["user"],
                                             nodeinf["passwd"])
                uploader.sshScpGet(remoteFile, localFile)
                print("success get the model from server %s" % nodeinf['ip'])

    def build_model(self):
        package = __import__("components.%s" % self.GScriptName, fromlist=True)
        genClass = getattr(package, 'Generator')
        self.GlobalG = genClass(self.g_conv_dim, self.gLayerNum, self.resNum,
                                self.n_classes, self.skipNum, self.skipRatio,
                                self.GEncActName, self.GSkipActName,
                                self.GDecActName,
                                self.GOutActName).to(self.device)

        self.GlobalG.load_state_dict(
            torch.load(os.path.join(
                self.model_save_path,
                '{}_LocalG.pth'.format(self.chechpoint_step)),
                       map_location=self.device))
        print('loaded trained models (step: {})..!'.format(
            self.chechpoint_step))
        self.Classifier = FaceAttributesClassifier(attr_dim=self.n_classes).to(
            self.device)
        self.Classifier.load_state_dict(
            torch.load(self.ACCModel, map_location=self.device))
        print("load classifer model successful!")

    def getNonConflictingLablels(self, orignalLabel, class_n):
        fixeDelta = torch.zeros_like(orignalLabel)
        labelSize = orignalLabel.size()[0]
        outlabel = torch.zeros_like(orignalLabel)
        for index in range(labelSize):
            if orignalLabel[index, class_n] == 0:
                fixeDelta[index, class_n] = 1
                outlabel[index, class_n] = 1
                # if index == 0 and fixedLabel[index1,0,1]==1: # from non-blad to blad, so we need to remove bangs in the same time
                #     fixeDelta[index,index1,:,1] = -1
                if class_n == 0:
                    if orignalLabel[index, 1] == 1:
                        fixeDelta[index, 1] = -1
                if class_n == 1:
                    if orignalLabel[index, 0] == 1:
                        fixeDelta[index, 0] = -1
                if class_n == 2:
                    if orignalLabel[index, 3] == 1:
                        fixeDelta[index, 3] = -1
                    if orignalLabel[index, 4] == 1:
                        fixeDelta[index, 4] = -1
                if class_n == 3:
                    if orignalLabel[index, 2] == 1:
                        fixeDelta[index, 2] = -1
                    if orignalLabel[index, 4] == 1:
                        fixeDelta[index, 4] = -1
                if class_n == 4:
                    if orignalLabel[index, 2] == 1:
                        fixeDelta[index, 2] = -1
                    if orignalLabel[index, 3] == 1:
                        fixeDelta[index, 3] = -1
            else:
                fixeDelta[index, class_n] = -1
                outlabel[index, class_n] = -1
                if class_n == 2:
                    fixeDelta[index,
                              3] = 1  # translate black hair to blond hair
                if class_n == 3:
                    fixeDelta[index,
                              4] = 1  # translate blond hair to Brown Hair
                if class_n == 4:
                    fixeDelta[index, 3] = 1
        return fixeDelta, outlabel

    def test(self):
        # Start time
        start_time = time.time()
        data_loader = getLoader(self.image_path,
                                self.attributes_path,
                                self.selected_attrs,
                                self.imCropSize,
                                self.imsize,
                                self.batch_size,
                                dataset=self.dataset,
                                num_workers=0,
                                toPatch=False,
                                mode='test',
                                microPatchSize=0)
        # Fixed input for debugging
        # BatchText   = BatchTensorClass(30,[10,10])
        data_iter = iter(data_loader)
        # BatchText1   = BatchTensorClass(30,[60,10],color='blue')
        total = self.total_images
        self.GlobalG.eval()
        self.Classifier.eval()
        acc_cnt_generate = np.zeros([self.n_classes])
        total_psnr = 0.0
        with torch.no_grad():
            for iii in tqdm(range(total // self.batch_size)):
                try:
                    realImages, labelOriginal = next(data_iter)
                except:
                    data_iter = iter(data_loader)
                    realImages, labelOriginal = next(data_iter)
                res = realImages
                realImages = realImages.to(self.device)
                labelOriginal = labelOriginal.to(self.device)

                for index in range(self.n_classes + 1):
                    if index < (self.n_classes):
                        templabel, gt_label = self.getNonConflictingLablels(
                            labelOriginal, index)
                        templabel = templabel.to(
                            self.device) * 2 * self.thres_int
                        xFake = self.GlobalG(realImages, templabel)
                        wocao = self.Classifier(xFake)
                        pred_laebl = torch.round(wocao) * 2 - 1
                        # gtpre       = torch.round(self.Classifier(realImages))*2-1
                        # print(self.selected_attrs[index])
                        # self.testTable.add_row(labelOriginal[:,index].cpu().numpy())
                        # self.testTable.add_row(gtpre[:,index].cpu().numpy())
                        # self.testTable.add_row(templabel[:,index].cpu().numpy())
                        # self.testTable.add_row(pred_laebl[:,index].cpu().numpy())
                        # self.testTable.add_row(wocao[:,index].cpu().numpy())
                        # print(self.testTable)
                        # self.testTable.clear_rows()
                        # self.reporter.writeInfo()
                        # gtpre       = torch.round(self.Classifier(realImages))*2-1
                        # print("getpre:")
                        # print(gtpre)
                        # print("prelabel:")
                        # print(labelOriginal)
                        # print("prelabel:")
                        # print(pred_laebl)
                        pre = pred_laebl[:, index]
                        gt = gt_label[:, index].to(self.device)
                        acc_generate = (pre == gt).cpu().numpy()
                        acc_cnt_generate[index] += np.sum(np.sum(acc_generate,
                                                                 axis=0),
                                                          axis=0)
                        if self.save_testimages:
                            imgSamples = xFake.cpu()
                            res = torch.cat([res, imgSamples], 0)
                    else:
                        recLabel = torch.zeros_like(labelOriginal)
                        xFake = self.GlobalG(realImages, recLabel)
                        _, meanPSNR = PSNR(xFake.cpu(), realImages.cpu())
                        total_psnr += (meanPSNR *
                                       float(labelOriginal.size()[0]))
                        if self.save_testimages:
                            imgSamples = xFake.cpu()
                            res = torch.cat([res, imgSamples], 0)

                if self.save_testimages:
                    save_image(denorm(res.data),
                               os.path.join(self.testimagepath,
                                            '{}_fake.png'.format(iii + 1)),
                               nrow=self.batch_size)
                # return
        tabledata = {"acc": [], 'psnr': 0.0, 'step': 0}
        tabledata['step'] = self.chechpoint_step
        total_psnr = total_psnr / float(total)
        acc_gen = acc_cnt_generate / (total)
        acc = ["%.3f" % x for x in acc_gen]
        meanAcc = sum(acc_gen) / 13.0
        acc += ["%.3f" % meanAcc]
        tabledata['acc'] = acc
        self.printTable.add_row(acc)
        print("The %s model logs:" % self.version)
        self.reporter.writeInfo("Batch Size: %d" % self.batch_size)
        self.reporter.writeInfo("Total images: %d" % total)
        self.reporter.writeInfo("The %s model logs:" % self.version)
        print(self.printTable)
        self.reporter.writeInfo("\n" + self.printTable.__str__())
        print("The mean PSNR is %.2f" % total_psnr)
        tabledata['psnr'] = total_psnr
        self.reporter.writeInfo("The mean PSNR is %.2f" % total_psnr)
        bestResultPath = os.path.join(self.testRoot, self.version,
                                      'bestResult.json')
        bestResultCsv = os.path.join(self.testRoot, self.version,
                                     'bestResult.csv')
        ResultCsv = os.path.join(self.testRoot, self.version,
                                 'step%d_Result.csv' % self.chechpoint_step)
        attrname = [
            "Bald", "Bangs", "Black", "Blond", "Brown", "Eyebrows", "Glass",
            "Male", "Mouth", "Mustache", "NoBeard", "Pale", "Young", "Average"
        ]
        if os.path.exists(bestResultPath):
            with open(bestResultPath, 'r') as cf:
                score = cf.read()
                score = json.loads(score)
                if score['acc'][13] < acc[13]:
                    with open(bestResultPath, 'w') as cf:
                        scorejson = json.dumps(tabledata)
                        cf.writelines(scorejson)
                    import csv
                    headers = ['Name', 'Arch', 'Score']
                    rows = []
                    for i in range(14):
                        rows.append({
                            "Name": attrname[i],
                            "Arch": self.ModelFigureName,
                            "Score": acc[i],
                        })

                    with open(bestResultCsv, 'w', newline='') as f:
                        f_csv = csv.DictWriter(f, headers)
                        f_csv.writeheader()
                        f_csv.writerows(rows)
        else:
            with open(bestResultPath, 'w') as cf:
                scorejson = json.dumps(tabledata)
                cf.writelines(scorejson)

        import csv
        headers = ['Name', 'Arch', 'Score']
        rows = []
        for i in range(14):
            rows.append({
                "Name": attrname[i],
                "Arch": self.ModelFigureName,
                "Score": acc[i],
            })

        with open(ResultCsv, 'w', newline='') as f:
            f_csv = csv.DictWriter(f, headers)
            f_csv.writeheader()
            f_csv.writerows(rows)

        elapsed = time.time() - start_time
        elapsed = str(datetime.timedelta(seconds=elapsed))
        print("Elapsed [{}]".format(elapsed))