예제 #1
0
    def __init__(self,
                 in_channels=3,
                 n_sequence=3,
                 out_channels=3,
                 n_resblock=3,
                 n_feat=32,
                 load_flow_net=False,
                 load_recons_net=False,
                 flow_pretrain_fn='',
                 recons_pretrain_fn='',
                 is_mask_filter=False,
                 device='cuda',
                 lbd=4,
                 hbd=16):
        super(VBDE_QM, self).__init__()
        print("Creating VBDE_QM Net")

        self.n_sequence = n_sequence
        self.device = device
        self.is_mask_filter = is_mask_filter

        self.lbd = lbd
        self.hbd = hbd

        assert n_sequence == 3, "Only support args.n_sequence=3; but get args.n_sequence={}".format(
            n_sequence)

        extra_channels = 1
        print('Concat quantization mask')

        self.flow_net = flow_pwc.Flow_PWC(load_pretrain=load_flow_net,
                                          pretrain_fn=flow_pretrain_fn,
                                          device=device)
        self.recons_net = recons_video.RECONS_VIDEO(
            in_channels=in_channels,
            n_sequence=3,
            out_channels=out_channels,
            n_resblock=n_resblock,
            n_feat=n_feat,
            extra_channels=extra_channels)
        if load_recons_net:
            self.recons_net.load_state_dict(torch.load(recons_pretrain_fn))
            print('Loading reconstruction pretrain model from {}'.format(
                recons_pretrain_fn))
예제 #2
0
    def __init__(self,
                 in_channels=3,
                 n_sequence=3,
                 out_channels=3,
                 n_resblock=3,
                 n_feat=32,
                 load_flow_net=False,
                 load_recons_net=False,
                 flow_pretrain_fn='',
                 recons_pretrain_fn='',
                 is_mask_filter=False,
                 device='cuda'):
        super(CDVD_TSP, self).__init__()
        print("Creating CDVD-TSP Net")

        self.n_sequence = n_sequence
        self.device = device

        assert n_sequence == 5, "Only support args.n_sequence=5; but get args.n_sequence={}".format(
            n_sequence)

        self.is_mask_filter = is_mask_filter
        print('Is meanfilter image when process mask:',
              'True' if is_mask_filter else 'False')
        extra_channels = 1
        print('Select mask mode: concat, num_mask={}'.format(extra_channels))

        self.flow_net = flow_pwc.Flow_PWC(load_pretrain=load_flow_net,
                                          pretrain_fn=flow_pretrain_fn,
                                          device=device)
        self.recons_net = recons_video.RECONS_VIDEO(
            in_channels=in_channels,
            n_sequence=3,
            out_channels=out_channels,
            n_resblock=n_resblock,
            n_feat=n_feat,
            extra_channels=extra_channels)
        if load_recons_net:
            self.recons_net.load_state_dict(torch.load(recons_pretrain_fn))
            print('Loading reconstruction pretrain model from {}'.format(
                recons_pretrain_fn))
예제 #3
0
    def __init__(self,
                 in_channels=3,
                 n_sequence=3,
                 out_channels=3,
                 n_resblock=3,
                 n_feat=32,
                 load_flow_net=False,
                 load_recons_net=False,
                 flow_pretrain_fn='',
                 recons_pretrain_fn='',
                 is_mask_filter=False,
                 device='cuda'):
        super(VBDE, self).__init__()
        print("Creating VBDE Net")

        self.n_sequence = n_sequence
        self.device = device

        assert n_sequence == 3, "Only support args.n_sequence=3; but get args.n_sequence={}".format(
            n_sequence)

        extra_channels = 0
        print('Select mask mode: concat, num_mask={}'.format(extra_channels))

        self.space2depth = SpaceToDepth(2)
        self.depth2space = nn.PixelShuffle(2)

        self.flow_net = flow_pwc.Flow_PWC(load_pretrain=load_flow_net,
                                          pretrain_fn=flow_pretrain_fn,
                                          device=device)
        self.recons_net = recons_video.RECONS_VIDEO(
            in_channels=in_channels,
            n_sequence=3,
            out_channels=out_channels,
            n_resblock=n_resblock,
            n_feat=n_feat,
            extra_channels=extra_channels)
        if load_recons_net:
            self.recons_net.load_state_dict(torch.load(recons_pretrain_fn))
            print('Loading reconstruction pretrain model from {}'.format(
                recons_pretrain_fn))
예제 #4
0
    def __init__(self, in_channels=3, n_sequence=3, out_channels=3, n_resblock=3, n_feat=32,
                 load_flow_net=False, load_recons_net=False, flow_pretrain_fn='', recons_pretrain_fn='',
                 is_mask_filter=False, device='cuda', lbd=4, hbd=16):
        super(VBDE_STEPMASK, self).__init__()
        print("Creating VBDE_STEPMASK Net")

        self.n_sequence = n_sequence
        self.device = device
        self.is_mask_filter = is_mask_filter

        self.lbd = lbd
        self.hbd = hbd
        self.quantization_step = 2**(self.hbd - self.lbd)/(2**self.hbd - 1)

        # initialize quantization step filter
        top_kernel = torch.FloatTensor([[0, -1, 0], [0, 1, 0], [0, 0, 0]]).unsqueeze(0).unsqueeze(0)
        down_kernel = torch.FloatTensor([[0, 0, 0], [0, 1, 0], [0, -1, 0]]).unsqueeze(0).unsqueeze(0)
        left_kernel = torch.FloatTensor([[0, 0, 0], [-1, 1, 0], [0, 0, 0]]).unsqueeze(0).unsqueeze(0)
        right_kernel = torch.FloatTensor([[0, 0, 0], [0, 1, -1], [0, 0, 0]]).unsqueeze(0).unsqueeze(0)

        self.top_filter = nn.Parameter(data=top_kernel, requires_grad=False)
        self.down_filter = nn.Parameter(data=down_kernel, requires_grad=False)
        self.left_filter = nn.Parameter(data=left_kernel, requires_grad=False)
        self.right_filter = nn.Parameter(data=right_kernel, requires_grad=False)

        assert n_sequence == 3, "Only support args.n_sequence=3; but get args.n_sequence={}".format(n_sequence)

        extra_channels = 1
        print('Concat quantization step mask')

        self.flow_net = flow_pwc.Flow_PWC(load_pretrain=load_flow_net, pretrain_fn=flow_pretrain_fn, device=device)
        self.recons_net = recons_video.RECONS_VIDEO(in_channels=in_channels, n_sequence=3, out_channels=out_channels,
                                                    n_resblock=n_resblock, n_feat=n_feat,
                                                    extra_channels=extra_channels)
        if load_recons_net:
            self.recons_net.load_state_dict(torch.load(recons_pretrain_fn))
            print('Loading reconstruction pretrain model from {}'.format(recons_pretrain_fn))