Esempio n. 1
0
def get_model(args: argparse.Namespace) -> nn.Module:
    if "waveone" in args.network:
        # context_vec_train_shape = (args.batch_size, 512,
        #    args.patch // 2 or 144, args.patch // 2 or 176)
        # context_vec_test_shape = (args.eval_batch_size, 512, 144, 176)
        # unet = UNet(3, shrink=1)
        encoder = Encoder(6, args.bits, use_context=False)
        # decoder = nn.Sequential(BitToContextDecoder(),
        # ContextToFlowDecoder(3)).cuda()
        decoder = BitToFlowDecoder(args.bits, 3)
        binarizer = Binarizer(args.bits, args.bits, not args.binarize_off)
        return WaveoneModel(encoder, binarizer, decoder, args.flow_off)
    if args.network == "cae":
        return CAE()
    if args.network == "unet":
        return AutoencoderUNet(6, shrink=1)
    if args.network == "opt":
        opt_encoder = LambdaModule(lambda f1, f2, _: f2 - f1)
        opt_binarizer = nn.Identity()  # type: ignore
        opt_decoder = LambdaModule(lambda t:
                                   (torch.tensor(0.), t[0], torch.tensor(0.)))
        return WaveoneModel(opt_encoder,
                            opt_binarizer,
                            opt_decoder,
                            flow_off=True)
    if args.network == "prednet":
        prednet = PredNet(R_channels=(3, 48, 96, 192),
                          A_channels=(3, 48, 96, 192))
        return prednet
    raise ValueError(f"No model type named {args.network}.")
Esempio n. 2
0
def test_forward_model_zero_residual():
    shape = (24, 3, 255, 255)
    frame = torch.rand(shape) - 0.5
    network = LambdaModule(lambda x: x[:, 3:] - x[:, :3])
    residuals, reconstructed = forward_model(network, frame, frame)
    assert residuals.norm().item() == pytest.approx(0.)
    l2_score = nn.MSELoss()(reconstructed, frame).item()
    assert l2_score == pytest.approx(0.)
Esempio n. 3
0
def test_forward_model_exact_residual():
    shape = (32, 3, 64, 64)
    frame1 = torch.rand(shape) - 0.5
    frame2 = torch.rand(shape) - 0.5
    network = LambdaModule(lambda x: x[:, 3:] - x[:, :3])
    _, reconstructed_frame2 = forward_model(network, frame1, frame2)
    msssim_score = MSSSIM(val_range=1)(frame2, reconstructed_frame2).item()
    assert msssim_score == pytest.approx(1.0)
    l2_score = nn.MSELoss()(frame2, reconstructed_frame2).item()
    assert l2_score == pytest.approx(0.)
