Ejemplo n.º 1
0
def optimize(args):
    """    Gatys et al. CVPR 2017
    ref: Image Style Transfer Using Convolutional Neural Networks
    """
    # load the content and style target
    content_image = utils.tensor_load_rgbimage(args.content_image,
                                               size=args.content_size,
                                               keep_asp=True)
    content_image = content_image.unsqueeze(0)
    content_image = Variable(utils.preprocess_batch(content_image),
                             requires_grad=False)
    content_image = utils.subtract_imagenet_mean_batch(content_image)
    style_image = utils.tensor_load_rgbimage(args.style_image,
                                             size=args.style_size)
    style_image = style_image.unsqueeze(0)
    style_image = Variable(utils.preprocess_batch(style_image),
                           requires_grad=False)
    style_image = utils.subtract_imagenet_mean_batch(style_image)

    # load the pre-trained vgg-16 and extract features
    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))
    if args.cuda:
        content_image = content_image.cuda()
        style_image = style_image.cuda()
        vgg.cuda()
    features_content = vgg(content_image)
    f_xc_c = Variable(features_content[1].data, requires_grad=False)
    features_style = vgg(style_image)
    gram_style = [utils.gram_matrix(y) for y in features_style]
    # init optimizer
    output = Variable(content_image.data, requires_grad=True)
    optimizer = Adam([output], lr=args.lr)
    mse_loss = torch.nn.MSELoss()
    # optimizing the images
    tbar = trange(args.iters)
    for e in tbar:
        utils.imagenet_clamp_batch(output, 0, 255)
        optimizer.zero_grad()
        features_y = vgg(output)
        content_loss = args.content_weight * mse_loss(features_y[1], f_xc_c)

        style_loss = 0.
        for m in range(len(features_y)):
            gram_y = utils.gram_matrix(features_y[m])
            gram_s = Variable(gram_style[m].data, requires_grad=False)
            style_loss += args.style_weight * mse_loss(gram_y, gram_s)

        total_loss = content_loss + style_loss
        total_loss.backward()
        optimizer.step()
        tbar.set_description(total_loss.data.cpu().numpy()[0])
    # save the image
    output = utils.add_imagenet_mean_batch(output)
    utils.tensor_save_bgrimage(output.data[0], args.output_image, args.cuda)
Ejemplo n.º 2
0
def init_vgg16(model_folder):
    """load the vgg16 model feature"""
    if not os.path.exists(os.path.join(model_folder, 'vgg16.weight')):
        if not os.path.exists(os.path.join(model_folder, 'vgg16.t7')):
            os.system(
                'wget http://cs.stanford.edu/people/jcjohns/fast-neural-style/models/vgg16.t7 -O ' + os.path.join(model_folder, 'vgg16.t7'))
        vgglua = load_lua(os.path.join(model_folder, 'vgg16.t7'))
        vgg = Vgg16()
        for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()):
            dst.data[:] = src
        torch.save(vgg.state_dict(), os.path.join(model_folder, 'vgg16.weight'))
