Example #1
0
def stylize(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    content_image = utils.load_image(args.content_image, scale=args.content_scale)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    if args.model.endswith(".onnx"):
        output = stylize_onnx_caffe2(content_image, args)
    else:
        with torch.no_grad():
            style_model = TransformerNet()
            print(style_model)
            state_dict = torch.load(args.model)
            # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
            for k in list(state_dict.keys()):
                if re.search(r'in\d+\.running_(mean|var)$', k):
                    del state_dict[k]
            style_model.load_state_dict(state_dict)
            style_model.to(device)
            for name, param in style_model.named_parameters():
                if param.requires_grad:
                    print(name)
            if args.export_onnx:
                assert args.export_onnx.endswith(".onnx"), "Export model file should end with .onnx"
                output = torch.onnx._export(style_model, content_image, args.export_onnx, verbose=False)
            else:
                output = style_model(content_image).cpu()
    utils.save_image(args.output_image, output[0])
Example #2
0
def train(args):
    """Meta train the model"""

    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    # first move parameters to GPU
    transformer = TransformerNet().to(device)
    vgg = Vgg16(requires_grad=False).to(device)
    global optimizer
    optimizer = Adam(transformer.parameters(), args.meta_lr)
    global mse_loss
    mse_loss = torch.nn.MSELoss()

    content_loader, style_loader, query_loader = get_data_loader(args)

    content_weight = args.content_weight
    style_weight = args.style_weight
    lr = args.lr

    writer = SummaryWriter(args.log_dir)

    for iteration in trange(args.max_iter):
        transformer.train()
        
        # bookkeeping
        # using state_dict causes problems, use named_parameters instead
        all_meta_grads = []
        avg_train_c_loss = 0.0
        avg_train_s_loss = 0.0
        avg_train_loss = 0.0
        avg_eval_c_loss = 0.0
        avg_eval_s_loss = 0.0
        avg_eval_loss = 0.0

        contents = content_loader.next()[0].to(device)
        features_contents = vgg(utils.normalize_batch(contents))
        querys = query_loader.next()[0].to(device)
        features_querys = vgg(utils.normalize_batch(querys))

        # learning rate scheduling
        lr = args.lr / (1.0 + iteration * 2.5e-5)
        meta_lr = args.meta_lr / (1.0 + iteration * 2.5e-5)
        for param_group in optimizer.param_groups:
            param_group['lr'] = meta_lr

        for i in range(args.meta_batch_size):
            # sample a style
            style = style_loader.next()[0].to(device)
            style = style.repeat(args.iter_batch_size, 1, 1, 1)
            features_style = vgg(utils.normalize_batch(style))
            gram_style = [utils.gram_matrix(y) for y in features_style]

            fast_weights = OrderedDict((name, param) for (name, param) in transformer.named_parameters() if re.search(r'in\d+\.', name))
            for j in range(args.meta_step):
                # run forward transformation on contents
                transformed = transformer(contents, fast_weights)

                # compute loss
                features_transformed = vgg(utils.standardize_batch(transformed))
                loss, c_loss, s_loss = loss_fn(features_transformed, features_contents, gram_style, content_weight, style_weight)

                # compute grad
                grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)

                # update fast weights
                fast_weights = OrderedDict((name, param - lr * grad) for ((name, param), grad) in zip(fast_weights.items(), grads))
            
            avg_train_c_loss += c_loss.item()
            avg_train_s_loss += s_loss.item()
            avg_train_loss += loss.item()

            # run forward transformation on querys
            transformed = transformer(querys, fast_weights)
            
            # compute loss
            features_transformed = vgg(utils.standardize_batch(transformed))
            loss, c_loss, s_loss = loss_fn(features_transformed, features_querys, gram_style, content_weight, style_weight)
            
            grads = torch.autograd.grad(loss / args.meta_batch_size, transformer.parameters())
            all_meta_grads.append({name: g for ((name, _), g) in zip(transformer.named_parameters(), grads)})

            avg_eval_c_loss += c_loss.item()
            avg_eval_s_loss += s_loss.item()
            avg_eval_loss += loss.item()
        
        writer.add_scalar("Avg_Train_C_Loss", avg_train_c_loss / args.meta_batch_size, iteration + 1)
        writer.add_scalar("Avg_Train_S_Loss", avg_train_s_loss / args.meta_batch_size, iteration + 1)
        writer.add_scalar("Avg_Train_Loss", avg_train_loss / args.meta_batch_size, iteration + 1)
        writer.add_scalar("Avg_Eval_C_Loss", avg_eval_c_loss / args.meta_batch_size, iteration + 1)
        writer.add_scalar("Avg_Eval_S_Loss", avg_eval_s_loss / args.meta_batch_size, iteration + 1)
        writer.add_scalar("Avg_Eval_Loss", avg_eval_loss / args.meta_batch_size, iteration + 1)

        # compute dummy loss to refresh buffer
        transformed = transformer(querys)
        features_transformed = vgg(utils.standardize_batch(transformed))
        dummy_loss, _, _ = loss_fn(features_transformed, features_querys, gram_style, content_weight, style_weight)

        meta_updates(transformer, dummy_loss, all_meta_grads)

        if args.checkpoint_model_dir is not None and (iteration + 1) % args.checkpoint_interval == 0:
            transformer.eval().cpu()
            ckpt_model_filename = "iter_" + str(iteration + 1) + ".pth"
            ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
            torch.save(transformer.state_dict(), ckpt_model_path)
            transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "Final_iter_" + str(args.max_iter) + "_" + \
                          str(args.content_weight) + "_" + \
                          str(args.style_weight) + "_" + \
                          str(args.lr) + "_" + \
                          str(args.meta_lr) + "_" + \
                          str(args.meta_batch_size) + "_" + \
                          str(args.meta_step) + "_" + \
                          time.ctime() + ".pth"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print "Done, trained model saved at {}".format(save_model_path)
