Exemplo n.º 1
0
def main():
    args = Options().parse()
    style_model = Net(ngf=args.ngf)
    model_dict = torch.load(args.model)
    model_dict_clone = model_dict.copy()
    for key, value in model_dict_clone.items():
        if key.endswith(('running_mean', 'running_var')):
            del model_dict[key]
    style_model.load_state_dict(model_dict, False)

    style_loaders = utils.StyleLoader(args.style_folder, args.style_size)
    content_image = utils.tensor_load_rgbimage(args.content_image,
                                               size=args.style_size,
                                               keep_asp=True)
    content_image = content_image.unsqueeze(0)
    if args.cuda:
        style_model.cuda()
        content_image = content_image.cuda()

    content_image = Variable(utils.preprocess_batch(content_image))

    # for i, style_loader in enumerate(style_loaders):
    for i in range(style_loaders.size()):
        print(i)
        style_v = style_loaders.get(i)
        style_model.setTarget(style_v)
        output = style_model(content_image)
        filepath = "out/output" + str(i + 1) + '.jpg'
        print(filepath)
        utils.tensor_save_bgrimage(output.data[0], filepath, args.cuda)
Exemplo n.º 2
0
def train(args):
    np.random.seed(args.seed)
    if args.cuda:
        ctx = mx.gpu(0)
    else:
        ctx = mx.cpu(0)
    # dataloader
    transform = utils.Compose([utils.Scale(args.image_size),
                               utils.CenterCrop(args.image_size),
                               utils.ToTensor(ctx),
                               ])
    train_dataset = data.ImageFolder(args.dataset, transform)
    train_loader = gluon.data.DataLoader(train_dataset, batch_size=args.batch_size, last_batch='discard')
    style_loader = utils.StyleLoader(args.style_folder, args.style_size, ctx=ctx)
    print('len(style_loader):',style_loader.size())
    # models
    vgg = net.Vgg16()
    utils.init_vgg_params(vgg, 'models', ctx=ctx)
    style_model = net.Net(ngf=args.ngf)
    style_model.initialize(init=mx.initializer.MSRAPrelu(), ctx=ctx)
    if args.resume is not None:
        print('Resuming, initializing using weight from {}.'.format(args.resume))
        style_model.collect_params().load(args.resume, ctx=ctx)
    print('style_model:',style_model)
    # optimizer and loss
    trainer = gluon.Trainer(style_model.collect_params(), 'adam',
                            {'learning_rate': args.lr})
    mse_loss = gluon.loss.L2Loss()

    for e in range(args.epochs):
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            # prepare data
            style_image = style_loader.get(batch_id)
            style_v = utils.subtract_imagenet_mean_preprocess_batch(style_image.copy())
            style_image = utils.preprocess_batch(style_image)

            features_style = vgg(style_v)
            gram_style = [net.gram_matrix(y) for y in features_style]

            xc = utils.subtract_imagenet_mean_preprocess_batch(x.copy())
            f_xc_c = vgg(xc)[1]
            with autograd.record():
                style_model.setTarget(style_image)
                y = style_model(x)

                y = utils.subtract_imagenet_mean_batch(y)
                features_y = vgg(y)

                content_loss = 2 * args.content_weight * mse_loss(features_y[1], f_xc_c)

                style_loss = 0.
                for m in range(len(features_y)):
                    gram_y = net.gram_matrix(features_y[m])
                    _, C, _ = gram_style[m].shape
                    gram_s = F.expand_dims(gram_style[m], 0).broadcast_to((args.batch_size, 1, C, C))
                    style_loss = style_loss + 2 * args.style_weight * mse_loss(gram_y, gram_s[:n_batch, :, :])

                total_loss = content_loss + style_loss
                total_loss.backward()
                
            trainer.step(args.batch_size)
            mx.nd.waitall()

            agg_content_loss += content_loss[0]
            agg_style_loss += style_loss[0]

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.3f}\tstyle: {:.3f}\ttotal: {:.3f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                agg_content_loss.asnumpy()[0] / (batch_id + 1),
                                agg_style_loss.asnumpy()[0] / (batch_id + 1),
                                (agg_content_loss + agg_style_loss).asnumpy()[0] / (batch_id + 1)
                )
                print(mesg)

            
            if (batch_id + 1) % (4 * args.log_interval) == 0:
                # save model
                save_model_filename = "Epoch_" + str(e) + "iters_" + str(count) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
                    args.content_weight) + "_" + str(args.style_weight) + ".params"
                save_model_path = os.path.join(args.save_model_dir, save_model_filename)
                style_model.collect_params().save(save_model_path)
                print("\nCheckpoint, trained model saved at", save_model_path)

    # save model
    save_model_filename = "Final_epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".params"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    style_model.collect_params().save(save_model_path)
    print("\nDone, trained model saved at", save_model_path)
