def main(): global args args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu dataset_base_path = path.join(args.base_path, "dataset", "celeba") image_base_path = path.join(dataset_base_path, "img_align_celeba") split_dataset_path = path.join(dataset_base_path, "Eval", "list_eval_partition.txt") with open(split_dataset_path, "r") as f: split_annotation = f.read().splitlines() # create the data name list for train,test and valid train_data_name_list = [] test_data_name_list = [] valid_data_name_list = [] for item in split_annotation: item = item.split(" ") if item[1] == '0': train_data_name_list.append(item[0]) elif item[1] == '1': valid_data_name_list.append(item[0]) else: test_data_name_list.append(item[0]) attribute_annotation_dict = None if args.need_label: attribute_annotation_path = path.join(dataset_base_path, "Anno", "list_attr_celeba.txt") with open(attribute_annotation_path, "r") as f: attribute_annotation = f.read().splitlines() attribute_annotation = attribute_annotation[2:] attribute_annotation_dict = {} for item in attribute_annotation: img_name, attribute = item.split(" ", 1) attribute = tuple([eval(attr) for attr in attribute.split(" ") if attr != ""]) assert len(attribute) == 40, "the attribute of item {} is not equal to 40".format(img_name) attribute_annotation_dict[img_name] = attribute discriminator = Discriminator(num_channel=args.num_channel, num_feature=args.dnf, data_parallel=args.data_parallel).cuda() generator = Decoder(latent_dim=args.latent_dim, num_feature=args.gnf, num_channel=args.num_channel, data_parallel=args.data_parallel).cuda() input("Begin the {} time's training, the train dataset has {} images and the valid has {} images".format( args.train_time, len(train_data_name_list), len(valid_data_name_list))) d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta1, 0.999)) g_optimizer = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, 0.999)) writer_log_dir = "{}/DCGAN/runs/train_time:{}".format(args.base_path, args.train_time) # Here we implement the resume part if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) if not args.not_resume_arg: args = checkpoint['args'] args.start_epoch = checkpoint['epoch'] discriminator.load_state_dict(checkpoint["discriminator_state_dict"]) generator.load_state_dict(checkpoint["generator_state_dict"]) d_optimizer.load_state_dict(checkpoint['discriminator_optimizer']) g_optimizer.load_state_dict(checkpoint['generator_optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: raise FileNotFoundError("Checkpoint Resume File {} Not Found".format(args.resume)) else: if os.path.exists(writer_log_dir): flag = input("DCGAN train_time:{} will be removed, input yes to continue:".format( args.train_time)) if flag == "yes": shutil.rmtree(writer_log_dir, ignore_errors=True) writer = SummaryWriter(log_dir=writer_log_dir) # Here we just use the train dset in training train_dset = CelebADataset(base_path=image_base_path, data_name_list=train_data_name_list, image_size=args.image_size, label_dict=attribute_annotation_dict) train_dloader = DataLoader(dataset=train_dset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) criterion = nn.BCELoss() for epoch in range(args.start_epoch, args.epochs): train(train_dloader, generator, discriminator, g_optimizer, d_optimizer, criterion, writer, epoch) # save parameters save_checkpoint({ 'epoch': epoch + 1, 'args': args, "discriminator_state_dict": discriminator.state_dict(), "generator_state_dict": generator.state_dict(), 'discriminator_optimizer': d_optimizer.state_dict(), 'generator_optimizer': g_optimizer.state_dict() })
class Trainer: def __init__(self, config, data_loader): self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") self.in_channels = config.in_channels self.out_channels = config.out_channels self.nf = config.nf self.storing_channels = config.storing_channels self.lr = config.lr self.b1 = config.b1 self.b2 = config.b2 self.weight_decay = config.weight_decay self.decay_epoch = config.decay_epoch self.content_loss_factor = config.content_loss_factor self.perceptual_loss_factor = config.perceptual_loss_factor self.generator_loss_factor = config.generator_loss_factor self.discriminator_loss_factor = config.discriminator_loss_factor self.penalty_loss_factor = config.penalty_loss_factor self.eta = config.eta self.checkpoint_dir = config.checkpoint_dir self.sample_dir = config.sample_dir self.epoch = config.epoch self.num_epoch = config.num_epoch self.image_size = config.image_size self.data_loader = data_loader self.content_losses = AverageMeter() self.generator_losses = AverageMeter() self.perceptual_losses = AverageMeter() self.discriminator_losses = AverageMeter() self.ae_losses = AverageMeter() self.visdom = Visdom() self.build_model() def train(self): optimizer_ae = Adam(chain(self.Encoder.parameters(), self.Decoder.parameters()), self.lr, betas=(self.b1, self.b2), weight_decay=self.weight_decay) optimizer_discriminator = Adam(self.Disciminator.parameters(), self.lr, betas=(self.b1, self.b2), weight_decay=self.weight_decay) lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer_ae, LambdaLR(self.num_epoch, self.epoch, self.decay_epoch).step) total_step = len(self.data_loader) perceptual_criterion = PerceptualLoss().to(self.device) content_criterion = nn.L1Loss().to(self.device) adversarial_criterion = nn.BCELoss().to(self.device) self.Encoder.train() self.Decoder.train() content_losses = AverageMeter() generator_losses = AverageMeter() perceptual_losses = AverageMeter() discriminator_losses = AverageMeter() ae_losses = AverageMeter() lr_window = create_vis_plot('Epoch', 'Learning rate', 'Learning rate') loss_window = create_vis_plot('Epoch', 'Loss', 'Total Loss') generator_loss_window = create_vis_plot('Epoch', 'Loss', 'Generator Loss') discriminator_loss_window = create_vis_plot('Epoch', 'Loss', 'Discriminator Loss') content_loss_window = create_vis_plot('Epoch', 'Loss', 'Content Loss') perceptual_loss_window = create_vis_plot('Epoch', 'Loss', 'Perceptual Loss') if not os.path.exists(self.sample_dir): os.makedirs(self.sample_dir) if not os.path.exists(self.checkpoint_dir): os.makedirs(self.checkpoint_dir) for epoch in range(self.epoch, self.num_epoch): content_losses.reset() perceptual_losses.reset() generator_losses.reset() ae_losses.reset() discriminator_losses.reset() for step, images in enumerate(self.data_loader): images = images.to(self.device) real_labels = torch.ones((images.size(0), 1)).to(self.device) fake_labels = torch.zeros((images.size(0), 1)).to(self.device) encoded_image = self.Encoder(images) binary_decoded_image = paq.compress( encoded_image.cpu().detach().numpy().tobytes()) # encoded_image = paq.decompress(binary_decoded_image) # # encoded_image = torch.from_numpy(np.frombuffer(encoded_image, dtype=np.float32) # .reshape(-1, self.storing_channels, self.image_size // 8, # self.image_size // 8)).to(self.device) decoded_image = self.Decoder(encoded_image) content_loss = content_criterion(images, decoded_image) perceptual_loss = perceptual_criterion(images, decoded_image) generator_loss = adversarial_criterion( self.Disciminator(decoded_image), real_labels) # generator_loss = -self.Disciminator(decoded_image).mean() ae_loss = content_loss * self.content_loss_factor + perceptual_loss * self.perceptual_loss_factor + \ generator_loss * self.generator_loss_factor content_losses.update(content_loss.item()) perceptual_losses.update(perceptual_loss.item()) generator_losses.update(generator_loss.item()) ae_losses.update(ae_loss.item()) optimizer_ae.zero_grad() ae_loss.backward(retain_graph=True) optimizer_ae.step() interpolated_image = self.eta * images + ( 1 - self.eta) * decoded_image gravity_penalty = self.Disciminator(interpolated_image).mean() real_loss = adversarial_criterion(self.Disciminator(images), real_labels) fake_loss = adversarial_criterion( self.Disciminator(decoded_image), fake_labels) discriminator_loss = (real_loss + fake_loss) * self.discriminator_loss_factor / 2 +\ gravity_penalty * self.penalty_loss_factor # discriminator_loss = self.Disciminator(decoded_image).mean() - self.Disciminator(images).mean() + \ # gravity_penalty * self.penalty_loss_factor optimizer_discriminator.zero_grad() discriminator_loss.backward(retain_graph=True) optimizer_discriminator.step() discriminator_losses.update(discriminator_loss.item()) if step % 100 == 0: print( f"[Epoch {epoch}/{self.num_epoch}] [Batch {step}/{total_step}] [Learning rate {get_lr(optimizer_ae)}] " f"[Content {content_loss:.4f}] [Perceptual {perceptual_loss:.4f}] [Gan {generator_loss:.4f}]" f"[Discriminator {discriminator_loss:.4f}]") save_image( torch.cat([images, decoded_image], dim=2), os.path.join(self.sample_dir, f"Sample-epoch-{epoch}-step-{step}.png")) update_vis_plot(epoch, ae_losses.avg, loss_window, 'append') update_vis_plot(epoch, generator_losses.avg, generator_loss_window, 'append') update_vis_plot(epoch, discriminator_losses.avg, discriminator_loss_window, 'append') update_vis_plot(epoch, content_losses.avg, content_loss_window, 'append') update_vis_plot(epoch, perceptual_losses.avg, perceptual_loss_window, 'append') update_vis_plot(epoch, get_lr(optimizer_ae), lr_window, 'append') lr_scheduler.step() torch.save( self.Encoder.state_dict(), os.path.join(self.checkpoint_dir, f"Encoder-{epoch}.pth")) torch.save( self.Decoder.state_dict(), os.path.join(self.checkpoint_dir, f"Decoder-{epoch}.pth")) torch.save( self.Disciminator.state_dict(), os.path.join(self.checkpoint_dir, f"Discriminator-{epoch}.pth")) def build_model(self): self.Encoder = Encoder(self.in_channels, self.storing_channels, self.nf).to(self.device) self.Decoder = Decoder(self.storing_channels, self.in_channels, self.nf).to(self.device) self.Disciminator = Discriminator().to(self.device) self.load_model() def load_model(self): print(f"[*] Load Model in {self.checkpoint_dir}") if not os.path.exists(self.checkpoint_dir): os.makedirs(self.checkpoint_dir) encoder_parameter_names = glob( os.path.join(self.checkpoint_dir, f"Encoder-{self.epoch - 1}.pth")) decoder_parameter_names = glob( os.path.join(self.checkpoint_dir, f"Decoder-{self.epoch - 1}.pth")) discriminator_parameters_names = glob( os.path.join(self.checkpoint_dir, f"Discriminator-{self.epoch - 1}.pth")) if not encoder_parameter_names or not decoder_parameter_names or not discriminator_parameters_names: print(f"[!] There is no parameter in {self.checkpoint_dir}") return self.Encoder.load_state_dict(torch.load(encoder_parameter_names[0])) self.Decoder.load_state_dict(torch.load(decoder_parameter_names[0])) self.Disciminator.load_state_dict( torch.load(discriminator_parameters_names[0]))