예제 #1
0
    def d_loss(self, Xs, Xt, same_person):
        with torch.no_grad():
            src_embed = self.arcface(F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True))
        Y_hat = self.generator(Xt, src_embed, return_attributes=False)
          
        fake_D = self.discriminator(Y_hat.detach())
        L_fake = 0
        for di in fake_D:
            L_fake += hinge_loss(di[0], False)
        real_D = self.discriminator(Xs)
        L_real = 0
        for di in real_D:
            L_real += hinge_loss(di[0], True)

        L_discriminator = 0.5*(L_real + L_fake)
        return L_fake, L_real, L_discriminator
예제 #2
0
    def g_loss(self, Xs, Xt, same_person):
        with torch.no_grad():
            src_embed = self.arcface(F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True))
            tgt_embed = self.arcface(F.interpolate(Xt[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True))

        Y_hat, Xt_attr = self.generator(Xt, src_embed, return_attributes=True)        

        Di = self.discriminator(Y_hat)

        L_adv = 0
        for di in Di:
            L_adv += hinge_loss(di[0], True)

        fake_embed = self.arcface(F.interpolate(Y_hat[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True))
        L_src_id =(1 - torch.cosine_similarity(src_embed, fake_embed, dim=1)).mean()
        L_tgt_id =(1 - torch.cosine_similarity(tgt_embed, fake_embed, dim=1)).mean()

        batch_size = Xs.shape[0]
        Y_hat_attr = self.generator.get_attr(Y_hat)
        L_attr = 0
        for i in range(len(Xt_attr)):
            L_attr += torch.mean(torch.pow(Xt_attr[i] - Y_hat_attr[i], 2).reshape(batch_size, -1), dim=1).mean()
        L_attr /= 2.0

        L_rec = torch.sum(0.5 * torch.mean(torch.pow(Y_hat - Xt, 2).reshape(batch_size, -1), dim=1) * same_person) / (same_person.sum() + 1e-6)
        L_generator = (self.adversarial_weight * L_adv) + (self.src_id_weight * L_src_id) + (self.tgt_id_weight * L_tgt_id) + (self.attributes_weight * L_attr) + (self.reconstruction_weight * L_rec)
        return L_adv, L_src_id, L_tgt_id, L_attr, L_rec, L_generator
예제 #3
0
def main():
    size = FLAGS.img_size

    # debug
    if len(FLAGS.train_dir) < 1:
        bn_name = ["nobn", "caffebn", "simplebn", "defaultbn", "cbn"]
        FLAGS.train_dir = os.path.join(
            "logs", FLAGS.model_name + "_" + bn_name[FLAGS.bn] + "_" +
            str(FLAGS.phases))

    if FLAGS.cgan:
        # the label file is npy format
        npy_dir = FLAGS.data_dir.replace(".zip", "") + '.npy'
    else:
        npy_dir = None

    if "celeb" in FLAGS.data_dir:
        dataset = dataloader.CelebADataset(FLAGS.data_dir,
                                           img_size=(size, size),
                                           npy_dir=npy_dir)
    elif "cityscapes" in FLAGS.data_dir:
        augmentations = Compose([
            RandomCrop(size * 4),
            Scale(size * 2),
            RandomRotate(10),
            RandomHorizontallyFlip(),
            RandomSizedCrop(size)
        ])
        dataset = dataloader.cityscapesLoader(FLAGS.data_dir,
                                              is_transform=True,
                                              augmentations=augmentations,
                                              img_size=(size, size))
        FLAGS.batch_size /= 64
    else:
        dataset = dataloader.FileDataset(FLAGS.data_dir,
                                         npy_dir=npy_dir,
                                         img_size=(size, size))

    dl = dataloader.TFDataloader(dataset, FLAGS.batch_size,
                                 dataset.file_num // FLAGS.batch_size)

    # TF Input
    x_fake_sample = tf.placeholder(tf.float32, [None, size, size, 3],
                                   name="x_fake_sample")
    x_real = tf.placeholder(tf.float32, [None, size, size, 3], name="x_real")
    s_real = tf.placeholder(tf.float32, [None, size, size, 3], name='s_real')
    z_noise = tf.placeholder(tf.float32, [None, 128], name="z_noise")

    if FLAGS.cgan:
        c_noise = tf.placeholder(tf.float32, [None, dataset.class_num],
                                 name="c_noise")
        c_label = tf.placeholder(tf.float32, [None, dataset.class_num],
                                 name="c_label")
        gen_input = [z_noise, c_noise]
    else:
        gen_input = z_noise

    # look up the config function from lib.config module
    gen_model, disc_model = getattr(config,
                                    FLAGS.model_name)(FLAGS.img_size,
                                                      dataset.class_num)
    disc_model.norm_mtd = FLAGS.bn

    x_fake = gen_model(gen_input, update_collection=None)
    gen_model.set_reuse()
    gen_model.x_fake = x_fake

    disc_model.set_label(c_noise)
    if FLAGS.phases > 1:
        disc_model.set_phase("fake")
    else:
        disc_model.set_phase("default")
    disc_fake, fake_cls_logits = disc_model(x_fake, update_collection=None)
    disc_model.set_reuse()

    disc_model.set_label(c_label)
    if FLAGS.phases > 1:
        disc_model.set_phase("real")
    else:
        disc_model.set_phase("default")
    disc_real, real_cls_logits = disc_model(x_real, update_collection=None)
    disc_model.disc_real = disc_real
    disc_model.disc_fake = disc_fake
    disc_model.real_cls_logits = real_cls_logits
    disc_model.fake_cls_logits = fake_cls_logits

    int_sum_op = []

    if FLAGS.use_cache:
        disc_fake_sample = disc_model(x_fake_sample)[0]
        disc_cost_sample = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=disc_fake_sample,
                labels=tf.zeros_like(disc_fake_sample)),
            name="cost_disc_fake_sample")
        disc_cost_sample_sum = tf.summary.scalar("disc_sample",
                                                 disc_cost_sample)

        fake_sample_grid = ops.get_grid_image_summary(x_fake_sample, 4)
        int_sum_op.append(tf.summary.image("fake sample", fake_sample_grid))

        sample_method = [disc_cost_sample, disc_cost_sample_sum, x_fake_sample]
    else:
        sample_method = None

    grid_x_fake = ops.get_grid_image_summary(gen_model.x_fake, 4)
    int_sum_op.append(tf.summary.image("generated image", grid_x_fake))

    grid_x_real = ops.get_grid_image_summary(x_real, 4)
    int_sum_op.append(tf.summary.image("real image", grid_x_real))

    int_sum_op = tf.summary.merge(int_sum_op)

    raw_gen_cost, raw_disc_real, raw_disc_fake = loss.hinge_loss(
        gen_model, disc_model, adv_weight=1.0, summary=False)
    disc_model.disc_real_loss = raw_disc_real
    disc_model.disc_fake_loss = raw_disc_fake

    if FLAGS.cgan:
        real_cls_cost, fake_cls_cost = loss.classifier_loss(gen_model,
                                                            disc_model,
                                                            x_real,
                                                            c_label,
                                                            c_noise,
                                                            weight=1.0 /
                                                            dataset.class_num,
                                                            summary=False)

    step_sum_op = []
    subloss_names = ["fake_cls", "real_cls", "gen", "disc_real", "disc_fake"]
    sublosses = [
        fake_cls_cost, real_cls_cost, raw_gen_cost, raw_disc_real,
        raw_disc_fake
    ]
    for n, l in zip(subloss_names, sublosses):
        step_sum_op.append(tf.summary.scalar(n, l))
    step_sum_op = tf.summary.merge(step_sum_op)

    ModelTrainer = trainer.base_gantrainer.BaseGANTrainer(  #trainer.separated_gantrainer.SeparatedGANTrainer(#
        int_sum_op=int_sum_op,
        step_sum_op=step_sum_op,
        dataloader=dl,
        FLAGS=FLAGS,
        gen_model=gen_model,
        disc_model=disc_model,
        gen_input=gen_input,
        x_real=x_real,
        label=c_label,
        sample_method=sample_method)

    #command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer)
    #command_controller.start_thread()

    print("=> Build train op")
    ModelTrainer.build_train_op()

    print("=> ##### Generator Variable #####")
    gen_model.print_trainble_vairables()
    print("=> ##### Discriminator Variable #####")
    disc_model.print_trainble_vairables()
    print("=> ##### All Variable #####")
    for v in tf.trainable_variables():
        print("%s\t\t\t\t%s" % (v.name, str(v.get_shape().as_list())))
    print("=> #### Moving Variable ####")
    for v in tf.global_variables():
        if "moving" in v.name:
            print("%s\t\t\t\t%s" % (v.name, str(v.get_shape().as_list())))

    ModelTrainer.init_training()
    ModelTrainer.train()
예제 #4
0
def run(args, data_iter, model, gender, optimizers, epoch, train=True, pretrain=False):
    n = args.batch_size
    size = len(data_iter.dataset)
    device = args.device
    dataset = args.data.rstrip('/').split('/')[-1]
    if args.loss == 'hinge':
        criterion = hinge_loss()
    else:
        criterion = torch.nn.CrossEntropyLoss()
    optimizer, optimizer_phi = optimizers
    # kernel = gp.kernels.RBF(input_dim=args.code_dim*3, variance=torch.tensor(5.),
                            # lengthscale=torch.tensor(10.))
    # kernel = gp.kernels.Linear(input_dim=args.code_dim)
    kernel = dot_kernel
    if train:
        model.train()
    else:
        model.eval()

    clf_loss = 0
    clf_acc = 0
    correct = 0
    correct_g = 0
    correct_ge = 0
    total = 0
    total_m = 0
    total_f = 0
    y_m = 0
    y_f = 0
    hs = 0
    for i, data in enumerate(data_iter):
        inputs, label, factor = [d.to(device) for d in data] 
        label = label.long().squeeze(1)
        if dataset == 'adult.data':
            label_g = factor.chunk(2, dim=-1)[1].squeeze(1).long()
        else:
            label_g = factor.long().squeeze(1)
            # _, label_g = torch.max(label_g, 1)
        # label = label.long()
        # label_g = factor.chunk(2, dim=-1)[0].long()
        # label_r = factor.chunk(2, dim=-1)[0].long()

        y, z, _ = model(inputs)
        phi = model.classifier.map(F.relu(z))
        loss = criterion(y, label)
        if args.crit == 'hsic':
            if args.fix:
                kic = HSIC(z, label_g, True)    
            else:
                kic = HSIC(phi, label_g)
        elif args.crit == 'coco':
            if args.fix:
                kic = COCO(z, label_g, True)    
            else:
                kic = COCO(phi, label_g)
        total_loss = loss + args.c * kic

        if train:
                
            if args.hsic:
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

                optimizer_phi.zero_grad()
                phi = model.classifier.map(F.relu(z.detach()))
                if args.crit == 'hsic':
                    if args.fix:
                        neg_h = -HSIC(phi, label_g)
                    else:
                        neg_h = -HSIC(z, label_g, True)
                else:
                    if args.fix:
                        neg_h = -COCO(phi, label_g)
                    else:
                        neg_h = -COCO(z, label_g, True)
                neg_h.backward()
                optimizer_phi.step()

            else:
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()
        
        _, predicted = torch.max(y.data, 1)
        correct += (predicted == label).sum().item()
        total += label.size(0)
        
        ones = torch.ones(label.size(0), dtype=torch.long).to(device)
        zeros = torch.zeros(label.size(0), dtype=torch.long).to(device)
        total_m += (ones == label_g).sum().item()
        total_f += (zeros == label_g).sum().item()
        y_m += ((predicted == ones) == (ones == label_g)).sum().item()
        y_f += ((predicted == ones) == (zeros == label_g)).sum().item()

        # predicted = (torch.sigmoid(y) >= 0.5).long()
        # predicted_r = (torch.sigmoid(r) >= 0.5).long()
        # predicted_g = (torch.sigmoid(g) >= 0.5).long()
        # correct += (predicted == label).sum().item()
        # correct_r += (predicted_r == label_r).sum().item()
        # correct_g += (predicted_g == label_g).sum().item()
        # total += label.size(0)
        # ones = torch.ones((label.size(0), 1), dtype=torch.long).to(device)
        # zeros = torch.zeros((label.size(0), 1), dtype=torch.long).to(device)
        # total_m += (ones == label_g).sum().item(), optimizer_ge
        # total_f += (zeros == label_g).sum().item()
        # y_m += ((predicted == ones) == (ones == label_g)).sum().item()
        # y_f += ((predicted == zeros) == (zeros == label_g)).sum().item()
        # print(((predicted == ones) == (zeros == label_g)).sum().item())

        clf_loss += loss.item()
        hs += kic.item()
    
    clf_acc = 100 * correct / total
    parity = np.abs(y_m / total_m - y_f / total_f)
    male = total_m / total

    return clf_loss, clf_acc, parity, hs  
예제 #5
0
def main():
    size = FLAGS.img_size

    if len(FLAGS.train_dir) < 1:
        # if the train dir is not set, then automatically decide one
        FLAGS.train_dir = os.path.join("logs",
                                       FLAGS.model_name + str(FLAGS.img_size))
        if FLAGS.cgan:
            FLAGS.train_dir += "_cgan"

    if FLAGS.cgan:
        # the label file should be npy format
        npy_dir = FLAGS.data_dir.replace(".zip", "") + '.npy'
    else:
        npy_dir = None

    if "celeb" in FLAGS.data_dir:
        dataset = dataloader.CelebADataset(FLAGS.data_dir,
                                           img_size=(size, size),
                                           npy_dir=npy_dir)
    elif "cityscapes" in FLAGS.data_dir:
        # outdated
        augmentations = Compose([
            RandomCrop(size * 4),
            Scale(size * 2),
            RandomRotate(10),
            RandomHorizontallyFlip(),
            RandomSizedCrop(size)
        ])
        dataset = dataloader.cityscapesLoader(FLAGS.data_dir,
                                              is_transform=True,
                                              augmentations=augmentations,
                                              img_size=(size, size))
        FLAGS.batch_size /= 64
    else:
        dataset = dataloader.FileDataset(FLAGS.data_dir,
                                         npy_dir=npy_dir,
                                         img_size=(size, size))

    dl = dataloader.TFDataloader(dataset, FLAGS.batch_size,
                                 dataset.file_num // FLAGS.batch_size)

    # TF Input
    x_fake_sample = tf.placeholder(tf.float32, [None, size, size, 3],
                                   name="x_fake_sample")
    x_real = tf.placeholder(tf.float32, [None, size, size, 3], name="x_real")
    s_real = tf.placeholder(tf.float32, [None, size, size, 3], name='s_real')
    z_noise = tf.placeholder(tf.float32, [None, 128], name="z_noise")

    if FLAGS.cgan:
        c_noise = tf.placeholder(tf.float32, [None, dataset.class_num],
                                 name="c_noise")
        c_label = tf.placeholder(tf.float32, [None, dataset.class_num],
                                 name="c_label")
        gen_input = [z_noise, c_noise]
    else:
        gen_input = z_noise
        c_label = c_noise = None

    # look up the config function from lib.config module
    gen_model, disc_model = getattr(config,
                                    FLAGS.model_name)(FLAGS.img_size,
                                                      dataset.class_num)

    gen_model.label = c_noise
    x_fake = gen_model(gen_input)
    gen_model.set_reuse()
    gen_model.x_fake = x_fake

    disc_model.label = c_noise
    disc_fake, fake_cls_logits = disc_model(x_fake)
    disc_model.set_reuse()

    disc_model.label = c_label
    disc_real, real_cls_logits = disc_model(x_real)
    disc_model.disc_real = disc_real
    disc_model.disc_fake = disc_fake
    disc_model.real_cls_logits = real_cls_logits
    disc_model.fake_cls_logits = fake_cls_logits

    raw_gen_cost, raw_disc_real, raw_disc_fake = loss.hinge_loss(
        gen_model, disc_model, adv_weight=1.0, summary=False)
    disc_model.disc_real_loss = raw_disc_real
    disc_model.disc_fake_loss = raw_disc_fake

    if FLAGS.cgan:
        real_cls_cost, fake_cls_cost = loss.classifier_loss(gen_model,
                                                            disc_model,
                                                            x_real,
                                                            c_label,
                                                            c_noise,
                                                            weight=1.0 /
                                                            dataset.class_num,
                                                            summary=False)
        subloss_names = [
            "fake_cls", "real_cls", "gen", "disc_real", "disc_fake"
        ]
        sublosses = [
            fake_cls_cost, real_cls_cost, raw_gen_cost, raw_disc_real,
            raw_disc_fake
        ]
    else:
        subloss_names = ["gen", "disc_real", "disc_fake"]
        sublosses = [raw_gen_cost, raw_disc_real, raw_disc_fake]

    step_sum_op = []  # summary at every step

    for n, l in zip(subloss_names, sublosses):
        step_sum_op.append(tf.summary.scalar(n, l))
    if gen_model.debug or disc_model.debug:
        for model_var in tf.global_variables():
            if gen_model.name in model_var.op.name or disc_model.name in model_var.op.name:
                step_sum_op.append(
                    tf.summary.histogram(model_var.op.name, model_var))
    step_sum_op = tf.summary.merge(step_sum_op)

    int_sum_op = []  # summary at some interval

    grid_x_fake = ops.get_grid_image_summary(gen_model.x_fake, 4)
    int_sum_op.append(tf.summary.image("generated image", grid_x_fake))

    grid_x_real = ops.get_grid_image_summary(x_real, 4)
    int_sum_op.append(tf.summary.image("real image", grid_x_real))

    int_sum_op = tf.summary.merge(int_sum_op)

    ModelTrainer = trainer.base_gantrainer.BaseGANTrainer(
        int_sum_op=int_sum_op,
        step_sum_op=step_sum_op,
        dataloader=dl,
        FLAGS=FLAGS,
        gen_model=gen_model,
        disc_model=disc_model,
        gen_input=gen_input,
        x_real=x_real,
        label=c_label)

    print("=> Build train op")
    ModelTrainer.build_train_op()

    print("=> ##### Generator Variable #####")
    gen_model.print_trainble_vairables()
    print("=> ##### Discriminator Variable #####")
    disc_model.print_trainble_vairables()
    print("=> #### Moving Variable ####")
    for v in tf.global_variables():
        if "moving" in v.name:
            print("%s\t\t\t\t%s" % (v.name, str(v.get_shape().as_list())))
    print("=> #### Generator update dependency ####")
    for v in gen_model.update_ops:
        print("%s" % (v.name))
    print("=> #### Discriminator update dependency ####")
    for v in disc_model.update_ops:
        print("%s" % (v.name))
    ModelTrainer.init_training()
    ModelTrainer.train()
예제 #6
0
def main():
    size = FLAGS.img_size

    if FLAGS.cgan:
        npy_dir = FLAGS.data_dir.replace(".zip", "") + '.npy'
    else:
        npy_dir = None

    if "celeb" in FLAGS.data_dir:
        dataset = dataloader.CelebADataset(FLAGS.data_dir,
                                           img_size=(size, size),
                                           npy_dir=npy_dir)
    elif "cityscapes" in FLAGS.data_dir:
        augmentations = Compose([
            RandomCrop(size * 4),
            Scale(size * 2),
            RandomRotate(10),
            RandomHorizontallyFlip(),
            RandomSizedCrop(size)
        ])
        dataset = dataloader.cityscapesLoader(FLAGS.data_dir,
                                              is_transform=True,
                                              augmentations=augmentations,
                                              img_size=(size, size))
    else:
        dataset = dataloader.FileDataset(FLAGS.data_dir,
                                         npy_dir=npy_dir,
                                         img_size=(size, size),
                                         shuffle=True)

    dl = DataLoader(dataset,
                    batch_size=FLAGS.batch_size // 64,
                    shuffle=True,
                    num_workers=NUM_WORKER,
                    collate_fn=dataloader.default_collate)

    # TF Input
    x_fake_sample = tf.placeholder(tf.float32, [None, size, size, 3],
                                   name="x_fake_sample")
    x_real = tf.placeholder(tf.float32, [None, size, size, 3], name="x_real")
    s_real = tf.placeholder(tf.float32, [None, size, size, 3], name='s_real')
    z_noise = tf.placeholder(tf.float32, [None, 128], name="z_noise")

    if FLAGS.cgan:
        c_noise = tf.placeholder(tf.float32, [None, dataset.class_num],
                                 name="c_noise")
        c_label = tf.placeholder(tf.float32, [None, dataset.class_num],
                                 name="c_label")
        gen_input = [z_noise, c_noise]
    else:
        gen_input = z_noise

    gen_model, disc_model = getattr(config,
                                    FLAGS.model_name)(FLAGS.img_size,
                                                      dataset.class_num)
    gen_model.cbn_project = FLAGS.cbn_project

    x_fake = gen_model(gen_input, update_collection=None)
    gen_model.set_reuse()
    gen_model.x_fake = x_fake

    disc_real, real_cls_logits = disc_model(x_real, update_collection=None)
    disc_model.set_reuse()
    disc_fake, fake_cls_logits = disc_model(x_fake, update_collection="no_ops")
    disc_model.disc_real = disc_real
    disc_model.disc_fake = disc_fake
    disc_model.real_cls_logits = real_cls_logits
    disc_model.fake_cls_logits = fake_cls_logits

    int_sum_op = []

    if FLAGS.use_cache:
        disc_fake_sample = disc_model(x_fake_sample)[0]
        disc_cost_sample = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=disc_fake_sample,
                labels=tf.zeros_like(disc_fake_sample)),
            name="cost_disc_fake_sample")
        disc_cost_sample_sum = tf.summary.scalar("disc_sample",
                                                 disc_cost_sample)

        fake_sample_grid = ops.get_grid_image_summary(x_fake_sample, 4)
        int_sum_op.append(tf.summary.image("fake sample", fake_sample_grid))

        sample_method = [disc_cost_sample, disc_cost_sample_sum, x_fake_sample]
    else:
        sample_method = None

    grid_x_fake = ops.get_grid_image_summary(gen_model.x_fake, 4)
    int_sum_op.append(tf.summary.image("generated image", grid_x_fake))

    grid_x_real = ops.get_grid_image_summary(x_real, 4)
    int_sum_op.append(tf.summary.image("real image", grid_x_real))

    if FLAGS.cgan:
        loss.classifier_loss(gen_model,
                             disc_model,
                             x_real,
                             c_label,
                             c_noise,
                             weight=1.0)

    loss.hinge_loss(gen_model, disc_model, adv_weight=1.0)

    int_sum_op = tf.summary.merge(int_sum_op)

    ModelTrainer = trainer.base_gantrainer.BaseGANTrainer(
        int_sum_op=int_sum_op,
        dataloader=dl,
        FLAGS=FLAGS,
        gen_model=gen_model,
        disc_model=disc_model,
        gen_input=gen_input,
        x_real=x_real,
        label=c_label,
        sample_method=sample_method)

    command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer)
    command_controller.start_thread()

    print("=> Build train op")
    ModelTrainer.build_train_op()

    print("=> ##### Generator Variable #####")
    gen_model.print_trainble_vairables()
    print("=> ##### Discriminator Variable #####")
    disc_model.print_trainble_vairables()
    print("=> ##### All Variable #####")
    for v in tf.trainable_variables():
        print("%s\t\t\t\t%s" % (v.name, str(v.get_shape().as_list())))
    print("=> #### Moving Variable ####")
    for v in tf.global_variables():
        if "moving" in v.name:
            print("%s\t\t\t\t%s" % (v.name, str(v.get_shape().as_list())))

    ModelTrainer.init_training()
    ModelTrainer.train()
