def main():
    image_dataset = read_images(DATASET_PATH)
    trainA = image_dataset['trainA']
    trainB = image_dataset['trainB']

    summary_writer = tf.summary.FileWriter('./logs')
    gan = CycleGAN(args)

    with tf.Session() as sess:
        gan.train(sess, summary_writer, trainA, trainB)
Beispiel #2
0
def create_model(op, device):
    model_name = op.model.name

    if model_name == 'dcgan':
        return DCGAN(op, device)
    elif model_name == 'cyclegan':
        return CycleGAN(op, device)
Beispiel #3
0
def save_model():

    cycleGAN = CycleGAN()
    optimizerGA = tf.keras.optimizers.Adam(2e-4)
    optimizerGB = tf.keras.optimizers.Adam(2e-4)
    optimizerDA = tf.keras.optimizers.Adam(2e-4)
    optimizerDB = tf.keras.optimizers.Adam(2e-4)
    checkpoint = tf.train.Checkpoint(GA=cycleGAN.GA,
                                     GB=cycleGAN.GB,
                                     DA=cycleGAN.DA,
                                     DB=cycleGAN.DB,
                                     optimizerGA=optimizerGA,
                                     optimizerGB=optimizerGB,
                                     optimizerDA=optimizerDA,
                                     optimizerDB=optimizerDB)
    checkpoint.restore(tf.train.latest_checkpoint('checkpoints'))
    if False == os.path.exists('models'): os.mkdir('models')
    cycleGAN.GA.save(os.path.join('models', 'GA.h5'))
    cycleGAN.GB.save(os.path.join('models', 'GB.h5'))
    cycleGAN.DA.save(os.path.join('models', 'DA.h5'))
    cycleGAN.DB.save(os.path.join('models', 'DB.h5'))
