Esempio n. 1
0
    def __init__(self, config):
        self.encoder = Encoder()

        print(self.encoder)
        self.l1_loss_fn = nn.MSELoss()
        self.mse_loss_fn = nn.MSELoss()
        self.opt_g = torch.optim.Adam(self.encoder.parameters(),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))

        self.dataset = VaganDataset(config.dataset_dir, train=config.is_train)

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=4,
                                      shuffle=True,
                                      drop_last=True)
        data_iter = iter(self.data_loader)
        data_iter.next()

        if config.cuda:

            self.encoder = self.encoder.cuda()
            self.l1_loss_fn = self.l1_loss_fn.cuda()

        self.config = config
        self.start_epoch = 0
Esempio n. 2
0
    def __init__(self, config):
        self.generator = Generator()

        print(self.generator)
        self.l1_loss_fn =  nn.L1Loss()
        # self.l1_loss_fn =  nn.MSELoss()
        self.opt_g = torch.optim.Adam(self.generator.parameters(),
            lr=config.lr, betas=(config.beta1, config.beta2))
        

        if config.dataset == 'grid':
            self.dataset = VaganDataset(config.dataset_dir, train=config.is_train)
        elif config.dataset == 'lrw':
            self.dataset = LRWdataset(config.dataset_dir, train=config.is_train)
    
        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=12,
                                      shuffle=True, drop_last=True)
        data_iter = iter(self.data_loader)
        data_iter.next()
        self.ones = Variable(torch.ones(config.batch_size), requires_grad=False)
        self.zeros = Variable(torch.zeros(config.batch_size), requires_grad=False)

        if config.cuda:

            device_ids = [int(i) for i in config.device_ids.split(',')]
            self.generator     = nn.DataParallel(self.generator.cuda(), device_ids=device_ids)
            self.l1_loss_fn = self.l1_loss_fn.cuda()
            self.ones          = self.ones.cuda()
            self.zeros         = self.zeros.cuda()

        self.config = config
        self.start_epoch = 0
Esempio n. 3
0
    def __init__(self, config):
        self.generator = Generator()
        self.discriminator = Discriminator()
        self.flow_encoder = FlowEncoder(flow_type='flownet')
        self.audio_deri_encoder = AudioDeriEncoder()
        self.encoder = Encoder()
        self.encoder.load_state_dict(
            torch.load(
                '/mnt/disk0/dat/lchen63/grid/model/model_embedding/encoder_6.pth'
            ))
        for param in self.encoder.parameters():
            param.requires_grad = False
        if config.load_model:
            flownet = flownets()
        else:
            flownet = flownets(config.flownet_pth)
        self.flows_gen = FlowsGen(flownet)

        print(self.generator)
        print(self.discriminator)
        self.bce_loss_fn = nn.BCELoss()
        self.mse_loss_fn = nn.MSELoss()
        self.l1_loss_fn = nn.L1Loss()
        # self.cosine_loss_fn = CosineSimiLoss(config.batch_size)
        self.cosine_loss_fn = nn.CosineEmbeddingLoss()

        self.opt_g = torch.optim.Adam(
            [p for p in self.generator.parameters() if p.requires_grad],
            lr=config.lr,
            betas=(config.beta1, config.beta2))
        self.opt_d = torch.optim.Adam(self.discriminator.parameters(),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))
        self.opt_corr = torch.optim.Adam(itertools.chain(
            self.flow_encoder.parameters(),
            self.audio_deri_encoder.parameters()),
                                         lr=config.lr_corr,
                                         betas=(0.9, 0.999),
                                         weight_decay=0.0005)
        self.opt_flownet = torch.optim.Adam(self.flows_gen.parameters(),
                                            lr=config.lr_flownet,
                                            betas=(0.9, 0.999),
                                            weight_decay=4e-4)

        if config.dataset == 'grid':
            self.dataset = VaganDataset(config.dataset_dir,
                                        train=config.is_train)
        elif config.dataset == 'lrw':
            self.dataset = LRWdataset(config.dataset_dir,
                                      train=config.is_train)

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=config.num_thread,
                                      shuffle=True,
                                      drop_last=True)

        self.ones = Variable(torch.ones(config.batch_size))
        self.zeros = Variable(torch.zeros(config.batch_size))
        self.one_corr = Variable(torch.ones(config.batch_size))

        if config.cuda:
            device_ids = [int(i) for i in config.device_ids.split(',')]
            self.generator = nn.DataParallel(self.generator.cuda(),
                                             device_ids=device_ids)
            self.discriminator = nn.DataParallel(self.discriminator.cuda(),
                                                 device_ids=device_ids)
            self.flow_encoder = nn.DataParallel(self.flow_encoder.cuda(),
                                                device_ids=device_ids)
            self.audio_deri_encoder = nn.DataParallel(
                self.audio_deri_encoder.cuda(), device_ids=device_ids)
            self.encoder = nn.DataParallel(self.encoder.cuda(),
                                           device_ids=device_ids)
            self.flows_gen = nn.DataParallel(self.flows_gen.cuda(),
                                             device_ids=device_ids)
            self.bce_loss_fn = self.bce_loss_fn.cuda()
            self.mse_loss_fn = self.mse_loss_fn.cuda()
            self.l1_loss_fn = self.l1_loss_fn.cuda()
            self.cosine_loss_fn = self.cosine_loss_fn.cuda()
            self.ones = self.ones.cuda(async=True)
            self.zeros = self.zeros.cuda(async=True)
            self.one_corr = self.one_corr.cuda(async=True)

        self.config = config
        if config.load_model:
            self.start_epoch = config.start_epoch
            self.load(config.model_dir, self.start_epoch - 1)
        else:
            self.start_epoch = 0
