Exemplo n.º 1
0
def test_per_user(args):
    dataset = Watermark(args.img_size, train=False, dev=False)
    loader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=False)
    msg_dist = torch.distributions.Bernoulli(probs=0.5*torch.ones(args.msg_l))
    list_msg = msg_dist.sample([args.n_users])
    
    np.savetxt("./foo.csv", list_msg.numpy(), delimiter=",")
    net = HiddenTest(args, dataset).to(args.device)
    net.load_state_dict(torch.load(path.save_path))
    list_stats = []
    with torch.no_grad():
        for i, msg in tqdm(enumerate(list_msg)):
            stats = {
                'enc_loss': 0,
                'dec_loss': 0,
                'accuracy0': 0,
                'accuracy3': 0,
                'avg_acc': 0,
                'num_right_bits': 0,
                'lab_dist_orig_watermarked': 0,
            }
            for img in loader:
                msg_batched = msg.repeat(img.shape[0], 1)
                img, msg_batched = img.to(args.device), msg_batched.to(args.device)
                batch_stats = net.stats(img, msg_batched, args.noise_type)
                for k in stats:
                    stats[k] += len(img) * batch_stats[k]

            for k in stats:
                stats[k] = stats[k] / len(dataset)
            list_stats.append(stats)
    pickle.dump( list_stats, open( "list_stats.p", "wb" ) )
Exemplo n.º 2
0
def train(args):
    test_process, queue = start_test_process(args)
    log_file = open(log_filename, 'w+', buffering=1)

    dataset = Watermark(args.img_size, args.msg_l, train=True)
    net = HiddenTrain(args, dataset).to(args.device)
    loader = DataLoader(dataset=dataset,
                        batch_size=args.batch_size,
                        shuffle=True)

    with trange(args.epochs, unit='epoch') as tqdm_bar:
        for epoch_i in tqdm_bar:
            for batch_i, (img, msg) in enumerate(loader):
                img, msg = img.to(args.device), msg.to(args.device)
                stats = net.optimize(img, msg)
                tqdm_bar.set_postfix(**stats)

            if epoch_i % args.save_freq == 0:
                log_file.write("Epoch {} | {}\n".format(
                    epoch_i,
                    " ".join([f"{k}: {v:.3f}" for k, v in stats.items()])))
                torch.save(net.state_dict(), path.save_path)

            if epoch_i % args.test_freq == 0:
                queue.put((epoch_i, net.state_dict()))

    log_file.close()
    queue.join()
    test_process.terminate()
Exemplo n.º 3
0
def save_img(args):
    dataset = Watermark(args.img_size, train=False, dev=False)
    loader = DataLoader(dataset=dataset, batch_size=args.n_imgs, shuffle=False)
    msg_dist = torch.distributions.Bernoulli(probs=0.5 *
                                             torch.ones(args.msg_l))

    net = DFW(args, dataset).to(args.device)
    net.set_depth(max_depth)
    net.load_state_dict(torch.load(path.save_path, map_location='cuda'))
    net.eval()

    hamming_coder = HammingCoder(device=args.device)

    save_dir = './examples_' + args.noise_type
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    with torch.no_grad():
        img = next(iter(loader))
        msg = msg_dist.sample([img.shape[0]])
        img, msg = img.to(args.device), msg.to(args.device)

        hamming_msg = torch.stack([hamming_coder.encode(x) for x in msg])
        watermark = net.encoder(hamming_msg)
        encoded_img = (img + watermark).clamp(-1, 1)
        noised_img, _ = net.noiser([encoded_img, img])
        decoded_msg_logit = net.decoder(noised_img)

        pred_without_hamming_dec = (torch.sigmoid(decoded_msg_logit) >
                                    0.5).int()
        pred_msg = torch.stack(
            [hamming_coder.decode(x) for x in pred_without_hamming_dec])

        store_images(img, msg, watermark, encoded_img, noised_img, pred_msg,
                     save_dir)
Exemplo n.º 4
0
def test(args):
    dataset = Watermark(args.img_size, train=False, dev=False)
    loader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=False)
    msg_dist = torch.distributions.Bernoulli(probs=0.5*torch.ones(args.msg_l))
    
    net = HiddenTest(args, dataset).to(args.device)
    net.noiser.to(args.device)
    net.load_state_dict(torch.load(path.save_path, map_location='cuda'))
    
    stats = {
        'enc_loss': 0,
        'dec_loss': 0,
        'accuracy0': 0,
        'accuracy3': 0,
        'avg_acc': 0,
        'num_right_bits': 0,
        'lab_dist_orig_watermarked': 0
    }

    with torch.no_grad():
        for img in loader:
            msg = msg_dist.sample([img.shape[0]])
            img, msg = img.to(args.device), msg.to(args.device)
            batch_stats = net.stats(img, msg, args.noise_type)
            for k in stats:
                stats[k] += len(img) * batch_stats[k]

    for k in stats:
        stats[k] = stats[k] / len(dataset)
    print("Noise type: ", args.noise_type, " ".join([f"{k}: {v:.3f}"for k, v in stats.items()]))