Ejemplo n.º 3
0
def optimize(args):
    style_image = utils.tensor_load_rgbimage(args.style_image,
                                             size=args.style_size)
    style_image = style_image.unsqueeze(0)
    style_image = Variable(utils.preprocess_batch(style_image),
                           requires_grad=False)
    # style_image = utils.subtract_imagenet_mean_batch(style_image)

    # generate the vector field that we want to stylize
    size = args.content_size
    vectors = np.zeros((size, size, 2), dtype=np.float32)
    eps = 1e-7
    for y in range(size):
        for x in range(size):
            xx = float(x - size / 2)
            yy = float(y - size / 2)
            rsq = xx**2 + yy**2
            if (rsq == 0):
                vectors[y, x, 0] = -1
                vectors[y, x, 1] = 1
            else:
                vectors[y, x, 0] = -yy / rsq if yy != 0 else eps
                vectors[y, x, 1] = xx / rsq if xx != 0 else eps
            # vectors[y, x, 0] = -1
            # vectors[y, x, 1] = 1

    # load the pre-trained vgg-16 and extract features
    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, 'vgg16.weight')))
    if args.cuda:
        style_image = style_image.cuda()
        vgg.cuda()
    features_style = vgg(style_image)
    gram_style = [utils.gram_matrix(y) for y in features_style]

    # output_size = torch.Size([1, size, size])
    # output = torch.randn(output_size) * 80 + 127
    # if args.cuda:
    #     output = output.cuda()
    # output = output.expand(3, size, size)
    # output = Variable(output, requires_grad=True)
    output_size = torch.Size([3, size, size])
    output = Variable(torch.randn(output_size, device="cuda") * 80 + 127,
                      requires_grad=True)
    optimizer = Adam([output], lr=args.lr)
    mse_loss = torch.nn.MSELoss()

    loss = []
    tbar = trange(args.iters)
    for e in tbar:
        utils.clamp_batch(output, 0, 255)
        optimizer.zero_grad()
        lic_input = output
        kernellen = 15
        kernel = np.sin(np.arange(kernellen) * np.pi / kernellen)
        kernel = kernel.astype(np.float32)

        loss.append(args.content_weight * lic.line_integral_convolution(
            vectors, lic_input, kernel, args.cuda))

        # vgg_input = output.unsqueeze(0)
        # features_y = vgg(vgg_input)
        # style_loss = 0
        # for m in range(len(features_y)):
        #     gram_y = utils.gram_matrix(features_y[m])
        #     gram_s = Variable(gram_style[m].data, requires_grad=False)
        #     style_loss += args.style_weight * mse_loss(gram_y, gram_s)
        # style_loss.backward()
        # loss[e] += style_loss

        loss[e].backward()
        optimizer.step()
        tbar.set_description(str(loss[e].data.cpu().numpy().item()))

        # save the image
        if ((e + 1) % args.log_interval == 0):
            # print("iter: %d content_loss: %f style_loss %f" % (e, loss[e].item(), style_loss.item()))
            utils.tensor_save_bgrimage(output.data,
                                       "output_iter_" + str(e + 1) + ".jpg",
                                       args.cuda)
Ejemplo n.º 4
0
def optimize(args):
    content_image = utils.tensor_load_grayimage(args.content_image, size=args.content_size)
    content_image = content_image.unsqueeze(0)
    content_image = Variable(content_image, requires_grad=False)
    content_image = utils.subtract_imagenet_mean_batch_gray(content_image)
    style_image = utils.tensor_load_rgbimage(args.style_image, size=args.style_size)
    style_image = style_image.unsqueeze(0)
    style_image = Variable(utils.preprocess_batch(style_image), requires_grad=False)
    style_image = utils.subtract_imagenet_mean_batch(style_image)

    # generate the vector field that we want to stylize
    # size = args.content_size
    # vectors = np.zeros((size, size, 2), dtype=np.float32)

    # vortex_spacing = 0.5
    # extra_factor = 2.
    #
    # a = np.array([1, 0]) * vortex_spacing
    # b = np.array([np.cos(np.pi / 3), np.sin(np.pi / 3)]) * vortex_spacing
    # rnv = int(2 * extra_factor / vortex_spacing)
    # vortices = [n * a + m * b for n in range(-rnv, rnv) for m in range(-rnv, rnv)]
    # vortices = [(x, y) for (x, y) in vortices if -extra_factor < x < extra_factor and -extra_factor < y < extra_factor]
    #
    # xs = np.linspace(-1, 1, size).astype(np.float32)[None, :]
    # ys = np.linspace(-1, 1, size).astype(np.float32)[:, None]
    #
    # for (x, y) in vortices:
    #     rsq = (xs - x) ** 2 + (ys - y) ** 2
    #     vectors[..., 0] += (ys - y) / rsq
    #     vectors[..., 1] += -(xs - x) / rsq
    #
    # for y in range(size):
    #     for x in range(size):
    #         angles[y, x] = math.atan(vectors[y, x, 1] / vectors[y, x, 0]) * 180 / math.pi

    # for y in range(size):
    #     for x in range(size):
    #         xx = float(x - size / 2)
    #         yy = float(y - size / 2)
    #         rsq = xx ** 2 + yy ** 2
    #         if (rsq == 0):
    #             vectors[y, x, 0] = 0
    #             vectors[y, x, 1] = 0
    #         else:
    #             vectors[y, x, 0] = -yy / rsq
    #             vectors[y, x, 1] = xx / rsq
    # f = h5py.File("../datasets/fake/vector_fields/cat_test3.h5", 'r')
    # a_group_key = list(f.keys())[0]
    # vectors = f[a_group_key][:]
    # vectors = utils.tensor_load_vector_field(vectors)
    # vectors = Variable(vectors, requires_grad=False)

    # load the pre-trained vgg-16 and extract features
    vgg = Vgg16()
    # utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(torch.load(os.path.join(args.vgg_model_dir, 'vgg16.weight')))
    if args.cuda:
        style_image = style_image.cuda()
        vgg.cuda()
    features_style = vgg(style_image)
    gram_style = [utils.gram_matrix(y) for y in features_style]

    # load the transformer net and extract features
    transformer_phi1 = TransformerNet()
    transformer_phi1.load_state_dict(torch.load(args.transformer_model_phi1_path))
    if args.cuda:
        # vectors = vectors.cuda()
        content_image = content_image.cuda()
        transformer_phi1.cuda()
    vectors = transformer_phi1(content_image)
    vectors = Variable(vectors.data, requires_grad=False)

    # init optimizer
    content_image_size = content_image.data.size()
    output_size = np.asarray(content_image_size)
    output_size[1] = 3
    output_size = torch.Size(output_size)
    output = Variable(torch.randn(output_size, device="cuda"), requires_grad=True)
    optimizer = Adam([output], lr=args.lr)
    mse_loss = torch.nn.MSELoss()
    cosine_loss = torch.nn.CosineEmbeddingLoss()
    # label = torch.ones(1, 1, args.content_size, args.content_size)
    label = torch.ones(1, 128, 128, 128)
    if args.cuda:
        label = label.cuda()

    # optimize the images
    transformer_phi2 = TransformerNet()
    transformer_phi2.load_state_dict(torch.load(args.transformer_model_phi2_path))
    if args.cuda:
        transformer_phi2.cuda()
    tbar = trange(args.iters)
    for e in tbar:
        utils.imagenet_clamp_batch(output, 0, 255)
        optimizer.zero_grad()
        transformer_input = utils.gray_bgr_batch(output)
        transformer_y = transformer_phi2(transformer_input)
        content_loss = args.content_weight * cosine_loss(vectors, transformer_y, label)
        # content_loss = args.content_weight * mse_loss(vectors, transformer_y)

        vgg_input = output
        features_y = vgg(vgg_input)
        style_loss = 0
        for m in range(len(features_y)):
            gram_y = utils.gram_matrix(features_y[m])
            gram_s = Variable(gram_style[m].data, requires_grad=False)
            style_loss += args.style_weight * mse_loss(gram_y, gram_s)

        total_loss = content_loss + style_loss
        # total_loss = content_loss
        total_loss.backward()
        optimizer.step()
        tbar.set_description(str(total_loss.data.cpu().numpy().item()))
        if ((e+1) % args.log_interval == 0):
            print("iter: %d content_loss: %f style_loss %f" % (e, content_loss.item(), style_loss.item()))

    # save the image
    output = utils.add_imagenet_mean_batch_device(output, args.cuda)
    utils.tensor_save_bgrimage(output.data[0], args.output_image, args.cuda)
