def __init__(self, config): for k in config.__dict__: setattr(self, k, config.__dict__[k]) self.curProjectPath = os.path.join(config.testRoot, config.test_version) self.reporter = Reporter( os.path.join(self.curProjectPath, 'report.log')) self.reporter.writeConfig(config) self.device = torch.device('cuda:%d' % config.cuda) self.testimagepath = os.path.join(self.curProjectPath, 'images') if not os.path.exists(self.testimagepath): os.makedirs(self.testimagepath) # load paramters self.n_classes = len(getattr(self, "selected_attrs")) self.version = config.test_version # self.thres_int = config.ThresInt if config.EnableThresIntSetting: # self.thres_int = config.TestThresInt self.thres_int = 1.5 else: self.thres_int = getattr(self, "ThresInt") self.reporter.writeInfo("thres_int:" + str(self.thres_int)) self.batch_size = getattr(self, "test_batch_size") self.build_model() self.printTable = pt.PrettyTable() self.printTable.field_names = getattr( self, "selected_simple_attrs") + ["Mean"] self.testTable = pt.PrettyTable() self.testTable.field_names = [str(i) for i in range(self.batch_size)] self.ModelFigureName = config.ModelFigureName
def __init__(self, data_loader, config): self.report_file = os.path.join(config.log_path, config.version, config.version + "_report.log") self.reporter = Reporter(self.report_file) # Data loader self.data_loader = data_loader # exact model and loss self.cGAN = config.cGAN self.adv_loss = config.adv_loss # Model hyper-parameters self.imsize = config.imsize self.z_dim = config.z_dim self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.n_classes = config.n_class if config.cGAN else 0 self.parallel = config.parallel self.seed = config.seed self.device = torch.device('cuda:%d' % config.cuda) self.GPUs = config.GPUs self.gen_distribution = config.gen_distribution self.gen_bottom_width = config.gen_bottom_width self.total_step = config.total_step self.batch_size = config.batch_size self.num_workers = config.num_workers self.g_lr = config.g_lr self.d_lr = config.d_lr self.lr_decay = config.lr_decay self.beta1 = config.beta1 self.beta2 = config.beta2 self.use_pretrained_model = config.use_pretrained_model self.chechpoint_step = config.chechpoint_step self.use_pretrained_model = config.use_pretrained_model self.dataset = config.dataset self.image_path = config.image_path self.log_path = config.log_path self.model_save_path = config.model_save_path self.sample_path = config.sample_path self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.version = config.version self.caculate_FID = config.caculate_FID self.DStep = config.D_step self.GStep = config.G_step self.metric_caculation_step = config.metric_caculation_step # Path self.log_path = os.path.join(config.log_path, self.version) self.summary_path = self.log_path self.sample_path = os.path.join(config.sample_path, self.version) self.model_save_path = os.path.join(config.model_save_path, self.version) self.build_model() self.reporter.writeConfig(config) self.reporter.writeModel(self.G.__str__()) self.reporter.writeModel(self.D.__str__()) if self.caculate_FID: z_sampler, c_sampler = Sampler.prepare_z_c(self.batch_size, self.z_dim, self.n_classes, device=self.device) gsampler = functools.partial(Sampler.sampleG, G=self.G, z_=z_sampler, c_=c_sampler, parallel=self.parallel) self.get_inception_metrics = FIDCaculator.prepare_inception_metrics( config.FID_mean_cov, gsampler, config.metric_images_num) self.writer = SummaryWriter(log_dir=self.summary_path) z = torch.zeros(1, self.z_dim).to(self.device) c = torch.zeros(1).long().to(self.device) y = torch.zeros(1, 3, self.imsize, self.imsize).to(self.device) vise_graph = make_dot(self.G(z, c), params=dict(self.G.named_parameters())) vise_graph.view(self.log_path + "/Generator") vise_graph = make_dot(self.D(y, c), params=dict(self.D.named_parameters())) vise_graph.view(self.log_path + "/Discriminator", quiet=False, quiet_view=False) del z del c del y # Start with trained model if self.use_pretrained_model: self.load_pretrained_model()
class Trainer(object): def __init__(self, data_loader, config): self.report_file = os.path.join(config.log_path, config.version, config.version + "_report.log") self.reporter = Reporter(self.report_file) # Data loader self.data_loader = data_loader # exact model and loss self.cGAN = config.cGAN self.adv_loss = config.adv_loss # Model hyper-parameters self.imsize = config.imsize self.z_dim = config.z_dim self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.n_classes = config.n_class if config.cGAN else 0 self.parallel = config.parallel self.seed = config.seed self.device = torch.device('cuda:%d' % config.cuda) self.GPUs = config.GPUs self.gen_distribution = config.gen_distribution self.gen_bottom_width = config.gen_bottom_width self.total_step = config.total_step self.batch_size = config.batch_size self.num_workers = config.num_workers self.g_lr = config.g_lr self.d_lr = config.d_lr self.lr_decay = config.lr_decay self.beta1 = config.beta1 self.beta2 = config.beta2 self.use_pretrained_model = config.use_pretrained_model self.chechpoint_step = config.chechpoint_step self.use_pretrained_model = config.use_pretrained_model self.dataset = config.dataset self.image_path = config.image_path self.log_path = config.log_path self.model_save_path = config.model_save_path self.sample_path = config.sample_path self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.version = config.version self.caculate_FID = config.caculate_FID self.DStep = config.D_step self.GStep = config.G_step self.metric_caculation_step = config.metric_caculation_step # Path self.log_path = os.path.join(config.log_path, self.version) self.summary_path = self.log_path self.sample_path = os.path.join(config.sample_path, self.version) self.model_save_path = os.path.join(config.model_save_path, self.version) self.build_model() self.reporter.writeConfig(config) self.reporter.writeModel(self.G.__str__()) self.reporter.writeModel(self.D.__str__()) if self.caculate_FID: z_sampler, c_sampler = Sampler.prepare_z_c(self.batch_size, self.z_dim, self.n_classes, device=self.device) gsampler = functools.partial(Sampler.sampleG, G=self.G, z_=z_sampler, c_=c_sampler, parallel=self.parallel) self.get_inception_metrics = FIDCaculator.prepare_inception_metrics( config.FID_mean_cov, gsampler, config.metric_images_num) self.writer = SummaryWriter(log_dir=self.summary_path) z = torch.zeros(1, self.z_dim).to(self.device) c = torch.zeros(1).long().to(self.device) y = torch.zeros(1, 3, self.imsize, self.imsize).to(self.device) vise_graph = make_dot(self.G(z, c), params=dict(self.G.named_parameters())) vise_graph.view(self.log_path + "/Generator") vise_graph = make_dot(self.D(y, c), params=dict(self.D.named_parameters())) vise_graph.view(self.log_path + "/Discriminator", quiet=False, quiet_view=False) del z del c del y # Start with trained model if self.use_pretrained_model: self.load_pretrained_model() def build_model(self): self.G = ResNetGenerator(self.g_conv_dim, self.z_dim, self.gen_bottom_width, num_classes=self.n_classes).to(self.device) self.D = SNResNetProjectionDiscriminator( self.d_conv_dim, self.n_classes).to(self.device) if self.parallel: self.G = nn.DataParallel(self.G, device_ids=self.GPUs) self.D = nn.DataParallel(self.D, device_ids=self.GPUs) # Loss and optimizer self.g_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2]) def train(self): # Data iterator data_iter = iter(self.data_loader) model_save_step = self.model_save_step # Fixed input for debugging sampleBatch = 10 fixed_z = torch.randn(self.n_classes * sampleBatch, self.z_dim) fixed_z = fixed_z.to(self.device) fixed_c = Sampler.sampleFixedLabels(self.n_classes, sampleBatch, self.device) runingZ, runingLabel = Sampler.prepare_z_c(self.batch_size, self.z_dim, self.n_classes, device=self.device) # Start with trained model if self.use_pretrained_model: start = self.chechpoint_step + 1 else: start = 0 # Start time start_time = time.time() self.reporter.writeInfo("Start to train the model") dstepCounter = 0 gstepCounter = 0 for step in range(start, self.total_step): # ================== Train D ================== # self.D.train() self.G.train() if dstepCounter < self.DStep: try: realImages, realLabel = next(data_iter) except: data_iter = iter(self.data_loader) realImages, realLabel = next(data_iter) # Compute loss with real images realImages = realImages.to(self.device) realLabel = realLabel.to(self.device).long() d_out_real = self.D(realImages, realLabel) d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean() # apply Gumbel Softmax runingZ.sample_() runingLabel.sample_() fake_images = self.G(runingZ, runingLabel) d_out_fake = self.D(fake_images, runingLabel) d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean() # Backward + Optimize d_loss = d_loss_real + d_loss_fake self.reset_grad() d_loss.backward() self.d_optimizer.step() dstepCounter += 1 else: # ================== Train G and gumbel ================== # # Create random noise runingZ.sample_() runingLabel.sample_() fake_images = self.G(runingZ, runingLabel) # Compute loss with fake images g_out_fake = self.D(fake_images, runingLabel) g_loss_fake = -g_out_fake.mean() self.reset_grad() g_loss_fake.backward() self.g_optimizer.step() gstepCounter += 1 if gstepCounter == self.GStep: dstepCounter = 0 gstepCounter = 0 # Print out log info if (step + 1) % self.log_step == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) print( "Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, " " d_loss_fake: {:.4f}, g_loss_fake: {:.4f}".format( elapsed, step + 1, self.total_step, (step + 1), self.total_step, d_loss_real.item(), d_loss_fake.item(), g_loss_fake.item())) self.writer.add_scalar('log/d_loss_real', d_loss_real.item(), (step + 1)) self.writer.add_scalar('log/d_loss_fake', d_loss_fake.item(), (step + 1)) self.writer.add_scalar('log/d_loss', d_loss.item(), (step + 1)) self.writer.add_scalar('log/g_loss_fake', g_loss_fake.item(), (step + 1)) if (step + 1) % self.sample_step == 0: fake_images = self.G(fixed_z, fixed_c) save_image(denorm(fake_images.data), os.path.join(self.sample_path, '{}_fake.png'.format(step + 1)), nrow=self.n_classes) if (step + 1) % model_save_step == 0: torch.save( self.G.state_dict(), os.path.join(self.model_save_path, '{}_G.pth'.format(step + 1))) torch.save( self.D.state_dict(), os.path.join(self.model_save_path, '{}_D.pth'.format(step + 1))) if (step + 1 ) % self.metric_caculation_step == 0 and self.caculate_FID: print("start to caculate the FID") FID = self.get_inception_metrics() print("FID is %.3f" % FID) self.writer.add_scalar('metric/FID', FID, (step + 1)) self.reporter.writeTrainLog(step + 1, "Current FID is %.4f" % FID) def load_pretrained_model(self): self.G.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_G.pth'.format(self.chechpoint_step)))) self.D.load_state_dict( torch.load( os.path.join(self.model_save_path, '{}_D.pth'.format(self.chechpoint_step)))) print('loaded trained models (step: {})..!'.format( self.chechpoint_step)) def reset_grad(self): self.d_optimizer.zero_grad() self.g_optimizer.zero_grad()
class Tester(object): def __init__(self, config): for k in config.__dict__: setattr(self, k, config.__dict__[k]) self.curProjectPath = os.path.join(config.testRoot, config.test_version) self.reporter = Reporter( os.path.join(self.curProjectPath, 'report.log')) self.reporter.writeConfig(config) self.device = torch.device('cuda:%d' % config.cuda) self.testimagepath = os.path.join(self.curProjectPath, 'images') if not os.path.exists(self.testimagepath): os.makedirs(self.testimagepath) # load paramters self.n_classes = len(getattr(self, "selected_attrs")) self.version = config.test_version # self.thres_int = config.ThresInt if config.EnableThresIntSetting: # self.thres_int = config.TestThresInt self.thres_int = 1.5 else: self.thres_int = getattr(self, "ThresInt") self.reporter.writeInfo("thres_int:" + str(self.thres_int)) self.batch_size = getattr(self, "test_batch_size") self.build_model() self.printTable = pt.PrettyTable() self.printTable.field_names = getattr( self, "selected_simple_attrs") + ["Mean"] self.testTable = pt.PrettyTable() self.testTable.field_names = [str(i) for i in range(self.batch_size)] self.ModelFigureName = config.ModelFigureName def build_model(self): device = self.device package = __import__("%s" % getattr(self, 'GeneratorScriptName'), fromlist=True) genClass = getattr(package, 'Generator') self.GlobalG = genClass(getattr(self, 'g_conv_dim'), getattr(self, 'gLayerNum'), getattr(self, 'resNum'), self.n_classes, getattr(self, 'skipNum'), getattr(self, 'skipRatio'), getattr(self, 'GEncActName'), getattr(self, 'GSkipActName'), getattr(self, 'GDecActName'), getattr(self, 'GOutActName'), getattr(self, "LSTUScriptName")).to(device) # self.LSTUScriptName).to(self.device) model_save_path = getattr(self, "model_save_path") chechpoint_step = getattr(self, "test_chechpoint_step") ACCModel = getattr(self, "ACCModel") self.GlobalG.load_state_dict( torch.load(os.path.join( model_save_path, 'Epoch{}_LocalG.pth'.format(chechpoint_step)), map_location=device)) print('loaded trained models (step: {})..!'.format(chechpoint_step)) self.Classifier = FaceAttributesClassifier( attr_dim=self.n_classes).to(device) self.Classifier.load_state_dict( torch.load(ACCModel, map_location=device)) print("load classifer model successful!") def getNonConflictingLablels(self, orignalLabel, class_n): fixeDelta = torch.zeros_like(orignalLabel) labelSize = orignalLabel.size()[0] outlabel = torch.zeros_like(orignalLabel) for index in range(labelSize): if orignalLabel[index, class_n] == 0: fixeDelta[index, class_n] = 1 outlabel[index, class_n] = 1 # if index == 0 and fixedLabel[index1,0,1]==1: # from non-blad to blad, so we need to remove bangs in the same time # fixeDelta[index,index1,:,1] = -1 if class_n == 0: if orignalLabel[index, 1] == 1: fixeDelta[index, 1] = -1 if class_n == 1: if orignalLabel[index, 0] == 1: fixeDelta[index, 0] = -1 if class_n == 2: if orignalLabel[index, 3] == 1: fixeDelta[index, 3] = -1 if orignalLabel[index, 4] == 1: fixeDelta[index, 4] = -1 if class_n == 3: if orignalLabel[index, 2] == 1: fixeDelta[index, 2] = -1 if orignalLabel[index, 4] == 1: fixeDelta[index, 4] = -1 if class_n == 4: if orignalLabel[index, 2] == 1: fixeDelta[index, 2] = -1 if orignalLabel[index, 3] == 1: fixeDelta[index, 3] = -1 else: fixeDelta[index, class_n] = -1 outlabel[index, class_n] = -1 if class_n == 2: fixeDelta[index, 3] = 1 # translate black hair to blond hair if class_n == 3: fixeDelta[index, 4] = 1 # translate blond hair to Brown Hair if class_n == 4: fixeDelta[index, 3] = 1 return fixeDelta, outlabel def test(self): # Start time start_time = time.time() data_loader = getLoader(getattr(self, "image_path"), getattr(self, "attributes_path"), getattr(self, "selected_attrs"), getattr(self, "imCropSize"), getattr(self, "imsize"), self.batch_size, dataset=getattr(self, "dataset"), num_workers=0, toPatch=False, mode='test', microPatchSize=0) chechpoint_step = getattr(self, "chechpoint_step") device = self.device total_images = getattr(self, "total_images") save_testimages = getattr(self, "save_testimages") imsize = getattr(self, "imsize") # Fixed input for debugging # BatchText = BatchTensorClass(30,[10,10]) data_iter = iter(data_loader) # BatchText1 = BatchTensorClass(30,[60,10],color='blue') total = total_images self.GlobalG.eval() self.Classifier.eval() acc_cnt_generate = np.zeros([self.n_classes]) total_psnr = 0.0 total_ssim = 0.0 with torch.no_grad(): for iii in tqdm(range(total // self.batch_size)): try: realImages, labelOriginal = next(data_iter) except: data_iter = iter(data_loader) realImages, labelOriginal = next(data_iter) res = realImages realImages = realImages.to(device) labelOriginal = labelOriginal.to(device) for index in range(self.n_classes + 1): if index < (self.n_classes): templabel, gt_label = self.getNonConflictingLablels( labelOriginal, index) templabel = templabel.to(device) * self.thres_int xFake = self.GlobalG(realImages, templabel) if imsize > 128: xFakelow = F.interpolate(xFake, size=128) else: xFakelow = xFake wocao = self.Classifier(xFakelow) pred_laebl = torch.round(wocao) * 2 - 1 # gtpre = torch.round(self.Classifier(realImages))*2-1 # print(self.selected_attrs[index]) # self.testTable.add_row(labelOriginal[:,index].cpu().numpy()) # self.testTable.add_row(gtpre[:,index].cpu().numpy()) # self.testTable.add_row(templabel[:,index].cpu().numpy()) # self.testTable.add_row(pred_laebl[:,index].cpu().numpy()) # self.testTable.add_row(wocao[:,index].cpu().numpy()) # print(self.testTable) # self.testTable.clear_rows() # self.reporter.writeInfo() # gtpre = torch.round(self.Classifier(realImages))*2-1 # print("getpre:") # print(gtpre) # print("prelabel:") # print(labelOriginal) # print("prelabel:") # print(pred_laebl) pre = pred_laebl[:, index] gt = gt_label[:, index].to(device) acc_generate = (pre == gt).cpu().numpy() acc_cnt_generate[index] += np.sum(np.sum(acc_generate, axis=0), axis=0) if save_testimages: imgSamples = xFake.cpu() res = torch.cat([res, imgSamples], 0) else: recLabel = torch.zeros_like(labelOriginal) xFake = self.GlobalG(realImages, recLabel) _, meanPSNR = PSNR(xFake.cpu(), realImages.cpu()) total_psnr += (meanPSNR * float(labelOriginal.size()[0])) imgs1_ssim = (realImages + 1) / 2 imgs2_ssim = (xFake + 1) / 2 per_ssim_value = pytorch_ssim.ssim( imgs1_ssim.cpu(), imgs2_ssim.cpu()).item() total_ssim += per_ssim_value * self.batch_size if save_testimages: imgSamples = xFake.cpu() res = torch.cat([res, imgSamples], 0) if save_testimages: print("Save test data") save_image(denorm(res.data), os.path.join(self.testimagepath, '{}_fake.png'.format(iii + 1)), nrow=self.batch_size) #,nrow=self.batch_size) # return tabledata = {"acc": [], 'psnr': 0.0, 'ssim': 0.0, 'step': 0} total_ssim = total_ssim / total_images tabledata['step'] = chechpoint_step tabledata['ssim'] = total_ssim total_psnr = total_psnr / float(total) acc_gen = acc_cnt_generate / (total) acc = ["%.3f" % x for x in acc_gen] meanAcc = sum(acc_gen) / 13.0 acc += ["%.3f" % meanAcc] tabledata['acc'] = acc self.printTable.add_row(acc) print("The %s model logs:" % self.version) print(self.printTable) self.reporter.writeInfo("\n" + self.printTable.__str__()) print("The mean PSNR is %.2f" % total_psnr) tabledata['psnr'] = total_psnr print("The average SSIM is %.4f" % total_ssim) self.reporter.writeInfo("The mean PSNR is %.2f" % total_psnr) self.reporter.writeInfo("The average SSIM is %.4f" % total_ssim) bestResultPath = os.path.join(self.curProjectPath, 'bestResult.json') bestResultCsv = os.path.join(self.curProjectPath, 'bestResult.csv') ResultCsv = os.path.join(self.curProjectPath, 'step%d_Result.csv' % chechpoint_step) attrname = [ "Bald", "Bangs", "Black", "Blond", "Brown", "Eyebrows", "Glass", "Male", "Mouth", "Mustache", "NoBeard", "Pale", "Young", "Average" ] if os.path.exists(bestResultPath): with open(bestResultPath, 'r') as cf: score = cf.read() score = json.loads(score) if score['acc'][13] < acc[13]: with open(bestResultPath, 'w') as cf: scorejson = json.dumps(tabledata) cf.writelines(scorejson) import csv headers = ['Name', 'Arch', 'Score'] rows = [] for i in range(14): rows.append({ "Name": attrname[i], "Arch": self.ModelFigureName, "Score": acc[i], }) with open(bestResultCsv, 'w', newline='') as f: f_csv = csv.DictWriter(f, headers) f_csv.writeheader() f_csv.writerows(rows) else: with open(bestResultPath, 'w') as cf: scorejson = json.dumps(tabledata) cf.writelines(scorejson) import csv headers = ['Name', 'Arch', 'Score'] rows = [] for i in range(14): rows.append({ "Name": attrname[i], "Arch": self.ModelFigureName, "Score": acc[i], }) with open(ResultCsv, 'w', newline='') as f: f_csv = csv.DictWriter(f, headers) f_csv.writeheader() f_csv.writerows(rows) elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) print("Elapsed [{}]".format(elapsed))
def main(config): ignoreKey = [ "TestScriptsName", "test_version", "read_node_local", "model_node", "test_chechpoint_step", "test_batch_size", "total_images", "save_testimages", "EnableThresIntSetting", "ModelFigureName", "use_pretrained_model", "chechpoint_step", "train", "testRoot", "ACCModel", "cuda", "use_system_cuda_way", "logRootPath" ] if config.cuda > -1: os.environ["CUDA_VISIBLE_DEVICES"] = str(config.cuda) config.cuda = 0 # For fast training cudnn.benchmark = True if config.mode != "test": # load the dataset path datapath = "./dataTool/dataPath.json" with open(datapath, 'r') as cf: datastr = cf.read() dataobj = json.loads(datastr) config.image_path = dataobj[config.dataset.lower()] config.attributes_path = dataobj[config.dataset.lower() + 'att'] print("Get data path: %s" % config.image_path) # create dataloader data_loader = getLoader(config.image_path, config.attributes_path, config.selected_attrs, config.imCropSize, config.imsize, config.batch_size, dataset=config.dataset, num_workers=config.num_workers) # create training log dirs if not os.path.exists(config.logRootPath): os.makedirs(config.logRootPath) makeFolder(config.logRootPath, config.version) currentProjectPath = os.path.join(config.logRootPath, config.version) makeFolder(currentProjectPath, "summary") config.log_path = os.path.join(currentProjectPath, "summary") makeFolder(currentProjectPath, "checkpoint") config.model_save_path = os.path.join(currentProjectPath, "checkpoint") makeFolder(currentProjectPath, "sample") config.sample_path = os.path.join(currentProjectPath, "sample") config.reporter = '' configjson = json.dumps(config.__dict__) with open( os.path.join(config.logRootPath, config.version, 'config.json'), 'w') as cf: cf.writelines(configjson) report_file = os.path.join(config.logRootPath, config.version, config.version + "_report") config.reporter = Reporter(report_file) moduleName = "TrainingScripts.trainer_" + config.TrainScriptName tempstr = "Start to run training script: {}".format(moduleName) print(tempstr) print("Traning version: %s" % config.version) print("Training Script Name: %s" % config.TrainScriptName) print("Image Size: %d" % config.imsize) print("Image Crop Size: %d" % config.imCropSize) print("ThresInt: %d" % config.ThresInt) print("D : G = %d : %d" % (config.D_step, config.G_step)) config.reporter.writeInfo(tempstr) package = __import__(moduleName, fromlist=True) trainerClass = getattr(package, 'Trainer') trainer = trainerClass(data_loader, config) trainer.train() else: # make a dir for saving test results makeFolder(config.testRoot, config.test_version) moduleName = "TestScripts.tester_" + config.TestScriptsName tempstr = "Start to run test script: {}".format(moduleName) print(tempstr) package = __import__(moduleName, fromlist=True) testerClass = getattr(package, 'Tester') tester = testerClass(config) tester.test()
def __init__(self, config): if config.read_node_local: with open('nodesInfo/modelpositions.json', 'r') as cf: nodelocaltionstr = cf.read() nodelocaltioninf = json.loads(nodelocaltionstr) self.model_node = nodelocaltioninf[config.test_version] self.model_node = self.model_node['server'] print("model node %s" % self.model_node) else: self.model_node = config.model_node self.version = config.test_version self.reporter = Reporter( os.path.join(config.testRoot, self.version, 'report.log')) self.reporter.writeInfo("version:" + str(self.version)) self.logRootPath = config.logRootPath self.testimagepath = os.path.join(config.testRoot, self.version, 'images') self.total_images = config.total_images self.testRoot = config.testRoot if not os.path.exists(self.testimagepath): os.makedirs(self.testimagepath) if self.model_node.lower() != "localhost": with open('nodesInfo/nodes.json', 'r') as cf: nodestr = cf.read() nodeinf = json.loads(nodestr) nodeinf = nodeinf[self.model_node.lower()] currentProjectPath = os.path.join(self.logRootPath, self.version) uploader = fileUploaderClass(nodeinf["ip"], nodeinf["user"], nodeinf["passwd"]) remoteFile = nodeinf['basePath'] + self.version + "/config.json" localFile = os.path.join(currentProjectPath, "config.json") uploader.sshScpGet(remoteFile, localFile) print("success get the config file from server %s" % nodeinf['ip']) with open( os.path.join(config.logRootPath, self.version, 'config.json'), 'r') as cf: cfstr = cf.read() configObj = json.loads(cfstr) # Model hyper-parameters self.g_conv_dim = configObj['g_conv_dim'] self.reporter.writeInfo("g_conv_dim" + str(self.g_conv_dim)) self.GEncActName = configObj['GEncActName'] self.reporter.writeInfo("GEncActName" + self.GEncActName) self.GDecActName = configObj['GDecActName'] self.reporter.writeInfo("GDecActName:" + self.GDecActName) self.GSkipActName = configObj['GSkipActName'] self.reporter.writeInfo("GSkipActName:" + self.GDecActName) self.resNum = configObj['resNum'] self.reporter.writeInfo("resNum:" + str(self.resNum)) self.skipNum = configObj['skipNum'] self.reporter.writeInfo("skipNum:" + str(self.skipNum)) self.gLayerNum = configObj['gLayerNum'] self.reporter.writeInfo("gLayerNum:" + str(self.gLayerNum)) self.skipRatio = configObj['skipRatio'] self.reporter.writeInfo("skipRatio:" + str(self.skipRatio)) self.GScriptName = configObj['GeneratorScriptName'] self.reporter.writeInfo("GScriptName:" + str(self.GScriptName)) self.GOutActName = configObj['GOutActName'] self.reporter.writeInfo("GOutActName:" + str(self.GOutActName)) # training information self.save_testimages = config.save_testimages self.imsize = configObj['imsize'] self.reporter.writeInfo("imsize:" + str(self.imsize)) self.imCropSize = configObj['imCropSize'] self.reporter.writeInfo("imCropSize:" + str(self.imCropSize)) self.device = torch.device('cuda:%d' % config.cuda) # self.selected_attrs = configObj['selected_attrs'] self.selected_attrs = config.selected_attrs self.reporter.writeInfo("selected_attrs:" + str(self.imCropSize)) self.simple_attrs = configObj['selected_simple_attrs'] self.n_classes = len(self.selected_attrs) self.specifiedImages = config.specifiedImages self.SampleImgNum = len(config.specifiedImages) # self.thres_int = config.ThresInt if config.EnableThresIntSetting: self.thres_int = config.TestThresInt else: self.thres_int = configObj['ThresInt'] self.reporter.writeInfo("thres_int:" + str(self.thres_int)) self.batch_size = config.test_batch_size self.ACCModel = config.ACCModel # steps self.chechpoint_step = config.test_chechpoint_step self.reporter.writeInfo("chechpoint_step:" + str(self.chechpoint_step)) # Path self.dataset = config.dataset # self.sample_path = config.sample_path self.model_save_path = config.model_save_path self.attributes_path = config.attributes_path self.image_path = config.image_path self.getModel() self.build_model() self.printTable = pt.PrettyTable() self.printTable.field_names = self.simple_attrs + ["Mean"] self.testTable = pt.PrettyTable() self.testTable.field_names = [str(i) for i in range(self.batch_size)] self.ModelFigureName = config.ModelFigureName
class Tester(object): def __init__(self, config): if config.read_node_local: with open('nodesInfo/modelpositions.json', 'r') as cf: nodelocaltionstr = cf.read() nodelocaltioninf = json.loads(nodelocaltionstr) self.model_node = nodelocaltioninf[config.test_version] self.model_node = self.model_node['server'] print("model node %s" % self.model_node) else: self.model_node = config.model_node self.version = config.test_version self.reporter = Reporter( os.path.join(config.testRoot, self.version, 'report.log')) self.reporter.writeInfo("version:" + str(self.version)) self.logRootPath = config.logRootPath self.testimagepath = os.path.join(config.testRoot, self.version, 'images') self.total_images = config.total_images self.testRoot = config.testRoot if not os.path.exists(self.testimagepath): os.makedirs(self.testimagepath) if self.model_node.lower() != "localhost": with open('nodesInfo/nodes.json', 'r') as cf: nodestr = cf.read() nodeinf = json.loads(nodestr) nodeinf = nodeinf[self.model_node.lower()] currentProjectPath = os.path.join(self.logRootPath, self.version) uploader = fileUploaderClass(nodeinf["ip"], nodeinf["user"], nodeinf["passwd"]) remoteFile = nodeinf['basePath'] + self.version + "/config.json" localFile = os.path.join(currentProjectPath, "config.json") uploader.sshScpGet(remoteFile, localFile) print("success get the config file from server %s" % nodeinf['ip']) with open( os.path.join(config.logRootPath, self.version, 'config.json'), 'r') as cf: cfstr = cf.read() configObj = json.loads(cfstr) # Model hyper-parameters self.g_conv_dim = configObj['g_conv_dim'] self.reporter.writeInfo("g_conv_dim" + str(self.g_conv_dim)) self.GEncActName = configObj['GEncActName'] self.reporter.writeInfo("GEncActName" + self.GEncActName) self.GDecActName = configObj['GDecActName'] self.reporter.writeInfo("GDecActName:" + self.GDecActName) self.GSkipActName = configObj['GSkipActName'] self.reporter.writeInfo("GSkipActName:" + self.GDecActName) self.resNum = configObj['resNum'] self.reporter.writeInfo("resNum:" + str(self.resNum)) self.skipNum = configObj['skipNum'] self.reporter.writeInfo("skipNum:" + str(self.skipNum)) self.gLayerNum = configObj['gLayerNum'] self.reporter.writeInfo("gLayerNum:" + str(self.gLayerNum)) self.skipRatio = configObj['skipRatio'] self.reporter.writeInfo("skipRatio:" + str(self.skipRatio)) self.GScriptName = configObj['GeneratorScriptName'] self.reporter.writeInfo("GScriptName:" + str(self.GScriptName)) self.GOutActName = configObj['GOutActName'] self.reporter.writeInfo("GOutActName:" + str(self.GOutActName)) # training information self.save_testimages = config.save_testimages self.imsize = configObj['imsize'] self.reporter.writeInfo("imsize:" + str(self.imsize)) self.imCropSize = configObj['imCropSize'] self.reporter.writeInfo("imCropSize:" + str(self.imCropSize)) self.device = torch.device('cuda:%d' % config.cuda) # self.selected_attrs = configObj['selected_attrs'] self.selected_attrs = config.selected_attrs self.reporter.writeInfo("selected_attrs:" + str(self.imCropSize)) self.simple_attrs = configObj['selected_simple_attrs'] self.n_classes = len(self.selected_attrs) self.specifiedImages = config.specifiedImages self.SampleImgNum = len(config.specifiedImages) # self.thres_int = config.ThresInt if config.EnableThresIntSetting: self.thres_int = config.TestThresInt else: self.thres_int = configObj['ThresInt'] self.reporter.writeInfo("thres_int:" + str(self.thres_int)) self.batch_size = config.test_batch_size self.ACCModel = config.ACCModel # steps self.chechpoint_step = config.test_chechpoint_step self.reporter.writeInfo("chechpoint_step:" + str(self.chechpoint_step)) # Path self.dataset = config.dataset # self.sample_path = config.sample_path self.model_save_path = config.model_save_path self.attributes_path = config.attributes_path self.image_path = config.image_path self.getModel() self.build_model() self.printTable = pt.PrettyTable() self.printTable.field_names = self.simple_attrs + ["Mean"] self.testTable = pt.PrettyTable() self.testTable.field_names = [str(i) for i in range(self.batch_size)] self.ModelFigureName = config.ModelFigureName def getModel(self): server = self.model_node.lower() with open('nodesInfo/nodes.json', 'r') as cf: nodestr = cf.read() nodeinf = json.loads(nodestr) if server == "localhost": return else: nodeinf = nodeinf[server] # makeFolder(self.logRootPath, self.version) currentProjectPath = os.path.join(self.logRootPath, self.version) # makeFolder(currentProjectPath, "checkpoint") remoteFile = nodeinf[ 'basePath'] + self.version + "/checkpoint/" + "%d_LocalG.pth" % self.chechpoint_step localFile = os.path.join(currentProjectPath, "checkpoint", "%d_LocalG.pth" % self.chechpoint_step) if os.path.exists(localFile): print("checkpoint already exists") else: uploader = fileUploaderClass(nodeinf["ip"], nodeinf["user"], nodeinf["passwd"]) uploader.sshScpGet(remoteFile, localFile) print("success get the model from server %s" % nodeinf['ip']) def build_model(self): package = __import__("components.%s" % self.GScriptName, fromlist=True) genClass = getattr(package, 'Generator') self.GlobalG = genClass(self.g_conv_dim, self.gLayerNum, self.resNum, self.n_classes, self.skipNum, self.skipRatio, self.GEncActName, self.GSkipActName, self.GDecActName, self.GOutActName).to(self.device) self.GlobalG.load_state_dict( torch.load(os.path.join( self.model_save_path, '{}_LocalG.pth'.format(self.chechpoint_step)), map_location=self.device)) print('loaded trained models (step: {})..!'.format( self.chechpoint_step)) self.Classifier = FaceAttributesClassifier(attr_dim=self.n_classes).to( self.device) self.Classifier.load_state_dict( torch.load(self.ACCModel, map_location=self.device)) print("load classifer model successful!") def getNonConflictingLablels(self, orignalLabel, class_n): fixeDelta = torch.zeros_like(orignalLabel) labelSize = orignalLabel.size()[0] outlabel = torch.zeros_like(orignalLabel) for index in range(labelSize): if orignalLabel[index, class_n] == 0: fixeDelta[index, class_n] = 1 outlabel[index, class_n] = 1 # if index == 0 and fixedLabel[index1,0,1]==1: # from non-blad to blad, so we need to remove bangs in the same time # fixeDelta[index,index1,:,1] = -1 if class_n == 0: if orignalLabel[index, 1] == 1: fixeDelta[index, 1] = -1 if class_n == 1: if orignalLabel[index, 0] == 1: fixeDelta[index, 0] = -1 if class_n == 2: if orignalLabel[index, 3] == 1: fixeDelta[index, 3] = -1 if orignalLabel[index, 4] == 1: fixeDelta[index, 4] = -1 if class_n == 3: if orignalLabel[index, 2] == 1: fixeDelta[index, 2] = -1 if orignalLabel[index, 4] == 1: fixeDelta[index, 4] = -1 if class_n == 4: if orignalLabel[index, 2] == 1: fixeDelta[index, 2] = -1 if orignalLabel[index, 3] == 1: fixeDelta[index, 3] = -1 else: fixeDelta[index, class_n] = -1 outlabel[index, class_n] = -1 if class_n == 2: fixeDelta[index, 3] = 1 # translate black hair to blond hair if class_n == 3: fixeDelta[index, 4] = 1 # translate blond hair to Brown Hair if class_n == 4: fixeDelta[index, 3] = 1 return fixeDelta, outlabel def test(self): # Start time start_time = time.time() data_loader = getLoader(self.image_path, self.attributes_path, self.selected_attrs, self.imCropSize, self.imsize, self.batch_size, dataset=self.dataset, num_workers=0, toPatch=False, mode='test', microPatchSize=0) # Fixed input for debugging # BatchText = BatchTensorClass(30,[10,10]) data_iter = iter(data_loader) # BatchText1 = BatchTensorClass(30,[60,10],color='blue') total = self.total_images self.GlobalG.eval() self.Classifier.eval() acc_cnt_generate = np.zeros([self.n_classes]) total_psnr = 0.0 with torch.no_grad(): for iii in tqdm(range(total // self.batch_size)): try: realImages, labelOriginal = next(data_iter) except: data_iter = iter(data_loader) realImages, labelOriginal = next(data_iter) res = realImages realImages = realImages.to(self.device) labelOriginal = labelOriginal.to(self.device) for index in range(self.n_classes + 1): if index < (self.n_classes): templabel, gt_label = self.getNonConflictingLablels( labelOriginal, index) templabel = templabel.to( self.device) * 2 * self.thres_int xFake = self.GlobalG(realImages, templabel) wocao = self.Classifier(xFake) pred_laebl = torch.round(wocao) * 2 - 1 # gtpre = torch.round(self.Classifier(realImages))*2-1 # print(self.selected_attrs[index]) # self.testTable.add_row(labelOriginal[:,index].cpu().numpy()) # self.testTable.add_row(gtpre[:,index].cpu().numpy()) # self.testTable.add_row(templabel[:,index].cpu().numpy()) # self.testTable.add_row(pred_laebl[:,index].cpu().numpy()) # self.testTable.add_row(wocao[:,index].cpu().numpy()) # print(self.testTable) # self.testTable.clear_rows() # self.reporter.writeInfo() # gtpre = torch.round(self.Classifier(realImages))*2-1 # print("getpre:") # print(gtpre) # print("prelabel:") # print(labelOriginal) # print("prelabel:") # print(pred_laebl) pre = pred_laebl[:, index] gt = gt_label[:, index].to(self.device) acc_generate = (pre == gt).cpu().numpy() acc_cnt_generate[index] += np.sum(np.sum(acc_generate, axis=0), axis=0) if self.save_testimages: imgSamples = xFake.cpu() res = torch.cat([res, imgSamples], 0) else: recLabel = torch.zeros_like(labelOriginal) xFake = self.GlobalG(realImages, recLabel) _, meanPSNR = PSNR(xFake.cpu(), realImages.cpu()) total_psnr += (meanPSNR * float(labelOriginal.size()[0])) if self.save_testimages: imgSamples = xFake.cpu() res = torch.cat([res, imgSamples], 0) if self.save_testimages: save_image(denorm(res.data), os.path.join(self.testimagepath, '{}_fake.png'.format(iii + 1)), nrow=self.batch_size) # return tabledata = {"acc": [], 'psnr': 0.0, 'step': 0} tabledata['step'] = self.chechpoint_step total_psnr = total_psnr / float(total) acc_gen = acc_cnt_generate / (total) acc = ["%.3f" % x for x in acc_gen] meanAcc = sum(acc_gen) / 13.0 acc += ["%.3f" % meanAcc] tabledata['acc'] = acc self.printTable.add_row(acc) print("The %s model logs:" % self.version) self.reporter.writeInfo("Batch Size: %d" % self.batch_size) self.reporter.writeInfo("Total images: %d" % total) self.reporter.writeInfo("The %s model logs:" % self.version) print(self.printTable) self.reporter.writeInfo("\n" + self.printTable.__str__()) print("The mean PSNR is %.2f" % total_psnr) tabledata['psnr'] = total_psnr self.reporter.writeInfo("The mean PSNR is %.2f" % total_psnr) bestResultPath = os.path.join(self.testRoot, self.version, 'bestResult.json') bestResultCsv = os.path.join(self.testRoot, self.version, 'bestResult.csv') ResultCsv = os.path.join(self.testRoot, self.version, 'step%d_Result.csv' % self.chechpoint_step) attrname = [ "Bald", "Bangs", "Black", "Blond", "Brown", "Eyebrows", "Glass", "Male", "Mouth", "Mustache", "NoBeard", "Pale", "Young", "Average" ] if os.path.exists(bestResultPath): with open(bestResultPath, 'r') as cf: score = cf.read() score = json.loads(score) if score['acc'][13] < acc[13]: with open(bestResultPath, 'w') as cf: scorejson = json.dumps(tabledata) cf.writelines(scorejson) import csv headers = ['Name', 'Arch', 'Score'] rows = [] for i in range(14): rows.append({ "Name": attrname[i], "Arch": self.ModelFigureName, "Score": acc[i], }) with open(bestResultCsv, 'w', newline='') as f: f_csv = csv.DictWriter(f, headers) f_csv.writeheader() f_csv.writerows(rows) else: with open(bestResultPath, 'w') as cf: scorejson = json.dumps(tabledata) cf.writelines(scorejson) import csv headers = ['Name', 'Arch', 'Score'] rows = [] for i in range(14): rows.append({ "Name": attrname[i], "Arch": self.ModelFigureName, "Score": acc[i], }) with open(ResultCsv, 'w', newline='') as f: f_csv = csv.DictWriter(f, headers) f_csv.writeheader() f_csv.writerows(rows) elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) print("Elapsed [{}]".format(elapsed))