Exemplo n.º 5
0
def test_worker(args, queue):
    log_file = open(log_filename, 'w+', buffering=1)

    dataset = Watermark(args.img_size, train=False, dev=False)
    net = HiddenTest(args, dataset).to(args.test_device)
    loader = DataLoader(dataset=dataset,
                        batch_size=args.batch_size,
                        shuffle=False)
    msg_dist = torch.distributions.Bernoulli(probs=0.5 *
                                             torch.ones(args.msg_l))
    while True:
        epoch_i, state_dict = queue.get()
        net.load_state_dict(state_dict)

        stats = {'D_loss': 0, 'enc_loss': 0, 'dec_loss': 0, 'adv_loss': 0}

        with torch.no_grad():
            for img in loader:
                msg = msg_dist.sample([img.shape[0]])
                img, msg = img.to(args.test_device), msg.to(args.test_device)
                batch_stats = net.stats(img, msg)
                for k in stats:
                    stats[k] += len(img) * batch_stats[k]

        for k in stats:
            stats[k] = stats[k] / len(dataset)

        log_file.write("Epoch {} | {}\n".format(
            epoch_i, " ".join([f"{k}: {v:.3f}" for k, v in stats.items()])))
        queue.task_done()
Exemplo n.º 6
0
def test(args):
    dataset = Watermark(args.img_size, args.msg_l, train=False)
    net = HiddenTest(args, dataset).to(args.test_device)
    loader = DataLoader(dataset=dataset,
                        batch_size=args.batch_size,
                        shuffle=False)

    net.load_state_dict(torch.load(path.save_path))

    stats = {
        'G_loss': 0,
        'enc_loss': 0,
        'dec_loss': 0,
    }

    with torch.no_grad():
        for img, msg in loader:
            img, msg = img.to(args.test_device), msg.to(args.test_device)
            batch_stats = net.stats(img, msg)
            for k in stats:
                stats[k] += len(img) * batch_stats[k]

    for k in stats:
        stats[k] = stats[k] / len(dataset)

    print(" ".join([f"{k}: {v:.3f}" for k, v in stats.items()]))
Exemplo n.º 7
0
def test_worker(args, queue):
    log_file = open(log_filename, 'w+', buffering=1)

    dataset = Watermark(args.img_size, args.msg_l, train=False)
    net = HiddenTest(args, dataset).to(args.test_device)
    loader = DataLoader(dataset=dataset,
                        batch_size=args.batch_size,
                        shuffle=False)

    while True:
        epoch_i, state_dict = queue.get()
        net.load_state_dict(state_dict)

        stats = {
            'G_loss': 0,
            'enc_loss': 0,
            'dec_loss': 0,
        }

        with torch.no_grad():
            for img, msg in loader:
                img, msg = img.to(args.test_device), msg.to(args.test_device)
                batch_stats = net.stats(img, msg)
                for k in stats:
                    stats[k] += len(img) * batch_stats[k]

        for k in stats:
            stats[k] = stats[k] / len(dataset)

        log_file.write("Epoch {} | {}\n".format(
            epoch_i, " ".join([f"{k}: {v:.3f}" for k, v in stats.items()])))
        queue.task_done()
Exemplo n.º 8
0
def save_img(args):
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    dataset = Watermark(args.img_size, args.msg_l, train=False)
    net = Hidden(args, dataset).to(args.device)
    loader = DataLoader(dataset=dataset, batch_size=args.n_imgs, shuffle=False)
    
    net.load_state_dict(torch.load(path.save_path))

    with torch.no_grad():
        img, msg = next(iter(loader))
        img, msg = img.to(args.device), msg.to(args.device)
        
        net.eval()
        encoded_img = net.encoder(img, msg)
        noised_img = net.noiser(encoded_img)
        decoded_msg = net.decoder(noised_img)
   

    convert = lambda img: np.moveaxis(denormalize(img).cpu().numpy(), [1, 2, 3], [3, 1, 2])
    img = convert(img)
    encoded_img = convert(encoded_img)
    noised_img = convert(noised_img)
    msg = msg.cpu().numpy()
    decoded_msg = (decoded_msg>0.5).float().cpu().numpy()

    for i in range(args.n_imgs):
        fig = plt.figure()
        gridspec = fig.add_gridspec(ncols=5, nrows=1, width_ratios=[2, 2, 2, 1, 1])
        axes = [fig.add_subplot(gridspec[0, i]) for i in range(5)]
        ax1, ax2, ax3, ax4, ax5 = axes
        
        for ax in axes:
            ax.set_xticks([])
            ax.set_yticks([])

        ax1.set_title('Original\nImage')
        ax1.imshow(img[i])

        ax2.set_title('Encoded\nImage')
        ax2.imshow(encoded_img[i])

        ax3.set_title('Noised\nImage')
        ax3.imshow(noised_img[i])

        ax4.set_title('Original\nMsg')
        ax4.imshow(msg[i][:, None], cmap='gray', aspect=2/31)
        
        ax5.set_title('Decoded\nMsg')
        ax5.imshow(decoded_msg[i][:, None], cmap='gray', aspect=2/31)
        
        fig.tight_layout()
        fig.savefig(os.path.join(save_dir, f'{i}.png'), bbox_inches='tight')