Esempio n. 4
0
def _sample(config):

    dataset = VaganDataset(config.dataset_dir,
                           output_shape=[64, 64],
                           train=False)
    num_test = len(dataset)
    # num_test=100

    real_path = os.path.join(config.sample_dir, 'image/real64')
    fake_path = os.path.join(config.sample_dir, 'image/fake64')

    if not os.path.exists(config.sample_dir):
        os.mkdir(config.sample_dir)
    if not os.path.exists(os.path.join(config.sample_dir, 'image')):
        os.mkdir(os.path.join(config.sample_dir, 'image'))
    if not os.path.exists(os.path.join(config.sample_dir, 'image/real64')):
        os.mkdir(os.path.join(config.sample_dir, 'image/real64'))
    if not os.path.exists(os.path.join(config.sample_dir, 'image/fake64')):
        os.mkdir(os.path.join(config.sample_dir, 'image/fake64'))

    paths = []
    # print type(dataset)
    # dataset = dataset[0:10]
    stage1_generator = Generator()
    _load(stage1_generator, config.model_dir)
    examples, ims, landmarks, embeds, captions = [], [], [], [], []
    for idx in range(num_test):
        example, im, landmark, embed, caption = dataset[idx]
        examples.append(example)
        ims.append(im)
        embeds.append(embed)
        captions.append(caption)
        landmarks.append(landmark)
    examples = torch.stack(examples, 0)
    # landmarks = torch.stack(landmarks,0)
    ims = torch.stack(ims, 0)
    embeds = torch.stack(embeds, 0)

    if config.cuda:
        examples = Variable(examples).cuda()
        # landmarks = Variable( landmarks).cuda()
        embeds = Variable(embeds).cuda()
        stage1_generator = stage1_generator.cuda()
    else:
        examples = Variable(examples)
        embeds = Variable(embeds)
        landmarks = Variable(landmarks)
    # embeds = embeds.view(embeds.size()[0],embeds.size()[1],embeds.size()[2],embeds.size()[3])
    # print embeds.size()
    # print '-------'
    # # embeds = embeds.view(len(indices), -1)
    # print embeds.size()
    batch_size = config.batch_size
    for i in range(num_test / batch_size):
        example = examples[i * batch_size:i * batch_size + batch_size]
        # landmark = landmarks[i*batch_size:  i*batch_size + batch_size]
        embed = embeds[i * batch_size:i * batch_size + batch_size]

        print '---------------------' + str(i) + '/' + str(
            num_test / batch_size)
        fake_ims_stage1 = stage1_generator(example, embed)
        # fake_ims_stage1 = stage1_generator(example, embed)

        # print fake_ims_stage1
        # fake_ims_stage1 = fake_ims_stage1.view(len(dataset)*16,3,64,64)
        # ims = ims.view(len(dataset),3,64,64)
        for inx in range(batch_size):
            real_ims = ims[inx + i * batch_size]
            # real_ims = real_ims.permute(1,2, 3, 0)
            fake_ims = fake_ims_stage1[inx]
            # fake_ims = fake_ims.permute(1,2,3,0)
            real_ims = real_ims.cpu().permute(1, 2, 3, 0).numpy()
            # print fake_ims
            # print fake_ims.size()
            # print fake_ims.data.cpu()
            # print type(fake_ims.data.cpu())
            # print  fake_ims.data.cpu().size()
            fake_ims = fake_ims.data.cpu().permute(1, 2, 3, 0).numpy()
            # print '------'
            # print real_ims.shape
            # print fake_ims.shape
            fff = {}
            rp = []
            fp = []
            for j in range(real_ims.shape[0]):

                real = real_ims[j]
                fake = fake_ims[j]
                # real = real_ims[:,i,:,:].cpu().permute(2,3,0).numpy()
                # fake = fake_ims[:,i,:,:].data.cpu().permute(2,3,0).numpy()
                # print captions[inx][i]
                temp = captions[inx + i * batch_size][j].split('/')
                if not os.path.exists(os.path.join(fake_path, temp[-2])):
                    os.mkdir(os.path.join(fake_path, temp[-2]))
                if not os.path.exists(os.path.join(real_path, temp[-2])):
                    os.mkdir(os.path.join(real_path, temp[-2]))
                real_name = os.path.join(
                    real_path, temp[-2]) + '/' + temp[-1][:-4] + '.jpg'
                fake_name = os.path.join(
                    fake_path, temp[-2]) + '/' + temp[-1][:-4] + '.jpg'
                scipy.misc.imsave(real_name, real)
                scipy.misc.imsave(fake_name, fake)
                rp.append(real_name)
                fp.append(fake_name)
            fff["real_path"] = rp
            fff["fake_path"] = fp
            paths.append(fff)
        # print os.path.join(fake_path,temp[-2])
        real_im = ims[i * batch_size:i * batch_size + batch_size]
        fake_store = fake_ims_stage1.data.permute(
            0, 2, 1, 3, 4).contiguous().view(config.batch_size * 16, 3, 64, 64)
        torchvision.utils.save_image(fake_store,
                                     "{}/fake_{}.png".format(
                                         os.path.join(fake_path, temp[-2]), i),
                                     nrow=16,
                                     normalize=True)
        real_store = real_im.permute(0, 2, 1, 3, 4).contiguous().view(
            config.batch_size * 16, 3, 64, 64)
        torchvision.utils.save_image(real_store,
                                     "{}/real_{}.png".format(
                                         os.path.join(real_path, temp[-2]), i),
                                     nrow=16,
                                     normalize=True)

        with open(os.path.join(config.sample_dir, 'image/test_result.pkl'),
                  'wb') as handle:
            pickle.dump(paths, handle, protocol=pickle.HIGHEST_PROTOCOL)
