Beispiel #1
0
    def __init__(self, config):
        self.generator = Generator()

        print(self.generator)
        self.l1_loss_fn =  nn.L1Loss()
        # self.l1_loss_fn =  nn.MSELoss()
        self.opt_g = torch.optim.Adam(self.generator.parameters(),
            lr=config.lr, betas=(config.beta1, config.beta2))
        

        if config.dataset == 'grid':
            self.dataset = VaganDataset(config.dataset_dir, train=config.is_train)
        elif config.dataset == 'lrw':
            self.dataset = LRWdataset(config.dataset_dir, train=config.is_train)
    
        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=12,
                                      shuffle=True, drop_last=True)
        data_iter = iter(self.data_loader)
        data_iter.next()
        self.ones = Variable(torch.ones(config.batch_size), requires_grad=False)
        self.zeros = Variable(torch.zeros(config.batch_size), requires_grad=False)

        if config.cuda:

            device_ids = [int(i) for i in config.device_ids.split(',')]
            self.generator     = nn.DataParallel(self.generator.cuda(), device_ids=device_ids)
            self.l1_loss_fn = self.l1_loss_fn.cuda()
            self.ones          = self.ones.cuda()
            self.zeros         = self.zeros.cuda()

        self.config = config
        self.start_epoch = 0
Beispiel #2
0
    def __init__(self, config):
        self.generator = Generator(config.cuda)

        print(self.generator)
        self.bce_loss_fn = nn.BCELoss()
        self.l1_loss_fn = nn.L1Loss()
        self.mse_loss_fn = nn.MSELoss()

        self.opt_g = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                             self.generator.parameters()),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))

        if config.dataset == 'lrw':
            self.dataset = LRWdataset(config.dataset_dir,
                                      train=config.is_train)

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=config.num_thread,
                                      pin_memory=True,
                                      shuffle=True,
                                      drop_last=True)
        data_iter = iter(self.data_loader)
        data_iter.next()
        self.ones = Variable(torch.ones(config.batch_size),
                             requires_grad=False)
        self.zeros = Variable(torch.zeros(config.batch_size),
                              requires_grad=False)
        ########multiple GPU####################
        # if config.cuda:
        #     device_ids = [int(i) for i in config.device_ids.split(',')]
        #     self.generator     = nn.DataParallel(self.generator.cuda(), device_ids=device_ids)
        #     self.discriminator = nn.DataParallel(self.discriminator.cuda(), device_ids=device_ids)
        #     self.bce_loss_fn   = self.bce_loss_fn.cuda()
        #     self.mse_loss_fn   = self.mse_loss_fn.cuda()
        #     self.ones          = self.ones.cuda()
        #########single GPU#######################

        if config.cuda:
            device_ids = [int(i) for i in config.device_ids.split(',')]
            self.generator = self.generator.cuda()

            self.bce_loss_fn = self.bce_loss_fn.cuda()
            self.mse_loss_fn = self.mse_loss_fn.cuda()
            self.ones = self.ones.cuda()
            self.zeros = self.zeros.cuda()

        self.config = config
        self.start_epoch = 0

        if config.load_model:
            self.start_epoch = config.start_epoch
            self.load(config.pretrained_dir, config.pretrained_epoch)
