Ejemplo n.º 1
0
def train_ofb(args):
    train_dataset = dataset.DAVISDataset(args.dataset, use_flow=True)
    train_loader = DataLoader(train_dataset, batch_size=1)

    transformer = transformer_net.TransformerNet(args.pad_type)
    transformer.train()
    optimizer = torch.optim.Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16()
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))
    vgg.eval()

    if args.cuda:
        transformer.cuda()
    vgg.cuda()
    mse_loss.cuda()

    style = utils.tensor_load_resize(args.style_image, args.style_size)
    style = style.unsqueeze(0)
    print("=> Style image size: " + str(style.size()))
    print("=> Pixel OFB loss weight: %f" % args.time_strength)

    style = utils.preprocess_batch(style)
    if args.cuda: style = style.cuda()
    style = utils.subtract_imagenet_mean_batch(style)
    features_style = vgg(style)
    gram_style = [utils.gram_matrix(y).detach() for y in features_style]

    train_loader.dataset.reset()
    transformer.train()
    transformer.cuda()
    agg_content_loss = agg_style_loss = agg_pixelofb_loss = 0.
    iters = 0
    anormaly = False
    elapsed_time = 0
    for batch_id, (x, flow, conf) in enumerate(tqdm(train_loader)):
        x, flow, conf = x[0], flow[0], conf[0]
        iters += 1

        optimizer.zero_grad()
        x = utils.preprocess_batch(x)  # (N, 3, 256, 256)
        if args.cuda:
            x = x.cuda()
            flow = flow.cuda()
            conf = conf.cuda()
        y = transformer(x)  # (N, 3, 256, 256)

        begin_time = time.time()
        warped_y, warped_y_mask = warp(y[1:], flow)
        warped_y = warped_y.detach()
        warped_y_mask *= conf
        pixel_ofb_loss = args.time_strength * weighted_mse(
            y[:-1], warped_y, warped_y_mask)
        pixel_ofb_loss.backward()
        elapsed_time += time.time() - begin_time
        if batch_id > 1000: break
    print(elapsed_time / float(batch_id + 1))
Ejemplo n.º 2
0
def create_style_model(style_number=0):
    model_file = glob.glob('styles/*.pth')[style_number]
    transformer = transformer_net.TransformerNet()
    # load model
    state_dict = torch.load(model_file)
    # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
    for k in list(state_dict.keys()):
        if re.search(r'in\d+\.running_(mean|var)$', k):
            del state_dict[k]
    transformer.load_state_dict(state_dict)
    transformer.to(0)
    transformer.eval()
    return transformer
