示例#1
0
    def __init__(self, config):
        super(InpaintingModel, self).__init__()
        self.name = 'InpaintingModel'
        self.config = config
        self.iteration = 0
        self.gen_weights_path = os.path.join(config.save_model_dir, 'InpaintingModel_gen.pth')
        self.dis_weights_path = os.path.join(config.save_model_dir, 'InpaintingModel_dis.pth')

        self.generator = FRRNet()
        self.discriminator = Discriminator(in_channels=3, use_sigmoid=True)

        if torch.cuda.device_count() > 1:
            device_ids=range(torch.cuda.device_count())
            self.generator = nn.DataParallel(self.generator, device_ids)
            self.discriminator = nn.DataParallel(self.discriminator , device_ids)

        self.l1_loss = nn.L1Loss()
        self.l2_loss = nn.MSELoss()
        self.style_loss = StyleLoss()
        self.adversarial_loss = AdversarialLoss()  

        self.gen_optimizer = optim.Adam(
            params=self.generator.parameters(),
            lr=float(config.LR),
            betas=(0.0, 0.9)
        )

        self.dis_optimizer = optim.Adam(
            params=self.discriminator.parameters(),
            lr=float(config.LR) * float(config.D2G_LR),
            betas=(0.0, 0.9)
        )
示例#2
0
    def __init__(self, config):
        super().__init__()
        self.config = config
        # generator input: [rgb(3) + edge(1)]
        # discriminator input: rgb(3)
        self.generator = DCGANGenerator(use_spectral_norm=False, net_type="sr")
        self.discriminator = PatchGANDiscriminator(
            in_channels=3, use_sigmoid=config.GAN_LOSS != 'hinge')

        self.l1_loss = nn.L1Loss()
        self.adversarial_loss = AdversarialLoss(type=config.GAN_LOSS)

        self.style_content_loss = StyleContentLoss()

        kernel = np.zeros((self.config.SCALE, self.config.SCALE))
        kernel[0, 0] = 1

        # (out_channels, in_channels/groups, height, width)
        scale_kernel = torch.FloatTensor(np.tile(kernel, (
            3,
            1,
            1,
            1,
        )))
        self.register_buffer('scale_kernel', scale_kernel)

        self.gen_optimizer = optim.Adam(params=self.generator.parameters(),
                                        lr=float(config.LR),
                                        betas=(config.BETA1, config.BETA2))

        self.dis_optimizer = optim.Adam(params=self.discriminator.parameters(),
                                        lr=float(config.LR),
                                        betas=(config.BETA1, config.BETA2))
示例#3
0
    def __init__(self, config):
        super().__init__('EdgeModel', config)

        # generator input: [rgb(3) + edge(1)]
        # discriminator input: (rgb(3) + edge(1))
        generator = EdgeGenerator(use_spectral_norm=True)
        discriminator = Discriminator(
            in_channels=2, use_sigmoid=config.GAN_LOSS != 'hinge')  #4-->2

        if len(config.GPU) > 1:
            generator = nn.DataParallel(generator, config.GPU)
            discriminator = nn.DataParallel(discriminator, config.GPU)

        l1_loss = nn.L1Loss()
        adversarial_loss = AdversarialLoss(type=config.GAN_LOSS)

        self.add_module('generator', generator)
        self.add_module('discriminator', discriminator)

        self.add_module('l1_loss', l1_loss)
        self.add_module('adversarial_loss', adversarial_loss)

        self.gen_optimizer = optim.Adam(params=generator.parameters(),
                                        lr=float(config.LR),
                                        betas=(config.BETA1, config.BETA2))

        self.dis_optimizer = optim.Adam(params=discriminator.parameters(),
                                        lr=float(config.LR),
                                        betas=(config.BETA1, config.BETA2))
示例#4
0
    def __init__(self, config):
        super().__init__()
        self.config = config
        # generator input: [rgb(3) + edge(1)]
        # discriminator input: (rgb(3) + edge(1))
        generator = DCGANGenerator(use_spectral_norm=True, net_type="edge")
        discriminator = PatchGANDiscriminator(
            in_channels=4, use_sigmoid=config.GAN_LOSS != 'hinge')

        l1_loss = nn.L1Loss()
        adversarial_loss = AdversarialLoss(type=config.GAN_LOSS)  #???

        self.add_module('generator', generator)
        self.add_module('discriminator', discriminator)

        self.add_module('l1_loss', l1_loss)
        self.add_module('adversarial_loss', adversarial_loss)

        self.gen_optimizer = optim.Adam(params=generator.parameters(),
                                        lr=float(config.LR),
                                        betas=(config.BETA1, config.BETA2))

        self.dis_optimizer = optim.Adam(params=discriminator.parameters(),
                                        lr=float(config.LR),
                                        betas=(config.BETA1, config.BETA2))