Beispiel #3
0
    def __init__(self, config):
        self.generator = Generator()
        self.discriminator = Discriminator()
        self.flow_encoder = FlowEncoder(flow_type='flownet')
        self.audio_deri_encoder = AudioDeriEncoder()
        self.encoder = Encoder()
        self.encoder.load_state_dict(
            torch.load(
                '/mnt/disk0/dat/lchen63/grid/model/model_embedding/encoder_6.pth'
            ))
        for param in self.encoder.parameters():
            param.requires_grad = False
        if config.load_model:
            flownet = flownets()
        else:
            flownet = flownets(config.flownet_pth)
        self.flows_gen = FlowsGen(flownet)

        print(self.generator)
        print(self.discriminator)
        self.bce_loss_fn = nn.BCELoss()
        self.mse_loss_fn = nn.MSELoss()
        self.l1_loss_fn = nn.L1Loss()
        # self.cosine_loss_fn = CosineSimiLoss(config.batch_size)
        self.cosine_loss_fn = nn.CosineEmbeddingLoss()

        self.opt_g = torch.optim.Adam(
            [p for p in self.generator.parameters() if p.requires_grad],
            lr=config.lr,
            betas=(config.beta1, config.beta2))
        self.opt_d = torch.optim.Adam(self.discriminator.parameters(),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))
        self.opt_corr = torch.optim.Adam(itertools.chain(
            self.flow_encoder.parameters(),
            self.audio_deri_encoder.parameters()),
                                         lr=config.lr_corr,
                                         betas=(0.9, 0.999),
                                         weight_decay=0.0005)
        self.opt_flownet = torch.optim.Adam(self.flows_gen.parameters(),
                                            lr=config.lr_flownet,
                                            betas=(0.9, 0.999),
                                            weight_decay=4e-4)

        if config.dataset == 'grid':
            self.dataset = VaganDataset(config.dataset_dir,
                                        train=config.is_train)
        elif config.dataset == 'lrw':
            self.dataset = LRWdataset(config.dataset_dir,
                                      train=config.is_train)

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=config.num_thread,
                                      shuffle=True,
                                      drop_last=True)

        self.ones = Variable(torch.ones(config.batch_size))
        self.zeros = Variable(torch.zeros(config.batch_size))
        self.one_corr = Variable(torch.ones(config.batch_size))

        if config.cuda:
            device_ids = [int(i) for i in config.device_ids.split(',')]
            self.generator = nn.DataParallel(self.generator.cuda(),
                                             device_ids=device_ids)
            self.discriminator = nn.DataParallel(self.discriminator.cuda(),
                                                 device_ids=device_ids)
            self.flow_encoder = nn.DataParallel(self.flow_encoder.cuda(),
                                                device_ids=device_ids)
            self.audio_deri_encoder = nn.DataParallel(
                self.audio_deri_encoder.cuda(), device_ids=device_ids)
            self.encoder = nn.DataParallel(self.encoder.cuda(),
                                           device_ids=device_ids)
            self.flows_gen = nn.DataParallel(self.flows_gen.cuda(),
                                             device_ids=device_ids)
            self.bce_loss_fn = self.bce_loss_fn.cuda()
            self.mse_loss_fn = self.mse_loss_fn.cuda()
            self.l1_loss_fn = self.l1_loss_fn.cuda()
            self.cosine_loss_fn = self.cosine_loss_fn.cuda()
            self.ones = self.ones.cuda(async=True)
            self.zeros = self.zeros.cuda(async=True)
            self.one_corr = self.one_corr.cuda(async=True)

        self.config = config
        if config.load_model:
            self.start_epoch = config.start_epoch
            self.load(config.model_dir, self.start_epoch - 1)
        else:
            self.start_epoch = 0
Beispiel #4
0
    def __init__(self, config):
        self.generator = Generator()
        self.discriminator = Discriminator2()
        # self.encoder = Encoder()
        # if config.perceptual:
        #     self.encoder.load_state_dict(torch.load('/mnt/disk1/dat/lchen63/lrw/model/embedding/encoder3_0.pth'))
        #     for param in self.encoder.parameters():
        #         param.requires_grad = False

        print(self.generator)
        self.bce_loss_fn = nn.BCELoss()
        self.l1_loss_fn = nn.L1Loss()
        self.mse_loss_fn = nn.MSELoss()

        self.opt_g = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                             self.generator.parameters()),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))
        self.opt_d = torch.optim.Adam(self.discriminator.parameters(),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))

        if config.dataset == 'lrw':
            self.dataset = LRWdataset(config.dataset_dir,
                                      train=config.is_train)

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=config.num_thread,
                                      pin_memory=True,
                                      shuffle=True,
                                      drop_last=True)
        data_iter = iter(self.data_loader)
        data_iter.next()
        self.ones = Variable(torch.ones(config.batch_size),
                             requires_grad=False)
        self.zeros = Variable(torch.zeros(config.batch_size),
                              requires_grad=False)
        ########multiple GPU####################
        # if config.cuda:
        #     device_ids = [int(i) for i in config.device_ids.split(',')]
        #     self.generator     = nn.DataParallel(self.generator.cuda(), device_ids=device_ids)
        #     self.discriminator = nn.DataParallel(self.discriminator.cuda(), device_ids=device_ids)
        #     self.bce_loss_fn   = self.bce_loss_fn.cuda()
        #     self.mse_loss_fn   = self.mse_loss_fn.cuda()
        #     self.ones          = self.ones.cuda()
        #########single GPU#######################

        if config.cuda:
            device_ids = [int(i) for i in config.device_ids.split(',')]
            self.generator = self.generator.cuda()
            self.discriminator = self.discriminator.cuda()
            # self.encoder = self.encoder.cuda()
            self.bce_loss_fn = self.bce_loss_fn.cuda()
            self.mse_loss_fn = self.mse_loss_fn.cuda()
            self.ones = self.ones.cuda()
            self.zeros = self.zeros.cuda()

        self.config = config
        self.start_epoch = 0
        init_weights(self.generator, init_type='kaiming')
        init_weights(self.discriminator, init_type='kaiming')
        if config.load_model:
            self.start_epoch = config.start_epoch
            self.load(config.pretrained_dir, config.pretrained_epoch)
