Example #1
0
def main():
    parser = argparse.ArgumentParser(description='Style Swap by Pytorch')
    parser.add_argument('--batch_size',
                        '-b',
                        type=int,
                        default=4,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=3,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--patch_size',
                        '-p',
                        type=int,
                        default=5,
                        help='Size of extracted patches from style features')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID(nagative value indicate CPU)')
    parser.add_argument('--learning_rate',
                        '-lr',
                        type=int,
                        default=1e-4,
                        help='learning rate for Adam')
    parser.add_argument('--tv_weight',
                        type=int,
                        default=1e-6,
                        help='weight for total variation loss')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=500,
                        help='Interval of snapshot to generate image')
    parser.add_argument('--train_content_dir',
                        type=str,
                        default='/data/chen/content',
                        help='content images directory for train')
    parser.add_argument('--train_style_dir',
                        type=str,
                        default='/data/chen/style',
                        help='style images directory for train')
    parser.add_argument('--test_content_dir',
                        type=str,
                        default='/data/chen/content',
                        help='content images directory for test')
    parser.add_argument('--test_style_dir',
                        type=str,
                        default='/data/chen/style',
                        help='style images directory for test')
    parser.add_argument('--save_dir',
                        type=str,
                        default='result',
                        help='save directory for result and loss')

    args = parser.parse_args()

    # create directory to save
    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)

    loss_dir = f'{args.save_dir}/loss'
    model_state_dir = f'{args.save_dir}/model_state'
    image_dir = f'{args.save_dir}/image'

    if not os.path.exists(loss_dir):
        os.mkdir(loss_dir)
        os.mkdir(model_state_dir)
        os.mkdir(image_dir)

    # set device on GPU if available, else CPU
    if torch.cuda.is_available() and args.gpu >= 0:
        device = torch.device(f'cuda:{args.gpu}')
        print(f'# CUDA available: {torch.cuda.get_device_name(0)}')
    else:
        device = 'cpu'

    print(f'# Minibatch-size: {args.batch_size}')
    print(f'# epoch: {args.epoch}')
    print('')

    # prepare dataset and dataLoader
    train_dataset = PreprocessDataset(args.train_content_dir,
                                      args.train_style_dir)
    test_dataset = PreprocessDataset(args.test_content_dir,
                                     args.test_style_dir)
    iters = len(train_dataset)
    print(f'Length of train image pairs: {iters}')

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=True)
    test_iter = iter(test_loader)

    # set model and optimizer
    encoder = VGGEncoder().to(device)
    decoder = Decoder().to(device)
    optimizer = Adam(decoder.parameters(), lr=args.learning_rate)

    # start training
    criterion = nn.MSELoss()
    loss_list = []

    for e in range(1, args.epoch + 1):
        print(f'Start {e} epoch')
        for i, (content, style) in tqdm(enumerate(train_loader, 1)):
            content = content.to(device)
            style = style.to(device)
            content_feature = encoder(content)
            style_feature = encoder(style)

            style_swap_res = []
            for b in range(content_feature.shape[0]):
                c = content_feature[b].unsqueeze(0)
                s = style_feature[b].unsqueeze(0)
                cs = style_swap(c, s, args.patch_size, 1)
                style_swap_res.append(cs)
            style_swap_res = torch.cat(style_swap_res, 0)

            out_style_swap = decoder(style_swap_res)
            out_content = decoder(content_feature)
            out_style = decoder(style_feature)

            out_style_swap_latent = encoder(out_style_swap)
            out_content_latent = encoder(out_content)
            out_style_latent = encoder(out_style)

            image_reconstruction_loss = criterion(
                content, out_content) + criterion(style, out_style)

            feature_reconstruction_loss = criterion(style_feature, out_style_latent) +\
                criterion(content_feature, out_content_latent) +\
                criterion(style_swap_res, out_style_swap_latent)

            tv_loss = TVloss(out_style_swap, args.tv_weight) + TVloss(out_content, args.tv_weight) \
                + TVloss(out_style, args.tv_weight)

            loss = image_reconstruction_loss + feature_reconstruction_loss + tv_loss

            loss_list.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(
                f'[{e}/total {args.epoch} epoch],[{i} /'
                f'total {round(iters/args.batch_size)} iteration]: {loss.item()}'
            )

            if i % args.snapshot_interval == 0:
                content, style = next(test_iter)
                content = content.to(device)
                style = style.to(device)
                with torch.no_grad():
                    content_feature = encoder(content)
                    style_feature = encoder(style)
                    style_swap_res = []
                    for b in range(content_feature.shape[0]):
                        c = content_feature[b].unsqueeze(0)
                        s = style_feature[b].unsqueeze(0)
                        cs = style_swap(c, s, args.patch_size, 1)
                        style_swap_res.append(cs)
                    style_swap_res = torch.cat(style_swap_res, 0)
                    out_style_swap = decoder(style_swap_res)
                    out_content = decoder(content_feature)
                    out_style = decoder(style_feature)

                content = denorm(content, device)
                style = denorm(style, device)
                out_style_swap = denorm(out_style_swap, device)
                out_content = denorm(out_content, device)
                out_style = denorm(out_style, device)
                res = torch.cat(
                    [content, style, out_content, out_style, out_style_swap],
                    dim=0)
                res = res.to('cpu')
                save_image(res,
                           f'{image_dir}/{e}_epoch_{i}_iteration.png',
                           nrow=content_feature.shape[0])
        torch.save(decoder.state_dict(), f'{model_state_dir}/{e}_epoch.pth')
    plt.plot(range(len(loss_list)), loss_list)
    plt.xlabel('iteration')
    plt.ylabel('loss')
    plt.title('train loss')
    plt.savefig(f'{loss_dir}/train_loss.png')
    with open(f'{loss_dir}/loss_log.txt', 'w') as f:
        for l in loss_list:
            f.write(f'{l}\n')
    print(f'Loss saved in {loss_dir}')
