예제 #1
0
def get_data_loader(args):
    content_transform = transforms.Compose([
        transforms.Resize(args.content_size),
        transforms.CenterCrop(args.content_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))])
    style_transform = transforms.Compose([
        transforms.Resize((args.style_size, args.style_size)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))])

    content_dataset = datasets.ImageFolder(args.content_dataset, content_transform)
    style_dataset = datasets.ImageFolder(args.style_dataset, style_transform)

    content_loader = DataLoader(content_dataset, 
                                batch_size=args.iter_batch_size, 
                                sampler=InfiniteSamplerWrapper(content_dataset),
                                num_workers=args.n_workers)
    style_loader = DataLoader(style_dataset, batch_size=1, 
                              sampler=InfiniteSamplerWrapper(style_dataset),
                              num_workers=args.n_workers)
    query_loader = DataLoader(content_dataset,
                              batch_size=args.iter_batch_size,
                              sampler=InfiniteSamplerWrapper(content_dataset),
                              num_workers=args.n_workers)

    return iter(content_loader), iter(style_loader), iter(query_loader)
예제 #2
0
파일: train.py 프로젝트: diyiiyiii/MCCNet
def load_dataset(content_dir, style_dir):
    content_tf = train_transform()
    style_tf = train_transform()

    content_dataset = FlatFolderDataset(content_dir, content_tf)
    style_dataset = FlatFolderDataset(style_dir, style_tf)

    content_iter = iter(data.DataLoader(
        content_dataset, batch_size=args.batch_size,
        sampler=InfiniteSamplerWrapper(content_dataset),
        num_workers=args.n_threads))
    style_iter = iter(data.DataLoader(
        style_dataset, batch_size=args.batch_size,
        sampler=InfiniteSamplerWrapper(style_dataset),
        num_workers=args.n_threads))

    return content_iter, style_iter
예제 #3
0
decoder.load_state_dict(torch.load(args.decoder))

network = net.Net(vgg, decoder)
network.train()
network.to(device)

content_tf = train_transform()
style_tf = train_transform()

content_dataset = FlatFolderDataset(args.content_dir, content_tf)
style_dataset = FlatFolderDataset(args.style_dir, style_tf)

content_iter = iter(
    data.DataLoader(content_dataset,
                    batch_size=args.batch_size,
                    sampler=InfiniteSamplerWrapper(content_dataset),
                    num_workers=args.n_threads))
style_iter = iter(
    data.DataLoader(style_dataset,
                    batch_size=args.batch_size,
                    sampler=InfiniteSamplerWrapper(style_dataset),
                    num_workers=args.n_threads))

optimizer = torch.optim.Adam(network.decoder.parameters(), lr=args.lr)

for i in tqdm(range(args.max_iter)):
    adjust_learning_rate(optimizer, iteration_count=i)
    content_images = next(content_iter).to(device)
    style_images = next(style_iter).to(device)
    loss_c, loss_s = network(content_images, style_images)
    loss_c = args.content_weight * loss_c
예제 #4
0
def train(args):

    # Device, save and log configuration

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    save_dir = Path(os.path.join(args.save_dir, args.name))
    save_dir.mkdir(exist_ok=True, parents=True)
    log_dir = Path(os.path.join(args.log_dir, args.name))
    log_dir.mkdir(exist_ok=True, parents=True)
    writer = SummaryWriter(log_dir=str(log_dir))

    # Prepare datasets

    content_dataset = TrainDataset(args.content_dir, args.img_size)
    texture_dataset = TrainDataset(args.texture_dir,
                                   args.img_size,
                                   gray_only=True)
    color_dataset = TrainDataset(args.color_dir, args.img_size)

    content_iter = iter(
        data.DataLoader(content_dataset,
                        batch_size=args.batch_size,
                        sampler=InfiniteSamplerWrapper(content_dataset),
                        num_workers=args.n_threads))
    texture_iter = iter(
        data.DataLoader(texture_dataset,
                        batch_size=args.batch_size,
                        sampler=InfiniteSamplerWrapper(texture_dataset),
                        num_workers=args.n_threads))
    color_iter = iter(
        data.DataLoader(color_dataset,
                        batch_size=args.batch_size,
                        sampler=InfiniteSamplerWrapper(color_dataset),
                        num_workers=args.n_threads))

    # Prepare network

    network = Net(args)
    network.train()
    network.to(device)

    # Training options

    opt_L = torch.optim.Adam(network.L_path.parameters(), lr=args.lr)
    opt_AB = torch.optim.Adam(network.AB_path.parameters(), lr=args.lr)

    opts = [opt_L, opt_AB]

    # Start Training

    for i in tqdm(range(args.max_iter)):
        # S1: Adjust lr and prepare data

        adjust_learning_rate(opts, iteration_count=i, args=args)

        content_l, content_ab = [x.to(device) for x in next(content_iter)]
        texture_l = next(texture_iter).to(device)
        color_l, color_ab = [x.to(device) for x in next(color_iter)]

        # S2: Forward

        l_pred, ab_pred = network(content_l, content_ab, texture_l, color_ab)

        # S3: Calculate loss

        loss_ct, loss_t = network.ct_t_loss(l_pred, content_l, texture_l)
        loss_cr = network.cr_loss(ab_pred, color_ab)

        loss_ctw = args.content_weight * loss_ct
        loss_tw = args.texture_weight * loss_t
        loss_crw = args.color_weight * loss_cr

        loss = loss_ctw + loss_tw + loss_crw

        # S4: Backward

        for opt in opts:
            opt.zero_grad()
        loss.backward()
        for opt in opts:
            opt.step()

        # S5: Summary loss and save subnets

        writer.add_scalar('loss_content', loss_ct.item(), i + 1)
        writer.add_scalar('loss_texture', loss_t.item(), i + 1)
        writer.add_scalar('loss_color', loss_cr.item(), i + 1)

        if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter:
            state_dict = network.state_dict()
            for key in state_dict.keys():
                state_dict[key] = state_dict[key].to(torch.device('cpu'))
            torch.save(state_dict,
                       save_dir / 'network_iter_{:d}.pth.tar'.format(i + 1))
    writer.close()