Beispiel #5
0
def test():
    os.environ["CUDA_VISIBLE_DEVICES"] = config.device_ids
    generator = VG_net()
    if config.cuda:
        generator = generator.cuda()

    state_dict = multi2single(config.model_name, 1)
    generator.load_state_dict(state_dict)
    print('load pretrained [{}]'.format(config.model_name))
    generator.eval()
    dataset = LRWdataset(config.dataset_dir, train='test')
    data_loader = DataLoader(dataset,
                             batch_size=config.batch_size,
                             num_workers=config.num_thread,
                             shuffle=True,
                             drop_last=True)
    data_iter = iter(data_loader)
    data_iter.next()

    if not os.path.exists(config.sample_dir):
        os.mkdir(config.sample_dir)
    if not os.path.exists(os.path.join(config.sample_dir, 'fake')):
        os.mkdir(os.path.join(config.sample_dir, 'fake'))
    if not os.path.exists(os.path.join(config.sample_dir, 'real')):
        os.mkdir(os.path.join(config.sample_dir, 'real'))
    if not os.path.exists(os.path.join(config.sample_dir, 'color')):
        os.mkdir(os.path.join(config.sample_dir, 'color'))
    if not os.path.exists(os.path.join(config.sample_dir, 'single')):
        os.mkdir(os.path.join(config.sample_dir, 'single'))
    generator.eval()
    for step, (example_img, example_landmark, real_im,
               landmarks) in enumerate(data_loader):
        with torch.no_grad():
            if step == 43:
                break
            if config.cuda:
                real_im = Variable(real_im.float()).cuda()
                example_img = Variable(example_img.float()).cuda()
                landmarks = Variable(landmarks.float()).cuda()
                example_landmark = Variable(example_landmark.float()).cuda()
            fake_im, atts, colors, lmark_att = generator(
                example_img, landmarks, example_landmark)

            fake_store = fake_im.data.contiguous().view(
                config.batch_size * 16, 3, 128, 128)
            real_store = real_im.data.contiguous().view(
                config.batch_size * 16, 3, 128, 128)
            atts_store = atts.data.contiguous().view(config.batch_size * 16, 1,
                                                     128, 128)
            colors_store = colors.data.contiguous().view(
                config.batch_size * 16, 3, 128, 128)
            lmark_att = bilinear_interpolate_torch_gridsample(lmark_att)
            lmark_att_store = lmark_att.data.contiguous().view(
                config.batch_size * 16, 1, 128, 128)
            torchvision.utils.save_image(atts_store,
                                         "{}color/att_{}.png".format(
                                             config.sample_dir, step),
                                         normalize=False,
                                         range=[0, 1])
            torchvision.utils.save_image(lmark_att_store,
                                         "{}color/lmark_att_{}.png".format(
                                             config.sample_dir, step),
                                         normalize=False,
                                         range=[0, 1])
            torchvision.utils.save_image(colors_store,
                                         "{}color/color_{}.png".format(
                                             config.sample_dir, step),
                                         normalize=True)
            torchvision.utils.save_image(fake_store,
                                         "{}fake/fake_{}.png".format(
                                             config.sample_dir, step),
                                         normalize=True)

            torchvision.utils.save_image(real_store,
                                         "{}real/real_{}.png".format(
                                             config.sample_dir, step),
                                         normalize=True)

            fake_store = fake_im.data.cpu().permute(0, 1, 3, 4, 2).view(
                config.batch_size * 16, 128, 128, 3).numpy()
            real_store = real_im.data.cpu().permute(0, 1, 3, 4, 2).view(
                config.batch_size * 16, 128, 128, 3).numpy()

            print(step)
            print(fake_store.shape)
            for inx in range(config.batch_size * 16):
                scipy.misc.imsave(
                    "{}single/fake_{}.png".format(
                        config.sample_dir,
                        step * config.batch_size * 16 + inx), fake_store[inx])
                scipy.misc.imsave(
                    "{}single/real_{}.png".format(
                        config.sample_dir,
                        step * config.batch_size * 16 + inx), real_store[inx])
