def get_model(args: argparse.Namespace) -> nn.Module: if "waveone" in args.network: # context_vec_train_shape = (args.batch_size, 512, # args.patch // 2 or 144, args.patch // 2 or 176) # context_vec_test_shape = (args.eval_batch_size, 512, 144, 176) # unet = UNet(3, shrink=1) encoder = Encoder(6, args.bits, use_context=False) # decoder = nn.Sequential(BitToContextDecoder(), # ContextToFlowDecoder(3)).cuda() decoder = BitToFlowDecoder(args.bits, 3) binarizer = Binarizer(args.bits, args.bits, not args.binarize_off) return WaveoneModel(encoder, binarizer, decoder, args.flow_off) if args.network == "cae": return CAE() if args.network == "unet": return AutoencoderUNet(6, shrink=1) if args.network == "opt": opt_encoder = LambdaModule(lambda f1, f2, _: f2 - f1) opt_binarizer = nn.Identity() # type: ignore opt_decoder = LambdaModule(lambda t: (torch.tensor(0.), t[0], torch.tensor(0.))) return WaveoneModel(opt_encoder, opt_binarizer, opt_decoder, flow_off=True) if args.network == "prednet": prednet = PredNet(R_channels=(3, 48, 96, 192), A_channels=(3, 48, 96, 192)) return prednet raise ValueError(f"No model type named {args.network}.")
def test_forward_model_zero_residual(): shape = (24, 3, 255, 255) frame = torch.rand(shape) - 0.5 network = LambdaModule(lambda x: x[:, 3:] - x[:, :3]) residuals, reconstructed = forward_model(network, frame, frame) assert residuals.norm().item() == pytest.approx(0.) l2_score = nn.MSELoss()(reconstructed, frame).item() assert l2_score == pytest.approx(0.)
def test_forward_model_exact_residual(): shape = (32, 3, 64, 64) frame1 = torch.rand(shape) - 0.5 frame2 = torch.rand(shape) - 0.5 network = LambdaModule(lambda x: x[:, 3:] - x[:, :3]) _, reconstructed_frame2 = forward_model(network, frame1, frame2) msssim_score = MSSSIM(val_range=1)(frame2, reconstructed_frame2).item() assert msssim_score == pytest.approx(1.0) l2_score = nn.MSELoss()(frame2, reconstructed_frame2).item() assert l2_score == pytest.approx(0.)
def get_model(args: argparse.Namespace) -> nn.Module: # if "waveone" in args.network: # # context_vec_train_shape = (args.batch_size, 512, # # args.patch // 2 or 144, args.patch // 2 or 176) # # context_vec_test_shape = (args.eval_batch_size, 512, 144, 176) # # unet = UNet(3, shrink=1) # encoder = Encoder(6, args.bits, use_context=False) # # decoder = nn.Sequential(BitToContextDecoder(), # # ContextToFlowDecoder(3)).cuda() # decoder = BitToFlowDecoder(args.bits, 3) # binarizer = Binarizer(args.bits, args.bits, # not args.binarize_off) # return WaveoneModel(encoder, binarizer, decoder, args.train_type) flow_loss_fn = get_loss_fn(args.flow_loss).cuda() reconstructed_loss_fn = get_loss_fn(args.reconstructed_loss).cuda() if args.network == "cae": return CAE() if args.network == "unet": return AutoencoderUNet(6, shrink=1) if args.network == "opt": opt_encoder = LambdaModule(lambda f1, f2, _: f2 - f1) opt_binarizer = nn.Identity() # type: ignore opt_decoder = LambdaModule( lambda t: { "flow": torch.zeros(1), "flow_grid": torch.zeros(1), "residuals": t[0], "context_vec": torch.zeros(1), "loss": torch.tensor(0.), }) opt_decoder.num_flows = 1 # type: ignore return WaveoneModel( opt_encoder, opt_binarizer, opt_decoder, "residual", False, flow_loss_fn, reconstructed_loss_fn, ) if args.network == "small": small_encoder = SmallEncoder(6, args.bits) small_binarizer = SmallBinarizer(not args.binarize_off) small_decoder = SmallDecoder(args.bits, 3) return WaveoneModel( small_encoder, small_binarizer, small_decoder, args.train_type, False, flow_loss_fn, reconstructed_loss_fn, ) if "resnet" in args.network: use_context = "ctx" in args.network resnet_encoder = ResNetEncoder(6, args.bits, resblocks=args.resblocks, use_context=use_context) resnet_binarizer = SmallBinarizer(not args.binarize_off) resnet_decoder = ResNetDecoder(args.bits, 3, resblocks=args.resblocks, use_context=use_context, num_flows=args.num_flows) return WaveoneModel( resnet_encoder, resnet_binarizer, resnet_decoder, args.train_type, use_context, flow_loss_fn, reconstructed_loss_fn, ) raise ValueError(f"No model type named {args.network}.")
def train(args) -> List[nn.Module]: log_dir = os.path.join(args.log_dir, args.save_model_name) output_dir = os.path.join(args.out_dir, args.save_model_name) model_dir = os.path.join(args.model_dir, args.save_model_name) create_directories((output_dir, model_dir, log_dir)) # logging.basicConfig( # filename=os.path.join(log_dir, args.save_model_name + ".out"), # filemode="w", # level=logging.DEBUG, # ) print(args) ############### Data ############### train_loader = get_master_loader(is_train=True, root=args.train, frame_len=4, sampling_range=12, args=args) eval_loader = get_master_loader( is_train=False, root=args.eval, frame_len=1, sampling_range=0, args=args, ) writer = SummaryWriter(f"runs/{args.save_model_name}") ############### Model ############### network = AutoencoderUNet(6, shrink=1) if args.network == 'unet' \ else CAE() if args.network == 'cae' \ else LambdaModule(lambda x: x[:, 3:] - x[:, :3]) network = network.cuda() nets: List[nn.Module] = [network] names = [args.network] solver = optim.Adam(network.parameters() if args.network != 'opt' else [torch.zeros((1, ))], lr=args.lr, weight_decay=args.weight_decay) milestones = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] scheduler = LS.MultiStepLR(solver, milestones=milestones, gamma=0.5) msssim_fn = MSSSIM(val_range=1, normalize=True).cuda() l1_loss_fn = nn.L1Loss(reduction="mean").cuda() l2_loss_fn = nn.MSELoss(reduction="mean").cuda() loss_fn = l2_loss_fn if args.loss == 'l2' else l1_loss_fn if args.loss == 'l1' \ else LambdaModule(lambda a, b: -msssim_fn(a, b)) ############### Checkpoints ############### def resume() -> None: for name, net in zip(names, nets): if net is not None: checkpoint_path = os.path.join( args.model_dir, args.load_model_name, f"{name}.pth", ) print('Loading %s from %s...' % (name, checkpoint_path)) net.load_state_dict(torch.load(checkpoint_path)) def save() -> None: for name, net in zip(names, nets): if net is not None: checkpoint_path = os.path.join( model_dir, f'{name}.pth', ) torch.save(net.state_dict(), checkpoint_path) def log_flow_context_residuals( writer: SummaryWriter, residuals: torch.Tensor, ) -> None: writer.add_scalar("mean_input_residuals", residuals.mean().item(), train_iter) writer.add_scalar("max_input_residuals", residuals.max().item(), train_iter) writer.add_scalar("min_input_residuals", residuals.min().item(), train_iter) ############### Training ############### train_iter = 0 just_resumed = False if args.load_model_name: print(f'Loading {args.load_model_name}') resume() just_resumed = True def train_loop( frames: List[torch.Tensor], ) -> Iterator[Tuple[float, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: for net in nets: net.train() if args.network != 'opt': solver.zero_grad() reconstructed_frames = [] reconstructed_frame2 = None loss: torch.Tensor = 0. # type: ignore frame1 = frames[0].cuda() for frame2 in frames[1:]: frame2 = frame2.cuda() residuals, reconstructed_frame2 = forward_model( network, frame1, frame2) reconstructed_frames.append(reconstructed_frame2.cpu()) loss += loss_fn(reconstructed_frame2, frame2) if args.save_max_l2: with torch.no_grad(): batch_l2 = ((frame2 - frame1 - residuals)**2).mean( dim=-1).mean(dim=-1).mean(dim=-1).cpu() max_batch_l2, max_batch_l2_idx = torch.max(batch_l2, dim=0) max_batch_l2_frames = ( frame1[max_batch_l2_idx].cpu(), frame2[max_batch_l2_idx].cpu(), reconstructed_frame2[max_batch_l2_idx].detach().cpu(), ) max_l2: float = max_batch_l2.item() # type: ignore yield max_l2, max_batch_l2_frames log_flow_context_residuals(writer, torch.abs(frame2 - frame1)) frame1 = reconstructed_frame2.detach() scores = { **eval_scores(frames[:-1], frames[1:], "train_baseline"), **eval_scores(frames[1:], reconstructed_frames, "train_reconstructed"), } if args.network != "opt": loss.backward() solver.step() writer.add_scalar("training_loss", loss.item(), train_iter) writer.add_scalar("lr", solver.param_groups[0]["lr"], train_iter) # type: ignore plot_scores(writer, scores, train_iter) score_diffs = get_score_diffs(scores, ["reconstructed"], "train") plot_scores(writer, score_diffs, train_iter) for epoch in range(args.max_train_epochs): for frames in train_loader: train_iter += 1 max_epoch_l2, max_epoch_l2_frames = max(train_loop(frames), key=lambda x: x[0]) if args.save_out_img: save_tensor_as_img( max_epoch_l2_frames[1], f"{max_epoch_l2 :.6f}_{epoch}_max_l2_frame", args, ) save_tensor_as_img( max_epoch_l2_frames[2], f"{max_epoch_l2 :.6f}_{epoch}_max_l2_reconstructed", args, ) if (epoch + 1) % args.checkpoint_epochs == 0: save() if just_resumed or ((epoch + 1) % args.eval_epochs == 0): run_eval("TVL", eval_loader, network, epoch, args, writer, reuse_reconstructed=True) run_eval("TVL", eval_loader, network, epoch, args, writer, reuse_reconstructed=False) scheduler.step() # type: ignore just_resumed = False print('Training done.') logging.shutdown() return nets