def main():
    args = parser.parse_args()
    print("================================")
    print(args)
    print("================================")
    os.makedirs('out', exist_ok=True)

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': args.num_workers, 'pin_memory': True} if use_cuda else {}

    torch.manual_seed(args.seed)

    transform = transforms.Compose([transforms.ToTensor()])
    train_set = CelebA('./data/celebA', transform=transform)
    # Inversion attack on TRAIN data of facescrub classifier
    test1_set = FaceScrub('./data/facescrub', transform=transform, train=True)
    # Inversion attack on TEST data of facescrub classifier
    test2_set = FaceScrub('./data/facescrub', transform=transform, train=False)

    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
    test1_loader = torch.utils.data.DataLoader(test1_set, batch_size=args.test_batch_size, shuffle=False, **kwargs)
    test2_loader = torch.utils.data.DataLoader(test2_set, batch_size=args.test_batch_size, shuffle=False, **kwargs)

    classifier = nn.DataParallel(Classifier(nc=args.nc, ndf=args.ndf, nz=args.nz)).to(device)
    inversion = nn.DataParallel(Inversion(nc=args.nc, ngf=args.ngf, nz=args.nz, truncation=args.truncation, c=args.c)).to(device)
    optimizer = optim.Adam(inversion.parameters(), lr=0.0002, betas=(0.5, 0.999), amsgrad=True)

    # Load classifier
    path = 'out/classifier.pth'
    try:
        checkpoint = torch.load(path)
        classifier.load_state_dict(checkpoint['model'])
        epoch = checkpoint['epoch']
        best_cl_acc = checkpoint['best_cl_acc']
        print("=> loaded classifier checkpoint '{}' (epoch {}, acc {:.4f})".format(path, epoch, best_cl_acc))
    except:
        print("=> load classifier checkpoint '{}' failed".format(path))
        return

    # Train inversion model
    best_recon_loss = 999
    for epoch in range(1, args.epochs + 1):
        train(classifier, inversion, args.log_interval, device, train_loader, optimizer, epoch)
        recon_loss = test(classifier, inversion, device, test1_loader, epoch, 'test1')
        test(classifier, inversion, device, test2_loader, epoch, 'test2')

        if recon_loss < best_recon_loss:
            best_recon_loss = recon_loss
            state = {
                'epoch': epoch,
                'model': inversion.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_recon_loss': best_recon_loss
            }
            torch.save(state, 'out/inversion.pth')
            shutil.copyfile('out/recon_test1_{}.png'.format(epoch), 'out/best_test1.png')
            shutil.copyfile('out/recon_test2_{}.png'.format(epoch), 'out/best_test2.png')
Exemplo n.º 2
0
def get_loader_labeled(image_dir, attr_path, selected_attrs, batch_size, train, new_size=None,
                       height=256, width=256, crop=True):
    """Build and return a data loader."""
    transform_list = [transforms.ToTensor(),
                      transforms.Normalize((0.5, 0.5, 0.5),
                                           (0.5, 0.5, 0.5))]
    transform_list = [transforms.RandomCrop((height, width))] + transform_list if crop else transform_list
    transform_list = [transforms.Resize(new_size)] + transform_list if new_size is not None else transform_list
    transform_list = [transforms.RandomHorizontalFlip()] + transform_list if train else transform_list
    transform_list = [transforms.ColorJitter(0.1, 0.1, 0.1, 0.1)] + transform_list if train else transform_list
    transform = transforms.Compose(transform_list)

    dataset = CelebA(image_dir, attr_path, selected_attrs, transform, train)
    data_loader = DataLoader(dataset=dataset,
                             batch_size=batch_size,
                             shuffle=train)
    return data_loader