Example #2
0
def main():
    parser = argparse.ArgumentParser(
        description='AdaIN Style Transfer by Pytorch')
    parser.add_argument('--batch_size',
                        '-b',
                        type=int,
                        default=12,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID(nagative value indicate CPU)')
    parser.add_argument('--learning_rate',
                        '-lr',
                        type=int,
                        default=5e-5,
                        help='learning rate for Adam')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=900,
                        help='Interval of snapshot to generate image')
    parser.add_argument('--train_content_dir',
                        type=str,
                        default='../content',
                        help='content images directory for train')
    parser.add_argument('--train_style_dir',
                        type=str,
                        default='../style',
                        help='style images directory for train')
    parser.add_argument('--test_content_dir',
                        type=str,
                        default='content',
                        help='content images directory for test')
    parser.add_argument('--test_style_dir',
                        type=str,
                        default='style',
                        help='style images directory for test')
    parser.add_argument('--save_dir',
                        type=str,
                        default='.',
                        help='save directory for result and loss')
    parser.add_argument('--reuse',
                        default=None,
                        help='model state path to load for reuse')

    args = parser.parse_args()

    print(args.save_dir)
    # create directory to save
    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)

    loss_dir = f'{args.save_dir}/loss'
    model_state_dir = f'{args.save_dir}/model_state'
    image_dir = f'{args.save_dir}/image'

    if not os.path.exists(loss_dir):
        os.mkdir(loss_dir)
    if not os.path.exists(model_state_dir):
        os.mkdir(model_state_dir)
    if not os.path.exists(image_dir):
        os.mkdir(image_dir)

    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"

    print(f'# Minibatch-size: {args.batch_size}')
    print(f'# epoch: {args.epoch}')
    print('')

    # prepare dataset and dataLoader
    train_dataset = PreprocessDataset(args.train_content_dir,
                                      args.train_style_dir)
    test_dataset = PreprocessDataset(args.test_content_dir,
                                     args.test_style_dir)
    iters = len(train_dataset)
    print(f'Length of train image pairs: {iters}')

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=False)
    test_iter = iter(test_loader)

    device_ids = [0, 1, 2]
    #    Re_encoder=nn.DataParallel(ReEncoder(),device_ids).cuda()
    vgg_encoder = nn.DataParallel(VGGEncoder(), device_ids).cuda()
    attn = nn.DataParallel(CoAttention(channel=512), device_ids).cuda()
    decoder = nn.DataParallel(Decoder(), device_ids).cuda()
    vggattn = nn.DataParallel(VGGAttn(), device_ids).cuda()
    D_img = Dimg().cuda()

    if args.reuse is not None:
        model.load_state_dict(torch.load(args.reuse))