Esempio n. 5
0
    def __init__(self, config):
        self.generator = Generator()
        self.discriminator = Discriminator()
        self.encoder = Encoder()
        self.encoder.load_state_dict(
            torch.load(
                '/mnt/disk1/dat/lchen63/grid/model/model_embedding/encoder_6.pth'
            ))
        for param in self.encoder.parameters():
            param.requires_grad = False

        print(self.generator)
        print(self.discriminator)

        self.l1_loss_fn = nn.L1Loss()
        self.bce_loss_fn = nn.BCELoss()
        self.mse_loss_fn = nn.MSELoss()

        self.opt_g = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                             self.generator.parameters()),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))
        self.opt_d = torch.optim.Adam(self.discriminator.parameters(),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))

        if config.dataset == 'grid':
            self.dataset = VaganDataset(config.dataset_dir,
                                        train=config.is_train)
        elif config.dataset == 'lrw':
            self.dataset = LRWdataset(config.dataset_dir,
                                      train=config.is_train)

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=4,
                                      shuffle=True,
                                      drop_last=True)
        data_iter = iter(self.data_loader)
        data_iter.next()
        self.ones = Variable(torch.ones(config.batch_size),
                             requires_grad=False)
        self.zeros = Variable(torch.zeros(config.batch_size),
                              requires_grad=False)

        if config.cuda:
            device_ids = [int(i) for i in config.device_ids.split(',')]
            self.generator = nn.DataParallel(self.generator.cuda(),
                                             device_ids=device_ids)
            self.discriminator = nn.DataParallel(self.discriminator.cuda(),
                                                 device_ids=device_ids)
            self.encoder = nn.DataParallel(self.encoder.cuda(),
                                           device_ids=device_ids)
            self.bce_loss_fn = self.bce_loss_fn.cuda()
            self.mse_loss_fn = self.mse_loss_fn.cuda()
            self.l1_loss_fn = self.l1_loss_fn.cuda()
            self.ones = self.ones.cuda()
            self.zeros = self.zeros.cuda()

        self.config = config
        self.start_epoch = 0
Esempio n. 6
0
from torch.utils.data import DataLoader
from dataset import VaganDataset
import time

dataset = VaganDataset('/mnt/disk1/dat/lchen63/grid/data/pickle/', train=True)
data_loader = DataLoader(dataset,
                         batch_size=12,
                         num_workers=12,
                         shuffle=True,
                         drop_last=True)
data_loader = iter(data_loader)

print('dataloader created')

for _ in range(len(data_loader)):
    start_time = time.time()
    data_loader.next()
    print("--- %s seconds ---" % (time.time() - start_time))