Exemplo n.º 3
0
def main():

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    logger = TensorBoardOutputFormat(logdir)

    config = tf.ConfigProto()

    sess = tf.Session(config=config)
    LABEL = None
    print("Loading data...")
    if FLAGS.dataset == 'cubes':
        dataset = Cubes(cond_idx=FLAGS.cond_idx)
        test_dataset = dataset

        if FLAGS.cond_idx == 0:
            label_size = 2
        elif FLAGS.cond_idx == 1:
            label_size = 1
        elif FLAGS.cond_idx == 2:
            label_size = 3
        elif FLAGS.cond_idx == 3:
            label_size = 20

        LABEL = tf.placeholder(shape=(None, label_size), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, label_size), dtype=tf.float32)
    elif FLAGS.dataset == 'color':
        dataset = CubesColor()
        test_dataset = dataset
        LABEL = tf.placeholder(shape=(None, 301), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 301), dtype=tf.float32)
        label_size = 301
    elif FLAGS.dataset == 'pos':
        dataset = CubesPos()
        test_dataset = dataset
        LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        label_size = 2
    elif FLAGS.dataset == "pairs":
        dataset = Pairs(cond_idx=0)
        test_dataset = dataset
        LABEL = tf.placeholder(shape=(None, 6), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 6), dtype=tf.float32)
        label_size = 6
    elif FLAGS.dataset == "continual":
        dataset = CubesContinual()
        test_dataset = dataset

        if FLAGS.prelearn_model_shape:
            LABEL = tf.placeholder(shape=(None, 20), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 20), dtype=tf.float32)
            label_size = 20
        else:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            label_size = 2

    elif FLAGS.dataset == "cross":
        dataset = CubesCrossProduct(FLAGS.ratio, cond_size=FLAGS.cond_size, cond_pos=FLAGS.cond_pos, joint_baseline=FLAGS.joint_baseline)
        test_dataset = dataset

        if FLAGS.cond_size:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            label_size = 1
        elif FLAGS.cond_pos:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            label_size = 2

        if FLAGS.joint_baseline:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            label_size = 3

    elif FLAGS.dataset == 'celeba':
        dataset = CelebA(cond_idx=FLAGS.celeba_cond_idx)
        test_dataset = dataset
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)

        model = ResNet128(
            num_channels=channel_num,
            num_filters=64,
            classes=2)

    if FLAGS.joint_baseline:
        # Other stuff for joint model
        optimizer = AdamOptimizer(FLAGS.lr, beta1=0.99, beta2=0.999)

        X = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32)
        X_NOISE = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32)
        ATTENTION_MASK = tf.placeholder(shape=(None, 64, 64, FLAGS.cond_func), dtype=tf.float32)
        NOISE = tf.placeholder(shape=(None, 128), dtype=tf.float32)
        HIER_LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)

        channel_num = 3

        model = CubesNetGen(num_channels=channel_num, label_size=label_size)
        weights = model.construct_weights('context_0')
        output = model.forward(NOISE, weights, reuse=False, label=LABEL)
        print(output.get_shape())
        mse_loss = tf.reduce_mean(tf.square(output - X))
        gvs = optimizer.compute_gradients(mse_loss)
        train_op = optimizer.apply_gradients(gvs)
        gvs = [(k, v) for (k, v) in gvs if k is not None]

        target_vars = {}
        target_vars['train_op'] = train_op
        target_vars['X'] = X
        target_vars['X_NOISE'] = X_NOISE
        target_vars['ATTENTION_MASK'] = ATTENTION_MASK
        target_vars['eps_begin'] = tf.zeros(1)
        target_vars['gvs'] = gvs
        target_vars['energy_pos'] = tf.zeros(1)
        target_vars['energy_neg'] = tf.zeros(1)
        target_vars['loss_energy'] = tf.zeros(1)
        target_vars['loss_ml'] = tf.zeros(1)
        target_vars['total_loss'] = mse_loss
        target_vars['attention_mask'] = tf.zeros(1)
        target_vars['attention_grad'] = tf.zeros(1)
        target_vars['x_off'] = tf.reduce_mean(tf.abs(output - X))
        target_vars['x_mod'] = tf.zeros(1)
        target_vars['x_grad'] = tf.zeros(1)
        target_vars['NOISE'] = NOISE
        target_vars['LABEL'] = LABEL
        target_vars['LABEL_POS'] = LABEL_POS
        target_vars['HIER_LABEL'] = HIER_LABEL

        data_loader = DataLoader(
            dataset,
            batch_size=FLAGS.batch_size,
            num_workers=FLAGS.data_workers,
            drop_last=True,
            shuffle=True)
    else:
        print("label size here ", label_size)
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32)
        HEIR_LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        ATTENTION_MASK = tf.placeholder(shape=(None, 64, 64, FLAGS.cond_func), dtype=tf.float32)

        if FLAGS.dataset != "celeba":
            model = CubesNet(num_channels=channel_num, label_size=label_size)

        heir_model = HeirNet(num_channels=FLAGS.cond_func)

        models_pretrain = []
        if FLAGS.prelearn_model:
            model_prelearn = CubesNet(num_channels=channel_num, label_size=FLAGS.prelearn_label)
            weights = model_prelearn.construct_weights('context_1')
            LABEL_PRELEARN = tf.placeholder(shape=(None, FLAGS.prelearn_label), dtype=tf.float32)
            models_pretrain.append((model_prelearn, weights, LABEL_PRELEARN))

            cubes_logdir = osp.join(FLAGS.logdir, FLAGS.prelearn_exp)
            if (FLAGS.prelearn_iter != -1 or not FLAGS.train):
                model_file = osp.join(cubes_logdir, 'model_{}'.format(FLAGS.prelearn_iter))
                resume_itr = FLAGS.resume_iter
                # saver.restore(sess, model_file)

                v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1))
                v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list}
                saver = tf.train.Saver(v_map)
                saver.restore(sess, model_file)

        if FLAGS.prelearn_model_shape:
            model_prelearn = CubesNet(num_channels=channel_num, label_size=FLAGS.prelearn_label_shape)
            weights = model_prelearn.construct_weights('context_2')
            LABEL_PRELEARN = tf.placeholder(shape=(None, FLAGS.prelearn_label_shape), dtype=tf.float32)
            models_pretrain.append((model_prelearn, weights, LABEL_PRELEARN))

            cubes_logdir = osp.join(FLAGS.logdir, FLAGS.prelearn_exp_shape)
            if (FLAGS.prelearn_iter_shape != -1 or not FLAGS.train):
                model_file = osp.join(cubes_logdir, 'model_{}'.format(FLAGS.prelearn_iter_shape))
                resume_itr = FLAGS.resume_iter
                # saver.restore(sess, model_file)

                v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2))
                v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list}
                saver = tf.train.Saver(v_map)
                saver.restore(sess, model_file)

        print("Done loading...")

        data_loader = DataLoader(
            dataset,
            batch_size=FLAGS.batch_size,
            num_workers=FLAGS.data_workers,
            drop_last=True,
            shuffle=True)

        batch_size = FLAGS.batch_size

        weights = model.construct_weights('context_0')

        if FLAGS.heir_mask:
            weights = heir_model.construct_weights('heir_0', weights=weights)

        Y = tf.placeholder(shape=(None), dtype=tf.int32)

        # Varibles to run in training

        X_SPLIT = tf.split(X, FLAGS.num_gpus)
        X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
        LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
        LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
        LABEL_SPLIT_INIT = list(LABEL_SPLIT)
        attention_mask = ATTENTION_MASK
        tower_grads = []
        tower_gen_grads = []
        x_mod_list = []

        optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.99)

        for j in range(FLAGS.num_gpus):

            x_mod = X_SPLIT[j]
            if FLAGS.comb_mask:
                steps = tf.constant(0)
                c = lambda i, x: tf.less(i, FLAGS.num_steps)

                def langevin_attention_step(counter, attention_mask):
                    attention_mask = attention_mask + tf.random_normal(tf.shape(attention_mask), mean=0.0, stddev=0.01)
                    energy_noise = energy_start = model.forward(
                                x_mod,
                                weights,
                                attention_mask,
                                label=LABEL_SPLIT[j],
                                reuse=True,
                                stop_at_grad=False,
                                stop_batch=True)

                    if FLAGS.heir_mask:
                        energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL)
                        energy_noise = energy_noise + energy_heir

                    attention_grad = tf.gradients(
                        FLAGS.temperature * energy_noise, [attention_mask])[0]
                    energy_noise_old = energy_noise

                    # Clip gradient norm for now
                    attention_mask = attention_mask - (FLAGS.attention_lr) * attention_grad
                    attention_mask = tf.layers.average_pooling2d(attention_mask, (3, 3), 1, padding='SAME')
                    attention_mask = tf.stop_gradient(attention_mask)

                    counter = counter + 1

                    return counter, attention_mask

                steps, attention_mask = tf.while_loop(c, langevin_attention_step, (steps, attention_mask))

                # attention_mask = tf.Print(attention_mask, [attention_mask])

                energy_pos = model.forward(
                        X_SPLIT[j],
                        weights,
                        tf.stop_gradient(attention_mask),
                        label=LABEL_POS_SPLIT[j],
                        stop_at_grad=False)

                if FLAGS.heir_mask:
                    energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL)
                    energy_pos = energy_heir + energy_pos

            else:
                energy_pos = model.forward(
                        X_SPLIT[j],
                        weights,
                        attention_mask,
                        label=LABEL_POS_SPLIT[j],
                        stop_at_grad=False)

                if FLAGS.heir_mask:
                    energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL)
                    energy_pos = energy_heir + energy_pos

            print("Building graph...")
            x_mod = x_orig = X_NOISE_SPLIT[j]

            x_grads = []

            loss_energys = []

            eps_begin = tf.zeros(1)

            steps = tf.constant(0)
            c_cond = lambda i, x, y: tf.less(i, FLAGS.num_steps)

            def langevin_step(counter, x_mod, attention_mask):

                lr = FLAGS.step_lr

                x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
                                                 mean=0.0,
                                                 stddev=0.001 * FLAGS.rescale * FLAGS.noise_scale)
                attention_mask = attention_mask + tf.random_normal(tf.shape(attention_mask), mean=0.0, stddev=0.01)

                energy_noise = model.forward(
                            x_mod,
                            weights,
                            attention_mask,
                            label=LABEL_SPLIT[j],
                            reuse=True,
                            stop_at_grad=False,
                            stop_batch=True)

                if FLAGS.prelearn_model:
                    for m_i, w_i, l_i in models_pretrain:
                        energy_noise = energy_noise + m_i.forward(
                                    x_mod,
                                    w_i,
                                    attention_mask,
                                    label=l_i,
                                    reuse=True,
                                    stop_at_grad=False,
                                    stop_batch=True)


                if FLAGS.heir_mask:
                    energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL)
                    energy_noise = energy_heir + energy_noise

                x_grad, attention_grad = tf.gradients(
                    FLAGS.temperature * energy_noise, [x_mod, attention_mask])

                if not FLAGS.comb_mask:
                    attention_grad = tf.zeros(1)
                energy_noise_old = energy_noise

                if FLAGS.proj_norm != 0.0:
                    if FLAGS.proj_norm_type == 'l2':
                        x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
                    elif FLAGS.proj_norm_type == 'li':
                        x_grad = tf.clip_by_value(
                            x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
                    else:
                        print("Other types of projection are not supported!!!")
                        assert False

                # Clip gradient norm for now
                x_last = x_mod - (lr) * x_grad

                if FLAGS.comb_mask:
                    attention_mask = attention_mask - FLAGS.attention_lr * attention_grad
                    attention_mask = tf.layers.average_pooling2d(attention_mask, (3, 3), 1, padding='SAME')
                    attention_mask = tf.stop_gradient(attention_mask)

                x_mod = x_last
                x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale)

                counter = counter + 1

                return counter, x_mod, attention_mask


            steps, x_mod, attention_mask = tf.while_loop(c_cond, langevin_step, (steps, x_mod, attention_mask))

            attention_mask = tf.stop_gradient(attention_mask)
            # attention_mask = tf.Print(attention_mask, [attention_mask])

            energy_eval = model.forward(x_mod, weights, attention_mask, label=LABEL_SPLIT[j],
                                        stop_at_grad=False, reuse=True)
            x_grad, attention_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod, attention_mask])
            x_grads.append(x_grad)

            energy_neg = model.forward(
                    tf.stop_gradient(x_mod),
                    weights,
                    tf.stop_gradient(attention_mask),
                    label=LABEL_SPLIT[j],
                    stop_at_grad=False,
                    reuse=True)

            if FLAGS.heir_mask:
                energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL)
                energy_neg = energy_heir + energy_neg


            temp = FLAGS.temperature

            x_off = tf.reduce_mean(
                tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))

            loss_energy = model.forward(
                x_mod,
                weights,
                attention_mask,
                reuse=True,
                label=LABEL,
                stop_grad=True)

            print("Finished processing loop construction ...")

            target_vars = {}

            if FLAGS.antialias:
                antialias = tf.tile(stride_3, (1, 1, tf.shape(x_mod)[3], tf.shape(x_mod)[3]))
                inp = tf.nn.conv2d(x_mod, antialias, [1, 2, 2, 1], padding='SAME')

            test_x_mod = x_mod

            if FLAGS.cclass or FLAGS.model_cclass:
                label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
                label_prob = label_sum / tf.reduce_sum(label_sum)
                label_ent = -tf.reduce_sum(label_prob *
                                           tf.math.log(label_prob + 1e-7))
            else:
                label_ent = tf.zeros(1)

            target_vars['label_ent'] = label_ent

            if FLAGS.train:
                if FLAGS.objective == 'logsumexp':
                    pos_term = temp * energy_pos
                    energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
                    coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced))
                    norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
                    pos_loss = tf.reduce_mean(temp * energy_pos)
                    neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant
                    loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
                elif FLAGS.objective == 'cd':
                    pos_loss = tf.reduce_mean(temp * energy_pos)
                    neg_loss = -tf.reduce_mean(temp * energy_neg)
                    loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
                elif FLAGS.objective == 'softplus':
                    loss_ml = FLAGS.ml_coeff * \
                        tf.nn.softplus(temp * (energy_pos - energy_neg))

                loss_total = tf.reduce_mean(loss_ml)

                if not FLAGS.zero_kl:
                    loss_total = loss_total + tf.reduce_mean(loss_energy)

                loss_total = loss_total + \
                    FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))))

                print("Started gradient computation...")
                gvs = optimizer.compute_gradients(loss_total)
                gvs = [(k, v) for (k, v) in gvs if k is not None]

                print("Applying gradients...")

                tower_grads.append(gvs)

                print("Finished applying gradients.")

                target_vars['loss_ml'] = loss_ml
                target_vars['total_loss'] = loss_total
                target_vars['loss_energy'] = loss_energy
                target_vars['weights'] = weights
                target_vars['gvs'] = gvs

            target_vars['X'] = X
            target_vars['Y'] = Y
            target_vars['LABEL'] = LABEL
            target_vars['HIER_LABEL'] = HEIR_LABEL
            target_vars['LABEL_POS'] = LABEL_POS
            target_vars['X_NOISE'] = X_NOISE
            target_vars['energy_pos'] = energy_pos
            target_vars['attention_grad'] = attention_grad

            if len(x_grads) >= 1:
                target_vars['x_grad'] = x_grads[-1]
                target_vars['x_grad_first'] = x_grads[0]
            else:
                target_vars['x_grad'] = tf.zeros(1)
                target_vars['x_grad_first'] = tf.zeros(1)

            target_vars['x_mod'] = x_mod
            target_vars['x_off'] = x_off
            target_vars['temp'] = temp
            target_vars['energy_neg'] = energy_neg
            target_vars['test_x_mod'] = test_x_mod
            target_vars['eps_begin'] = eps_begin
            target_vars['ATTENTION_MASK'] = ATTENTION_MASK
            target_vars['models_pretrain'] = models_pretrain
            if FLAGS.comb_mask:
                target_vars['attention_mask'] = tf.nn.softmax(attention_mask)
            else:
                target_vars['attention_mask'] = tf.zeros(1)

        if FLAGS.train:
            grads = average_gradients(tower_grads)
            train_op = optimizer.apply_gradients(grads)
            target_vars['train_op'] = train_op

    # sess = tf.Session(config=config)

    saver = loader = tf.train.Saver(
        max_to_keep=30, keep_checkpoint_every_n_hours=6)

    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Model has a total of {} parameters".format(total_parameters))

    sess.run(tf.global_variables_initializer())

    resume_itr = 0

    if (FLAGS.resume_iter != -1 or not FLAGS.train):
        model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter
        # saver.restore(sess, model_file)
        optimistic_restore(sess, model_file)

    print("Initializing variables...")

    print("Start broadcast")
    print("End broadcast")

    if FLAGS.train:
        train(target_vars, saver, sess,
              logger, data_loader, resume_itr,
              logdir)

    test(target_vars, saver, sess, logger, data_loader)