#    optimizer_Re_encoder = Adam(Re_encoder.parameters(), lr=args.learning_rate)
    optimizer_decoder = Adam(decoder.parameters(), lr=args.learning_rate)
    optimizer_attn = Adam(attn.parameters(), lr=args.learning_rate)
    optimizer_vggattn = Adam(filter(lambda p: p.requires_grad,
                                    vggattn.parameters()),
                             lr=args.learning_rate)
    optimizer_D_img = Adam(D_img.parameters(), lr=args.learning_rate)

    # start training
    loss_list_1 = []
    loss_list_2 = []
    loss_list_D_img = []
    lam = 10.0
    #    print(list(vggattn.parameters()))
    for e in range(1, args.epoch + 1):
        print(f'Start {e} epoch')
        for i, (content, style) in tqdm(enumerate(train_loader, 1)):
            content = content.cuda()
            style = style.cuda()
            t_1 = vggattn(content, style, output_last_feature=True)
            t_2 = vggattn(style, content, output_last_feature=True)

            c1s2 = decoder(t_1)
            output_features_1 = vgg_encoder(images=c1s2,
                                            output_last_feature=True)
            output_middle_features_1 = vgg_encoder(images=c1s2,
                                                   output_last_feature=False)
            style_middle_features_1 = vgg_encoder(images=style,
                                                  output_last_feature=False)
            loss_c_1 = calc_content_loss(output_features_1, t_1)
            loss_s_1 = calc_style_loss(output_middle_features_1,
                                       style_middle_features_1)

            c2s1 = decoder(t_2)
            output_features_2 = vgg_encoder(images=c2s1,
                                            output_last_feature=True)
            output_middle_features_2 = vgg_encoder(images=c2s1,
                                                   output_last_feature=False)
            style_middle_features_2 = vgg_encoder(images=content,
                                                  output_last_feature=False)
            loss_c_2 = calc_content_loss(output_features_2, t_2)
            loss_s_2 = calc_style_loss(output_middle_features_2,
                                       style_middle_features_2)

            D_content = D_img(content.to('cuda:1'))
            D_style = D_img(style.to('cuda:1'))
            D_c1s2 = D_img(c1s2.to('cuda:1'))
            D_c2s1 = D_img(c2s1.to('cuda:1'))

            D_loss = MSE_Loss(D_content, fake_label) + MSE_Loss(
                D_style, fake_label) + MSE_Loss(D_c1s2, real_label) + MSE_Loss(
                    D_c2s1, real_label)

            loss = loss_c_1 + lam * loss_s_1 + loss_c_2 + lam * loss_s_2 + 0.01 * D_loss.to(
                'cuda:0')

            loss_list_1.append(loss.sum().item())

            optimizer_vggattn.zero_grad()
            optimizer_decoder.zero_grad()
            loss.sum().backward(retain_graph=True)
            optimizer_decoder.step()
            optimizer_vggattn.step()
            #

            t_1_c1s1 = vggattn(c1s2, c2s1, output_last_feature=True)
            t_2_c2s2 = vggattn(c2s1, c1s2, output_last_feature=True)
            c1s1 = decoder(t_1_c1s1)
            c2s2 = decoder(t_2_c2s2)
            #
            #
            #
            output_features_c1s1 = vgg_encoder(images=c1s1,
                                               output_last_feature=True)
            output_middle_features_c1s1 = vgg_encoder(
                images=c1s1, output_last_feature=False)
            style_middle_features_c1s1 = vgg_encoder(images=content,
                                                     output_last_feature=False)
            c_old = vgg_encoder(images=content, output_last_feature=True)
            loss_c_c1s1 = calc_content_loss(output_features_c1s1,
                                            c_old)  #与原图比较
            loss_s_c1s1 = calc_style_loss(output_middle_features_c1s1,
                                          style_middle_features_c1s1)

            output_features_c2s2 = vgg_encoder(images=c2s2,
                                               output_last_feature=True)
            s_old = vgg_encoder(images=style, output_last_feature=True)
            output_middle_features_c2s2 = vgg_encoder(
                images=c2s2, output_last_feature=False)
            style_middle_features_c2s2 = vgg_encoder(images=style,
                                                     output_last_feature=False)
            loss_c_c2s2 = calc_content_loss(output_features_c2s2,
                                            s_old)  #与原图比较
            loss_s_c2s2 = calc_style_loss(output_middle_features_c2s2,
                                          style_middle_features_c2s2)

            mse_c1s1 = MSE_Loss(content, c1s1)
            mse_c2s2 = MSE_Loss(style, c2s2)

            loss = loss_c_c1s1 + lam * loss_s_c1s1 + loss_c_c2s2 + lam * loss_s_c2s2 + 10 * mse_c1s1 + 10 * mse_c2s2
            loss_list_1.append(loss.sum().item())

            optimizer_vggattn.zero_grad()
            optimizer_decoder.zero_grad()
            loss.sum().backward(retain_graph=True)
            optimizer_decoder.step()
            optimizer_vggattn.step()
            #
            for g_index in range(g_steps):
                optimizer_D_img.zero_grad()
                D_loss = MSE_Loss(D_content, real_label) + MSE_Loss(
                    D_style, real_label) + MSE_Loss(
                        D_c1s2, fake_label) + MSE_Loss(D_c2s1, fake_label)
                D_loss.backward()
                optimizer_D_img.step()

            print(
                f'[{e}/total {args.epoch} epoch],[{i} /'
                f'total {round(iters/args.batch_size)} iteration]: {loss.sum().item()}'
            )

            if i % args.snapshot_interval == 0:

                content = denorm(content)
                style = denorm(style)
                c1s2 = denorm(c1s2)
                c2s1 = denorm(c2s1)
                c1s1 = denorm(c1s1)
                c2s2 = denorm(c2s2)
                res = torch.cat([content, style, c1s2, c2s1, c1s1, c2s2],
                                dim=0)

                res = res.to('cpu')
                save_image(res,
                           f'{image_dir}/{e}_epoch_{i}_iteration.png',
                           nrow=args.batch_size)


