criterion = nn.MSELoss(size_average=True).cuda()
    recon_loss = 0
    refl_loss = 0
    depth_loss = 0
    shape_loss = 0
    lights_loss = 0
    shad_loss = 0
    depth_normals_loss = 0

    masks = []

    for ind, tensors in enumerate(loader):
        tensors = [Variable(t.float().cuda(async=True)) for t in tensors]
        inp, mask, refl_targ, depth_targ, shape_targ, lights_targ, shad_targ = tensors
        depth_normals_targ = pipeline.depth_to_normals(depth_targ.unsqueeze(1),
                                                       mask=mask)
        # depth_normals_targ

        depth_targ = depth_targ.unsqueeze(1).repeat(1, 3, 1, 1)
        shad_targ = shad_targ.unsqueeze(1).repeat(1, 3, 1, 1)

        recon, refl_pred, depth_pred, shape_pred, lights_pred, shad_pred = model.forward(
            inp, mask)
        # relit = pipeline.relight(model.shader, shape_pred, lights_pred, 6)
        # relit_mean = relit.mean(0).squeeze()

        depth_normals_pred = pipeline.depth_to_normals(depth_pred, mask=mask)

        depth_pred = depth_pred.repeat(1, 3, 1, 1)
        shad_pred = shad_pred.repeat(1, 3, 1, 1)
def visualize_composer_alt(model, loader, save_path, epoch, raw=False):
    model.train(mode=False)
    render = pipeline.Render()
    images = []

    criterion = nn.MSELoss(size_average=True).cuda()
    recon_loss = 0
    refl_loss = 0
    depth_loss = 0
    shape_loss = 0
    lights_loss = 0
    shad_loss = 0
    depth_normals_loss = 0

    masks = []

    for ind, tensors in enumerate(loader):
        tensors = [
            Variable(t.float().cuda(non_blocking=True)) for t in tensors
        ]
        inp, mask, refl_targ, depth_targ, shape_targ, lights_targ, shad_targ = tensors
        depth_normals_targ = pipeline.depth_to_normals(depth_targ.unsqueeze(1),
                                                       mask=mask)
        # depth_normals_targ

        depth_targ = depth_targ.unsqueeze(1).repeat(1, 3, 1, 1)
        shad_targ = shad_targ.unsqueeze(1).repeat(1, 3, 1, 1)

        (
            recon,
            refl_pred,
            depth_pred,
            shape_pred,
            lights_pred,
            shad_pred,
        ) = model.forward(inp, mask)

        ####
        shad_pred = model.shader(shape_pred, lights_pred)
        print("shad_pred: ", shad_pred.size())
        # shad_pred = shad_pred.repeat(1,3,1,1)

        # relit = pipeline.relight(model.shader, shape_pred, lights_pred, 6)
        # relit_mean = relit.mean(0).squeeze()

        depth_normals_pred = pipeline.depth_to_normals(depth_pred, mask=mask)

        depth_pred = depth_pred.repeat(1, 3, 1, 1)
        shad_pred = shad_pred.repeat(1, 3, 1, 1)

        # recon_loss += criterion(recon, inp).data[0]
        # refl_loss += criterion(refl_pred, refl_targ).data[0]
        # depth_loss += criterion(depth_pred, depth_targ).data[0]
        # shape_loss += criterion(shape_pred, shape_targ).data[0]
        # lights_loss += criterion(lights_pred, lights_targ).data[0]
        # shad_loss += criterion(shad_pred, shad_targ).data[0]
        # depth_normals_loss += criterion(shape_pred, depth_normals_pred.detach()).data[0]
        recon_loss += criterion(recon, inp).item()
        refl_loss += criterion(refl_pred, refl_targ).item()
        depth_loss += criterion(depth_pred, depth_targ).item()
        shape_loss += criterion(shape_pred, shape_targ).item()
        lights_loss += criterion(lights_pred, lights_targ).item()
        shad_loss += criterion(shad_pred, shad_targ).item()
        depth_normals_loss += criterion(shape_pred,
                                        depth_normals_pred.detach()).item()

        lights_rendered_targ = render.vis_lights(lights_targ, verbose=False)
        lights_rendered_pred = render.vis_lights(lights_pred, verbose=False)
        # pdb.set_trace()

        shape_targ = pipeline.vector_to_image(shape_targ)
        shape_pred = pipeline.vector_to_image(shape_pred)

        depth_normals_targ = pipeline.vector_to_image(depth_normals_targ)
        depth_normals_pred = pipeline.vector_to_image(depth_normals_pred)

        splits = []
        # pdb.set_trace()
        for tensor in [
                inp,
                refl_targ,
                depth_targ,
                depth_normals_targ,
                shape_targ,
                shad_targ,
                lights_rendered_targ,
                recon,
                refl_pred,
                depth_pred,
                depth_normals_pred,
                shape_pred,
                shad_pred,
                lights_rendered_pred,
        ]:
            # relit[0], relit[1], relit[2], relit[3], relit[4], relit[5], relit_mean]:
            splits.append([img.squeeze() for img in tensor.data.split(1)])

        masks.append(mask)

        # pdb.set_trace()
        # print shad_targ.size()
        # print shad_pred.size()
        # print [len(sublist) for sublist in splits]
        splits = [
            sublist[ind] for ind in range(len(splits[0])) for sublist in splits
        ]
        images.extend(splits)

    labels = [
        "recon_targ",
        "refl_targ",
        "depth_targ",
        "depth_normals_targ",
        "shape_targ",
        "shad_targ",
        "lights_targ",
        "recon_pred",
        "refl_pred",
        "depth_pred",
        "depth_normals_pred",
        "shape_pred",
        "shad_pred",
        "lights_pred",
    ]

    masks = [i.split(1) for i in masks]
    masks = [
        item.squeeze()[0].unsqueeze(0).data.cpu().numpy().transpose(1, 2, 0)
        for sublist in masks for item in sublist
    ]

    if epoch == 0:
        raw_path = os.path.join(save_path, "raw_original")
    else:
        raw_path = os.path.join(save_path, "raw_trained")

    if raw:
        save_raw(images, masks, labels, raw_path)

    recon_loss /= float(ind)
    refl_loss /= float(ind)
    depth_loss /= float(ind)
    shape_loss /= float(ind)
    lights_loss /= float(ind)
    shad_loss /= float(ind)
    depth_normals_loss /= float(ind)

    # pdb.set_trace()
    grid = torchvision.utils.make_grid(images, nrow=7).cpu().numpy().transpose(
        1, 2, 0)
    grid = np.clip(grid, 0, 1)
    fullpath = os.path.join(save_path, str(epoch) + ".png")
    imageio.imsave(fullpath, grid)
    # torchvision.utils.save_image(grid, os.path.join(save_path, 'shader.png'))
    return [
        recon_loss,
        refl_loss,
        depth_loss,
        shape_loss,
        lights_loss,
        shad_loss,
        depth_normals_loss,
    ]