Exemplo n.º 4
0
args.custom_attr = args_.custom_attr
args.n_attrs = len(args.attrs)
args.betas = (args.beta1, args.beta2)

print(args)


if args.custom_img:
    output_path = join('output', args.experiment_name, 'custom_testing')
    from data import Custom
    test_dataset = Custom(args.custom_data, args.custom_attr, args.img_size, 'test', args.attrs)
else:
    output_path = join('output', args.experiment_name, 'sample_testing')
    if args.data == 'CelebA':
        from data import CelebA
        test_dataset = CelebA(args.data_path, args.attr_path, args.img_size, 'test', args.attrs)
    if args.data == 'CelebA-HQ':
        from data import CelebA_HQ
        test_dataset = CelebA_HQ(args.data_path, args.attr_path, args.image_list_path, args.img_size, 'test', args.attrs)
os.makedirs(output_path, exist_ok=True)
test_dataloader = data.DataLoader(
    test_dataset, batch_size=1, num_workers=args.num_workers,
    shuffle=False, drop_last=False
)
if args.num_test is None:
    print('Testing images:', len(test_dataset))
else:
    print('Testing images:', min(len(test_dataset), args.num_test))


attgan = AttGAN(args)
Exemplo n.º 5
0
    # Paths
    checkpoint_path = join('results', args.experiment_name, 'checkpoint')
    sample_path = join('results', args.experiment_name, 'sample')
    summary_path = join('results', args.experiment_name, 'summary')
    os.makedirs(checkpoint_path, exist_ok=True)
    os.makedirs(sample_path, exist_ok=True)
    os.makedirs(summary_path, exist_ok=True)
    with open(join('results', args.experiment_name, 'setting.json'),
              'w',
              encoding='utf-8') as f:
        json.dump(vars(args), f, indent=2, sort_keys=True)
    writer = SummaryWriter(summary_path)

    # Data
    selected_attrs = [args.target_attr]
    train_dset = CelebA(args.data_path, args.attr_path, args.image_size,
                        'train', selected_attrs)
    train_data = data.DataLoader(train_dset,
                                 args.batch_size,
                                 shuffle=True,
                                 drop_last=True)
    train_data = loop(train_data)
    test_dset = CelebA(args.data_path, args.attr_path, args.image_size, 'test',
                       selected_attrs)
    test_data = data.DataLoader(test_dset, args.num_samples)
    for fixed_reals, fixed_labels in test_data:
        # Get the first batch of images from the testing set
        fixed_reals, fixed_labels = fixed_reals.to(
            device), fixed_labels.type_as(fixed_reals).to(device)
        fixed_target_labels = 1 - fixed_labels
        break
    del test_dset