Ejemplo n.º 5
0
def optimize(args):
    style_image = utils.tensor_load_rgbimage(args.style_image,
                                             size=args.style_size)
    style_image = style_image.unsqueeze(0)
    style_image = Variable(utils.preprocess_batch(style_image),
                           requires_grad=False)
    style_image = utils.subtract_imagenet_mean_batch(style_image)

    # generate the vector field that we want to backward from
    size = args.content_size
    vectors = np.zeros((size, size, 2), dtype=np.float32)
    vortex_spacing = 0.5
    extra_factor = 2.

    a = np.array([1, 0]) * vortex_spacing
    b = np.array([np.cos(np.pi / 3), np.sin(np.pi / 3)]) * vortex_spacing
    rnv = int(2 * extra_factor / vortex_spacing)
    vortices = [
        n * a + m * b for n in range(-rnv, rnv) for m in range(-rnv, rnv)
    ]
    vortices = [(x, y) for (x, y) in vortices
                if -extra_factor < x < extra_factor
                and -extra_factor < y < extra_factor]

    xs = np.linspace(-1, 1, size).astype(np.float32)[None, :]
    ys = np.linspace(-1, 1, size).astype(np.float32)[:, None]

    for (x, y) in vortices:
        rsq = (xs - x)**2 + (ys - y)**2
        vectors[..., 0] += (ys - y) / rsq
        vectors[..., 1] += -(xs - x) / rsq
    # for y in range(size):
    #     for x in range(size):
    #         xx = float(x - size / 2)
    #         yy = float(y - size / 2)
    #         rsq = xx ** 2 + yy ** 2
    #         if rsq == 0:
    #             vectors[y, x, 0] = 1
    #             vectors[y, x, 1] = 1
    #         else:
    #             vectors[y, x, 0] = -yy / rsq
    #             vectors[y, x, 1] = xx / rsq
    #         # vectors[y, x, 0] = 1
    #         # vectors[y, x, 1] = -1
    vectors = utils.tensor_load_vector_field(vectors)

    # load the pre-trained vgg-16 and extract features
    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, 'vgg16.weight')))
    if args.cuda:
        style_image = style_image.cuda()
        vgg.cuda()
    features_style = vgg(style_image)
    gram_style = [utils.gram_matrix(y) for y in features_style]

    # load the sobel network
    sobel = Sobel()
    if args.cuda:
        vectors = vectors.cuda()
        sobel.cuda()

    # init optimizer
    vectors_size = vectors.data.size()
    output_size = np.asarray(vectors_size)
    output_size[1] = 3
    output_size = torch.Size(output_size)
    output = Variable(torch.randn(output_size, device="cuda") * 30,
                      requires_grad=True)
    optimizer = Adam([output], lr=args.lr)
    cosine_loss = CosineLoss()
    mse_loss = torch.nn.MSELoss()

    #optimize the images
    tbar = trange(args.iters)
    for e in tbar:
        utils.imagenet_clamp_batch(output, 0, 255)
        optimizer.zero_grad()
        sobel_input = utils.gray_bgr_batch(output)
        sobel_y = sobel(sobel_input)
        content_loss = args.content_weight * cosine_loss(vectors, sobel_y)

        vgg_input = output
        features_y = vgg(vgg_input)
        style_loss = 0
        for m in range(len(features_y)):
            gram_y = utils.gram_matrix(features_y[m])
            gram_s = Variable(gram_style[m].data, requires_grad=False)
            style_loss += args.style_weight * mse_loss(gram_y, gram_s)

        total_loss = content_loss + style_loss
        total_loss.backward()
        optimizer.step()
        if ((e + 1) % args.log_interval == 0):
            print("iter: %d content_loss: %f style_loss %f" %
                  (e, content_loss.item() / args.content_weight,
                   style_loss.item() / args.style_weight))
        tbar.set_description(str(total_loss.data.cpu().numpy().item()))

    # save the image
    output = utils.add_imagenet_mean_batch_device(output, args.cuda)
    utils.tensor_save_bgrimage(output.data[0], args.output_image, args.cuda)