Esempio n. 4
0
def get_model(args: argparse.Namespace) -> nn.Module:
    # if "waveone" in args.network:
    #     # context_vec_train_shape = (args.batch_size, 512,
    #                         #    args.patch // 2 or 144, args.patch // 2 or 176)
    #     # context_vec_test_shape = (args.eval_batch_size, 512, 144, 176)
    #     # unet = UNet(3, shrink=1)
    #     encoder = Encoder(6, args.bits, use_context=False)
    #     # decoder = nn.Sequential(BitToContextDecoder(),
    #     # ContextToFlowDecoder(3)).cuda()
    #     decoder = BitToFlowDecoder(args.bits, 3)
    #     binarizer = Binarizer(args.bits, args.bits,
    #                           not args.binarize_off)
    #     return WaveoneModel(encoder, binarizer, decoder, args.train_type)
    flow_loss_fn = get_loss_fn(args.flow_loss).cuda()
    reconstructed_loss_fn = get_loss_fn(args.reconstructed_loss).cuda()
    if args.network == "cae":
        return CAE()
    if args.network == "unet":
        return AutoencoderUNet(6, shrink=1)
    if args.network == "opt":
        opt_encoder = LambdaModule(lambda f1, f2, _: f2 - f1)
        opt_binarizer = nn.Identity()  # type: ignore
        opt_decoder = LambdaModule(
            lambda t: {
                "flow": torch.zeros(1),
                "flow_grid": torch.zeros(1),
                "residuals": t[0],
                "context_vec": torch.zeros(1),
                "loss": torch.tensor(0.),
            })
        opt_decoder.num_flows = 1  # type: ignore
        return WaveoneModel(
            opt_encoder,
            opt_binarizer,
            opt_decoder,
            "residual",
            False,
            flow_loss_fn,
            reconstructed_loss_fn,
        )
    if args.network == "small":
        small_encoder = SmallEncoder(6, args.bits)
        small_binarizer = SmallBinarizer(not args.binarize_off)
        small_decoder = SmallDecoder(args.bits, 3)
        return WaveoneModel(
            small_encoder,
            small_binarizer,
            small_decoder,
            args.train_type,
            False,
            flow_loss_fn,
            reconstructed_loss_fn,
        )
    if "resnet" in args.network:
        use_context = "ctx" in args.network
        resnet_encoder = ResNetEncoder(6,
                                       args.bits,
                                       resblocks=args.resblocks,
                                       use_context=use_context)
        resnet_binarizer = SmallBinarizer(not args.binarize_off)
        resnet_decoder = ResNetDecoder(args.bits,
                                       3,
                                       resblocks=args.resblocks,
                                       use_context=use_context,
                                       num_flows=args.num_flows)
        return WaveoneModel(
            resnet_encoder,
            resnet_binarizer,
            resnet_decoder,
            args.train_type,
            use_context,
            flow_loss_fn,
            reconstructed_loss_fn,
        )
    raise ValueError(f"No model type named {args.network}.")