예제 #7
0
    def tower(gpu_id,
              gen_input,
              x_real,
              c_label=None,
              c_noise=None,
              update_collection=None,
              loss_collection=[]):
        """
        The loss function builder of gen and disc
        """
        gen_model.cost = disc_model.cost = 0

        gen_model.set_phase("gpu%d" % gpu_id)
        x_fake = gen_model(gen_input, update_collection=update_collection)
        gen_model.set_reuse()
        gen_model.x_fake = x_fake

        disc_model.set_phase("gpu%d" % gpu_id)
        disc_real, real_cls_logits = disc_model(
            x_real, update_collection=update_collection)
        disc_model.set_reuse()
        disc_model.recorded_tensors = []
        disc_model.recorded_names = []
        disc_fake, fake_cls_logits = disc_model(
            x_fake, update_collection=update_collection)
        disc_model.disc_real = disc_real
        disc_model.disc_fake = disc_fake
        disc_model.real_cls_logits = real_cls_logits
        disc_model.fake_cls_logits = fake_cls_logits

        if FLAGS.cgan:
            fake_cls_cost, real_cls_cost = loss.classifier_loss(
                gen_model,
                disc_model,
                x_real,
                c_label,
                c_noise,
                weight=1.0 / dataset.class_num,
                summary=False)

        raw_gen_cost, raw_disc_real, raw_disc_fake = loss.hinge_loss(
            gen_model, disc_model, adv_weight=1.0, summary=False)

        gen_model.vars = [
            v for v in tf.trainable_variables() if gen_model.name in v.name
        ]
        disc_model.vars = [
            v for v in tf.trainable_variables() if disc_model.name in v.name
        ]
        g_grads = tf.gradients(gen_model.cost,
                               gen_model.vars,
                               colocate_gradients_with_ops=True)
        d_grads = tf.gradients(disc_model.cost,
                               disc_model.vars,
                               colocate_gradients_with_ops=True)
        g_grads = [
            tf.check_numerics(g, "G grad nan: " + str(g)) for g in g_grads
        ]
        d_grads = [
            tf.check_numerics(g, "D grad nan: " + str(g)) for g in d_grads
        ]
        g_tower_grads.append(g_grads)
        d_tower_grads.append(d_grads)

        tensors = gen_model.recorded_tensors + disc_model.recorded_tensors
        names = gen_model.recorded_names + disc_model.recorded_names
        if gpu_id == 0: x_name.extend(names)
        xs.append(tensors)
        names = names[::-1]
        tensors = tensors[::-1]
        grads = tf.gradients(disc_fake,
                             tensors,
                             colocate_gradients_with_ops=True)
        for n, g in zip(names, grads):
            print(n, g)
        grad_x.append(
            [tf.check_numerics(g, "BP nan: " + str(g)) for g in grads])
        if gpu_id == 0: grad_x_name.extend(names)
        disc_model.recorded_tensors = []
        disc_model.recorded_names = []
        gen_model.recorded_tensors = []
        gen_model.recorded_names = []

        return gen_model.cost, disc_model.cost, [
            fake_cls_cost, real_cls_cost, raw_gen_cost, raw_disc_real,
            raw_disc_fake
        ]