Exemplo n.º 6
0
    # TODO: support conditional inputs

    args = parser.parse_args()
    opts = {k: v for k, v in args._get_kwargs()}

    latent_size = 512
    sigmoid_at_end = args.gan in ['lsgan', 'gan']
    if hasattr(args, 'no_tanh'):
        tanh_at_end = False
    else:
        tanh_at_end = True

    G = Generator(num_channels=3,
                  latent_size=latent_size,
                  resolution=args.target_resol,
                  fmap_max=latent_size,
                  fmap_base=8192,
                  tanh_at_end=tanh_at_end)
    D = Discriminator(num_channels=3,
                      resolution=args.target_resol,
                      fmap_max=latent_size,
                      fmap_base=8192,
                      sigmoid_at_end=sigmoid_at_end)
    print(G)
    print(D)
    data = CelebA()
    noise = RandomNoiseGenerator(latent_size, 'gaussian')
    pggan = PGGAN(G, D, data, noise, opts)
    pggan.train()
Exemplo n.º 7
0
print(args)

args.lr_base = args.lr
args.n_attrs = len(args.attrs)
args.betas = (args.beta1, args.beta2)

os.makedirs(join('output', args.experiment_name), exist_ok=True)
os.makedirs(join('output', args.experiment_name, 'checkpoint'), exist_ok=True)
os.makedirs(join('output', args.experiment_name, 'sample_training'),
            exist_ok=True)
