Exemple #1
0
                        default="/mnt/disk1/dat/lchen63/grid/data/pickle/")
    parser.add_argument(
        "--model_dir",
        type=str,
        default="/mnt/disk1/dat/lchen63/grid/data/flownets_pytorch.pth")
    parser.add_argument(
        "--sample_dir",
        type=str,
        default="/mnt/disk1/dat/lchen63/grid/test_result/vis_flow_no_train/")
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--num_thread", type=int, default=40)
    return parser.parse_args()


config = parse_args()
net = flownets(config.model_dir)
net.eval()
net = net.cuda()
testset = VaganFlowDataset(config.dataset_dir, train=False)
test_dataloader = DataLoader(testset,
                             batch_size=config.batch_size,
                             num_workers=config.num_thread,
                             shuffle=False,
                             drop_last=True)

img_idx = 0

for real_im, real_flow in itertools.islice(test_dataloader, 0, 10):
    if config.cuda:
        real_im = Variable(real_im).cuda(async=True)
        real_flow = Variable(real_flow).cuda(async=True)
Exemple #2
0
    def __init__(self, config):
        self.generator = Generator()
        self.discriminator = Discriminator()
        self.flow_encoder = FlowEncoder(flow_type='flownet')
        self.audio_deri_encoder = AudioDeriEncoder()
        self.encoder = Encoder()
        self.encoder.load_state_dict(
            torch.load(
                '/mnt/disk0/dat/lchen63/grid/model/model_embedding/encoder_6.pth'
            ))
        for param in self.encoder.parameters():
            param.requires_grad = False
        if config.load_model:
            flownet = flownets()
        else:
            flownet = flownets(config.flownet_pth)
        self.flows_gen = FlowsGen(flownet)

        print(self.generator)
        print(self.discriminator)
        self.bce_loss_fn = nn.BCELoss()
        self.mse_loss_fn = nn.MSELoss()
        self.l1_loss_fn = nn.L1Loss()
        # self.cosine_loss_fn = CosineSimiLoss(config.batch_size)
        self.cosine_loss_fn = nn.CosineEmbeddingLoss()

        self.opt_g = torch.optim.Adam(
            [p for p in self.generator.parameters() if p.requires_grad],
            lr=config.lr,
            betas=(config.beta1, config.beta2))
        self.opt_d = torch.optim.Adam(self.discriminator.parameters(),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))
        self.opt_corr = torch.optim.Adam(itertools.chain(
            self.flow_encoder.parameters(),
            self.audio_deri_encoder.parameters()),
                                         lr=config.lr_corr,
                                         betas=(0.9, 0.999),
                                         weight_decay=0.0005)
        self.opt_flownet = torch.optim.Adam(self.flows_gen.parameters(),
                                            lr=config.lr_flownet,
                                            betas=(0.9, 0.999),
                                            weight_decay=4e-4)

        if config.dataset == 'grid':
            self.dataset = VaganDataset(config.dataset_dir,
                                        train=config.is_train)
        elif config.dataset == 'lrw':
            self.dataset = LRWdataset(config.dataset_dir,
                                      train=config.is_train)

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=config.num_thread,
                                      shuffle=True,
                                      drop_last=True)

        self.ones = Variable(torch.ones(config.batch_size))
        self.zeros = Variable(torch.zeros(config.batch_size))
        self.one_corr = Variable(torch.ones(config.batch_size))

        if config.cuda:
            device_ids = [int(i) for i in config.device_ids.split(',')]
            self.generator = nn.DataParallel(self.generator.cuda(),
                                             device_ids=device_ids)
            self.discriminator = nn.DataParallel(self.discriminator.cuda(),
                                                 device_ids=device_ids)
            self.flow_encoder = nn.DataParallel(self.flow_encoder.cuda(),
                                                device_ids=device_ids)
            self.audio_deri_encoder = nn.DataParallel(
                self.audio_deri_encoder.cuda(), device_ids=device_ids)
            self.encoder = nn.DataParallel(self.encoder.cuda(),
                                           device_ids=device_ids)
            self.flows_gen = nn.DataParallel(self.flows_gen.cuda(),
                                             device_ids=device_ids)
            self.bce_loss_fn = self.bce_loss_fn.cuda()
            self.mse_loss_fn = self.mse_loss_fn.cuda()
            self.l1_loss_fn = self.l1_loss_fn.cuda()
            self.cosine_loss_fn = self.cosine_loss_fn.cuda()
            self.ones = self.ones.cuda(async=True)
            self.zeros = self.zeros.cuda(async=True)
            self.one_corr = self.one_corr.cuda(async=True)

        self.config = config
        if config.load_model:
            self.start_epoch = config.start_epoch
            self.load(config.model_dir, self.start_epoch - 1)
        else:
            self.start_epoch = 0
Exemple #3
0
    def __init__(self, config):
        self.generator = Generator()
        self.discriminator = Discriminator()
        flownet = flownets(config.flownet_pth)
        print('flownet_pth: {}'.format(config.flownet_pth))
        self.flows_gen = FlowsGen(flownet)

        print(self.generator)
        print(self.discriminator)
        self.bce_loss_fn = nn.BCELoss()
        self.mse_loss_fn = nn.MSELoss()

        self.opt_g = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                             self.generator.parameters()),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))
        self.opt_d = torch.optim.Adam(self.discriminator.parameters(),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))
        self.opt_flow = torch.optim.Adam(self.flows_gen.parameters(),
                                         lr=config.lr_flownet,
                                         betas=(0.9, 0.999),
                                         weight_decay=4e-4)

        if config.dataset == 'grid':
            self.dataset = VaganDataset(config.dataset_dir,
                                        train=config.is_train)
        elif config.dataset == 'lrw':
            self.dataset = LRWdataset(config.dataset_dir,
                                      train=config.is_train)

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=config.num_thread,
                                      shuffle=True,
                                      drop_last=True)
        self.ones = Variable(torch.ones(config.batch_size),
                             requires_grad=False)
        self.zeros = Variable(torch.zeros(config.batch_size),
                              requires_grad=False)

        device_ids = [int(i) for i in config.device_ids.split(',')]
        if config.cuda:
            if len(device_ids) == 1:
                self.generator = self.generator.cuda(device_ids[0])
                self.discriminator = self.discriminator.cuda(device_ids[0])
                self.flows_gen = self.flows_gen.cuda(device_ids[0])
            else:
                self.generator = nn.DataParallel(self.generator.cuda(),
                                                 device_ids=device_ids)
                self.discriminator = nn.DataParallel(self.discriminator.cuda(),
                                                     device_ids=device_ids)
                self.flows_gen = nn.DataParallel(self.flows_gen.cuda(),
                                                 device_ids=device_ids)
            self.bce_loss_fn = self.bce_loss_fn.cuda()
            self.mse_loss_fn = self.mse_loss_fn.cuda()
            self.ones = self.ones.cuda()
            self.zeros = self.zeros.cuda()

        self.config = config
        self.start_epoch = 0