예제 #8
0
def main():
    size = FLAGS.img_size

    if FLAGS.cgan:
        npy_dir = FLAGS.data_dir.replace(".zip", "") + '.npy'
    else:
        npy_dir = None

    if "celeb" in FLAGS.data_dir:
        dataset = dataloader.CelebADataset(FLAGS.data_dir,
            img_size=(size, size),
            npy_dir=npy_dir)
    else:
        dataset = dataloader.FileDataset(FLAGS.data_dir,
            npy_dir=npy_dir,
            img_size=(size, size),
            shuffle=True)
    dl = DataLoader(dataset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=NUM_WORKER)

    # TF Input
    x_fake_sample = tf.placeholder(tf.float32, [None, size, size, 3], name="x_fake_sample")
    x_real = tf.placeholder(tf.float32, [None, size, size, 3], name="x_real")
    s_real = tf.placeholder(tf.float32, [None, size, size, 3], name='s_real')
    z_noise = tf.placeholder(tf.float32, [None, 128], name="z_noise")

    if FLAGS.cgan:
        c_noise = tf.placeholder(tf.float32, [None, dataset.class_num], name="c_noise")
        c_label = tf.placeholder(tf.float32, [None, dataset.class_num], name="c_label")
        gen_input = [z_noise, c_noise]
    else:
        gen_input = z_noise

    gen_model, disc_model = getattr(config, FLAGS.model_name)(FLAGS.img_size, dataset.class_num)
    gen_model.mask_num = FLAGS.mask_num
    gen_model.cbn_project = FLAGS.cbn_project

    x_fake = gen_model(gen_input, update_collection=None)
    gen_model.set_reuse()
    gen_model.x_fake = x_fake

    disc_real, real_cls_logits = disc_model(x_real, update_collection="no_ops")
    disc_model.set_reuse()
    disc_fake, fake_cls_logits = disc_model(x_fake, update_collection=None)
    disc_model.disc_real        = disc_real       
    disc_model.disc_fake        = disc_fake       
    disc_model.real_cls_logits = real_cls_logits
    disc_model.fake_cls_logits = fake_cls_logits

    int_sum_op = []
    
    if FLAGS.use_cache:
        disc_fake_sample = disc_model(x_fake_sample)[0]
        disc_cost_sample = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
                logits=disc_fake_sample,
                labels=tf.zeros_like(disc_fake_sample)), name="cost_disc_fake_sample")
        disc_cost_sample_sum = tf.summary.scalar("disc_sample", disc_cost_sample)

        fake_sample_grid = ops.get_grid_image_summary(x_fake_sample, 4)
        int_sum_op.append(tf.summary.image("fake sample", fake_sample_grid))

        sample_method = [disc_cost_sample, disc_cost_sample_sum, x_fake_sample]
    else:
        sample_method = None

    print("=> Mask num " + str(gen_model.overlapped_mask.get_shape()))

    # diverse mask
    diverse_loss, diverse_loss_sum = loss.cosine_diverse_distribution(gen_model.overlapped_mask)

    # make sure mask is not eliminated
    mask_weight = tf.reduce_sum(gen_model.overlapped_mask, [1, 2])
    mask_num = mask_weight.get_shape().as_list()[-1]
    avg_map_weight = (size ** 2) / float(mask_num)
    diff_map = tf.abs(mask_weight - avg_map_weight)
    restricted_diff_map = tf.nn.relu(diff_map - 2 * avg_map_weight)
    restricted_var_loss = 1e-3 * tf.reduce_mean(restricted_diff_map)
    var_loss_sum = tf.summary.scalar("variance loss", restricted_var_loss)

    # semantic
    """
    uniform_loss = 0
    vgg_net = model.classifier.MyVGG16("lib/tensorflowvgg/vgg16.npy")
    vgg_net.build(tf.image.resize_bilinear(x_fake, (224, 224)))
    sf = vgg_net.conv3_3
    mask_shape = sf.get_shape().as_list()[1:3]
    print("=> VGG feature shape: " + str(mask_shape))
    diff_maps = []
    for i in range(mask_num):
        mask = tf.image.resize_bilinear(gen_model.overlapped_mask[:, :, :, i:i+1], mask_shape) # (batch, size, size, 1)
        mask = mask / tf.reduce_sum(mask, [1, 2], keepdims=True)
        expected_feature = tf.reduce_sum(mask * sf, [1, 2], keepdims=True) # (batch, 1, 1, 256)
        diff_map = tf.reduce_mean(tf.abs(mask * (sf - expected_feature)), [3]) # (batch, size, size)
        diff_maps.append(diff_map[0] / tf.reduce_max(diff_map[0]))
        restricted_diff_map = diff_map # TODO: add margin
        uniform_loss += 1e-3 * tf.reduce_mean(tf.reduce_sum(diff_map, [1, 2]))
    uniform_loss_sum = tf.summary.scalar("uniform loss", uniform_loss)
    """

    # smooth mask
    tv_loss = tf.reduce_mean(tf.image.total_variation(gen_model.overlapped_mask)) / (size ** 2)
    tv_sum = tf.summary.scalar("TV loss", tv_loss)

    gen_model.cost += diverse_loss + tv_loss + restricted_var_loss

    gen_model.sum_op.extend([tv_sum, diverse_loss_sum, var_loss_sum])

    edge_num = int(np.sqrt(gen_model.overlapped_mask.get_shape().as_list()[-1]))
    mask_seq = tf.transpose(gen_model.overlapped_mask[0], [2, 0, 1])
    grid_mask = tf.expand_dims(ops.get_grid_image_summary(mask_seq, edge_num), -1)
    int_sum_op.append(tf.summary.image("stroke mask", grid_mask))

    #uniform_diff_map = tf.expand_dims(ops.get_grid_image_summary(tf.stack(diff_maps, 0), edge_num), -1)
    #int_sum_op.append(tf.summary.image("uniform diff map", uniform_diff_map))
    
    grid_x_fake = ops.get_grid_image_summary(gen_model.x_fake, 4)
    int_sum_op.append(tf.summary.image("generated image", grid_x_fake))

    grid_x_real = ops.get_grid_image_summary(x_real, 4)
    int_sum_op.append(tf.summary.image("real image", grid_x_real))

    if FLAGS.cgan:
        loss.classifier_loss(gen_model, disc_model, x_real, c_label, c_noise,
        weight=1.0)

    loss.hinge_loss(gen_model, disc_model, adv_weight=1.0)
    
    int_sum_op = tf.summary.merge(int_sum_op)

    ModelTrainer = trainer.base_gantrainer.BaseGANTrainer(
        int_sum_op=int_sum_op,
        dataloader=dl,
        FLAGS=FLAGS,
        gen_model=gen_model,
        disc_model=disc_model,
        gen_input=gen_input,
        x_real=x_real,
        label=c_label,
        sample_method=sample_method)

    #command_controller = trainer.cmd_ctrl.CMDControl(ModelTrainer)
    #command_controller.start_thread()

    print("=> Build train op")
    ModelTrainer.build_train_op()
    
    print("=> ##### Generator Variable #####")
    gen_model.print_trainble_vairables()
    print("=> ##### Discriminator Variable #####")
    disc_model.print_trainble_vairables()
    print("=> ##### All Variable #####")
    for v in tf.trainable_variables():
        print("%s" % v.name)

    ModelTrainer.init_training()
    ModelTrainer.train()