Ejemplo n.º 6
0
def train(args):
    check_paths(args)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        kwargs = {'num_workers': 0, 'pin_memory': False}
    else:
        kwargs = {}

    transform = transforms.Compose([
        transforms.Scale(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              **kwargs)

    style_model = Net(ngf=args.ngf)
    if args.resume is not None:
        print('Resuming, initializing using weight from {}.'.format(
            args.resume))
        style_model.load_state_dict(torch.load(args.resume))
    print(style_model)
    optimizer = Adam(style_model.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))

    if args.cuda:
        style_model.cuda()
        vgg.cuda()

    style_loader = utils.StyleLoader(args.style_folder, args.style_size)

    tbar = trange(args.epochs)
    for e in tbar:
        style_model.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            x = Variable(utils.preprocess_batch(x))
            if args.cuda:
                x = x.cuda()

            style_v = style_loader.get(batch_id)
            style_model.setTarget(style_v)

            style_v = utils.subtract_imagenet_mean_batch(style_v)
            features_style = vgg(style_v)
            gram_style = [utils.gram_matrix(y) for y in features_style]

            y = style_model(x)
            xc = Variable(x.data.clone())

            y = utils.subtract_imagenet_mean_batch(y)
            xc = utils.subtract_imagenet_mean_batch(xc)

            features_y = vgg(y)
            features_xc = vgg(xc)

            f_xc_c = Variable(features_xc[1].data, requires_grad=False)

            content_loss = args.content_weight * mse_loss(
                features_y[1], f_xc_c)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_y = utils.gram_matrix(features_y[m])
                gram_s = Variable(gram_style[m].data,
                                  requires_grad=False).repeat(
                                      args.batch_size, 1, 1, 1)
                style_loss += args.style_weight * mse_loss(
                    gram_y, gram_s[:n_batch, :, :])

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                tbar.set_description(mesg)

            if (batch_id + 1) % (4 * args.log_interval) == 0:
                # save model
                style_model.eval()
                style_model.cpu()
                save_model_filename = "Epoch_" + str(e) + "iters_" + str(count) + "_" + \
                    str(time.ctime()).replace(' ', '_') + "_" + str(
                    args.content_weight) + "_" + str(args.style_weight) + ".model"
                save_model_path = os.path.join(args.save_model_dir,
                                               save_model_filename)
                torch.save(style_model.state_dict(), save_model_path)
                style_model.train()
                style_model.cuda()
                tbar.set_description("\nCheckpoint, trained model saved at",
                                     save_model_path)

    # save model
    style_model.eval()
    style_model.cpu()
    save_model_filename = "Final_epoch_" + str(args.epochs) + "_" + \
        str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(style_model.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Ejemplo n.º 7
0
def train():
    check_point_path = ''

    transform = transforms.Compose([transforms.Scale(IMAGE_SIZE),
                                    transforms.CenterCrop(IMAGE_SIZE),
                                    transforms.ToTensor(),
                                    transforms.Lambda(lambda x: x.mul(255))])

    train_dataset = datasets.ImageFolder(DATASET_FOLDER, transform)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)

    style_model = Net(ngf=FILTER_CHANNEL, dv=device).to(device)
    if RESUME is not None:
        print('Resuming, initializing using weight from {}.'.format(RESUME))
        style_model.load_state_dict(torch.load(RESUME))
    print(style_model)
    optimizer = Adam(style_model.parameters(), LEARNING_RATE)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16()
    utils.init_vgg16(VGG_DIR)
    vgg.load_state_dict(torch.load(os.path.join(VGG_DIR, "vgg16.weight")))
    vgg.to(device)

    style_loader = utils.StyleLoader(STYLE_FOLDER, IMAGE_SIZE, device)
    
    tbar = tqdm(range(EPOCHS))
    for e in tbar:
        style_model.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            x = Variable(utils.preprocess_batch(x)).to(device)

            style_v = style_loader.get(batch_id)
            style_model.setTarget(style_v)

            style_v = utils.subtract_imagenet_mean_batch(style_v, device)
            features_style = vgg(style_v)
            gram_style = [utils.gram_matrix(y) for y in features_style]

            y = style_model(x)
            xc = Variable(x.data.clone())

            y = utils.subtract_imagenet_mean_batch(y, device)
            xc = utils.subtract_imagenet_mean_batch(xc, device)

            features_y = vgg(y)
            features_xc = vgg(xc)

            f_xc_c = Variable(features_xc[1].data, requires_grad=False)

            content_loss = CONT_WEIGHT * mse_loss(features_y[1], f_xc_c)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_y = utils.gram_matrix(features_y[m])
                gram_s = Variable(gram_style[m].data, requires_grad=False).repeat(BATCH_SIZE, 1, 1, 1)
                style_loss += STYLE_WEIGHT * mse_loss(gram_y.unsqueeze_(1), gram_s[:n_batch, :, :])

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % 100 == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                agg_content_loss / (batch_id + 1),
                                agg_style_loss / (batch_id + 1),
                                (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                tbar.set_description(mesg)

            
            if (batch_id + 1) % (4 * 100) == 0:
                # save model
                style_model.eval()
                style_model.cpu()
                save_model_filename = "Epoch_" + str(e) + "iters_" + str(count) + "_" +                     str(time.ctime()).replace(' ', '_') + "_" + str(
                    CONT_WEIGHT) + "_" + str(STYLE_WEIGHT) + ".model"
                save_model_path = os.path.join(SAVE_MODEL_DIR, save_model_filename)
                torch.save(style_model.state_dict(), save_model_path)
                if check_point_path:
                    os.remove(check_point_path)
                check_point_path = save_model_path
                style_model.train()
                style_model.cuda()
                tbar.set_description("\nCheckpoint, trained model saved at", save_model_path)

    # save model
    style_model.eval()
    style_model.cpu()
    save_model_filename = "Final_epoch_" + str(EPOCHS) + "_" +         str(time.ctime()).replace(' ', '_') + "_" + str(
        CONT_WEIGHT) + "_" + str(STYLE_WEIGHT) + ".model"
    save_model_path = os.path.join(SAVE_MODEL_DIR, save_model_filename)
    torch.save(style_model.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Ejemplo n.º 8
0
from __future__ import print_function
import tensorflow as tf
from utils import ImageGenerator
from net import Vgg16


generator = ImageGenerator('./data')
test_data, test_label = generator.get_test()

tf.reset_default_graph()

images = tf.placeholder(tf.float32, [None, 640, 640, 3])
true_out = tf.placeholder(tf.float32, [None, 4])
train_mode = tf.placeholder(tf.bool)

network = Vgg16(vgg16_npy_path='./train-save.npy')
network.build(images, train_mode)

with tf.device('/gpu:2'):
    sess = tf.Session()

    sess.run(tf.global_variables_initializer())

    correct = tf.equal(tf.argmax(network.prob, 1), tf.argmax(true_out, 1))
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

    acc = sess.run(accuracy, feed_dict={images: test_data, true_out: test_label, train_mode: False})
    print('Accuracy for test data: %s' % acc)