def run_train(config):
    print('come!')
    #visualizer = Visualizer(config)  # create a visualizer that display/save images and plots
    device = 'cpu' if config.cpu or not torch.cuda.is_available() else 'cuda:0'
    device = torch.device(device)

    transfer_at = set()
    if config.transfer_at_encoder:
        transfer_at.add('encoder')
    if config.transfer_at_decoder:
        transfer_at.add('decoder')
    if config.transfer_at_skip:
        transfer_at.add('skip')
    save_dir = Path(config.save_dir)
    save_dir.mkdir(exist_ok=True, parents=True)
    log_dir = Path(config.log_dir)
    log_dir.mkdir(exist_ok=True, parents=True)
    writer = SummaryWriter(log_dir=str(log_dir))
    vgg = net.vgg
    wct2 = Lap_Sob_Gaus(transfer_at=transfer_at,
                        option_unpool=config.option_unpool,
                        device=device,
                        verbose=config.verbose,
                        vgg=vgg)

    encoder = Lap_Sob_GausEncoder(config.option_unpool).to(device)
    decoder = Lap_Sob_GausDecoder(config.option_unpool).to(device)
    vgg.load_state_dict(torch.load(config.vgg))
    vgg = nn.Sequential(*list(vgg.children())[:31])
    network = net.Net(encoder, decoder, vgg=vgg)
    network.train()
    network.to(device)

    content_tf = train_transform()
    style_tf = train_transform()

    # # Data loading
    # transfroms = tv.transforms.Compose([
    #      tv.transforms.Resize(config.image_size),
    #      tv.transforms.CenterCrop(config.image_size),
    #      tv.transforms.ToTensor(),
    #      tv.transforms.Lambda(lambda x: x * 255)
    # ])
    # dataset = tv.datasets.ImageFolder(config.data_root, transfroms)
    # dataloader = data.DataLoader(dataset, config.batch_size)

    content_dataset = FlatFolderDataset(config.content_dir, content_tf)
    style_dataset = FlatFolderDataset(config.style_dir, style_tf)

    content_iter = iter(
        data.DataLoader(content_dataset,
                        batch_size=config.batch_size,
                        sampler=InfiniteSamplerWrapper(content_dataset),
                        num_workers=config.n_threads))
    style_iter = iter(
        data.DataLoader(style_dataset,
                        batch_size=config.batch_size,
                        sampler=InfiniteSamplerWrapper(style_dataset),
                        num_workers=config.n_threads))

    # Optimizer
    enoptimizer = torch.optim.Adam(network.encoder.parameters(), lr=config.lr)
    deoptimizer = torch.optim.Adam(network.decoder.parameters(), lr=config.lr)
    # # Loss meter
    # style_meter = tnt.meter.AverageValueMeter()
    # content_meter = tnt.meter.AverageValueMeter()
    vis = Visdom(env="loss")
    # style = utils.get_style_data(config.style_path)
    # vis.img('style', (style.data[0] * 0.225 + 0.45).clamp(min=0, max=1))
    # style = style.to(device)

    contet_loss, style_loss, iters = 0, 0, 0

    win_c = vis.line(np.array([contet_loss]),
                     np.array([iters]),
                     win='content_loss')
    win_s = vis.line(np.array([style_loss]),
                     np.array([iters]),
                     win='style_loss')
    # for epoch in range(config.epoches):
    #     content_meter.reset()
    #     style_meter.reset()
    #     for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)):
    # Train
    for i in tqdm(range(config.max_iter)):
        enoptimizer.zero_grad()
        deoptimizer.zero_grad()
        # x = x.to(device)
        # y = network(x, style)

        adjust_learning_rate(enoptimizer, iteration_count=i)
        adjust_learning_rate(deoptimizer, iteration_count=i)
        content_images = next(content_iter).to(device)
        style_images = next(style_iter).to(device)
        content_images.requires_grad_()
        style_images.requires_grad_()
        loss_c, loss_s = network(content_images, style_images, wct2)
        loss_c = config.content_weight * loss_c
        loss_s = config.style_weight * loss_s
        loss = loss_c + loss_s

        # optimizer.zero_grad()
        loss.backward()
        enoptimizer.step()
        deoptimizer.step()

        if i % 50 == 1:
            print('\n')
            print('iters:', i, 'loss:', loss, 'loss_c:', loss_c, 'loss_s: ',
                  loss_s)
        if i % 20 == 0:
            iters = np.array([i])
            content_loss = np.array([loss_c.item()])
            style_loss = np.array([loss_s.item()])
            vis.line(content_loss, iters, win_c, update='append')
            vis.line(style_loss, iters, win_s, update='append')

        writer.add_scalar('loss_content', loss_c.item(), i + 1)
        writer.add_scalar('loss_style', loss_s.item(), i + 1)

        if (i + 1) % config.save_model_interval == 0 or (i +
                                                         1) == config.max_iter:
            state_dict = network.decoder.state_dict()
            for key in state_dict.keys():
                state_dict[key] = state_dict[key].to(torch.device('cpu'))
            torch.save(state_dict,
                       save_dir / 'decoder_iter_{:d}.pth.tar'.format(i + 1))
        if (i + 1) % config.save_model_interval == 0 or (i +
                                                         1) == config.max_iter:
            state_dict = network.encoder.state_dict()
            for key in state_dict.keys():
                state_dict[key] = state_dict[key].to(torch.device('cpu'))
            torch.save(state_dict,
                       save_dir / 'encoder_iter_{:d}.pth.tar'.format(i + 1))
    writer.close()