Exemplo n.º 3
0
def train():
    check_point_path = ''

    transform = transforms.Compose([transforms.Scale(IMAGE_SIZE),
                                    transforms.CenterCrop(IMAGE_SIZE),
                                    transforms.ToTensor(),
                                    transforms.Lambda(lambda x: x.mul(255))])

    train_dataset = datasets.ImageFolder(DATASET_FOLDER, transform)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)

    style_model = Net(ngf=FILTER_CHANNEL, dv=device).to(device)
    if RESUME is not None:
        print('Resuming, initializing using weight from {}.'.format(RESUME))
        style_model.load_state_dict(torch.load(RESUME))
    print(style_model)
    optimizer = Adam(style_model.parameters(), LEARNING_RATE)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16()
    utils.init_vgg16(VGG_DIR)
    vgg.load_state_dict(torch.load(os.path.join(VGG_DIR, "vgg16.weight")))
    vgg.to(device)

    style_loader = utils.StyleLoader(STYLE_FOLDER, IMAGE_SIZE, device)
    
    tbar = tqdm(range(EPOCHS))
    for e in tbar:
        style_model.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            x = Variable(utils.preprocess_batch(x)).to(device)

            style_v = style_loader.get(batch_id)
            style_model.setTarget(style_v)

            style_v = utils.subtract_imagenet_mean_batch(style_v, device)
            features_style = vgg(style_v)
            gram_style = [utils.gram_matrix(y) for y in features_style]

            y = style_model(x)
            xc = Variable(x.data.clone())

            y = utils.subtract_imagenet_mean_batch(y, device)
            xc = utils.subtract_imagenet_mean_batch(xc, device)

            features_y = vgg(y)
            features_xc = vgg(xc)

            f_xc_c = Variable(features_xc[1].data, requires_grad=False)

            content_loss = CONT_WEIGHT * mse_loss(features_y[1], f_xc_c)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_y = utils.gram_matrix(features_y[m])
                gram_s = Variable(gram_style[m].data, requires_grad=False).repeat(BATCH_SIZE, 1, 1, 1)
                style_loss += STYLE_WEIGHT * mse_loss(gram_y.unsqueeze_(1), gram_s[:n_batch, :, :])

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % 100 == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                agg_content_loss / (batch_id + 1),
                                agg_style_loss / (batch_id + 1),
                                (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                tbar.set_description(mesg)

            
            if (batch_id + 1) % (4 * 100) == 0:
                # save model
                style_model.eval()
                style_model.cpu()
                save_model_filename = "Epoch_" + str(e) + "iters_" + str(count) + "_" +                     str(time.ctime()).replace(' ', '_') + "_" + str(
                    CONT_WEIGHT) + "_" + str(STYLE_WEIGHT) + ".model"
                save_model_path = os.path.join(SAVE_MODEL_DIR, save_model_filename)
                torch.save(style_model.state_dict(), save_model_path)
                if check_point_path:
                    os.remove(check_point_path)
                check_point_path = save_model_path
                style_model.train()
                style_model.cuda()
                tbar.set_description("\nCheckpoint, trained model saved at", save_model_path)

    # save model
    style_model.eval()
    style_model.cpu()
    save_model_filename = "Final_epoch_" + str(EPOCHS) + "_" +         str(time.ctime()).replace(' ', '_') + "_" + str(
        CONT_WEIGHT) + "_" + str(STYLE_WEIGHT) + ".model"
    save_model_path = os.path.join(SAVE_MODEL_DIR, save_model_filename)
    torch.save(style_model.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Exemplo n.º 4
0
def train(args):
    check_paths(args)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        kwargs = {'num_workers': 0, 'pin_memory': False}
    else:
        kwargs = {}

    transform = transforms.Compose([
        transforms.Scale(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              **kwargs)

    style_model = Net(ngf=args.ngf)
    if args.resume is not None:
        print('Resuming, initializing using weight from {}.'.format(
            args.resume))
        style_model.load_state_dict(torch.load(args.resume))
    print(style_model)
    optimizer = Adam(style_model.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))

    if args.cuda:
        style_model.cuda()
        vgg.cuda()

    style_loader = utils.StyleLoader(args.style_folder, args.style_size)

    tbar = trange(args.epochs)
    for e in tbar:
        style_model.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            x = Variable(utils.preprocess_batch(x))
            if args.cuda:
                x = x.cuda()

            style_v = style_loader.get(batch_id)
            style_model.setTarget(style_v)

            style_v = utils.subtract_imagenet_mean_batch(style_v)
            features_style = vgg(style_v)
            gram_style = [utils.gram_matrix(y) for y in features_style]

            y = style_model(x)
            xc = Variable(x.data.clone())

            y = utils.subtract_imagenet_mean_batch(y)
            xc = utils.subtract_imagenet_mean_batch(xc)

            features_y = vgg(y)
            features_xc = vgg(xc)

            f_xc_c = Variable(features_xc[1].data, requires_grad=False)

            content_loss = args.content_weight * mse_loss(
                features_y[1], f_xc_c)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_y = utils.gram_matrix(features_y[m])
                gram_s = Variable(gram_style[m].data,
                                  requires_grad=False).repeat(
                                      args.batch_size, 1, 1, 1)
                style_loss += args.style_weight * mse_loss(
                    gram_y, gram_s[:n_batch, :, :])

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                tbar.set_description(mesg)

            if (batch_id + 1) % (4 * args.log_interval) == 0:
                # save model
                style_model.eval()
                style_model.cpu()
                save_model_filename = "Epoch_" + str(e) + "iters_" + str(count) + "_" + \
                    str(time.ctime()).replace(' ', '_') + "_" + str(
                    args.content_weight) + "_" + str(args.style_weight) + ".model"
                save_model_path = os.path.join(args.save_model_dir,
                                               save_model_filename)
                torch.save(style_model.state_dict(), save_model_path)
                style_model.train()
                style_model.cuda()
                tbar.set_description("\nCheckpoint, trained model saved at",
                                     save_model_path)

    # save model
    style_model.eval()
    style_model.cpu()
    save_model_filename = "Final_epoch_" + str(args.epochs) + "_" + \
        str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(style_model.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Exemplo n.º 5
0
def run_demo(args, mirror=False):
    # style_model = Net(ngf=args.ngf)

    model_dict = torch.load(
        args.model
    )  # or args.resume, matching what's in the line with style_model.load_state_dict
    model_dict_clone = model_dict.copy()  # We can't mutate while iterating
    for key, value in model_dict_clone.items():
        if key.endswith(('running_mean', 'running_var')):
            del model_dict[key]

    style_model = Net(ngf=128)  # to run with torch-0.3.0.post4

    # style_model.load_state_dict(torch.load(args.model))
    style_model.load_state_dict(model_dict, False)
    style_model.eval()
    if args.cuda:
        style_loader = utils.StyleLoader(args.style_folder, args.style_size)
        style_model.cuda()
    else:
        style_loader = utils.StyleLoader(args.style_folder, args.style_size,
                                         False)

    # Define the codec and create VideoWriter object
    height = args.demo_size
    width = int(4.0 / 3 * args.demo_size)
    swidth = int(width / 4)
    sheight = int(height / 4)
    if args.record:
        fourcc = cv2.VideoWriter_fourcc('F', 'M', 'P', '4')
        out = cv2.VideoWriter('output.mp4', fourcc, 20.0, (2 * width, height))
    cam = cv2.VideoCapture(0)
    cam.set(3, width)
    cam.set(4, height)
    key = 0
    idx = 0
    while True:
        # read frame
        idx += 1
        ret_val, img = cam.read()
        if mirror:
            img = cv2.flip(img, 1)
        cimg = img.copy()
        img = np.array(img).transpose(2, 0, 1)
        # changing style
        if idx % 20 == 1:
            style_v = style_loader.get(int(idx / 20))
            style_v = Variable(style_v.data)
            style_model.setTarget(style_v)

        img = torch.from_numpy(img).unsqueeze(0).float()
        if args.cuda:
            img = img.cuda()

        img = Variable(img)
        img = style_model(img)

        if args.cuda:
            simg = style_v.cpu().data[0].numpy()
            img = img.cpu().clamp(0, 255).data[0].numpy()
        else:
            # simg = style_v.data().numpy()
            simg = style_v.data.numpy().reshape((3, 512, 512))
            img = img.clamp(0, 255).data[0].numpy()
        img = img.transpose(1, 2, 0).astype('uint8')
        simg = simg.transpose(1, 2, 0).astype('uint8')

        # display
        simg = cv2.resize(simg, (swidth, sheight),
                          interpolation=cv2.INTER_CUBIC)
        cimg[0:sheight, 0:swidth, :] = simg
        img = np.concatenate((cimg, img), axis=1)
        cv2.imshow('MSG Demo', img)
        #cv2.imwrite('stylized/%i.jpg'%idx,img)
        key = cv2.waitKey(1)
        if args.record:
            out.write(img)
        if key == 27:
            break
    cam.release()
    if args.record:
        out.release()
    cv2.destroyAllWindows()