Ejemplo n.º 3
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        kwargs = {'num_workers': 0, 'pin_memory': False}
    else:
        kwargs = {}

    if args.model_type == "rnn":
        transformer = transformer_net.TransformerRNN(args.pad_type)
        seq_size = 4
    else:
        transformer = transformer_net.TransformerNet(args.pad_type)
        seq_size = 2

    train_dataset = dataset.DAVISDataset(args.dataset,
                                         seq_size=seq_size,
                                         use_flow=args.flow)
    train_loader = DataLoader(train_dataset, batch_size=1, **kwargs)

    if args.model_type == "rnn":
        transformer = transformer_net.TransformerRNN(args.pad_type)
    else:
        transformer = transformer_net.TransformerNet(args.pad_type)
    model_path = args.init_model
    print("=> Load from model file %s" % model_path)
    transformer.load_state_dict(torch.load(model_path))
    transformer.train()
    if args.model_type == "rnn":
        transformer.conv1 = transformer_net.ConvLayer(6,
                                                      32,
                                                      kernel_size=9,
                                                      stride=1,
                                                      pad_type=args.pad_type)
    optimizer = torch.optim.Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()
    l1_loss = torch.nn.SmoothL1Loss()

    vgg = Vgg16()
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model, "vgg16.weight")))
    vgg.eval()

    if args.cuda:
        transformer.cuda()
        vgg.cuda()
        mse_loss.cuda()
        l1_loss.cuda()

    style = utils.tensor_load_resize(args.style_image, args.style_size)
    style = style.unsqueeze(0)
    print("=> Style image size: " + str(style.size()))
    print("=> Pixel OFB loss weight: %f" % args.time_strength)

    style = utils.preprocess_batch(style)
    if args.cuda: style = style.cuda()
    utils.tensor_save_bgrimage(
        style[0].detach(), os.path.join(args.save_model_dir,
                                        'train_style.jpg'), args.cuda)
    style = utils.subtract_imagenet_mean_batch(style)
    features_style = vgg(style)
    gram_style = [utils.gram_matrix(y).detach() for y in features_style]

    for e in range(args.epochs):
        train_loader.dataset.reset()
        transformer.train()
        transformer.cuda()
        agg_content_loss = agg_style_loss = agg_pixelofb_loss = 0.
        iters = 0
        anormaly = False
        for batch_id, (x, flow, conf) in enumerate(train_loader):
            x, flow, conf = x[0], flow[0], conf[0]
            iters += 1

            optimizer.zero_grad()
            x = utils.preprocess_batch(x)  # (N, 3, 256, 256)
            if args.cuda:
                x = x.cuda()
                flow = flow.cuda()
                conf = conf.cuda()
            y = transformer(x)  # (N, 3, 256, 256)

            xc = center_crop(x.detach(), y.size(2), y.size(3))

            vgg_y = utils.subtract_imagenet_mean_batch(y)
            vgg_x = utils.subtract_imagenet_mean_batch(xc)

            features_y = vgg(vgg_y)
            features_xc = vgg(vgg_x)

            #content target
            f_xc_c = features_xc[2].detach()
            # content
            f_c = features_y[2]

            #content_feature_target = center_crop(f_xc_c, f_c.size(2), f_c.size(3))
            content_loss = args.content_weight * mse_loss(f_c, f_xc_c)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_s = gram_style[m]
                gram_y = utils.gram_matrix(features_y[m])
                batch_style_loss = 0
                for n in range(gram_y.shape[0]):
                    batch_style_loss += args.style_weight * mse_loss(
                        gram_y[n], gram_s[0])
                style_loss += batch_style_loss / gram_y.shape[0]

            warped_y, warped_y_mask = warp(y[1:], flow)
            warped_y = warped_y.detach()
            warped_y_mask *= conf
            pixel_ofb_loss = args.time_strength * weighted_mse(
                y[:-1], warped_y, warped_y_mask)

            total_loss = content_loss + style_loss + pixel_ofb_loss

            total_loss.backward()
            optimizer.step()

            if (batch_id + 1) % 100 == 0:
                prefix = args.save_model_dir + "/"
                idx = (batch_id + 1) // 100
                flow_image = flow_to_color(
                    flow[0].detach().cpu().numpy().transpose(1, 2, 0))
                utils.save_image(prefix + "forward_flow_%d.png" % idx,
                                 flow_image)
                warped_x, warped_x_mask = warp(x[1:], flow)
                warped_x = warped_x.detach()
                warped_x_mask *= conf
                for i in range(2):
                    utils.tensor_save_bgrimage(
                        y.data[i], prefix + "out_%d-%d.png" % (idx, i),
                        args.cuda)
                    utils.tensor_save_bgrimage(
                        x.data[i], prefix + "in_%d-%d.png" % (idx, i),
                        args.cuda)
                    if i < warped_y.shape[0]:
                        utils.tensor_save_bgrimage(
                            warped_y.data[i],
                            prefix + "wout_%d-%d.png" % (idx, i), args.cuda)
                        utils.tensor_save_bgrimage(
                            warped_x.data[i],
                            prefix + "win_%d-%d.png" % (idx, i), args.cuda)
                        utils.tensor_save_image(
                            prefix + "conf_%d-%d.png" % (idx, i),
                            warped_x_mask.data[i])

            agg_content_loss += content_loss.data
            agg_style_loss += style_loss.data
            agg_pixelofb_loss += pixel_ofb_loss.data

            agg_total = agg_content_loss + agg_style_loss + agg_pixelofb_loss
            mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\tpixel ofb: {:.6f}\ttotal: {:.6f}".format(
                time.ctime(), e + 1, batch_id + 1, len(train_loader),
                agg_content_loss / iters, agg_style_loss / iters,
                agg_pixelofb_loss / iters, agg_total / iters)
            print(mesg)
            agg_content_loss = agg_style_loss = agg_pixelofb_loss = 0.0
            iters = 0

        # save model
        transformer.eval()
        transformer.cpu()
        save_model_filename = "epoch_" + str(e) + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
        save_model_path = os.path.join(args.save_model_dir,
                                       save_model_filename)
        torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Ejemplo n.º 4