예제 #6
0
def fast_train(args):
    """Fast training"""

    device = torch.device("cuda" if args.cuda else "cpu")

    transformer = TransformerNet().to(device)
    if args.model:
        transformer.load_state_dict(torch.load(args.model))
    vgg = Vgg16(requires_grad=False).to(device)
    global mse_loss
    mse_loss = torch.nn.MSELoss()

    content_weight = args.content_weight
    style_weight = args.style_weight
    lr = args.lr

    content_transform = transforms.Compose([
        transforms.Resize(args.content_size),
        transforms.CenterCrop(args.content_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))])
    content_dataset = datasets.ImageFolder(args.content_dataset, content_transform)
    content_loader = DataLoader(content_dataset, 
                                batch_size=args.iter_batch_size, 
                                sampler=InfiniteSamplerWrapper(content_dataset),
                                num_workers=args.n_workers)
    content_loader = iter(content_loader)
    style_transform = transforms.Compose([
            transforms.Resize((args.style_size, args.style_size)),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))])

    style_image = utils.load_image(args.style_image)
    style_image = style_transform(style_image)
    style_image = style_image.unsqueeze(0).to(device)
    features_style = vgg(utils.normalize_batch(style_image.repeat(args.iter_batch_size, 1, 1, 1)))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    if args.only_in:
        optimizer = Adam([param for (name, param) in transformer.named_parameters() if "in" in name], lr=lr)
    else:
        optimizer = Adam(transformer.parameters(), lr=lr)

    for i in trange(args.update_step):
        contents = content_loader.next()[0].to(device)
        features_contents = vgg(utils.normalize_batch(contents))

        transformed = transformer(contents)
        features_transformed = vgg(utils.standardize_batch(transformed))
        loss, c_loss, s_loss = loss_fn(features_transformed, features_contents, gram_style, content_weight, style_weight)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # save model
    transformer.eval().cpu()
    style_name = os.path.basename(args.style_image).split(".")[0]
    save_model_filename = style_name + ".pth"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)