Exemplo n.º 9
0
def pretrain(args):
    dataset = Watermark(args.img_size, train=True, dev=False)
    msg_dist = torch.distributions.Bernoulli(probs=0.5 *
                                             torch.ones(args.msg_l))

    net = DFWTrain(args, dataset).to(args.device)

    print('Pre-training Start')
    depth = 1
    while depth <= pretrain_depth:
        net.set_depth(depth)
        msg = msg_dist.sample((args.batch_size, )).to(args.device)
        stats = net.pre_optimize(msg)
        if stats['loss'] < 0.05:
            print(f"Grown: {depth}/{pretrain_depth} | loss: {stats['loss']}")
            if depth == 2:
                depth += 2
            else:
                depth += 1
    torch.save(net.state_dict(), pretrain_filename)
    print('Pre-trained Weight Saved')
Exemplo n.º 10
0
def test(args):
    dataset = Watermark(args.img_size, train=False, dev=False)
    net = HiddenTest(args, dataset).to(args.test_device)
    loader = DataLoader(dataset=dataset,
                        batch_size=args.batch_size,
                        shuffle=False)
    msg_dist = torch.distributions.Bernoulli(probs=0.5 *
                                             torch.ones(args.msg_l))
    net.load_state_dict(torch.load(path.save_path))

    stats = {'D_loss': 0, 'enc_loss': 0, 'dec_loss': 0, 'adv_loss': 0}

    with torch.no_grad():
        for img in loader:
            msg = msg_dist.sample([img.shape[0]])
            img, msg = img.to(args.test_device), msg.to(args.test_device)
            batch_stats = net.stats(img, msg)
            for k in stats:
                stats[k] += len(img) * batch_stats[k]

    for k in stats:
        stats[k] = stats[k] / len(dataset)

    print(" ".join([f"{k}: {v:.3f}" for k, v in stats.items()]))
Exemplo n.º 11
0
def train(args):
    test_process, queue = start_test_process(args)
    log_file = open(log_filename, 'w+', buffering=1)

    train_set = Watermark(args.img_size, train=True, dev=False)
    dev_set = Watermark(args.img_size, train=False, dev=True)
    train_loader = DataLoader(dataset=train_set,
                              batch_size=args.batch_size,
                              shuffle=True)
    dev_loader = DataLoader(dataset=dev_set,
                            batch_size=args.batch_size,
                            shuffle=True)
    msg_dist = torch.distributions.Bernoulli(probs=0.5 *
                                             torch.ones(args.msg_l))

    net = DFWTrain(args, train_set).to(args.device)

    if not os.path.exists(pretrain_filename):
        raise FileNotFoundError('Pre-trained weight not found')
    net.load_state_dict(torch.load(pretrain_filename))
    print('Pre-trained Weight Loaded')

    net.set_depth(max_depth)
    valid_loss_min = np.Inf
    with trange(args.epochs, unit='epoch') as tqdm_bar:

        for epoch_i in tqdm_bar:
            ######################
            # train the model #
            ######################
            enc_scale = args.enc_scale * min(1,
                                             epoch_i / args.annealing_epochs)
            limit = max(1, 5 - 4 * epoch_i / args.annealing_epochs)

            for batch_i, img in enumerate(train_loader):
                msg = msg_dist.sample([img.shape[0]])
                img, msg = img.to(args.device), msg.to(args.device)
                stats = net.optimize(img, msg, enc_scale, limit)
                tqdm_bar.set_postfix(**stats)

            if epoch_i % args.save_freq == 0:
                log_file.write("Epoch {} | {}\n".format(
                    epoch_i,
                    " ".join([f"{k}: {v:.3f}" for k, v in stats.items()])))

            ######################
            # validate the model #
            ######################
            list_valid_loss = []
            for dev_batch, img in enumerate(dev_loader):
                msg = msg_dist.sample([img.shape[0]])
                img, msg = img.to(args.device), msg.to(args.device)
                stats = net.evaluate(img, msg, enc_scale, limit)
                list_valid_loss.append(stats['loss'])
            valid_loss = np.mean(list_valid_loss)
            if valid_loss <= valid_loss_min:
                print(
                    'Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...\n'
                    .format(valid_loss_min, valid_loss))
                log_file.write(
                    'Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...\n'
                    .format(valid_loss_min, valid_loss))
                torch.save(net.state_dict(), path.save_path)
                valid_loss_min = valid_loss
    log_file.close()
    queue.join()
    test_process.terminate()