Пример #1
0
def stylize(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    content_image = utils.load_image(args.content_image,
                                     scale=args.content_scale)
    content_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    if args.model.endswith(".onnx"):
        output = stylize_onnx_caffe2(content_image, args)
    else:
        with torch.no_grad():
            style_model = TransformerNet()
            state_dict = torch.load(args.model)
            # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
            for k in list(state_dict.keys()):
                if re.search(r'in\d+\.running_(mean|var)$', k):
                    del state_dict[k]
            style_model.load_state_dict(state_dict)
            style_model.to(device)
            if args.export_onnx:
                assert args.export_onnx.endswith(
                    ".onnx"), "Export model file should end with .onnx"
                output = torch.onnx._export(style_model, content_image,
                                            args.export_onnx).cpu()
            else:
                output = style_model(content_image).cpu()
    utils.save_image(args.output_image, output[0])
Пример #2
0
def stylize(has_cuda,
            content_image,
            model,
            output_image_path=None,
            content_scale=None):
    device = torch.device("cuda" if has_cuda else "cpu")

    # content_image = utils.load_image(content_image_path, scale=content_scale)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    with torch.no_grad():
        style_model = TransformerNet()
        state_dict = torch.load(model)
        # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        output = style_model(content_image).cpu()
    # utils.save_image(output_image_path, output[0])
    output = utils.un_normalize_batch(output)

    return output[0]
Пример #3
0
def stylize(has_cuda, left_image, right_image, model):
    device = torch.device("cuda" if has_cuda else "cpu")

    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    left_image = content_transform(left_image)
    left_image = left_image.unsqueeze(0).to(device)
    right_image = content_transform(right_image)
    right_image = right_image.unsqueeze(0).to(device)

    with torch.no_grad():
        style_model = TransformerNet()
        state_dict = torch.load(model)
        # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        output_left, output_right = style_model(left_image, right_image)

    # output_left = output_left.cpu()
    # output_right = output_right.cpu()
    output_left = utils.un_normalize_batch(output_left)
    output_right = utils.un_normalize_batch(output_right)

    return output_left[0], output_right[0]
Пример #4
0
def get_model(path, device="cuda:0"):
    style_model = TransformerNet()
    state_dict = torch.load(path)
    # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
    for k in list(state_dict.keys()):
        if re.search(r'in\d+\.running_(mean|var)$', k):
            del state_dict[k]
    style_model.load_state_dict(state_dict)
    style_model.to(device)

    return style_model
Пример #5
0
def transfer(content_img_stream, style):
    device = torch.device("cpu")

    # content_image = utils.load_image(args.content_image, scale=args.content_scale)
    content_image = Image.open(content_img_stream)
    content_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    with torch.no_grad():
        style_model = TransformerNet()
        path = 'fast_neural_style/saved_models/' + style + '.pth'
        # state_dict = torch.load('fast_neural_style/saved_models/mosaic.pth')
        state_dict = torch.load(path)
        # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        output = style_model(content_image).cpu()
    return misc.toimage(output[0])
Пример #6
0
def train(dataset_path,
          style_image_path,
          save_model_dir,
          has_cuda,
          epochs=2,
          image_limit=None,
          checkpoint_model_dir=None,
          image_size=256,
          style_size=None,
          seed=42,
          content_weight=1,
          style_weight=10,
          temporal_weight=10,
          tv_weight=1e-3,
          lr=1e-3,
          log_interval=500,
          checkpoint_interval=2000,
          model_filename="myModel"):
    device = torch.device("cuda" if has_cuda else "cpu")
    np.random.seed(seed)
    torch.manual_seed(seed)
    batch_size = 1  # needs to be 1, batch is created using MyDataSet
    loss_list = []
    loss_filename = model_filename + '_losses.txt'

    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        # transforms.Resize(image_size),
        # transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    # videos_list = os.listdir(dataset_path)
    # train_dataset = {}
    # train_loader = {}
    # for video_name in videos_list:
    #     video_dataset_path = os.path.join(dataset_path, video_name)
    #     train_dataset[video_name] = MyDataSet(video_dataset_path, transform)
    #     train_loader[video_name] = DataLoader(train_dataset[video_name], batch_size=batch_size)

    # video_dataset_path = os.path.join(dataset_path, "Monkaa")  # dataset_path = "Data/Monkaa"
    train_dataset_path = os.path.join(dataset_path, "frames_cleanpass")
    flow_path = os.path.join(dataset_path, "optical_flow_resized")
    train_dataset = MyDataSet(
        train_dataset_path, flow_path, transform,
        image_limit=image_limit)  # remove if using all datasets
    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True)

    transformer_net = TransformerNet().to(device)
    optimizer = Adam(transformer_net.parameters(), lr)

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    style_image = utils.load_image(style_image_path, size=style_size)
    style_image = style_transform(style_image)
    style_image = style_image.repeat(batch_size, 1, 1, 1).to(device)

    features_style = vgg(style_image)
    gram_style = [utils.gram_matrix(y) for y in features_style]
    for e in range(epochs):
        batch_num = 0
        # for video_name in videos_list:
        for frames_and_flow in tqdm(train_loader):
            (frames_curr_lr, flow_lr, frames_next_lr) = frames_and_flow
            batch_num += 1
            for i in [0, 1]:  # Left,  Right
                to_save = (batch_num + 1) % (checkpoint_interval / 4) == 0
                frame_curr = frames_curr_lr[i]
                frame_next = frames_next_lr[i]
                flow = flow_lr[i]
                batch_size = len(frame_curr)
                optimizer.zero_grad()
                frame_curr_to_save = frame_curr.permute(2, 3, 1, 0).squeeze(3)
                frame_next_to_save = frame_next.permute(2, 3, 1, 0).squeeze(3)
                namefile_frame_curr = 'test_images/frame_curr/frame_curr_epo' + str(
                    e) + 'batch_num' + str(batch_num) + '.png'
                namefile_frame_next = 'test_images/frame_next/frame_next_epo' + str(
                    e) + 'batch_num' + str(batch_num) + '.png'
                if to_save:
                    utils.save_image_loss(frame_curr_to_save,
                                          namefile_frame_curr)
                    utils.save_image_loss(frame_next_to_save,
                                          namefile_frame_next)
                frame_curr = frame_curr.to(device)
                frame_next = frame_next.to(device)
                frame_style = transformer_net(frame_curr)
                frame_next_style = transformer_net(frame_next)
                # TODO: input frames to net as batch (frame_curr, frame_next)
                features_frame = vgg(frame_curr)
                features_frame_style = vgg(frame_style)
                # print(frame_curr.shape)
                content_loss = losses.content_loss(features_frame,
                                                   features_frame_style)
                style_loss = losses.style_loss(features_frame_style,
                                               gram_style, batch_size)
                if to_save:
                    temporal_loss = losses.temporal_loss(frame_style,
                                                         frame_next_style,
                                                         flow,
                                                         device,
                                                         to_save=to_save,
                                                         batch_num=batch_num,
                                                         e=e)
                else:
                    temporal_loss = losses.temporal_loss(
                        frame_style, frame_next_style, flow, device)
                tv_loss = losses.tv_loss(frame_curr)

                total_loss = (content_weight * content_loss +
                              style_weight * style_loss +
                              temporal_weight * temporal_loss)
                total_loss.backward()
                optimizer.step()

            if (
                    batch_num + 1
            ) % log_interval == 0:  # TODO: Choose between TQDM and printing
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttemporal: {:.6f}" \
                       "\ttotal: {:.6f}".format(
                        time.ctime(), e + 1, batch_num + 1, len(train_dataset),
                        content_loss.item(),
                        style_loss.item(),
                        temporal_loss.item(),
                        total_loss.item()
                        )
                # print(mesg)
                losses_string = (str(content_loss.item()) + "," +
                                 str(style_loss.item()) + "," +
                                 str(temporal_loss.item()) + "," +
                                 str(total_loss.item()))
                loss_list.append(losses_string)
                utils.save_loss_file(loss_list, loss_filename)

            if (checkpoint_model_dir is not None
                    and (batch_num + 1) % checkpoint_interval == 0):
                transformer_net.eval().cpu()
                ckpt_model_filename = (model_filename + "_ckpt_epoch_" +
                                       str(e + 1) + "_batch_id_" +
                                       str(batch_num + 1) + ".pth")
                ckpt_model_path = os.path.join(checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer_net.state_dict(), ckpt_model_path)
                utils.save_loss_file(loss_list, loss_filename)
                transformer_net.to(device).train()

    # save model
    transformer_net.eval().cpu()
    # save_model_filename = "epoch_" + str(epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
    #     content_weight) + "_" + str(style_weight) + ".model"
    save_model_path = os.path.join(save_model_dir, model_filename + ".pth")
    torch.save(transformer_net.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Пример #7
0
def train(dataset_path,
          style_image_path,
          save_model_dir,
          has_cuda,
          epochs=2,
          image_limit=None,
          checkpoint_model_dir=None,
          image_size=256,
          style_size=None,
          seed=42,
          content_weight=1e5,
          style_weight=1e10,
          lr=1e-3,
          log_interval=500,
          checkpoint_interval=2000):
    device = torch.device("cuda" if has_cuda else "cpu")

    np.random.seed(seed)
    torch.manual_seed(seed)
    batch_size = 1  # needs to be 1, batch is created using MyDataSet

    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    # videos_list = os.listdir(dataset_path)
    # train_dataset = {}
    # train_loader = {}
    # for video_name in videos_list:
    #     video_dataset_path = os.path.join(dataset_path, video_name)
    #     train_dataset[video_name] = MyDataSet(video_dataset_path, transform)
    #     train_loader[video_name] = DataLoader(train_dataset[video_name], batch_size=batch_size)

    # video_dataset_path = os.path.join(dataset_path, "Monkaa")  # dataset_path = "Data/Monkaa"
    train_dataset = MyDataSet(
        dataset_path, transform,
        image_limit=image_limit)  # remove if using all datasets
    train_loader = DataLoader(train_dataset, batch_size=batch_size)

    transformer_net = TransformerNet().to(device)
    optimizer = Adam(transformer_net.parameters(), lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    style_image = utils.load_image(style_image_path, size=style_size)
    style_image = style_transform(style_image)
    style_image = style_image.repeat(batch_size, 1, 1, 1).to(device)

    features_style = vgg(style_image)
    gram_style = [utils.gram_matrix(y) for y in features_style]
    for e in range(epochs):
        batch_num = 0
        # for video_name in videos_list:
        for frames_curr_next in tqdm(train_loader):
            (frames_curr, frames_next) = frames_curr_next
            batch_num += 1
            for frame in frames_curr:  # Left + Right
                batch_size = len(frame)
                optimizer.zero_grad()

                frame = frame.to(device)
                frame_style = transformer_net(frame)

                features_frame = vgg(frame)
                features_frame_style = vgg(frame_style)

                content_loss = losses.content_loss(features_frame,
                                                   features_frame_style)
                style_loss = losses.style_loss(features_frame_style,
                                               gram_style, batch_size)

                total_loss = content_weight * content_loss + style_weight * style_loss
                total_loss.backward()
                optimizer.step()

                if (
                        batch_num + 1
                ) % log_interval == 0:  # TODO: Choose between TQDM and printing
                    mesg = "\n{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                        time.ctime(), e + 1, batch_num, 2 * len(train_dataset),
                        content_loss.item(), style_loss.item(), total_loss)
                    print(mesg)

                if checkpoint_model_dir is not None and (
                        batch_num + 1) % checkpoint_interval == 0:
                    transformer_net.eval().cpu()
                    ckpt_model_filename = "ckpt_epoch_" + str(
                        e) + "_batch_id_" + str(batch_num + 1) + ".pth"
                    ckpt_model_path = os.path.join(checkpoint_model_dir,
                                                   ckpt_model_filename)
                    torch.save(transformer_net.state_dict(), ckpt_model_path)
                    transformer_net.to(device).train()

    # save model
    transformer_net.eval().cpu()
    # save_model_filename = "epoch_" + str(epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
    #     content_weight) + "_" + str(style_weight) + ".model"
    save_model_filename = "myModel.pth"
    save_model_path = os.path.join(save_model_dir, save_model_filename)
    torch.save(transformer_net.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Пример #8
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

            if args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Пример #9
0
def train(dataset_path,
          style_image_path,
          save_model_dir,
          has_cuda,
          epochs=2,
          image_limit=None,
          checkpoint_model_dir=None,
          image_size=(360, 640),
          style_size=None,
          seed=42,
          content_weight=1,
          style_weight=10,
          temporal_weight=10,
          tv_weight=1e-3,
          disp_weight=1e-3,
          lr=1e-3,
          log_interval=500,
          checkpoint_interval=2000,
          model_filename="myModel",
          model_init=None):
    device = torch.device("cuda" if has_cuda else "cpu")
    np.random.seed(seed)
    torch.manual_seed(seed)
    batch_size = 1  # needs to be 1, batch is created using MyDataSet
    loss_list = []
    loss_filename = model_filename + '_losses.txt'

    transform = transforms.Compose([
        transforms.Resize(image_size),
        # transforms.Resize(image_size),
        # transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    # videos_list = os.listdir(dataset_path)
    # train_dataset = {}
    # train_loader = {}
    # for video_name in videos_list:
    #     video_dataset_path = os.path.join(dataset_path, video_name)
    #     train_dataset[video_name] = MyDataSet(video_dataset_path, transform)
    #     train_loader[video_name] = DataLoader(train_dataset[video_name], batch_size=batch_size)

    # video_dataset_path = os.path.join(dataset_path, "Monkaa")  # dataset_path = "Data/Monkaa"
    train_dataset_path = os.path.join(dataset_path, "frames_cleanpass")
    flow_path = os.path.join(dataset_path, "optical_flow_resized")
    train_dataset = MyDataSet(
        train_dataset_path, flow_path, transform,
        image_limit=image_limit)  # remove if using all datasets
    train_loader = DataLoader(train_dataset, batch_size=batch_size)

    if model_init is not None:
        transformer_net = TransformerNet()
        state_dict = torch.load(model_init)
        # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
        # for k in list(state_dict.keys()):
        #     if re.search(r'in\d+\.running_(mean|var)$', k):
        #         del state_dict[k]
        transformer_net.load_state_dict(state_dict)
        transformer_net.to(device)
    else:
        transformer_net = TransformerNet().to(device)

    optimizer = Adam(transformer_net.parameters(), lr)

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    style_image = utils.load_image(style_image_path, size=style_size)
    style_image = style_transform(style_image)
    style_image = style_image.repeat(batch_size, 1, 1, 1).to(device)

    features_style = vgg(style_image)
    gram_style = [utils.gram_matrix(y) for y in features_style]
    for e in range(epochs):
        batch_num = 0
        # for video_name in videos_list:
        for frames_and_flow in tqdm(train_loader):
            (frames_curr_lr, flow_lr, frames_next_lr) = frames_and_flow
            batch_num += 1
            total_loss = 0

            to_save = (batch_num + 1) % checkpoint_interval == 0
            optimizer.zero_grad()

            frames_left_batch = torch.cat(
                (frames_curr_lr[0], frames_next_lr[0]), 0)
            frames_right_batch = torch.cat(
                (frames_curr_lr[1], frames_next_lr[1]), 0)

            frames_left_batch = frames_left_batch.to(device)
            frames_right_batch = frames_right_batch.to(device)

            frame_style_left, frame_style_right = transformer_net(
                frames_left_batch,
                frames_right_batch)  # Two batches 2 x 3 x H x W
            frame_curr_style_combined = (frame_style_left[0, ::].unsqueeze(0),
                                         frame_style_right[0, ::].unsqueeze(0))
            frame_next_style_combined = (frame_style_left[1, ::].unsqueeze(0),
                                         frame_style_right[1, ::].unsqueeze(0))

            disparity_loss_l2r = losses.disparity_loss(
                frame_curr_style_combined[0], frame_curr_style_combined[1],
                disparity[0], device)
            disparity_loss_r2l = losses.disparity_loss(
                frame_curr_style_combined[1], frame_curr_style_combined[0],
                disparity[1], device, to_save, batch_num, e)
            total_loss = disp_weight * (disparity_loss_l2r +
                                        disparity_loss_r2l)
            # total_loss = disp_weight * disparity_loss_l2r

            for i in [0, 1]:  # Left,  Right
                to_save = (batch_num + 1) % checkpoint_interval == 0
                frame_curr = frames_curr_lr[i]
                frame_next = frames_next_lr[i]
                flow = flow_lr[i]
                batch_size = len(frame_curr)
                frame_style = frame_curr_style_combined[i]
                frame_next_style = frame_next_style_combined[i]
                features_frame = vgg(frame_curr)
                features_frame_style = vgg(frame_style)
                content_loss = losses.content_loss(features_frame,
                                                   features_frame_style)
                style_loss = losses.style_loss(features_frame_style,
                                               gram_style, batch_size)
                if to_save:
                    temporal_loss = losses.temporal_loss(frame_style,
                                                         frame_next_style,
                                                         flow,
                                                         device,
                                                         to_save=to_save,
                                                         batch_num=batch_num,
                                                         e=e)
                else:
                    temporal_loss = losses.temporal_loss(
                        frame_style, frame_next_style, flow, device)
                tv_loss = losses.tv_loss(frame_curr)
                total_loss = total_loss + (content_weight * content_loss +
                                           style_weight * style_loss +
                                           temporal_weight * temporal_loss)

                # Save stuff:
                frame_curr_to_save = frame_curr.permute(2, 3, 1, 0).squeeze(3)
                frame_next_to_save = frame_next.permute(2, 3, 1, 0).squeeze(3)
                namefile_frame_curr = 'test_images/frame_curr/frame_curr_epo' + str(
                    e) + 'batch_num' + str(batch_num) + "eye" + str(i) + '.png'
                namefile_frame_next = 'test_images/frame_next/frame_next_epo' + str(
                    e) + 'batch_num' + str(batch_num) + "eye" + str(i) + '.png'
                namefile_frame_flow = 'test_images/frame_flow/frame_next_epo' + str(
                    e) + 'batch_num' + str(batch_num) + "eye" + str(i) + '.png'
                frame_flow, _ = utils.apply_flow(frame_curr, flow)
                if to_save:
                    utils.save_image_loss(frame_curr_to_save,
                                          namefile_frame_curr)
                    utils.save_image_loss(frame_next_to_save,
                                          namefile_frame_next)
                    utils.save_image_loss(frame_flow, namefile_frame_flow)

                frame_curr = frame_curr.to(device)
                frame_next = frame_next.to(device)
                frames_batch = torch.cat((frame_curr, frame_next))
                frames_style_batch = transformer_net(frames_batch)
                # frame_style = transformer_net(frame_curr)
                # frame_next_style = transformer_net(frame_next)
                frame_style = frames_style_batch[0, ::].unsqueeze(
                    0)  # add batch dim
                frame_next_style = frames_style_batch[1, ::].unsqueeze(
                    0)  # add batch dim
                features_frame = vgg(frame_curr)
                features_frame_style = vgg(frame_style)
                # print(frame_curr.shape)
                content_loss = losses.content_loss(features_frame,
                                                   features_frame_style)
                style_loss = losses.style_loss(features_frame_style,
                                               gram_style, batch_size)
                if to_save:
                    temporal_loss = losses.temporal_loss(frame_style,
                                                         frame_next_style,
                                                         flow,
                                                         device,
                                                         to_save=to_save,
                                                         batch_num=batch_num,
                                                         e=e)
                else:
                    temporal_loss = losses.temporal_loss(
                        frame_style, frame_next_style, flow, device)
                tv_loss = losses.tv_loss(frame_curr)

                total_loss = (content_weight * content_loss +
                              style_weight * style_loss +
                              temporal_weight * temporal_loss)
                total_loss.backward()
                optimizer.step()

            if (
                    batch_num + 1
            ) % log_interval == 0:  # TODO: Choose between TQDM and printing
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttemporal: {:.6f}" \
                       "\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, batch_num + 1, len(train_dataset),
                    content_loss.item(),
                    style_loss.item(),
                    temporal_loss.item(),
                    total_loss.item()
                )
                # print(mesg)
                losses_string = (str(content_loss.item()) + "," +
                                 str(style_loss.item()) + "," +
                                 str(temporal_loss.item()) + "," +
                                 str(total_loss.item()))
                loss_list.append(losses_string)
                utils.save_loss_file(loss_list, loss_filename)

            if (checkpoint_model_dir is not None
                    and (batch_num + 1) % checkpoint_interval == 0):
                transformer_net.eval().cpu()
                ckpt_model_filename = (model_filename + "_ckpt_epoch_" +
                                       str(e + 1) + "_batch_id_" +
                                       str(batch_num + 1) + ".pth")
                ckpt_model_path = os.path.join(checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer_net.state_dict(), ckpt_model_path)
                transformer_net.to(device).train()

    # save model
    transformer_net.eval().cpu()
    save_model_path = os.path.join(save_model_dir, model_filename + ".pth")
    torch.save(transformer_net.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)