示例#5
0
    def __init__(self, config):
        super(InpaintModel, self).__init__('InpaintModel', config)

        #generator = Generator()
        generator = Generator_SE()
        discriminator = Discriminator(in_channels=3, use_sigmoid=False)

        # data = torch.load('ablation_v0/InpaintModel121_gen.pth')
        # generator.load_state_dict(data['generator'])
        # self.iteration = data['iteration']
        # data = torch.load('ablation_v0/InpaintModel121_dis.pth')
        # discriminator.load_state_dict(data['discriminator'])

        l1_loss = nn.L1Loss()
        adversarial_loss = AdversarialLoss(type='hinge')
        perceptual_loss = PerceptualLoss()
        style_loss = StyleLoss()

        self.add_module('generator', generator)
        self.add_module('discriminator', discriminator)
        self.add_module('l1_loss', l1_loss)
        self.add_module('adversarial_loss', adversarial_loss)
        self.add_module('perceptual_loss', perceptual_loss)
        self.add_module('style_loss', style_loss)

        self.optimizer = optim.Adam(params=generator.parameters(),
                                    lr=config.LR)
        self.dis_optimizer = optim.Adam(params=discriminator.parameters(),
                                        lr=config.LR * config.D2G_LR)
示例#6
0
    def __init__(self, config):
        super().__init__('SRModel', config)

        # generator input: [gray(1) + edge(1)]
        # discriminator input: [gray(1)]
        generator = SRGenerator()
        discriminator = Discriminator(
            in_channels=1, use_sigmoid=config.GAN_LOSS != 'hinge')  # 3-->1

        if len(config.GPU) > 1:
            generator = nn.DataParallel(generator, config.GPU)
            discriminator = nn.DataParallel(discriminator, config.GPU)

        l1_loss = nn.L1Loss()
        content_loss = ContentLoss()
        style_loss = StyleLoss()
        adversarial_loss = AdversarialLoss(type=config.GAN_LOSS)

        kernel = np.zeros((self.config.SCALE, self.config.SCALE))
        kernel[0, 0] = 1
        #kernel_weight = torch.tensor(np.tile(kernel, (3, 1, 1, 1))).float().to(config.DEVICE)     # (out_channels, in_channels/groups, height, width)

        #self.add_module('scale_kernel', kernel_weight)
        #self.scale_kernel = torch.tensor(np.tile(kernel, (1, 1, 1, 1))).float().to(config.DEVICE)  #3-->1

        self.add_module('generator', generator)
        self.add_module('discriminator', discriminator)

        self.add_module('l1_loss', l1_loss)
        self.add_module('content_loss', content_loss)
        self.add_module('style_loss', style_loss)
        self.add_module('adversarial_loss', adversarial_loss)

        self.gen_optimizer = optim.Adam(params=generator.parameters(),
                                        lr=float(config.LR),
                                        betas=(config.BETA1, config.BETA2))

        self.dis_optimizer = optim.Adam(params=discriminator.parameters(),
                                        lr=float(config.LR),
                                        betas=(config.BETA1, config.BETA2))