0
def train_fdb(args):
    transformer = transformer_net.TransformerNet(args.pad_type)
    train_dataset = dataset.DAVISDataset(args.dataset,
                                         seq_size=2,
                                         use_flow=args.flow)
    train_loader = DataLoader(train_dataset, batch_size=1)

    transformer.train()
    optimizer = torch.optim.Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16()
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))
    vgg.eval()

    if args.cuda:
        transformer.cuda()
        vgg.cuda()
        mse_loss.cuda()

    style = utils.tensor_load_resize(args.style_image, args.style_size)
    style = style.unsqueeze(0)
    print("=> Style image size: " + str(style.size()))

    style = utils.preprocess_batch(style)
    if args.cuda: style = style.cuda()
    style = utils.subtract_imagenet_mean_batch(style)
    features_style = vgg(style)
    gram_style = [utils.gram_matrix(y).detach() for y in features_style]

    train_loader.dataset.reset()
    agg_content_loss = agg_style_loss = agg_pixelfdb_loss = agg_featurefdb_loss = 0.
    iters = 0
    elapsed_time = 0
    for batch_id, (x, flow, conf) in enumerate(tqdm(train_loader)):
        x = x[0]
        iters += 1

        optimizer.zero_grad()
        x = utils.preprocess_batch(x)  # (N, 3, 256, 256)
        if args.cuda: x = x.cuda()
        y = transformer(x)  # (N, 3, 256, 256)

        xc = center_crop(x.detach(), y.shape[2], y.shape[3])

        y = utils.subtract_imagenet_mean_batch(y)
        xc = utils.subtract_imagenet_mean_batch(xc)

        features_y = vgg(y)
        features_xc = vgg(xc)

        # FDB
        begin_time = time.time()
        pixel_fdb_loss = mse_loss(y[1:] - y[:-1], xc[1:] - xc[:-1])
        # temporal content: 16th
        feature_fdb_loss = mse_loss(features_y[2][1:] - features_y[2][:-1],
                                    features_xc[2][1:] - features_xc[2][:-1])
        pixel_fdb_loss.backward()
        elapsed_time += time.time() - begin_time

        if batch_id > 1000: break
    print(elapsed_time / float(batch_id + 1))
Ejemplo n.º 5
0
import re
import argparse

import torch
import transformer_net


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Utility to convert PyTorch 0.3.x style transfer models to 0.4.1+")
    parser.add_argument("--model", type=str, required=True, help="model file to convert")
    parser.add_argument("--output", type=str, required=True, help="location for new model")
    args = parser.parse_args()

    style_model = transformer_net.TransformerNet()
    state_dict = torch.load(args.model)

    # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
    for k in list(state_dict.keys()):
        if re.search(r'in\d+\.running_(mean|var)$', k):
            del state_dict[k]

    style_model.load_state_dict(state_dict)
    print(style_model)

    style_model.eval()
    style_model.cpu()

    torch.save(style_model.state_dict(), args.output)