#        torch.save(attn.state_dict(), f'{model_state_dir}/attn_{e}_epoch.pth')
        torch.save(vgg_encoder.state_dict(),
                   f'{model_state_dir}/vgg_encoder_{e}_epoch.pth')
        torch.save(decoder.state_dict(),
                   f'{model_state_dir}/decoder_{e}_epoch.pth')
        torch.save(D_img.state_dict(),
                   f'{model_state_dir}/D_img_{e}_epoch.pth')
        torch.save(vggattn.state_dict(),
                   f'{model_state_dir}/vggattn{e}_epoch.pth')

        with open(f'{loss_dir}/loss_log.txt', 'w') as f:
            for l in loss_list_1:
                f.write(f'{l}\n')

    # plt.plot(range(len(loss_list)), loss_list)
    # plt.xlabel('iteration')
    # plt.ylabel('loss')
    # plt.title('train loss')
    # plt.savefig(f'{loss_dir}/train_loss.png')
    print(f'Loss saved in {loss_dir}')
Example #3
0
def main():
    parser = argparse.ArgumentParser(description='AdaIN Style Transfer by Pytorch')
    parser.add_argument('--batch_size', '-b', type=int, default=8,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch', '-e', type=int, default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', type=int, default=0,
                        help='GPU ID(nagative value indicate CPU)')
    parser.add_argument('--learning_rate', '-lr', type=int, default=5e-5,
                        help='learning rate for Adam')
    parser.add_argument('--snapshot_interval', type=int, default=1000,
                        help='Interval of snapshot to generate image')
    parser.add_argument('--train_content_dir', type=str, default='content',
                        help='content images directory for train')
    parser.add_argument('--train_style_dir', type=str, default='style',
                        help='style images directory for train')
    parser.add_argument('--test_content_dir', type=str, default='content',
                        help='content images directory for test')
    parser.add_argument('--test_style_dir', type=str, default='style',
                        help='style images directory for test')
    parser.add_argument('--save_dir', type=str, default='result',
                        help='save directory for result and loss')
    parser.add_argument('--reuse', default=None,
                        help='model state path to load for reuse')

    args = parser.parse_args()

    # create directory to save
    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)

    loss_dir = f'{args.save_dir}/loss'
    model_state_dir = f'{args.save_dir}/model_state'
    image_dir = f'{args.save_dir}/image'

    if not os.path.exists(loss_dir):
        os.mkdir(loss_dir)
        os.mkdir(model_state_dir)
        os.mkdir(image_dir)

    # set device on GPU if available, else CPU
    if torch.cuda.is_available() and args.gpu >= 0:
        device = torch.device(f'cuda:{args.gpu}')
        print(f'# CUDA available: {torch.cuda.get_device_name(0)}')
    else:
        device = 'cpu'

    print(f'# Minibatch-size: {args.batch_size}')
    print(f'# epoch: {args.epoch}')
    print('')

    # prepare dataset and dataLoader
    train_dataset = PreprocessDataset(args.train_content_dir, args.train_style_dir)
    test_dataset = PreprocessDataset(args.test_content_dir, args.test_style_dir)
    iters = len(train_dataset)
    print(f'Length of train image pairs: {iters}')

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
    test_iter = iter(test_loader)

    # set model and optimizer
    model = Model().to(device)
    if args.reuse is not None:
        model.load_state_dict(torch.load(args.reuse))
    optimizer = Adam(model.parameters(), lr=args.learning_rate)

    # start training
    loss_list = []
    for e in range(1, args.epoch + 1):
        print(f'Start {e} epoch')
        for i, (content, style) in tqdm(enumerate(train_loader, 1)):
            content = content.to(device)
            style = style.to(device)
            loss = model(content, style)
            loss_list.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(f'[{e}/total {args.epoch} epoch],[{i} /'
                  f'total {round(iters/args.batch_size)} iteration]: {loss.item()}')

            if i % args.snapshot_interval == 0:
                content, style = next(test_iter)
                content = content.to(device)
                style = style.to(device)
                with torch.no_grad():
                    out = model.generate(content, style)
                content = denorm(content, device)
                style = denorm(style, device)
                out = denorm(out, device)
                res = torch.cat([content, style, out], dim=0)
                res = res.to('cpu')
                save_image(res, f'{image_dir}/{e}_epoch_{i}_iteration.png', nrow=args.batch_size)
        torch.save(model.state_dict(), f'{model_state_dir}/{e}_epoch.pth')
    plt.plot(range(len(loss_list)), loss_list)
    plt.xlabel('iteration')
    plt.ylabel('loss')
    plt.title('train loss')
    plt.savefig(f'{loss_dir}/train_loss.png')
    with open(f'{loss_dir}/loss_log.txt', 'w') as f:
        for l in loss_list:
            f.write(f'{l}\n')
    print(f'Loss saved in {loss_dir}')
