Exemple #1
0
def stylize(**kwargs):
    opt = Config()

    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    device = t.device('cuda' if t.cuda.is_available() else 'cpu')

    # 图片处理
    content_image = PIL.Image.open(opt.content_path)
    content_transform = tv.transforms.Compose(
        [tv.transforms.ToTensor(),
         tv.transforms.Lambda(lambda x: x * 255)])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device).detach()

    # 模型
    style_model = TransformerNet().eval()
    style_model.load_state_dict(
        t.load(opt.model_path, map_location=t.device('cpu')))
    style_model.to(device)

    # 风格迁移和保存
    output = style_model(content_image)
    output_data = output.cpu().data[0]
    tv.utils.save_image((output_data / 255).clamp(min=0, max=1),
                        opt.result_path)
def prepare_model(inference_config, device):
    stylization_model = TransformerNet().to(device)
    training_state = torch.load(os.path.join(inference_config['model_binaries_path'], inference_config['model_name']), map_location=torch.device('cpu'))
    state_dict = training_state['state_dict']
    stylization_model.load_state_dict(state_dict, strict=True)
    stylization_model.eval()
    return stylization_model
Exemple #3
0
def stylize(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    content_image = utils.load_image(args.content_image,
                                     scale=args.content_scale)
    content_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    if args.model.endswith(".onnx"):
        output = stylize_onnx_caffe2(content_image, args)
    else:
        with torch.no_grad():
            style_model = TransformerNet().eval()
            state_dict = torch.load(args.model)
            # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
            for k in list(state_dict.keys()):
                if re.search(r'in\d+\.running_(mean|var)$', k):
                    del state_dict[k]
            style_model.load_state_dict(state_dict)
            style_model.to(device)
            if args.export_onnx:
                assert args.export_onnx.endswith(
                    ".onnx"), "Export model file should end with .onnx"
                output = torch.onnx._export(style_model, content_image,
                                            args.export_onnx).cpu()
            else:
                output = style_model(content_image).cpu()
    utils.save_image(args.output_image, output[0])
Exemple #4
0
def stylize(imgname):
    transformer_name = './models/model_exp_{}.pt'.format(exp_num)
    content_image = load_image(imgname)
    content_image = style_transform(content_image)
    content_image = content_image.unsqueeze(0)
    content_image = Variable(content_image)

    transformer = TransformerNet()
    model_dict = torch.load(transformer_name)
    transformer.load_state_dict(model_dict)

    if use_cuda:
        transformer.cuda()
        content_image = content_image.cuda()

    o = transformer(content_image)
    y = o.data.cpu()[0]
    name, backend = os.path.splitext(os.path.basename(imgname))
    save_style_name = os.path.join(
        os.path.dirname(imgname),
        '{}_stylized_{}{}'.format(name, exp_num, backend))
    save_image(save_style_name, y)
Exemple #5
0
def test():
    pretrain, code_id, word_id = source_prepare()
    print('%d different words, %d different codes' % (len(word_id), len(code_id)), flush=True)

    Acc = 0.
    for fold in range(5):
        Net = TransformerNet(Pretrain_type, pretrain, Max_seq_len, Embedding_size, Inner_hid_size, len(code_id), D_k,
                             D_v, dropout_ratio=Dropout, num_layers=Num_layers, num_head=Num_head, Freeze=Freeze_emb).cuda()
        Net.load_state_dict(torch.load(SAVE_DIR + 'data1_false_biobert_' + str(fold) + '_599'))
        Net.eval()

        # test_file = DATA_path + 'AskAPatient/AskAPatient.fold-' + str(fold) + '.test.txt'
        test_file = DATA_path + 'test_' + str(fold) + '.csv'
        test_data = tokenizer(word_id, code_id, test_file, pretrain_type=Pretrain_type)

        print('Fold %d: %d test data' % (fold, len(test_data.data)))
        print('max length: %d' % test_data.max_length, flush=True)

        test_data.reset_epoch()
        test_correct = 0
        i = 0
        while not test_data.epoch_finish:
            seq, label, seq_length, mask, seq_pos, standard_emb = test_data.get_batch(1)
            results = Net(seq, seq_pos, standard_emb)
            _, idx = results.max(1)
            test_correct += len((idx == label).nonzero())
            i += 1
        assert i == len(test_data.data)
        test_accuracy = float(test_correct) / float(i)

        print('[fold %d] test: %d correct, %.4f accuracy' % (fold, test_correct, test_accuracy), flush=True)

        Acc += test_accuracy

        del test_data
        gc.collect()

    print('finial validation accuracy: %.4f' % (Acc / 5), flush=True)
Exemple #6
0
def stylize(args):
    device = torch.device("cpu")

    content_image = utils.load_image(args.content_image, scale=args.content_scale)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    with torch.no_grad():
        style_model = TransformerNet()
        state_dict = torch.load(args.model)
        # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        output = style_model(content_image)
       
    utils.save_image(args.output_image, output[0])
Exemple #7
0
def train(**kwargs):
    opt = Config()
    for _k, _v in kwargs.items():
        setattr(opt, _k, _v)

    device = t.device("cuda" if t.cuda.is_available() else "cpu")
    vis = utils.Visualizer(opt.env)

    # 数据加载
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Lambda(lambda x: x * 255)
    ])
    dataset = tv.datasets.ImageFolder(opt.data_root, transforms)
    dataloader = data.DataLoader(dataset, opt.batch_size)

    # 风格转换网络
    transformer = TransformerNet()
    if opt.model_path:
        transformer.load_state_dict(
            t.load(opt.model_path, map_location=t.device('cpu')))
    transformer.to(device)

    # 损失网络 Vgg16
    vgg = Vgg16().eval()
    vgg.to(device)
    for param in vgg.parameters():
        param.requires_grad = False

    # 优化器
    optimizer = t.optim.Adam(transformer.parameters(), opt.lr)

    # 获取风格图片的数据
    style = utils.get_style_data(opt.style_path)
    vis.img('style', (style.data[0] * 0.225 + 0.45).clamp(min=0, max=1))
    style = style.to(device)

    # 风格图片的gramj矩阵
    with t.no_grad():
        features_style = vgg(style)
        gram_style = [utils.gram_matrix(y) for y in features_style]

    # 损失统计
    style_loss_avg = 0
    content_loss_avg = 0

    for epoch in range(opt.epoches):
        for ii, (x, _) in tqdm(enumerate(dataloader)):

            # 训练
            optimizer.zero_grad()
            x = x.to(device)
            y = transformer(x)
            # print(y.size())
            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)
            features_x = vgg(x)
            features_y = vgg(y)

            # content loss
            content_loss = opt.content_weight * F.mse_loss(
                features_y.relu3_3, features_x.relu3_3)

            # style loss
            style_loss = 0
            for ft_y, gm_s in zip(features_y, gram_style):
                with t.no_grad():
                    gram_y = utils.gram_matrix(ft_y)
                style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y))
            style_loss *= opt.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            content_loss_avg += content_loss.item()
            style_loss_avg += style_loss.item()

            if (ii + 1) % opt.plot_every == 0:
                vis.plot('content_loss', content_loss_avg / opt.plot_every)
                vis.plot('style_loss', style_loss_avg / opt.plot_every)
                content_loss_avg = 0
                style_loss_avg = 0
                vis.img('output',
                        (y.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1))
                vis.img('input', (x.data.cpu()[0] * 0.225 + 0.45).clamp(min=0,
                                                                        max=1))

            if (ii + 1) % opt.save_every == 0:
                vis.save([opt.env])
                t.save(transformer.state_dict(),
                       'checkpoints/%s_style.pth' % (ii + 1))