Ejemplo n.º 6
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    kwargs = {'num_workers': 0, 'pin_memory': False}

    if args.model_type == "rnn":
        transformer = transformer_net.TransformerRNN(args.pad_type)
        seq_size = 4
    else:
        transformer = transformer_net.TransformerNet(args.pad_type)
        seq_size = 2

    train_dataset = dataset.DAVISDataset(args.dataset,
                                         "train",
                                         seq_size=seq_size,
                                         interval=args.interval,
                                         no_flow=True)
    train_loader = DataLoader(train_dataset,
                              batch_size=1,
                              shuffle=True,
                              **kwargs)

    model_path = args.init_model
    print("=> Load from model file %s" % model_path)
    transformer.load_state_dict(torch.load(model_path))
    transformer.train()
    if args.model_type == "rnn":
        transformer.conv1 = transformer_net.ConvLayer(6,
                                                      32,
                                                      kernel_size=9,
                                                      stride=1,
                                                      pad_type=args.pad_type)
    optimizer = torch.optim.Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()
    l1_loss = torch.nn.L1Loss()

    vgg = Vgg16()
    vgg.load_state_dict(torch.load(os.path.join(args.vgg_model)))
    vgg.eval()

    transformer.cuda()
    vgg.cuda()
    mse_loss.cuda()

    style = utils.tensor_load_resize(args.style_image, args.style_size)
    style = style.unsqueeze(0)
    print("=> Style image size: " + str(style.size()))
    print("=> Pixel FDB loss weight: %f" % args.time_strength1)
    print("=> Feature FDB loss weight: %f" % args.time_strength2)

    style = utils.preprocess_batch(style).cuda()
    utils.tensor_save_bgrimage(
        style[0].detach(), os.path.join(args.save_model_dir,
                                        'train_style.jpg'), True)
    style = utils.subtract_imagenet_mean_batch(style)
    features_style = vgg(style)
    gram_style = [utils.gram_matrix(y).detach() for y in features_style]

    for e in range(args.epochs):
        agg_content_loss = agg_style_loss = agg_pixelfdb_loss = agg_featurefdb_loss = 0.
        iters = 0
        for batch_id, (x, flow, occ, _) in enumerate(train_loader):
            x = x[0]
            iters += 1

            optimizer.zero_grad()
            x = utils.preprocess_batch(x).cuda()
            y = transformer(x)  # (N, 3, 256, 256)

            if (batch_id + 1) % 100 == 0:
                idx = (batch_id + 1) // 100
                for i in range(args.batch_size):
                    utils.tensor_save_bgrimage(
                        y.data[i],
                        os.path.join(args.save_model_dir,
                                     "out_%02d_%02d.png" % (idx, i)), True)
                    utils.tensor_save_bgrimage(
                        x.data[i],
                        os.path.join(args.save_model_dir,
                                     "in_%02d-%02d.png" % (idx, i)), True)

            #xc = center_crop(x.detach(), y.shape[2], y.shape[3])

            y = utils.subtract_imagenet_mean_batch(y)
            x = utils.subtract_imagenet_mean_batch(x)

            features_y = vgg(y)
            features_xc = vgg(x)

            #content target
            f_xc_c = features_xc[2].detach()
            # content
            f_c = features_y[2]

            content_loss = args.content_weight * mse_loss(f_c, f_xc_c)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_s = gram_style[m]
                gram_y = utils.gram_matrix(features_y[m])
                batch_style_loss = 0
                for n in range(gram_y.shape[0]):
                    batch_style_loss += args.style_weight * mse_loss(
                        gram_y[n], gram_s[0])
                style_loss += batch_style_loss / gram_y.shape[0]

            # FDB
            pixel_fdb_loss = args.time_strength1 * mse_loss(
                y[1:] - y[:-1], x[1:] - x[:-1])
            # temporal content: 16th
            feature_fdb_loss = args.time_strength2 * l1_loss(
                features_y[2][1:] - features_y[2][:-1],
                features_xc[2][1:] - features_xc[2][:-1])

            total_loss = content_loss + style_loss + pixel_fdb_loss + feature_fdb_loss

            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.data
            agg_style_loss += style_loss.data
            agg_pixelfdb_loss += pixel_fdb_loss.data
            agg_featurefdb_loss += feature_fdb_loss.data

            agg_total = agg_content_loss + agg_style_loss + agg_pixelfdb_loss + agg_featurefdb_loss
            mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\tpixel fdb: {:.6f}\tfeature fdb: {:.6f}\ttotal: {:.6f}".format(
                time.ctime(), e + 1, batch_id + 1, len(train_loader),
                agg_content_loss / iters, agg_style_loss / iters,
                agg_pixelfdb_loss / iters, agg_featurefdb_loss / iters,
                agg_total / iters)
            print(mesg)
            agg_content_loss = agg_style_loss = agg_pixelfdb_loss = agg_featurefdb_loss = 0.0
            iters = 0

        # save model
        save_model_filename = "epoch_" + str(e) + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
        save_model_path = os.path.join(args.save_model_dir,
                                       save_model_filename)
        torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)