def __init__(self, in_channels=3, n_sequence=3, out_channels=3, n_resblock=3, n_feat=32, load_flow_net=False, load_recons_net=False, flow_pretrain_fn='', recons_pretrain_fn='', is_mask_filter=False, device='cuda', lbd=4, hbd=16): super(VBDE_QM, self).__init__() print("Creating VBDE_QM Net") self.n_sequence = n_sequence self.device = device self.is_mask_filter = is_mask_filter self.lbd = lbd self.hbd = hbd assert n_sequence == 3, "Only support args.n_sequence=3; but get args.n_sequence={}".format( n_sequence) extra_channels = 1 print('Concat quantization mask') self.flow_net = flow_pwc.Flow_PWC(load_pretrain=load_flow_net, pretrain_fn=flow_pretrain_fn, device=device) self.recons_net = recons_video.RECONS_VIDEO( in_channels=in_channels, n_sequence=3, out_channels=out_channels, n_resblock=n_resblock, n_feat=n_feat, extra_channels=extra_channels) if load_recons_net: self.recons_net.load_state_dict(torch.load(recons_pretrain_fn)) print('Loading reconstruction pretrain model from {}'.format( recons_pretrain_fn))
def __init__(self, in_channels=3, n_sequence=3, out_channels=3, n_resblock=3, n_feat=32, load_flow_net=False, load_recons_net=False, flow_pretrain_fn='', recons_pretrain_fn='', is_mask_filter=False, device='cuda'): super(CDVD_TSP, self).__init__() print("Creating CDVD-TSP Net") self.n_sequence = n_sequence self.device = device assert n_sequence == 5, "Only support args.n_sequence=5; but get args.n_sequence={}".format( n_sequence) self.is_mask_filter = is_mask_filter print('Is meanfilter image when process mask:', 'True' if is_mask_filter else 'False') extra_channels = 1 print('Select mask mode: concat, num_mask={}'.format(extra_channels)) self.flow_net = flow_pwc.Flow_PWC(load_pretrain=load_flow_net, pretrain_fn=flow_pretrain_fn, device=device) self.recons_net = recons_video.RECONS_VIDEO( in_channels=in_channels, n_sequence=3, out_channels=out_channels, n_resblock=n_resblock, n_feat=n_feat, extra_channels=extra_channels) if load_recons_net: self.recons_net.load_state_dict(torch.load(recons_pretrain_fn)) print('Loading reconstruction pretrain model from {}'.format( recons_pretrain_fn))
def __init__(self, in_channels=3, n_sequence=3, out_channels=3, n_resblock=3, n_feat=32, load_flow_net=False, load_recons_net=False, flow_pretrain_fn='', recons_pretrain_fn='', is_mask_filter=False, device='cuda'): super(VBDE, self).__init__() print("Creating VBDE Net") self.n_sequence = n_sequence self.device = device assert n_sequence == 3, "Only support args.n_sequence=3; but get args.n_sequence={}".format( n_sequence) extra_channels = 0 print('Select mask mode: concat, num_mask={}'.format(extra_channels)) self.space2depth = SpaceToDepth(2) self.depth2space = nn.PixelShuffle(2) self.flow_net = flow_pwc.Flow_PWC(load_pretrain=load_flow_net, pretrain_fn=flow_pretrain_fn, device=device) self.recons_net = recons_video.RECONS_VIDEO( in_channels=in_channels, n_sequence=3, out_channels=out_channels, n_resblock=n_resblock, n_feat=n_feat, extra_channels=extra_channels) if load_recons_net: self.recons_net.load_state_dict(torch.load(recons_pretrain_fn)) print('Loading reconstruction pretrain model from {}'.format( recons_pretrain_fn))
def __init__(self, in_channels=3, n_sequence=3, out_channels=3, n_resblock=3, n_feat=32, load_flow_net=False, load_recons_net=False, flow_pretrain_fn='', recons_pretrain_fn='', is_mask_filter=False, device='cuda', lbd=4, hbd=16): super(VBDE_STEPMASK, self).__init__() print("Creating VBDE_STEPMASK Net") self.n_sequence = n_sequence self.device = device self.is_mask_filter = is_mask_filter self.lbd = lbd self.hbd = hbd self.quantization_step = 2**(self.hbd - self.lbd)/(2**self.hbd - 1) # initialize quantization step filter top_kernel = torch.FloatTensor([[0, -1, 0], [0, 1, 0], [0, 0, 0]]).unsqueeze(0).unsqueeze(0) down_kernel = torch.FloatTensor([[0, 0, 0], [0, 1, 0], [0, -1, 0]]).unsqueeze(0).unsqueeze(0) left_kernel = torch.FloatTensor([[0, 0, 0], [-1, 1, 0], [0, 0, 0]]).unsqueeze(0).unsqueeze(0) right_kernel = torch.FloatTensor([[0, 0, 0], [0, 1, -1], [0, 0, 0]]).unsqueeze(0).unsqueeze(0) self.top_filter = nn.Parameter(data=top_kernel, requires_grad=False) self.down_filter = nn.Parameter(data=down_kernel, requires_grad=False) self.left_filter = nn.Parameter(data=left_kernel, requires_grad=False) self.right_filter = nn.Parameter(data=right_kernel, requires_grad=False) assert n_sequence == 3, "Only support args.n_sequence=3; but get args.n_sequence={}".format(n_sequence) extra_channels = 1 print('Concat quantization step mask') self.flow_net = flow_pwc.Flow_PWC(load_pretrain=load_flow_net, pretrain_fn=flow_pretrain_fn, device=device) self.recons_net = recons_video.RECONS_VIDEO(in_channels=in_channels, n_sequence=3, out_channels=out_channels, n_resblock=n_resblock, n_feat=n_feat, extra_channels=extra_channels) if load_recons_net: self.recons_net.load_state_dict(torch.load(recons_pretrain_fn)) print('Loading reconstruction pretrain model from {}'.format(recons_pretrain_fn))