Example #4
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device = torch.device('cuda') if opt.gpu else torch.device('cpu')
    os.makedirs(opt.out_dir, exist_ok=True)
    os.makedirs(opt.save_dir, exist_ok=True)

    if opt.vis:
        vis = visdom.Visdom(env='Style-Swap')

    VggNet = VGGEncoder(opt.relu_level).to(device)
    InvNet = Decoder(opt.relu_level).to(device)
    VggNet.train()
    InvNet.train()

    train_trans = transforms.Compose([
        transforms.Resize(size=opt.img_size),
        transforms.CenterCrop(size=opt.img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    train_dataset = PreprocessDataset(opt.content, opt.style, train_trans)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=opt.minibatch,
                                  shuffle=True,
                                  drop_last=True)

    optimizer = torch.optim.Adam(InvNet.parameters(), lr=opt.lr)

    criterion = nn.MSELoss()

    loss_list = []
    i = 0
    for epoch in range(1, opt.max_epoch + 1):
        for _, image in enumerate(train_dataloader):
            content = image['c_img'].to(device)
            style = image['s_img'].to(device)
            cf = VggNet(content)
            sf = VggNet(style)
            csf = style_swap(cf, sf, opt.patch_size, stride=3)
            I_stylized = InvNet(csf)
            I_c = InvNet(cf)
            I_s = InvNet(sf)

            P_stylized = VggNet(I_stylized)  # size: 2 x 256 x 64 x 64
            P_c = VggNet(I_c)
            P_s = VggNet(I_s)

            loss_stylized = criterion(P_stylized, csf) + criterion(
                P_c, cf) + criterion(P_s, sf)
            loss_tv = TVLoss(I_stylized, opt.tv_weight)
            loss = loss_stylized + loss_tv
            loss_list.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print(
                "%d / %d epoch\tloss: %.4f\tloss_stylized: %.4f loss_tv: %.4f"
                % (epoch, opt.max_epoch, loss.item() / opt.minibatch,
                   loss_stylized.item() / opt.minibatch,
                   loss_tv.item() / opt.minibatch))
            i += 1
            vis.line(Y=np.array([loss.item()]),
                     X=np.array([i]),
                     win='train_loss',
                     update='append')
        torch.save(InvNet.state_dict(),
                   f'{opt.save_dir}/InvNet_{epoch}_epoch.pth')

    with open('loss_log.txt', 'w') as f:
        for l in loss_list:
            f.write(f'{l}\n')
Example #5
0
def train(**kwargs):
    opt = Config()
    for k, v in kwargs.items():
        setattr(opt, k, v)

    device = torch.device(opt.gpu_id)
    vis = visdom.Visdom(port=2333,
                        env='gin')  # python -m visdom.server -p 2333

    train_dataset = PreprocessDataset('/data/lzd/train_data/content',
                                      '/data/lzd/train_data/style')
    train_loader = DataLoader(train_dataset,
                              batch_size=opt.batch_size,
                              shuffle=True)

    test_dataset = PreprocessDataset('/data/lzd/test_data/content',
                                     '/data/lzd/test_data/style')
    test_loader = DataLoader(test_dataset,
                             batch_size=opt.test_bs,
                             shuffle=False)
    test_iter = iter(test_loader)

    iters = len(train_dataset)
    print(f'Length of train image pairs: {iters}')
    # model = Model()
    # model= torch.nn.DataParallel(model, device_ids=[0,1,2,3])
    # model.to(device)
    model = Model().to(device)
    optimizer = Adam([{
        'params': model.decoder.parameters(),
        'lr': opt.lr
    }, {
        'params': model.gat.parameters(),
        'lr': 0.0005
    }],
                     lr=opt.lr)

    for e in range(1, opt.epoch):
        print(f'start {e} epoch:')
        for i, (content, style) in enumerate(train_loader, 1):
            content = content.to(device)  # [8, 3, 256, 256]
            style = style.to(device)  # [8, 3, 256, 256]
            loss_c, loss_s = model(content, style)
            loss = loss_c + 2 * loss_s
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(
                f'[{e}/{opt.epoch} epoch],[{i} /'
                f'{round(iters/opt.batch_size)}]: {loss_c.item()} and {loss_s.item()}'
            )

            if i % opt.loss_interval == 0:
                vis.line(Y=np.array([loss_c.item()]),
                         X=np.array([
                             (e - 1) * round(iters / opt.batch_size) + i
                         ]),
                         win='loss_c',
                         update='append',
                         opts=dict(xlabel='iteration',
                                   ylabel='Content loss',
                                   title='loss_c',
                                   legend=['Loss']))
                vis.line(Y=np.array([loss_s.item()]),
                         X=np.array([
                             (e - 1) * round(iters / opt.batch_size) + i
                         ]),
                         win='loss_s',
                         update='append',
                         opts=dict(xlabel='iteration',
                                   ylabel='style loss',
                                   title='loss_s',
                                   legend=['Loss']))
                vis.line(Y=np.array([loss.item()]),
                         X=np.array([
                             (e - 1) * round(iters / opt.batch_size) + i
                         ]),
                         win='loss',
                         update='append',
                         opts=dict(xlabel='iteration',
                                   ylabel='Total loss',
                                   title='loss',
                                   legend=['Loss']))

            if i % opt.img_interval == 0:
                c, s = next(test_iter)
                c = c.to(device)
                s = s.to(device)
                with torch.no_grad():
                    out = model.generate(c, s)
                c = denorm(c, device)
                s = denorm(s, device)
                out = denorm(out, device)
                res = torch.cat([c, s, out], dim=0)
                vis.images(torch.clamp(res, 0, 1),
                           win='image',
                           nrow=opt.test_bs)
def main():
    parser = argparse.ArgumentParser(
        description=
        'Structure-emphasized Multimodal Style Transfer by CHEN CHEN')
    parser.add_argument('--batch_size',
                        '-b',
                        type=int,
                        default=4,
                        help='number of images in each mini-batch')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=1,
                        help='number of sweeps over the dataset to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID(negative value indicate CPU)')
    parser.add_argument('--learning_rate',
                        '-lr',
                        type=int,
                        default=1e-5,
                        help='learning rate for Adam')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=1000,
                        help='Interval of snapshot to generate image')
    parser.add_argument(
        '--alpha',
        default=1,
        help=
        'fusion degree, should be a float or a list which length is n_cluster')
    parser.add_argument('--gamma',
                        type=float,
                        default=1,
                        help='weight of style loss')
    parser.add_argument('--train_content_dir',
                        type=str,
                        default='/data/chen/content',
                        help='content images directory for train')
    parser.add_argument('--train_style_dir',
                        type=str,
                        default='/data/chen/style',
                        help='style images directory for train')
    parser.add_argument('--test_content_dir',
                        type=str,
                        default='/data/chen/content',
                        help='content images directory for test')
    parser.add_argument('--test_style_dir',
                        type=str,
                        default='/data/chen/style',
                        help='style images directory for test')
    parser.add_argument('--save_dir',
                        type=str,
                        default='result',
                        help='save directory for result and loss')
    parser.add_argument('--reuse',
                        default=None,
                        help='model state path to load for reuse')

    args = parser.parse_args()

    # create directory to save
    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)

    loss_dir = f'{args.save_dir}/loss'
    model_state_dir = f'{args.save_dir}/model_state'
    image_dir = f'{args.save_dir}/image'

    if not os.path.exists(loss_dir):
        os.mkdir(loss_dir)
        os.mkdir(model_state_dir)
        os.mkdir(image_dir)

    # set device on GPU if available, else CPU
    if torch.cuda.is_available() and args.gpu >= 0:
        device = torch.device(f'cuda:{args.gpu}')
        print(f'# CUDA available: {torch.cuda.get_device_name(0)}')
    else:
        device = 'cpu'
        print(f'# CUDA unavailable')

    print(f'# Minibatch-size: {args.batch_size}')
    print(f'# epoch: {args.epoch}')
    print('')

    # prepare dataset and dataLoader
    train_dataset = PreprocessDataset(args.train_content_dir,
                                      args.train_style_dir)
    test_dataset = PreprocessDataset(args.test_content_dir,
                                     args.test_style_dir)
    data_length = len(train_dataset)
    print(f'Length of train image pairs: {data_length}')

    train_loader = get_loader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)
    test_loader = get_loader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=True)
    test_iter = iter(test_loader)

    # set model and optimizer
    model = Model(alpha=args.alpha, device=device, pre_train=True)

    if args.reuse is not None:
        model.load_state_dict(
            torch.load(args.reuse, map_location=lambda storage, loc: storage))
        print(f'{args.reuse} loaded')

    optimizer = Adam(model.parameters(), lr=args.learning_rate)

    prev_model = copy.deepcopy(model)
    prev_optimizer = copy.deepcopy(optimizer)

    # start training
    loss_list = []
    for e in range(1, args.epoch + 1):
        print(f'Start {e} epoch')
        i = 1
        for content_path, style_path, content_tensor, style_tensor in tqdm(
                train_loader):
            loss = model(content_path, style_path, content_tensor,
                         style_tensor, args.gamma)
            if torch.isnan(loss):
                model = prev_model
                optimizer = torch.optim.Adam(model.parameters())
                optimizer.load_state_dict(prev_optimizer.state_dict())
            else:
                prev_model = copy.deepcopy(model)
                prev_optimizer = copy.deepcopy(optimizer)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                loss_list.append(loss.item())

                # print(f'[{e}/total {args.epoch} epoch],[{i} /'
                #       f'total {round(data_length/args.batch_size)} iteration]: {loss.item()}')

                if i % args.snapshot_interval == 0:
                    content_path, style_path, content_tensor, style_tensor = next(
                        test_iter)
                    content = content_tensor.to(device)
                    style = style_tensor.to(device)
                    with torch.no_grad():
                        out = model.generate(content_path, style_path,
                                             content_tensor, style_tensor)
                    res = torch.cat([content, style, out], dim=0)
                    res = res.to('cpu')
                    save_image(res,
                               f'{image_dir}/{e}_epoch_{i}_iteration.png',
                               nrow=args.batch_size)
                    torch.save(
                        model.state_dict(),
                        f'{model_state_dir}/{e}_epoch_{i}_iteration.pth')
                i += 1
        torch.save(model.state_dict(), f'{model_state_dir}/{e}_epoch.pth')
    plt.plot(range(len(loss_list)), loss_list)
    plt.xlabel('iteration')
    plt.ylabel('loss')
    plt.title('train loss')
    plt.savefig(f'{loss_dir}/train_loss.png')
    with open(f'{loss_dir}/loss_log.txt', 'w') as f:
        for l in loss_list:
            f.write(f'{l}\n')
    print(f'Loss saved in {loss_dir}')