Beispiel #4
0
    # load data
    print(f"Loading dataset {param.test_path_a.split('/')[1]}...")
    datasetx, datasety = get_datasets(param, train=False)
    dataloaderx = torch.utils.data.DataLoader(datasetx,
                                              batch_size=param.batch_size,
                                              shuffle=False,
                                              num_workers=param.num_workers)
    dataloadery = torch.utils.data.DataLoader(datasety,
                                              batch_size=param.batch_size,
                                              shuffle=False,
                                              num_workers=param.num_workers)

    print("Creating model...")
    device = torch.device(("cpu", "cuda:0")[torch.cuda.is_available()])
    model = CycleGAN(param, device)
    model.load(param)
    model.eval(
    )  # needs to be implemented, will essentailly just set all networks inside to eval

    print("Beginning to test...")
    if len(datasetx) < len(datasety):
        packed = zip(cycle(dataloaderx), dataloadery)
    else:
        packed = zip(cycle(dataloadery), dataloaderx)
    for i, (data_x, data_y) in enumerate(packed):
        fA, realA, maskA = data_x
        fB, realB, maskB = data_y
        realA = realA.view(-1, param.in_nc, param.image_size,
                           param.image_size).to(device)
        realB = realB.view(-1, param.in_nc, param.image_size,
Beispiel #5
0
import torch
from models import CycleGAN

dataroot = './data/final'
sub_dirs = ['trainA', 'trainB', 'testA', 'testB']
batch_sizes = [1, 1, 3, 3]
workers = 2
lr = 0.0002
betas = (0.5, 0.999)
epochs = 200
gpu_ids = [0]
ckpt_dir = './ckpt/visceral'
results_dir = './results/visceral'

# check cuda
# print(torch.cuda.is_available())
# print("Current device ", torch.cuda.current_device())
# print("How many device? ", torch.cuda.device_count())
# torch.cuda.set_device(6)
# print("Current device ", torch.cuda.current_device())

# init cycleGAN instance
cg = CycleGAN(dataroot, sub_dirs, batch_sizes, workers, lr, betas, gpu_ids,
              ckpt_dir, results_dir)

cg.train(epochs, 10)
Beispiel #6
0
B = np.moveaxis(B,3,1)

A = torch.from_numpy(A)
B = torch.from_numpy(B)

data_sizes = [list(A.shape)[0],list(B.shape)[0]]

# Returns batch in form of (A,B)
def get_batch(batch_size):
	indsA = torch.randint(0,data_sizes[0],(batch_size,))
	indsB = torch.randint(0,data_sizes[1],(batch_size,))

	return Tensor(A[indsA]),Tensor(B[indsB])


model = CycleGAN()
if USE_CUDA: model.cuda()

if LOAD_CHECKPOINTS:
	model.load_checkpoint()

# Rendering stuff
fig,axs = plt.subplots(4,4)
# Draws samples
def save_samples(title):
	A_sample, B_sample = get_batch(4)
	B_fake, A_fake = model.A_to_B(A_sample), model.B_to_A(B_sample)

	A_sample = npimage(A_sample)
	B_sample = npimage(B_sample)
	A_fake = npimage(A_fake)
Beispiel #7
0
def main():
    n_train = 100
    n_test = 100
    model_type = 'CycleGLO'
    if model_type == 'GLO':
        train, test = data.get_cifar10(n_train, n_test, 1, True, classes=[3])
        n_vectors = len(train)
        n_pixels = len(train[0][0])
        code_dim = 64
        print(n_pixels)
        lossfun = F.mean_squared_error
        model = GLOModel(code_dim, n_vectors, 1, train, n_pixels, 50)
        #show_results(model)
        model.train(lossfun, n_epochs=10)

        show_results(model)

    elif model_type == 'CycleGAN':
        # Read train/test data
        train_data1, test_data1 = data.get_mnist(n_train,
                                                 n_test,
                                                 1,
                                                 False,
                                                 classes=[3])
        train_data2, test_data2 = data.get_mnist(n_train,
                                                 n_test,
                                                 1,
                                                 False,
                                                 classes=[5])
        train_data = data.pair_datasets(train_data1, train_data2)

        # Create model
        n_pixels = len(train_data[0][0])
        g_hidden = 50
        d_hidden = 50
        d_learning_rate = 0.01
        g_learning_rate = 0.05
        #model = GANModel(n_pixels, g_hidden, d_hidden, d_learning_rate, g_learning_rate)
        alpha = 0.01
        beta = 0.5
        lambda1 = 10.0
        lambda2 = 3.0
        learningrate_decay = 0.0
        learningrate_interval = 1000
        max_buffer_size = 25
        model = CycleGAN(alpha, beta, lambda1, lambda2, n_pixels,
                         learningrate_decay, learningrate_interval, g_hidden,
                         d_hidden, max_buffer_size)

        # Train model
        lossfun = F.mean_squared_error

        n_epochs = 1000
        d_steps = 1
        g_steps = 1
        minibatch_size = 1

        mu = 1
        sigma = 1
        noisefun = np.random.normal
        #g_sampler = NoiseSampler(fun=noisefun, loc=mu, scale=sigma, size=(n_pixels, 1))  # iterator over randomized noise
        #d_input_iter = iterators.SerialIterator(train_data, batch_size=minibatch_size, repeat=True, shuffle=True)  # iterator over real data
        # model.train(d_input_iter, g_sampler, lossfun, n_epochs, d_steps, g_steps, minibatch_size)

        batch_iter = iterators.MultiprocessIterator(train_data,
                                                    batch_size=minibatch_size,
                                                    n_processes=4)
        model.train(n_epochs, batch_iter)

        # Visualize training

        # Visualize result/test/performance
        sqrt_pixels = int(np.sqrt(n_pixels))
        #for g_index in range(0, 10):
        #    gen_input = g_sampler(1)
        #    g_fake_data = model.G(gen_input)
        #    f, axarr = plt.subplots(1, 2)
        #    axarr[0].imshow(gen_input.reshape((sqrt_pixels, sqrt_pixels)))
        #    axarr[0].set_title('noise input')
        #    axarr[1].imshow(g_fake_data.data.reshape((sqrt_pixels, sqrt_pixels)))
        #    axarr[1].set_title('generated sample')
        #    plt.show()
        print('Visualizing!')
        for input in test_data1[:5]:
            generated = model.g(input.reshape(1, n_pixels))
            #print(generated.data)
            f, axarr = plt.subplots(1, 2)
            axarr[0].imshow(input.reshape((sqrt_pixels, sqrt_pixels)))
            axarr[0].set_title('input image for G')
            axarr[1].imshow(generated.data.reshape((sqrt_pixels, sqrt_pixels)))
            axarr[1].set_title('generated sample')
            plt.show()

        for input in test_data2[:5]:
            generated = model.f(input.reshape(1, n_pixels))
            #print(generated.data)
            f, axarr = plt.subplots(1, 2)
            axarr[0].imshow(input.reshape((sqrt_pixels, sqrt_pixels)))
            axarr[0].set_title('input image for F')
            axarr[1].imshow(generated.data.reshape((sqrt_pixels, sqrt_pixels)))
            axarr[1].set_title('generated sample')
            plt.show()
    elif model_type == 'CycleGLO':
        train_data1, test_data1 = data.get_mnist(n_train,
                                                 n_test,
                                                 1,
                                                 False,
                                                 classes=[3])
        train_data2, test_data2 = data.get_mnist(n_train,
                                                 n_test,
                                                 1,
                                                 False,
                                                 classes=[5])
        train_data = data.pair_datasets(train_data1, train_data2)

        # Create model
        n_pixels = len(train_data[0][0])
        g_hidden = 50
        d_hidden = 50

        code_dim = 64
        alpha = 0.01
        beta = 0.5
        lambda1 = 10.0
        lambda2 = 3.0
        learningrate_decay = 0.0
        learningrate_interval = 1000
        max_buffer_size = 25
        model = CycleGLO(alpha, beta, lambda1, lambda2, n_pixels, code_dim,
                         train_data, learningrate_decay, learningrate_interval,
                         g_hidden, d_hidden, max_buffer_size)

        # Train model
        lossfun = F.mean_squared_error

        model.train(lossfun, n_epochs=1000)
        show_results(model, False)
Beispiel #8
0
    opt.isTrain = False

    if opt.load_path is None or not os.path.exists(opt.load_path):
        raise FileExistsError('Load path must be exist!!!')
    device = torch.device(f'cuda:{opt.gpu_ids[0]}') if len(
        opt.gpu_ids) > 0 else 'cpu'
    ckpt = torch.load(opt.load_path, map_location=device)
    cfg = ckpt['cfg'] if 'cfg' in ckpt.keys() else (None, None)

    # create model
    if opt.model == 'cyclegan':
        if opt.mask:
            model = MaskCycleGAN.MaskCycleGANModel(opt)
        else:
            model = CycleGAN.CycleGANModel(opt,
                                           cfg_AtoB=cfg[0],
                                           cfg_BtoA=cfg[1])
    elif opt.model == 'pix2pix':
        opt.norm = 'batch'
        opt.dataset_mode = 'aligned'
        opt.pool_size = 0
        if opt.mask:
            model = MaskPix2Pix.MaskPix2PixModel(opt)
        else:
            model = Pix2Pix.Pix2PixModel(opt,
                                         filter_cfgs=cfg[0],
                                         channel_cfgs=cfg[1])
    elif opt.model == 'mobilecyclegan':
        if opt.mask:
            model = MaskMobileCycleGAN.MaskMobileCycleGANModel(opt)
        else:
Beispiel #9
0
    opt.isTrain = True
    util.mkdirs(os.path.join(opt.checkpoints_dir, opt.name))
    logger = util.get_logger(
        os.path.join(opt.checkpoints_dir, opt.name, 'logger.log'))

    best_AtoB_fid = float('inf') if 'cityscapes' not in opt.dataroot else 0.0
    best_BtoA_fid = float('inf') if 'cityscapes' not in opt.dataroot else 0.0
    best_AtoB_epoch = 0
    best_BtoA_epoch = 0

    # create model
    if opt.model == 'cyclegan':
        if opt.mask:
            model = MaskCycleGAN.MaskCycleGANModel(opt)
        else:
            model = CycleGAN.CycleGANModel(opt)
    elif opt.model == 'pix2pix':
        opt.norm = 'batch'
        opt.dataset_mode = 'aligned'
        opt.pool_size = 0
        if opt.mask:
            model = MaskPix2Pix.MaskPix2PixModel(opt)
        else:
            model = Pix2Pix.Pix2PixModel(opt)
    elif opt.model == 'mobilecyclegan':
        if opt.mask:
            model = MaskMobileCycleGAN.MaskMobileCycleGANModel(opt)
        else:
            model = MobileCycleGAN.MobileCycleGANModel(opt)
    elif opt.model == 'mobilepix2pix':
        opt.norm = 'batch'
    parser = argparse.ArgumentParser()
    parser.add_argument("params_path", help="param file", type=str)
    args = parser.parse_args()
    print("Loading params...")
    param = Params(args.params_path)

    # load data
    print("Loading data...")
    x, y = get_test_dataset(param)
    filenamesx = x[0]
    datasetx = x[1]
    filenamesy = y[0]
    datasety = y[1]
    datasetx = datasetx.batch(param.batch_size)
    datasety = datasety.batch(param.batch_size)

    print("Creating model...")
    model = CycleGAN(param)
    model.load(param)

    print("Beginning to test...")
    for fx, data_x in zip(filenamesx, datasetx):
        print(fx.eval())
        print(data_x)
        for fy, data_y in zip(filenamesy, datasety):
            G_x_out, G_y_out = model(data_x, data_y)
            # Save G_x_out as fx->fy.jpg
            save_outputs(G_x_out, fx, fy, param.out_directory)
            # Save G_y_out as fy->fx.jpg
            save_outputs(G_y_out, fy, fx, param.out_directory)
Beispiel #11
0
from utils.data_loader import *

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("params_path", help="param file", type=str)
    args = parser.parse_args()
    print("Loading params...")
    param = Params(args.params_path)

    # load data
    print(f"Loading dataset {param.train_path_a.split('/')[1]}...")
    datasetx, datasety = get_datasets(param, train=True)

    print("Creating model...")
    device = torch.device(("cpu", "cuda:0")[torch.cuda.is_available()])
    model = CycleGAN(param, device)
    model.train()

    dataloaderx = torch.utils.data.DataLoader(datasetx,
                                              batch_size=param.batch_size,
                                              shuffle=True,
                                              num_workers=param.num_workers)
    dataloadery = torch.utils.data.DataLoader(datasety,
                                              batch_size=param.batch_size,
                                              shuffle=True,
                                              num_workers=param.num_workers)

    for e in range(param.epochs):
        print(f'Starting epoch [{e+1} / {param.epochs}]')
        epoch_start_time = time.time()
        avgloss = 0.0
Beispiel #12
0
from utils.params import *
from utils.data_loader import *

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("params_path", help="param file", type=str)
    args = parser.parse_args()
    print("Loading params...")
    param = Params(args.params_path)

    # load data
    print("Loading data...")
    datasetx, datasety = get_train_dataset(param)

    print("Creating model...")
    model = CycleGAN(param)
    optimizer = tf.optimizers.Adam(param.lr, beta_1=param.beta_1)

    for e in range(param.epochs):
        print(f'Starting epoch {e+1}')
        epoch_start_time = time.time()
        avgloss = 0.0
        datasetx = datasetx.shuffle(buffer_size=param.buff_size).batch(
            param.batch_size)
        datasety = datasety.shuffle(buffer_size=param.buff_size).batch(
            param.batch_size)
        for i, (data_x, data_y) in enumerate(zip(cycle(datasetx), datasety)):
            data_x = tf.reshape(
                data_x, [-1, param.image_size, param.image_size, param.in_nc])
            data_y = tf.reshape(
                data_y, [-1, param.image_size, param.image_size, param.in_nc])
Beispiel #13
0
def main():

    # models
    cycleGAN = CycleGAN()
    optimizerGA = tf.keras.optimizers.Adam(
        tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            boundaries=[
                dataset_size * 100 + i * dataset_size * 100 / 4
                for i in range(5)
            ],
            values=list(reversed([i * 2e-4 / 5 for i in range(6)]))),
        beta_1=0.5)
    optimizerGB = tf.keras.optimizers.Adam(
        tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            boundaries=[
                dataset_size * 100 + i * dataset_size * 100 / 4
                for i in range(5)
            ],
            values=list(reversed([i * 2e-4 / 5 for i in range(6)]))),
        beta_1=0.5)
    optimizerDA = tf.keras.optimizers.Adam(
        tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            boundaries=[
                dataset_size * 100 + i * dataset_size * 100 / 4
                for i in range(5)
            ],
            values=list(reversed([i * 2e-4 / 5 for i in range(6)]))),
        beta_1=0.5)
    optimizerDB = tf.keras.optimizers.Adam(
        tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            boundaries=[
                dataset_size * 100 + i * dataset_size * 100 / 4
                for i in range(5)
            ],
            values=list(reversed([i * 2e-4 / 5 for i in range(6)]))),
        beta_1=0.5)

    # load dataset
    '''
  A = tf.data.TFRecordDataset(os.path.join('dataset', 'A.tfrecord')).map(parse_function_generator(img_shape)).shuffle(batch_size).batch(batch_size).__iter__();
  B = tf.data.TFRecordDataset(os.path.join('dataset', 'B.tfrecord')).map(parse_function_generator(img_shape)).shuffle(batch_size).batch(batch_size).__iter__();
  '''
    A = iter(
        tfds.load(name='cycle_gan/horse2zebra', split="trainA",
                  download=False).repeat(-1).map(
                      parse_function_generator()).shuffle(batch_size).batch(
                          batch_size).prefetch(tf.data.experimental.AUTOTUNE))
    B = iter(
        tfds.load(name='cycle_gan/horse2zebra', split="trainB",
                  download=False).repeat(-1).map(
                      parse_function_generator()).shuffle(batch_size).batch(
                          batch_size).prefetch(tf.data.experimental.AUTOTUNE))
    testA = iter(
        tfds.load(name='cycle_gan/horse2zebra', split='testA',
                  download=False).repeat(-1).map(
                      parse_function_generator(isTrain=False)).batch(1))
    testB = iter(
        tfds.load(name='cycle_gan/horse2zebra', split='testB',
                  download=False).repeat(-1).map(
                      parse_function_generator(isTrain=False)).batch(1))
    # restore from existing checkpoint
    checkpoint = tf.train.Checkpoint(GA=cycleGAN.GA,
                                     GB=cycleGAN.GB,
                                     DA=cycleGAN.DA,
                                     DB=cycleGAN.DB,
                                     optimizerGA=optimizerGA,
                                     optimizerGB=optimizerGB,
                                     optimizerDA=optimizerDA,
                                     optimizerDB=optimizerDB)
    checkpoint.restore(tf.train.latest_checkpoint('checkpoints'))
    # create log
    log = tf.summary.create_file_writer('checkpoints')
    # train model
    avg_ga_loss = tf.keras.metrics.Mean(name='GA loss', dtype=tf.float32)
    avg_gb_loss = tf.keras.metrics.Mean(name='GB loss', dtype=tf.float32)
    avg_da_loss = tf.keras.metrics.Mean(name='DA loss', dtype=tf.float32)
    avg_db_loss = tf.keras.metrics.Mean(name='DB loss', dtype=tf.float32)
    while True:
        imageA, _ = next(A)
        imageB, _ = next(B)
        with tf.GradientTape(persistent=True) as tape:
            outputs = cycleGAN((imageA, imageB))
            GA_loss = cycleGAN.GA_loss(outputs)
            GB_loss = cycleGAN.GB_loss(outputs)
            DA_loss = cycleGAN.DA_loss(outputs)
            DB_loss = cycleGAN.DB_loss(outputs)
        # calculate discriminator gradients
        da_grads = tape.gradient(DA_loss, cycleGAN.DA.trainable_variables)
        avg_da_loss.update_state(DA_loss)
        db_grads = tape.gradient(DB_loss, cycleGAN.DB.trainable_variables)
        avg_db_loss.update_state(DB_loss)
        # calculate generator gradients
        ga_grads = tape.gradient(GA_loss, cycleGAN.GA.trainable_variables)
        avg_ga_loss.update_state(GA_loss)
        gb_grads = tape.gradient(GB_loss, cycleGAN.GB.trainable_variables)
        avg_gb_loss.update_state(GB_loss)
        # update discriminator weights
        optimizerDA.apply_gradients(
            zip(da_grads, cycleGAN.DA.trainable_variables))
        optimizerDB.apply_gradients(
            zip(db_grads, cycleGAN.DB.trainable_variables))
        # update generator weights
        optimizerGA.apply_gradients(
            zip(ga_grads, cycleGAN.GA.trainable_variables))
        optimizerGB.apply_gradients(
            zip(gb_grads, cycleGAN.GB.trainable_variables))
        if tf.equal(optimizerGA.iterations % 500, 0):
            imageA, _ = next(testA)
            imageB, _ = next(testB)
            outputs = cycleGAN((imageA, imageB))
            real_A = tf.cast(tf.clip_by_value((imageA + 1) * 127.5,
                                              clip_value_min=0.,
                                              clip_value_max=255.),
                             dtype=tf.uint8)
            real_B = tf.cast(tf.clip_by_value((imageB + 1) * 127.5,
                                              clip_value_min=0.,
                                              clip_value_max=255.),
                             dtype=tf.uint8)
            fake_B = tf.cast(tf.clip_by_value((outputs[1] + 1) * 127.5,
                                              clip_value_min=0.,
                                              clip_value_max=255.),
                             dtype=tf.uint8)
            fake_A = tf.cast(tf.clip_by_value((outputs[7] + 1) * 127.5,
                                              clip_value_min=0.,
                                              clip_value_max=255.),
                             dtype=tf.uint8)
            with log.as_default():
                tf.summary.scalar('generator A loss',
                                  avg_ga_loss.result(),
                                  step=optimizerGA.iterations)
                tf.summary.scalar('generator B loss',
                                  avg_gb_loss.result(),
                                  step=optimizerGB.iterations)
                tf.summary.scalar('discriminator A loss',
                                  avg_da_loss.result(),
                                  step=optimizerDA.iterations)
                tf.summary.scalar('discriminator B loss',
                                  avg_db_loss.result(),
                                  step=optimizerDB.iterations)
                tf.summary.image('real A', real_A, step=optimizerGA.iterations)
                tf.summary.image('fake B', fake_B, step=optimizerGA.iterations)
                tf.summary.image('real B', real_B, step=optimizerGA.iterations)
                tf.summary.image('fake A', fake_A, step=optimizerGA.iterations)
            print('Step #%d GA Loss: %.6f GB Loss: %.6f DA Loss: %.6f DB Loss: %.6f lr: %.6f' % \
                  (optimizerGA.iterations, avg_ga_loss.result(), avg_gb_loss.result(), avg_da_loss.result(), avg_db_loss.result(), \
                  optimizerGA._hyper['learning_rate'](optimizerGA.iterations)))
            avg_ga_loss.reset_states()
            avg_gb_loss.reset_states()
            avg_da_loss.reset_states()
            avg_db_loss.reset_states()
        if tf.equal(optimizerGA.iterations % 10000, 0):
            # save model
            checkpoint.save(os.path.join('checkpoints', 'ckpt'))
        if GA_loss < 0.01 and GB_loss < 0.01 and DA_loss < 0.01 and DB_loss < 0.01:
            break
    # save the network structure with weights
    if False == os.path.exists('models'): os.mkdir('models')
    cycleGAN.GA.save(os.path.join('models', 'GA.h5'))
    cycleGAN.GB.save(os.path.join('models', 'GB.h5'))
    cycleGAN.DA.save(os.path.join('models', 'DA.h5'))
    cycleGAN.DB.save(os.path.join('models', 'DB.h5'))
Beispiel #14
0
from models import CycleGAN
from utils import Option

if __name__ == '__main__':
    opt = Option()
    opt.batch_size = 1
    opt.save_iter = args.save_iter
    opt.niter = args.niter
    opt.lmbd = args.lmbd
    opt.pic_dir = args.pic_dir
    opt.idloss = 0.0
    opt.lr = 0.0002
    opt.d_iter = 1
    if args.lmbd_feat != 0:
        opt.perceptionloss = True
    else:
        opt.perceptionloss = False
    opt.lmbd_feat = args.lmbd_feat

    opt.__dict__.update(args.__dict__)
    opt.summary()

    cycleGAN = CycleGAN(opt)

    IG = ImageGenerator(path_trainA=args.path_trainA,
                        path_trainB=args.path_trainB,
                        resize=opt.resize,
                        crop=opt.crop)

    cycleGAN.fit(IG)