Esempio n. 5
0
def train(args) -> List[nn.Module]:
    log_dir = os.path.join(args.log_dir, args.save_model_name)
    output_dir = os.path.join(args.out_dir, args.save_model_name)
    model_dir = os.path.join(args.model_dir, args.save_model_name)
    create_directories((output_dir, model_dir, log_dir))

    # logging.basicConfig(
    #     filename=os.path.join(log_dir, args.save_model_name + ".out"),
    #     filemode="w",
    #     level=logging.DEBUG,
    # )

    print(args)
    ############### Data ###############

    train_loader = get_master_loader(is_train=True,
                                     root=args.train,
                                     frame_len=4,
                                     sampling_range=12,
                                     args=args)
    eval_loader = get_master_loader(
        is_train=False,
        root=args.eval,
        frame_len=1,
        sampling_range=0,
        args=args,
    )
    writer = SummaryWriter(f"runs/{args.save_model_name}")

    ############### Model ###############
    network = AutoencoderUNet(6, shrink=1) if args.network == 'unet' \
        else CAE() if args.network == 'cae' \
        else LambdaModule(lambda x: x[:, 3:] - x[:, :3])
    network = network.cuda()
    nets: List[nn.Module] = [network]
    names = [args.network]
    solver = optim.Adam(network.parameters()
                        if args.network != 'opt' else [torch.zeros((1, ))],
                        lr=args.lr,
                        weight_decay=args.weight_decay)
    milestones = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
    scheduler = LS.MultiStepLR(solver, milestones=milestones, gamma=0.5)
    msssim_fn = MSSSIM(val_range=1, normalize=True).cuda()
    l1_loss_fn = nn.L1Loss(reduction="mean").cuda()
    l2_loss_fn = nn.MSELoss(reduction="mean").cuda()
    loss_fn = l2_loss_fn if args.loss == 'l2' else l1_loss_fn if args.loss == 'l1' \
        else LambdaModule(lambda a, b: -msssim_fn(a, b))

    ############### Checkpoints ###############

    def resume() -> None:
        for name, net in zip(names, nets):
            if net is not None:
                checkpoint_path = os.path.join(
                    args.model_dir,
                    args.load_model_name,
                    f"{name}.pth",
                )

                print('Loading %s from %s...' % (name, checkpoint_path))
                net.load_state_dict(torch.load(checkpoint_path))

    def save() -> None:
        for name, net in zip(names, nets):
            if net is not None:
                checkpoint_path = os.path.join(
                    model_dir,
                    f'{name}.pth',
                )
                torch.save(net.state_dict(), checkpoint_path)

    def log_flow_context_residuals(
        writer: SummaryWriter,
        residuals: torch.Tensor,
    ) -> None:
        writer.add_scalar("mean_input_residuals",
                          residuals.mean().item(), train_iter)
        writer.add_scalar("max_input_residuals",
                          residuals.max().item(), train_iter)
        writer.add_scalar("min_input_residuals",
                          residuals.min().item(), train_iter)

    ############### Training ###############

    train_iter = 0
    just_resumed = False
    if args.load_model_name:
        print(f'Loading {args.load_model_name}')
        resume()
        just_resumed = True

    def train_loop(
        frames: List[torch.Tensor],
    ) -> Iterator[Tuple[float, Tuple[torch.Tensor, torch.Tensor,
                                     torch.Tensor]]]:
        for net in nets:
            net.train()
        if args.network != 'opt':
            solver.zero_grad()

        reconstructed_frames = []
        reconstructed_frame2 = None

        loss: torch.Tensor = 0.  # type: ignore

        frame1 = frames[0].cuda()
        for frame2 in frames[1:]:
            frame2 = frame2.cuda()

            residuals, reconstructed_frame2 = forward_model(
                network, frame1, frame2)
            reconstructed_frames.append(reconstructed_frame2.cpu())
            loss += loss_fn(reconstructed_frame2, frame2)

            if args.save_max_l2:
                with torch.no_grad():
                    batch_l2 = ((frame2 - frame1 - residuals)**2).mean(
                        dim=-1).mean(dim=-1).mean(dim=-1).cpu()
                    max_batch_l2, max_batch_l2_idx = torch.max(batch_l2, dim=0)
                    max_batch_l2_frames = (
                        frame1[max_batch_l2_idx].cpu(),
                        frame2[max_batch_l2_idx].cpu(),
                        reconstructed_frame2[max_batch_l2_idx].detach().cpu(),
                    )
                    max_l2: float = max_batch_l2.item()  # type: ignore
                    yield max_l2, max_batch_l2_frames

            log_flow_context_residuals(writer, torch.abs(frame2 - frame1))

            frame1 = reconstructed_frame2.detach()

        scores = {
            **eval_scores(frames[:-1], frames[1:], "train_baseline"),
            **eval_scores(frames[1:], reconstructed_frames, "train_reconstructed"),
        }

        if args.network != "opt":
            loss.backward()
            solver.step()

        writer.add_scalar("training_loss", loss.item(), train_iter)
        writer.add_scalar("lr", solver.param_groups[0]["lr"],
                          train_iter)  # type: ignore
        plot_scores(writer, scores, train_iter)
        score_diffs = get_score_diffs(scores, ["reconstructed"], "train")
        plot_scores(writer, score_diffs, train_iter)

    for epoch in range(args.max_train_epochs):
        for frames in train_loader:
            train_iter += 1
            max_epoch_l2, max_epoch_l2_frames = max(train_loop(frames),
                                                    key=lambda x: x[0])

        if args.save_out_img:
            save_tensor_as_img(
                max_epoch_l2_frames[1],
                f"{max_epoch_l2 :.6f}_{epoch}_max_l2_frame",
                args,
            )
            save_tensor_as_img(
                max_epoch_l2_frames[2],
                f"{max_epoch_l2 :.6f}_{epoch}_max_l2_reconstructed",
                args,
            )

        if (epoch + 1) % args.checkpoint_epochs == 0:
            save()

        if just_resumed or ((epoch + 1) % args.eval_epochs == 0):
            run_eval("TVL",
                     eval_loader,
                     network,
                     epoch,
                     args,
                     writer,
                     reuse_reconstructed=True)
            run_eval("TVL",
                     eval_loader,
                     network,
                     epoch,
                     args,
                     writer,
                     reuse_reconstructed=False)
            scheduler.step()  # type: ignore
            just_resumed = False

    print('Training done.')
    logging.shutdown()
    return nets