Esempio n. 7
0
    def __init__(self, config):
        self.generator = Generator()
        self.discriminator = Discriminator()
        self.images_encoder = FlowEncoder('embedding')
        self.audio_encoder = AudioDeriEncoder(need_deri=False)

        print(self.generator)
        print(self.discriminator)
        self.bce_loss_fn = nn.BCELoss()
        self.mse_loss_fn = nn.MSELoss()
        # self.cosine_loss_fn = CosineSimiLoss(config.batch_size)
        self.cosine_loss_fn = nn.CosineEmbeddingLoss()

        self.opt_g = torch.optim.Adam(
            [p for p in self.generator.parameters() if p.requires_grad],
            lr=config.lr,
            betas=(config.beta1, config.beta2))
        self.opt_d = torch.optim.Adam(self.discriminator.parameters(),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))
        self.opt_corr = torch.optim.Adam(itertools.chain(
            self.images_encoder.parameters(), self.audio_encoder.parameters()),
                                         lr=config.lr_corr,
                                         betas=(0.9, 0.999),
                                         weight_decay=0.0005)

        if config.dataset == 'grid':
            self.dataset = VaganDataset(config.dataset_dir,
                                        train=config.is_train)
        elif config.dataset == 'lrw':
            self.dataset = LRWdataset(config.dataset_dir,
                                      train=config.is_train)

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=config.num_thread,
                                      shuffle=True,
                                      drop_last=True)

        self.ones = Variable(torch.ones(config.batch_size))
        self.zeros = Variable(torch.zeros(config.batch_size))
        self.one_corr = Variable(torch.ones(config.batch_size))

        if config.cuda:
            device_ids = [int(i) for i in config.device_ids.split(',')]
            if len(device_ids) == 1:
                self.generator = self.generator.cuda(device_ids[0])
                self.discriminator = self.discriminator.cuda(device_ids[0])
                self.images_encoder = self.images_encoder.cuda(device_ids[0])
                self.audio_encoder = self.audio_encoder.cuda(device_ids[0])
            else:
                self.generator = nn.DataParallel(self.generator.cuda(),
                                                 device_ids=device_ids)
                self.discriminator = nn.DataParallel(self.discriminator.cuda(),
                                                     device_ids=device_ids)
                self.images_encoder = nn.DataParallel(
                    self.images_encoder.cuda(), device_ids=device_ids)
                self.audio_encoder = nn.DataParallel(self.audio_encoder.cuda(),
                                                     device_ids=device_ids)
            self.bce_loss_fn = self.bce_loss_fn.cuda()
            self.mse_loss_fn = self.mse_loss_fn.cuda()
            self.cosine_loss_fn = self.cosine_loss_fn.cuda()
            self.ones = self.ones.cuda(async=True)
            self.zeros = self.zeros.cuda(async=True)
            self.one_corr = self.one_corr.cuda(async=True)

        self.config = config
        self.start_epoch = 0
Esempio n. 8
0
    def __init__(self, config):
        self.generator = Generator()
        self.discriminator = Discriminator()
        flownet = flownets(config.flownet_pth)
        print('flownet_pth: {}'.format(config.flownet_pth))
        self.flows_gen = FlowsGen(flownet)

        print(self.generator)
        print(self.discriminator)
        self.bce_loss_fn = nn.BCELoss()
        self.mse_loss_fn = nn.MSELoss()

        self.opt_g = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                             self.generator.parameters()),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))
        self.opt_d = torch.optim.Adam(self.discriminator.parameters(),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))
        self.opt_flow = torch.optim.Adam(self.flows_gen.parameters(),
                                         lr=config.lr_flownet,
                                         betas=(0.9, 0.999),
                                         weight_decay=4e-4)

        if config.dataset == 'grid':
            self.dataset = VaganDataset(config.dataset_dir,
                                        train=config.is_train)
        elif config.dataset == 'lrw':
            self.dataset = LRWdataset(config.dataset_dir,
                                      train=config.is_train)

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=config.num_thread,
                                      shuffle=True,
                                      drop_last=True)
        self.ones = Variable(torch.ones(config.batch_size),
                             requires_grad=False)
        self.zeros = Variable(torch.zeros(config.batch_size),
                              requires_grad=False)

        device_ids = [int(i) for i in config.device_ids.split(',')]
        if config.cuda:
            if len(device_ids) == 1:
                self.generator = self.generator.cuda(device_ids[0])
                self.discriminator = self.discriminator.cuda(device_ids[0])
                self.flows_gen = self.flows_gen.cuda(device_ids[0])
            else:
                self.generator = nn.DataParallel(self.generator.cuda(),
                                                 device_ids=device_ids)
                self.discriminator = nn.DataParallel(self.discriminator.cuda(),
                                                     device_ids=device_ids)
                self.flows_gen = nn.DataParallel(self.flows_gen.cuda(),
                                                 device_ids=device_ids)
            self.bce_loss_fn = self.bce_loss_fn.cuda()
            self.mse_loss_fn = self.mse_loss_fn.cuda()
            self.ones = self.ones.cuda()
            self.zeros = self.zeros.cuda()

        self.config = config
        self.start_epoch = 0