示例#7
0
文件: train.py 项目: raahii/dcvgan
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config",
        "-c",
        default="configs/default.yml",
        help="training configuration file",
    )
    args = parser.parse_args()

    # parse config yaml
    with open(args.config) as f:
        configs = yaml.load(f, Loader=yaml.FullLoader)
    configs["config_path"] = args.config

    # fix seed
    fix_seed(configs["seed"])

    # initialize logger
    log_path = Path(configs["log_dir"]) / configs["experiment_name"]
    tb_path = Path(configs["tensorboard_dir"]) / configs["experiment_name"]
    logger = Logger(log_path, tb_path)
    logger.debug("(experiment)")
    logger.debug(f"name: {configs['experiment_name']}", 1)
    logger.debug(f"directory {configs['log_dir']}", 1)
    logger.debug(f"tensorboard: {configs['tensorboard_dir']}", 1)
    logger.debug(f"geometric_info: {configs['geometric_info']}", 1)
    logger.debug(f"log_interval: {configs['log_interval']}", 1)
    logger.debug(f"log_samples_interval: {configs['log_samples_interval']}", 1)
    logger.debug(f"snapshot_interval: {configs['snapshot_interval']}", 1)
    logger.debug(f"evaluation_interval: {configs['evaluation_interval']}", 1)

    # loss
    loss: Loss
    if configs["loss"] == "adversarial-loss":
        loss = AdversarialLoss()
    elif configs["loss"] == "hinge-loss":
        loss = HingeLoss()
    else:
        logger.error(f"Specified loss is not supported {configs['loss']}")
        sys.exit(1)
    logger.debug(f"loss: {configs['loss']}", 1)

    # prepare dataset
    dataset = VideoDataset(
        configs["dataset"]["name"],
        Path(configs["dataset"]["path"]),
        eval(f'preprocess_{configs["dataset"]["name"]}_dataset'),
        configs["video_length"],
        configs["image_size"],
        configs["dataset"]["number_limit"],
        geometric_info=configs["geometric_info"]["name"],
    )
    dataloader = VideoDataLoader(
        dataset,
        batch_size=configs["batchsize"],
        num_workers=configs["dataset"]["n_workers"],
        shuffle=True,
        drop_last=True,
        pin_memory=True,
        worker_init_fn=_worker_init_fn,
    )
    logger.debug("(dataset)")
    logger.debug(f"name: {dataset.name}", 1)
    logger.debug(f"size: {len(dataset)}", 1)
    logger.debug(f"batchsize: {dataloader.batch_size}", 1)
    logger.debug(f"workers: {dataloader.num_workers}", 1)

    # prepare models
    ggen = GeometricVideoGenerator(
        configs["ggen"]["dim_z_content"],
        configs["ggen"]["dim_z_motion"],
        configs["geometric_info"]["channel"],
        configs["geometric_info"]["name"],
        configs["ggen"]["ngf"],
        configs["video_length"],
    )

    cgen = ColorVideoGenerator(
        ggen.channel,
        configs["cgen"]["dim_z_color"],
        configs["geometric_info"]["name"],
        configs["cgen"]["ngf"],
        configs["video_length"],
    )

    idis = ImageDiscriminator(
        ggen.channel,
        cgen.channel,
        configs["idis"]["use_noise"],
        configs["idis"]["noise_sigma"],
        configs["idis"]["ndf"],
    )

    vdis = VideoDiscriminator(
        ggen.channel,
        cgen.channel,
        configs["vdis"]["use_noise"],
        configs["vdis"]["noise_sigma"],
        configs["vdis"]["ndf"],
    )

    gdis = GradientDiscriminator(
        ggen.channel,
        cgen.channel,
        configs["gdis"]["use_noise"],
        configs["gdis"]["noise_sigma"],
        configs["gdis"]["ndf"],
    )
    models = {
        "ggen": ggen,
        "cgen": cgen,
        "idis": idis,
        "vdis": vdis,
        "gdis": gdis
    }

    logger.debug("(models)")
    for m in models.values():
        logger.debug(str(m), 1)

    # init weights
    for m in models.values():
        m.apply(util.init_weights)

    # optimizers
    logger.debug("(optimizers)")
    optimizers = {}
    for name, model in models.items():
        lr = configs[name]["optimizer"]["lr"]
        betas = (0.5, 0.999)
        decay = configs[name]["optimizer"]["decay"]
        optimizers[name] = optim.Adam(model.parameters(),
                                      lr=lr,
                                      betas=betas,
                                      weight_decay=decay)
        logger.debug(
            json.dumps(
                {name: {
                    "betas": betas,
                    "lr": lr,
                    "weight_decay": decay
                }}), 1)

    # start training
    trainer = Trainer(dataloader, logger, models, optimizers, loss, configs)
    trainer.train()
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
traindata = TrainDataset(train_data_path, transform)
traindata_loader = DataLoader(traindata, batch_size=batchsize, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
netG = Generator().to(device)
netD = Discriminator().to(device)
optimizerG = optim.Adam(netG.parameters(), lr=learning_rate)
optimizerD = optim.Adam(netD.parameters(), lr=learning_rate)
bce = nn.BCELoss()
contentLoss = ContentLoss().to(device)
adversarialLoss = AdversarialLoss()
# print(netG)
# print(netD)

if not os.path.exists(checkpoint_path):
    os.mkdir(checkpoint_path)

torch.save(netG, 'netG-epoch_000.pth')
for epoch in range(1, epochs + 1):
    for idx, (lr, hr) in enumerate(traindata_loader):
        lr = lr.to(device)
        hr = hr.to(device)

        # 更新判别器
        netD.zero_grad()
        logits_fake = netD(netG(lr).detach())