Example #3
0
def fast_train(args):
    """Fast training"""

    device = torch.device("cuda" if args.cuda else "cpu")

    transformer = TransformerNet().to(device)
    if args.model:
        transformer.load_state_dict(torch.load(args.model))
    vgg = Vgg16(requires_grad=False).to(device)
    global mse_loss
    mse_loss = torch.nn.MSELoss()

    content_weight = args.content_weight
    style_weight = args.style_weight
    lr = args.lr

    content_transform = transforms.Compose([
        transforms.Resize(args.content_size),
        transforms.CenterCrop(args.content_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))])
    content_dataset = datasets.ImageFolder(args.content_dataset, content_transform)
    content_loader = DataLoader(content_dataset, 
                                batch_size=args.iter_batch_size, 
                                sampler=InfiniteSamplerWrapper(content_dataset),
                                num_workers=args.n_workers)
    content_loader = iter(content_loader)
    style_transform = transforms.Compose([
            transforms.Resize((args.style_size, args.style_size)),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))])

    style_image = utils.load_image(args.style_image)
    style_image = style_transform(style_image)
    style_image = style_image.unsqueeze(0).to(device)
    features_style = vgg(utils.normalize_batch(style_image.repeat(args.iter_batch_size, 1, 1, 1)))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    if args.only_in:
        optimizer = Adam([param for (name, param) in transformer.named_parameters() if "in" in name], lr=lr)
    else:
        optimizer = Adam(transformer.parameters(), lr=lr)

    for i in trange(args.update_step):
        contents = content_loader.next()[0].to(device)
        features_contents = vgg(utils.normalize_batch(contents))

        transformed = transformer(contents)
        features_transformed = vgg(utils.standardize_batch(transformed))
        loss, c_loss, s_loss = loss_fn(features_transformed, features_contents, gram_style, content_weight, style_weight)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # save model
    transformer.eval().cpu()
    style_name = os.path.basename(args.style_image).split(".")[0]
    save_model_filename = style_name + ".pth"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)
Example #4
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    # log content and style weight parameters
    if hvd.rank() == 0:
        run.log('content_weight', np.float(args.content_weight))
        run.log('style_weight', np.float(args.style_weight))

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    train_dataset = datasets.ImageFolder(args.dataset, transform)

    # Horovod: partition dataset among workers using DistributedSampler
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=hvd.size(), rank=hvd.rank())

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              sampler=train_sampler,
                              **kwargs)

    transformer = TransformerNet().to(device)

    # Horovod: broadcast parameters from rank 0 to all other processes
    hvd.broadcast_parameters(transformer.state_dict(), root_rank=0)
    # Horovod: scale learning rate by the number of GPUs
    optimizer = Adam(transformer.parameters(), args.lr * hvd.size())
    # Horovod: wrap optimizer with DistributedOptimizer
    optimizer = hvd.DistributedOptimizer(
        optimizer, named_parameters=transformer.named_parameters())
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    print("starting training...")
    for e in range(args.epochs):
        print("epoch {}...".format(e))
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                avg_content_loss = agg_content_loss / (batch_id + 1)
                avg_style_loss = agg_style_loss / (batch_id + 1)
                avg_total_loss = (agg_content_loss +
                                  agg_style_loss) / (batch_id + 1)
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_sampler),
                    avg_content_loss, avg_style_loss, avg_total_loss)
                print(mesg)

                # log the losses the run history
                run.log('avg_content_loss', np.float(avg_content_loss))
                run.log('avg_style_loss', np.float(avg_style_loss))
                run.log('avg_total_loss', np.float(avg_total_loss))

            if hvd.rank() == 0 and args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    if hvd.rank() == 0:
        transformer.eval().cpu()
        if args.export_to_onnx:
            # export model to ONNX format
            dummy_input = torch.randn(1, 3, 1024, 1024, device='cpu')
            save_model_path = os.path.join(args.save_model_dir,
                                           '{}.onnx'.format(args.model_name))
            torch.onnx.export(transformer, dummy_input, save_model_path)
        else:
            save_model_path = os.path.join(args.save_model_dir,
                                           '{}.pth'.format(args.model_name))
            torch.save(transformer.state_dict(), save_model_path)

        print("\nDone, trained model saved at", save_model_path)