Beispiel #6
0
    def __init__(self, config):
        self.generator = Generator()
        self.discriminator = Discriminator()
        self.images_encoder = FlowEncoder('embedding')
        self.audio_encoder = AudioDeriEncoder(need_deri=False)

        print(self.generator)
        print(self.discriminator)
        self.bce_loss_fn = nn.BCELoss()
        self.mse_loss_fn = nn.MSELoss()
        # self.cosine_loss_fn = CosineSimiLoss(config.batch_size)
        self.cosine_loss_fn = nn.CosineEmbeddingLoss()

        self.opt_g = torch.optim.Adam(
            [p for p in self.generator.parameters() if p.requires_grad],
            lr=config.lr,
            betas=(config.beta1, config.beta2))
        self.opt_d = torch.optim.Adam(self.discriminator.parameters(),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))
        self.opt_corr = torch.optim.Adam(itertools.chain(
            self.images_encoder.parameters(), self.audio_encoder.parameters()),
                                         lr=config.lr_corr,
                                         betas=(0.9, 0.999),
                                         weight_decay=0.0005)

        if config.dataset == 'grid':
            self.dataset = VaganDataset(config.dataset_dir,
                                        train=config.is_train)
        elif config.dataset == 'lrw':
            self.dataset = LRWdataset(config.dataset_dir,
                                      train=config.is_train)

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=config.num_thread,
                                      shuffle=True,
                                      drop_last=True)

        self.ones = Variable(torch.ones(config.batch_size))
        self.zeros = Variable(torch.zeros(config.batch_size))
        self.one_corr = Variable(torch.ones(config.batch_size))

        if config.cuda:
            device_ids = [int(i) for i in config.device_ids.split(',')]
            if len(device_ids) == 1:
                self.generator = self.generator.cuda(device_ids[0])
                self.discriminator = self.discriminator.cuda(device_ids[0])
                self.images_encoder = self.images_encoder.cuda(device_ids[0])
                self.audio_encoder = self.audio_encoder.cuda(device_ids[0])
            else:
                self.generator = nn.DataParallel(self.generator.cuda(),
                                                 device_ids=device_ids)
                self.discriminator = nn.DataParallel(self.discriminator.cuda(),
                                                     device_ids=device_ids)
                self.images_encoder = nn.DataParallel(
                    self.images_encoder.cuda(), device_ids=device_ids)
                self.audio_encoder = nn.DataParallel(self.audio_encoder.cuda(),
                                                     device_ids=device_ids)
            self.bce_loss_fn = self.bce_loss_fn.cuda()
            self.mse_loss_fn = self.mse_loss_fn.cuda()
            self.cosine_loss_fn = self.cosine_loss_fn.cuda()
            self.ones = self.ones.cuda(async=True)
            self.zeros = self.zeros.cuda(async=True)
            self.one_corr = self.one_corr.cuda(async=True)

        self.config = config
        self.start_epoch = 0
Beispiel #7
0
    def __init__(self, config):
        self.generator = Generator()
        self.discriminator = Discriminator()
        flownet = flownets(config.flownet_pth)
        print('flownet_pth: {}'.format(config.flownet_pth))
        self.flows_gen = FlowsGen(flownet)

        print(self.generator)
        print(self.discriminator)
        self.bce_loss_fn = nn.BCELoss()
        self.mse_loss_fn = nn.MSELoss()

        self.opt_g = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                             self.generator.parameters()),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))
        self.opt_d = torch.optim.Adam(self.discriminator.parameters(),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))
        self.opt_flow = torch.optim.Adam(self.flows_gen.parameters(),
                                         lr=config.lr_flownet,
                                         betas=(0.9, 0.999),
                                         weight_decay=4e-4)

        if config.dataset == 'grid':
            self.dataset = VaganDataset(config.dataset_dir,
                                        train=config.is_train)
        elif config.dataset == 'lrw':
            self.dataset = LRWdataset(config.dataset_dir,
                                      train=config.is_train)

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=config.num_thread,
                                      shuffle=True,
                                      drop_last=True)
        self.ones = Variable(torch.ones(config.batch_size),
                             requires_grad=False)
        self.zeros = Variable(torch.zeros(config.batch_size),
                              requires_grad=False)

        device_ids = [int(i) for i in config.device_ids.split(',')]
        if config.cuda:
            if len(device_ids) == 1:
                self.generator = self.generator.cuda(device_ids[0])
                self.discriminator = self.discriminator.cuda(device_ids[0])
                self.flows_gen = self.flows_gen.cuda(device_ids[0])
            else:
                self.generator = nn.DataParallel(self.generator.cuda(),
                                                 device_ids=device_ids)
                self.discriminator = nn.DataParallel(self.discriminator.cuda(),
                                                     device_ids=device_ids)
                self.flows_gen = nn.DataParallel(self.flows_gen.cuda(),
                                                 device_ids=device_ids)
            self.bce_loss_fn = self.bce_loss_fn.cuda()
            self.mse_loss_fn = self.mse_loss_fn.cuda()
            self.ones = self.ones.cuda()
            self.zeros = self.zeros.cuda()

        self.config = config
        self.start_epoch = 0