with open(join('output', args.experiment_name, 'setting.txt'), 'w') as f:
    f.write(json.dumps(vars(args), indent=4, separators=(',', ':')))

if args.data == 'CelebA':
    from data import CelebA
    train_dataset = CelebA(args.data_path, args.attr_path, args.img_size,
                           'train', args.attrs)
    valid_dataset = CelebA(args.data_path, args.attr_path, args.img_size,
                           'valid', args.attrs)
if args.data == 'CelebA-HQ':
    from data import CelebA_HQ
    train_dataset = CelebA_HQ(args.data_path, args.attr_path,
                              args.image_list_path, args.img_size, 'train',
                              args.attrs)
    valid_dataset = CelebA_HQ(args.data_path, args.attr_path,
                              args.image_list_path, args.img_size, 'valid',
                              args.attrs)
train_dataloader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.num_workers,
                                   shuffle=True,
                                   drop_last=True)
Exemplo n.º 8
0
    test_annos[i, index] = value

tf = transforms.Compose([
    transforms.Resize(image_size), 
    transforms.CenterCrop(image_size), 
    transforms.ToTensor(), 
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
test_dataset = datasets.ImageFolder(root='test_imgs', transform=tf)
test_dataloader = data.DataLoader(test_dataset, batch_size=batch_size)


if dataset == 'celeba':
    from data import CelebA, PairedData
    train_dset = CelebA(
        data_path, image_size, selected_attr=selected_attributes, mode='train', test_num=2000
    )
    train_data = PairedData(train_dset, batch_size)
    valid_dset = CelebA(
        data_path, image_size, selected_attr=selected_attributes, mode='val', test_num=2000
    )
    valid_data = PairedData(valid_dset, batch_size)
if dataset == 'celeba-hq':
    from data import CelebAHQ, PairedData
    train_dset = CelebAHQ(
        data_path, image_size, selected_attr=selected_attributes, mode='train', test_num=2000
    )
    train_data = PairedData(train_dset, batch_size)
    valid_dset = CelebAHQ(
        data_path, image_size, selected_attr=selected_attributes, mode='val', test_num=2000
    )
Exemplo n.º 9
0
    attrs_default = [
        'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair',
        'Bushy_Eyebrows', 'Eyeglasses', 'Male', 'Mouth_Slightly_Open',
        'Mustache', 'No_Beard', 'Pale_Skin', 'Young'
    ]
    attrs_list = attrs_default.copy()
    batch_size = 32
    n_samples = 16

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    PLL = PerceptualLoss()
    model_e_d = LSA_VAE(8, len(attrs_list))

    from data import CelebA
    train_dataset = CelebA(data_path, attr_path, img_size, 'train', attrs_list)
    valid_dataset = CelebA(data_path, attr_path, img_size, 'valid', attrs_list)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  num_workers=num_workers,
                                  shuffle=True,
                                  drop_last=True)
    valid_dataloader = DataLoader(valid_dataset,
                                  batch_size=n_samples,
                                  num_workers=num_workers,
                                  shuffle=False,
                                  drop_last=False)

    #    SGD_optimizer = optim.SGD(model_e_d.parameters(), lr=0.0005, momentum=0.5, weight_decay=1e-4)
    ADAM_optimizer = optim.Adam(model_e_d.parameters(),
                                lr=0.0005,
Exemplo n.º 10
0
# Logger
base_dir = 'results/gamma={:.1f}_run={:d}'.format(args.gamma, args.run)
writer = tb.FileWriter(os.path.join(base_dir, 'log.csv'), args=args, overwrite=args.run >= 999)
writer.add_var('d_real', '{:8.4f}', d_real_loss)
writer.add_var('d_fake', '{:8.4f}', d_fake_loss)
writer.add_var('k', '{:8.4f}', k * 1)
writer.add_var('M', '{:8.4f}', m_global)
writer.add_var('lr', '{:8.6f}', lr * 1)
writer.add_var('iter', '{:>8d}')
writer.initialize()

sess = tf.Session()
load_model(sess)
f_gen = tb.function(sess, [z], x_fake)
f_rec = tb.function(sess, [x_real], d_real)
celeba = CelebA(args.data)

# Alternatively try grouping d_train/g_train together
all_tensors = [d_train, g_train, d_real_loss, d_fake_loss]
# d_tensors = [d_train, d_real_loss]
# g_tensors = [g_train, d_fake_loss]
for i in xrange(args.max_iter):
    x = celeba.next_batch(args.bs)
    z = np.random.uniform(-1, 1, (args.bs, args.e_size))

    feed_dict = {'x:0': x, 'z:0': z, 'k:0': args.k, 'lr:0': args.lr, 'g:0': args.gamma}
    _, _, d_real_loss, d_fake_loss = sess.run(all_tensors, feed_dict)
    # _, d_real_loss = sess.run(d_tensors, feed_dict)
    # _, d_fake_loss = sess.run(g_tensors, feed_dict)
    args.k = np.clip(args.k + args.lambd * (args.gamma * d_real_loss - d_fake_loss), 0., 1.)
Exemplo n.º 11
0
from networks import create_Generator
from data import CelebA
from model import InpaintingModel
import json, time, os
from config import config

G = create_Generator(config)
celeba = CelebA(os.path.expanduser('~/datasets/celeba-hq-1024x1024.h5'),
                os.path.expanduser('~/datasets/holes_hq.hdf5'))
config['model_dir'] = config['model_dir'].replace(
    '<time>', time.strftime("%Y-%b-%d %H_%M_%S"))
os.makedirs(config['model_dir'])
with open(os.path.join(config['model_dir'], 'config.json'), 'w') as f:
    json.dump(config, f)

model = InpaintingModel(G, celeba, config)
model.run()