예제 #7
0
def main():
    parser = argparse.ArgumentParser()
    # Basic options
    parser.add_argument('--content_dir',
                        type=str,
                        required=True,
                        help='Directory path to a batch of content images')
    parser.add_argument('--style_dir',
                        type=str,
                        required=True,
                        help='Directory path to a batch of style images')
    parser.add_argument('--vgg', type=str, default='models/vgg_normalised.pth')

    # training options
    parser.add_argument('--save_dir',
                        default='./experiments',
                        help='Directory to save the model')
    parser.add_argument('--log_dir',
                        default='./logs',
                        help='Directory to save the log')
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--lr_decay', type=float, default=5e-5)
    parser.add_argument('--max_iter', type=int, default=160000)
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--style_weight', type=float, default=1.0)  #defualt 10
    parser.add_argument('--content_weight', type=float, default=1.0)
    parser.add_argument('--n_threads', type=int, default=8)
    parser.add_argument('--save_model_interval', type=int, default=20000)
    args = parser.parse_args()

    # 80000iter, b_s = 1; 160000iter, b_s=4
    def adjust_learning_rate(optimizer, iteration_count):
        """Imitating the original implementation"""
        lr = args.lr / (1.0 + args.lr_decay * iteration_count)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    device = torch.device('cuda')
    save_dir = Path(args.save_dir)
    save_dir.mkdir(exist_ok=True, parents=True)
    log_dir = Path(args.log_dir)
    log_dir.mkdir(exist_ok=True, parents=True)
    writer = SummaryWriter(log_dir=str(log_dir))

    decoder = net.decoder
    checkpoint = torch.load("experiments\\ex3\\decoder_iter_80000.pth.tar")
    decoder.load_state_dict(checkpoint)
    vgg = net.vgg

    vgg.load_state_dict(torch.load(args.vgg))
    vgg = nn.Sequential(*list(vgg.children())[:31])
    network = net.Net(vgg, decoder)
    network.train()
    network.to(device)

    content_tf = train_transform()
    style_tf = train_transform()

    content_dataset = FlatFolderDataset(args.content_dir, content_tf)
    style_dataset = FlatFolderDataset(args.style_dir, style_tf)

    content_iter = iter(
        data.DataLoader(content_dataset,
                        batch_size=args.batch_size,
                        sampler=InfiniteSamplerWrapper(content_dataset),
                        num_workers=args.n_threads))
    style_iter = iter(
        data.DataLoader(style_dataset,
                        batch_size=args.batch_size,
                        sampler=InfiniteSamplerWrapper(style_dataset),
                        num_workers=args.n_threads))

    optimizer = torch.optim.Adam(network.decoder.parameters(), lr=args.lr)

    for i in tqdm(range(args.max_iter)):
        adjust_learning_rate(optimizer, iteration_count=i)
        content_images = next(content_iter).to(device)
        style_images = next(style_iter).to(device)
        loss_c, loss_s = network(content_images, style_images)
        loss_c = args.content_weight * loss_c
        loss_s = args.style_weight * loss_s
        loss = loss_c + loss_s

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        writer.add_scalar('loss_content', loss_c.item(), i + 1)
        writer.add_scalar('loss_style', loss_s.item(), i + 1)

        if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter:
            state_dict = net.decoder.state_dict()
            for key in state_dict.keys():
                state_dict[key] = state_dict[key].to(torch.device('cpu'))
            torch.save(state_dict,
                       save_dir / 'decoder_iter_{:d}.pth.tar'.format(i + 1))
    writer.close()
#load dataset from path
style_dataset = datasets.ImageFolder(root=args.style_dir, transform=train_transform)
content_dataset = datasets.ImageFolder(root=args.content_dir, transform=train_transform)

#setup sampler
content_sampler = None
# style_sampler = None
if args.distributed:
    content_sampler = torch.utils.data.distributed.DistributedSampler(content_dataset)
    # style_sampler = torch.utils.data.distributed.DistributedSampler(style_dataset)

#make data loader
args.dist_batch_size = int(args.batch_size/torch.distributed.get_world_size()) if args.distributed else args.batch_size
content_loader = torch.utils.data.DataLoader(content_dataset, sampler=content_sampler,
        batch_size=args.dist_batch_size, shuffle=(content_sampler is None), drop_last=True, **kwargs)
style_loader = torch.utils.data.DataLoader(style_dataset, sampler=InfiniteSamplerWrapper(style_dataset),
        batch_size=args.dist_batch_size, **kwargs)



if not os.path.exists(args.save_dir) and args.local_rank == 0:
    os.mkdir(args.save_dir)
if not os.path.exists(args.log_dir) and args.local_rank == 0:
    os.mkdir(args.log_dir)
#writer = SummaryWriter(log_dir=args.log_dir)



#Create model object
#vgg and decoder needs to be created as objects when using distributed training. 
decoder = model.get_decoder()