Exemple #1
0
def pred(config, mode='cifar10'):
    if mode == 'cifar10':
        obs = (3, 32, 32)
    sample_batch_size = 25
    model = PixelCNN(nr_resnet=config.nr_resnet, nr_filters=config.nr_filters,
                     input_channels=obs[0], nr_logistic_mix=config.nr_logistic_mix).cuda()

    if config.load_params:
        load_part_of_model(model, config.load_params)
        print('model parameters loaded')
    sample_op = lambda x: sample_from_discretized_mix_logistic(x, config.nr_logistic_mix)
    rescaling_inv = lambda x: .5 * x + .5

    def sample(model):
        model.train(False)
        data = torch.zeros(sample_batch_size, obs[0], obs[1], obs[2])
        data = data.cuda()
        for i in range(obs[1]):
            for j in range(obs[2]):
                with torch.no_grad():
                    data_v = data
                    out = model(data_v, sample=True)
                    out_sample = sample_op(out)
                    data[:, :, i, j] = out_sample.data[:, :, i, j]
        return data

    print('sampling...')
    sample_t = sample(model)
    sample_t = rescaling_inv(sample_t)
    save_image(sample_t, 'images/sample.png', nrow=5, padding=0)
Exemple #2
0
def main():
    save_path = 'models/model_23.pt'
    no_images = 64
    images_size = 32
    images_channels = 3
    

    #Define and load model
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net = PixelCNN().to(device)
    net.load_state_dict(torch.load(save_path))
    net.eval()

    sample = torch.zeros(no_images, images_channels, images_size, images_size).to(device)
    print('-------------------------------------SAMPLING!!!!!---------------------------------')

    for i in tqdm(range(images_size)):
        for j in range(images_size):
            for c in range(images_channels):
                out = net(sample)
                probs = torch.softmax(out[:, :, c, i, j], dim=1)
                # print(probs)
                sampled_levels = torch.multinomial(probs, 1).squeeze().float() / (63.0)
                sample[:,c,i,j] = sampled_levels


    torchvision.utils.save_image(sample, 'sample.png', nrow=12, padding=0)
Exemple #3
0
def load_checkpoint(file_path, use_cuda=False):
    if use_cuda:
        checkpoint = torch.load(file_path)
    else:
        checkpoint = torch.load(file_path,
                                map_location=lambda storage, location: storage)
    if checkpoint['gated']:
        model = GatedPixelCNN(checkpoint['data_channels'],
                              checkpoint['out_dims'])
    else:
        model = PixelCNN(checkpoint['data_channels'], checkpoint['out_dims'])
    model.load_state_dict(checkpoint['state_dict'])
    if use_cuda:
        model.cuda()
    return model
Exemple #4
0
    def __init__(self, conf, X_train, X_test):
        self.X_train = X_train
        self.X_test = X_test
        self.conf = conf

        ########## parameter ##########
        self.epochs = conf.epochs
        self.batch_size = conf.batch_size
        self.learning_rate = conf.learning_rate
        self.display_step = conf.display_step

        ########## data ##########
        self.q_levels = conf.q_levels
        self.classes = conf.classes
        self.height = conf.height
        self.width = conf.width
        self.channel = conf.channel

        ########## sample ##########
        self.show_figure = conf.show_figure
        self.figure = conf.figure

        ########## placeholder ##########
        self.x = tf.placeholder(tf.float32, [None, self.height, self.width, self.channel])
        self.y = tf.placeholder(tf.int64, [None, self.height, self.width, self.channel])
        self.y = tf.reshape(self.y, [-1])

        #### model pred 影像判断结果 ####
        self.pred = PixelCNN(self.conf, self.x).pred

        #### loss 损失计算 ####
        self.cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.pred, labels=self.y))

        #### optimization 优化 ####
        self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.cost)

        #### accuracy 准确率 ####
        self.correct_pred = tf.equal(tf.argmax(tf.nn.softmax(self.pred), 1), self.y)
        self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32))
Exemple #5
0
                         as_supervised=True)

    ds_test = tfds.load(args.dataset,
                        split='test',
                        shuffle_files='False',
                        batch_size=32,
                        as_supervised=True)

    color_conditioning = hyperparams[args.dataset]['color_conditioning']
    input_shape = hyperparams[args.dataset]['input_shape']
    n_mixtures = hyperparams[args.dataset]['n_mixtures']
    n_epochs = hyperparams[
        args.dataset]['n_epochs'] if not args.use_openai_loss else 250

    model = PixelCNN(n_mixtures=n_mixtures,
                     color_conditioning=color_conditioning,
                     input_shape=input_shape)

    def neg_log_likelihood(target, output, n_mixtures, input_channels=3):
        B, H, W, total_channels = output.shape
        assert total_channels == input_channels * 3 * n_mixtures, 'Total channels should be equal to 9 times the number of mixture models. (RGB + pi, mu, s)'
        output = tf.reshape(output,
                            shape=(B, H, W, input_channels, 3 * n_mixtures))
        means = output[..., :n_mixtures]
        log_scales_inverse = output[..., n_mixtures:2 * n_mixtures]
        mixture_scales = output[..., n_mixtures * 2:]

        mixture_scales = tf.nn.softmax(mixture_scales, axis=4)  # last index
        scales_inverse = tf.math.exp(log_scales_inverse)

        targets = tf.stack([target for _ in range(n_mixtures)], axis=-1)
Exemple #6
0
def main():
    path = 'data'
    data_name = 'CIFAR'
    batch_size = 64

    layers = 10
    kernel = 7
    channels = 128
    epochs = 25
    save_path = 'models'

    normalize = transforms.Lambda(lambda image: np.array(image) / 255.0)


    def quantisize(image, levels):
        return np.digitize(image, np.arange(levels) / levels) - 1
    discretize = transforms.Compose([
        transforms.Lambda(lambda image: quantisize(image, (channels - 1))),
        transforms.ToTensor()
    ])
    cifar_transform = transforms.Compose([normalize, discretize])

    train= datasets.CIFAR10(root=path, train=True, download=True, transform = cifar_transform)
    test= datasets.CIFAR10(root=path, train=False, download=True, transform = cifar_transform)
    
    train = data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers =0, pin_memory = True)
    test = data.DataLoader(test, batch_size=batch_size, shuffle=False, num_workers =0, pin_memory = True)


    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    net = PixelCNN(num_layers=layers, kernel_size=kernel, num_channels=channels).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters())
    loss_overall = []
    for i in range(epochs):
        if (i%3) == 1:
            sampling(net, i-1, channels)
        net.train(True)
        step = 0
        loss_= 0
        for images, labels in tqdm(train, desc='Epoch {}/{}'.format(i + 1, epochs)):
            images = images.to(device)
            normalized_images = images.float() / ((channels - 1))
            optimizer.zero_grad()

            output = net(normalized_images)
            loss = criterion(output, images)
            loss.backward()
            optimizer.step()

            loss_+=loss
            step+=1

        print('Epoch:'+str(i)+'       , '+ 'Average loss: ', loss_/step)
        with open("hst.txt", "a") as myfile:
            myfile.write(str(loss_/step) + '\n')
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        if(i==epochs-1):
            torch.save(net.state_dict(), save_path+'/model_'+'Last'+'.pt')
        else:
            torch.save(net.state_dict(), save_path+'/model_'+str(i)+'.pt')
        print('model saved')
Exemple #7
0
IMAGES_FOLDER = 'generated_images'
BIAS = False

train_loader = DataLoader(BinaryMNIST(),
                          batch_size=args.b,
                          shuffle=True,
                          num_workers=1,
                          pin_memory=True)
test_loader = DataLoader(BinaryMNIST(train=False),
                         batch_size=args.b,
                         shuffle=False,
                         num_workers=1,
                         pin_memory=True)

model = PixelCNN(args.ch, args.hl, args.k, BIAS)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

start = time.time()

for i in range(args.e):
    print(f'Epoch {i + 1}...')

    #training
    for images, labels in train_loader:
        images = images.to(model.device)
        out = model(images)  #(batch_size, 1, 28, 28)

        neg_log_likelihood = -(
            images * torch.log(out + 1e-7) +
Exemple #8
0
def train(config, mode='cifar10'):
    model_name = 'pcnn_lr:{:.5f}_nr-resnet{}_nr-filters{}'.format(config.lr, config.nr_resnet, config.nr_filters)
    try:
        os.makedirs('models')
        os.makedirs('images')
        # print('mkdir:', config.outfile)
    except OSError:
        pass

    seed = np.random.randint(0, 10000)
    print("Random Seed: ", seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    cudnn.benchmark = True

    trainset, train_loader, testset, test_loader, classes = load_data(mode=mode, batch_size=config.batch_size)
    if mode == 'cifar10' or mode == 'faces':
        obs = (3, 32, 32)
        loss_op = lambda real, fake: discretized_mix_logistic_loss(real, fake, config.nr_logistic_mix)
        sample_op = lambda x: sample_from_discretized_mix_logistic(x, config.nr_logistic_mix)
    elif mode == 'mnist':
        obs = (1, 28, 28)
        loss_op = lambda real, fake: discretized_mix_logistic_loss_1d(real, fake, config.nr_logistic_mix)
        sample_op = lambda x: sample_from_discretized_mix_logistic_1d(x, config.nr_logistic_mix)
    sample_batch_size = 25
    rescaling_inv = lambda x: .5 * x + .5

    model = PixelCNN(nr_resnet=config.nr_resnet, nr_filters=config.nr_filters,
                     input_channels=obs[0], nr_logistic_mix=config.nr_logistic_mix).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=config.lr_decay)

    if config.load_params:
        load_part_of_model(model, config.load_params)
        print('model parameters loaded')

    def sample(model):
        model.train(False)
        data = torch.zeros(sample_batch_size, obs[0], obs[1], obs[2])
        data = data.cuda()
        with tqdm(total=obs[1] * obs[2]) as pbar:
            for i in range(obs[1]):
                for j in range(obs[2]):
                    with torch.no_grad():
                        data_v = data
                        out = model(data_v, sample=True)
                        out_sample = sample_op(out)
                        data[:, :, i, j] = out_sample.data[:, :, i, j]
                    pbar.update(1)
        return data

    print('starting training')
    for epoch in range(config.max_epochs):
        model.train()
        torch.cuda.synchronize()
        train_loss = 0.
        time_ = time.time()
        with tqdm(total=len(train_loader)) as pbar:
            for batch_idx, (data, label) in enumerate(train_loader):
                data = data.requires_grad_(True).cuda()

                output = model(data)
                loss = loss_op(data, output)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_loss += loss.item()
                pbar.update(1)

        deno = batch_idx * config.batch_size * np.prod(obs)
        print('train loss : %s' % (train_loss / deno), end='\t')

        # decrease learning rate
        scheduler.step()

        model.eval()
        test_loss = 0.
        with tqdm(total=len(test_loader)) as pbar:
            for batch_idx, (data, _) in enumerate(test_loader):
                data = data.requires_grad_(False).cuda()

                output = model(data)
                loss = loss_op(data, output)
                test_loss += loss.item()
                del loss, output
                pbar.update(1)
        deno = batch_idx * config.batch_size * np.prod(obs)
        print('test loss : {:.4f}, time : {:.4f}'.format((test_loss / deno), (time.time() - time_)))

        torch.cuda.synchronize()

        if (epoch + 1) % config.save_interval == 0:
            torch.save(model.state_dict(), 'models/{}_{}.pth'.format(model_name, epoch))
            print('sampling...')
            sample_t = sample(model)
            sample_t = rescaling_inv(sample_t)
            save_image(sample_t, 'images/{}_{}.png'.format(model_name, epoch), nrow=5, padding=0)
    h5f = h5py.File(hdf5_path, 'r')
else:
    with h5py.File(hdf5_path, 'w') as h5f:
        for name, shape in init_vars:

            val = tf.train.load_variable(tf_path, name)

            print(val.dtype)
            print("Loading TF weight {} with shape {}, {}".format(
                name, shape, val.shape))
            torch.from_numpy(np.array(val))
            if 'model' in name:
                new_name = name.replace('/', '.')
                print(new_name)
                h5f.create_dataset(str(new_name), data=val)

    h5f = h5py.File(hdf5_path, 'r')

model = PixelCNN(nr_resnet=5,
                 nr_filters=160,
                 input_channels=3,
                 nr_logistic_mix=10)

#print(model.state_dict().keys())
converter = TF2Pytorch(h5f)
converter.load_pixelcnn()

model.load_state_dict(converter.state_dict)
torch.save(model.state_dict(), ckpt_path)
h5f.close()
Exemple #10
0
def train_prior(config, RANDOM_SEED, MODEL, TRAIN_NUM, BATCH_SIZE,
                LEARNING_RATE, DECAY_VAL, DECAY_STEPS, DECAY_STAIRCASE,
                GRAD_CLIP, K, D, BETA, NUM_LAYERS, NUM_FEATURE_MAPS,
                SUMMARY_PERIOD, SAVE_PERIOD, **kwargs):
    np.random.seed(RANDOM_SEED)
    tf.set_random_seed(RANDOM_SEED)
    LOG_DIR = os.path.join(os.path.dirname(MODEL), 'pixelcnn')
    # >>>>>>> DATASET
    train_dataset = imagenet.get_split('train', 'datasets/ILSVRC2012')
    ims, labels = _build_batch(train_dataset, BATCH_SIZE, 4)
    # <<<<<<<

    # >>>>>>> MODEL for Generate Images
    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        vq_net = VQVAE(None, None, BETA, ims, K, D, _imagenet_arch, params,
                       False)
    # <<<<<<<

    # >>>>>> MODEL for Training Prior
    with tf.variable_scope('pixelcnn'):
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(LEARNING_RATE,
                                                   global_step,
                                                   DECAY_STEPS,
                                                   DECAY_VAL,
                                                   staircase=DECAY_STAIRCASE)
        tf.summary.scalar('lr', learning_rate)

        net = PixelCNN(learning_rate, global_step, GRAD_CLIP,
                       vq_net.k.get_shape()[1], vq_net.embeds, K, D, 1000,
                       NUM_LAYERS, NUM_FEATURE_MAPS)
    # <<<<<<
    with tf.variable_scope('misc'):
        # Summary Operations
        tf.summary.scalar('loss', net.loss)
        summary_op = tf.summary.merge_all()

        # Initialize op
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        config_summary = tf.summary.text('TrainConfig',
                                         tf.convert_to_tensor(
                                             config.as_matrix()),
                                         collections=[])

        sample_images = tf.placeholder(tf.float32, [None, 128, 128, 3])
        sample_summary_op = tf.summary.image('samples',
                                             sample_images,
                                             max_outputs=20)

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)
    vq_net.load(sess, MODEL)

    summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
    summary_writer.add_summary(config_summary.eval(session=sess))

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)
    try:
        for step in tqdm(xrange(TRAIN_NUM), dynamic_ncols=True):
            batch_xs, batch_ys = sess.run([vq_net.k, labels])
            it, loss, _ = sess.run([global_step, net.loss, net.train_op],
                                   feed_dict={
                                       net.X: batch_xs,
                                       net.h: batch_ys
                                   })

            if (it % SAVE_PERIOD == 0):
                net.save(sess, LOG_DIR, step=it)
                sampled_zs, log_probs = net.sample_from_prior(
                    sess, np.random.randint(0, 1000, size=(10, )), 2)
                sampled_ims = sess.run(vq_net.gen,
                                       feed_dict={vq_net.latent: sampled_zs})
                summary_writer.add_summary(
                    sess.run(sample_summary_op,
                             feed_dict={sample_images: sampled_ims}), it)

            if (it % SUMMARY_PERIOD == 0):
                tqdm.write('[%5d] Loss: %1.3f' % (it, loss))
                summary = sess.run(summary_op,
                                   feed_dict={
                                       net.X: batch_xs,
                                       net.h: batch_ys
                                   })
                summary_writer.add_summary(summary, it)

    except Exception as e:
        coord.request_stop(e)
    finally:
        net.save(sess, LOG_DIR)

        coord.request_stop()
        coord.join(threads)
train_ds = (train_ds.shuffle(BUFFER_SIZE)
                    .batch(BATCH_SIZE)
                    .map(prepare, num_parallel_calls=AUTOTUNE)
                    .map(duplicate)
                    .prefetch(AUTOTUNE))

test_ds = (test_ds.batch(BATCH_SIZE)
                   .map(prepare, num_parallel_calls=AUTOTUNE)
                   .map(duplicate)
                   .prefetch(AUTOTUNE))

# Define model
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = PixelCNN(
        hidden_dim=args.hidden_dim,
        n_res=args.n_res
    )
    model.compile(optimizer='adam', loss=bits_per_dim_loss)

# Learning rate scheduler
steps_per_epochs = info.splits['train'].num_examples // args.batch
decay_per_epoch = args.lr_decay ** steps_per_epochs
schedule = tfk.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=args.learning_rate,
    decay_rate=decay_per_epoch,
    decay_steps=1
)

# Callbacks
time = datetime.now().strftime('%Y%m%d-%H%M%S')
log_dir = os.path.join('.', 'logs', 'pixelcnn', time)
Exemple #12
0
NB_SAMPLES = 10
NB_CLASSES = 10

def get_device():
    if TRY_CUDA == False:
        return torch.device('cpu')
    if torch.cuda.is_available():
        return torch.device('cuda')
    return torch.device('cpu')

device = torch.device('cuda' if TRY_CUDA and torch.cuda.is_available() else 'cpu')
print(f"> Using device {device}")

try:
    print(f"> Loading PixelCNN from file {sys.argv[1]}")
    model = PixelCNN(IMAGE_DIM, 16, 5, 256, 10).to(device)
    model.load_state_dict(torch.load(sys.argv[1]))
    model.eval()
    print("> Loaded PixelCNN succesfully!")
except:
    print("! Failed to load state dict!")
    print("! Make sure model is of correct size and path is correct!")
    exit()

with torch.no_grad():
    sample = torch.zeros(NB_SAMPLES*NB_CLASSES, *IMAGE_DIM).to(device)
    cond = torch.tensor([d for d in range(NB_CLASSES) for _ in range(NB_SAMPLES)]).to(device)

    pb = tqdm(total=IMAGE_DIM[0]*IMAGE_DIM[1]*IMAGE_DIM[2])

    for c in range(IMAGE_DIM[0]):
Exemple #13
0
        },
        "cifar10": {
            "input_shape": (32, 32, 3),
            "color_conditioning": True,
            "n_mixtures": 10,
            "epochs": 5
        }
    }

    n_mixtures = hyperparams[args.dataset]['n_mixtures']
    color_conditioning = hyperparams[args.dataset]['color_conditioning']
    input_shape = hyperparams[args.dataset]['input_shape']
    epochs = hyperparams[args.dataset]['epochs']

    model = PixelCNN(n_mixtures=n_mixtures,
                     color_conditioning=color_conditioning,
                     input_shape=input_shape)
    model.build(input_shape=(16, *input_shape))
    model.load_weights(
        f'weights/pixel_cnn_{args.dataset}_{epochs if (not args.use_openai_sampler or args.dataset == "mnist") else 250}.h5'
    )

    random_input = np.random.uniform(size=(16, *input_shape), low=-1,
                                     high=1).astype(np.float32)

    output = model(random_input)

    for x in range(input_shape[0]):
        for y in range(input_shape[1]):
            for c in range(input_shape[-1]):
Exemple #14
0
IMAGE_DIM = (3, 32, 32)


def get_device():
    if TRY_CUDA == False:
        return torch.device('cpu')
    if torch.cuda.is_available():
        return torch.device('cuda')
    return torch.device('cpu')


device = torch.device(
    'cuda' if TRY_CUDA and torch.cuda.is_available() else 'cpu')
print(f"> Using device {device}")
print("> Instantiating PixelCNN")
model = PixelCNN(IMAGE_DIM, 32, 5, 256, 10).to(device)

print("> Loading dataset")
train_dataset = torchvision.datasets.CIFAR10(
    'data',
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10(
    'data',
    train=False,
    download=True,
    transform=torchvision.transforms.ToTensor())
# train_dataset = torchvision.datasets.ImageFolder('data/pokemon', transform=torchvision.transforms.Compose([
# torchvision.transforms.Resize(32),
# torchvision.transforms.Grayscale(),
Exemple #15
0
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from model import PixelCNN

BATCH_SIZE = 32
dataset = datasets.CIFAR10(root='cifar10', train=True,
                           transform=transforms.ToTensor(), download=True)

trainloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)


device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")
model = PixelCNN().to(device)


criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), learning_rate=lr, betas=(0.5, 0.999))

def train(trainloader):
    model.train()
    for idx, data in enumerate(trainloader):
        img, _ = data
        img = img.to(device)
        optimizer.zero_grad()
        output = model(img)
        loss = criterion(output, img)
        loss.backward()
        optimizer.step()
    return loss
Exemple #16
0
def train_prior(config,
                RANDOM_SEED,
                MODEL,
                TRAIN_NUM,
                BATCH_SIZE,
                LEARNING_RATE,
                DECAY_VAL,
                DECAY_STEPS,
                DECAY_STAIRCASE,
                GRAD_CLIP,
                K,
                D,
                BETA,
                NUM_LAYERS,
                NUM_FEATURE_MAPS,
                SUMMARY_PERIOD,
                SAVE_PERIOD,
                **kwargs):
    np.random.seed(RANDOM_SEED)
    tf.set_random_seed(RANDOM_SEED)
    LOG_DIR = os.path.join(os.path.dirname(MODEL),'pixelcnn')
    # >>>>>>> DATASET
    class Latents():
        def __init__(self,path,validation_size=5000):
            from tensorflow.contrib.learn.python.learn.datasets.mnist import DataSet
            from tensorflow.contrib.learn.python.learn.datasets import base

            data = np.load(path)
            train = DataSet(data['ks'][validation_size:], data['ys'][validation_size:],reshape=False,dtype=np.uint8,one_hot=False) #dtype won't bother even in the case when latent is int32 type.
            validation = DataSet(data['ks'][:validation_size], data['ys'][:validation_size],reshape=False,dtype=np.uint8,one_hot=False)
            #test = DataSet(data['test_x'],np.argmax(data['test_y'],axis=1),reshape=False,dtype=np.float32,one_hot=False)
            self.size = data['ks'].shape[1]
            self.data = base.Datasets(train=train, validation=validation, test=None)
    latent = Latents(os.path.join(os.path.dirname(MODEL),'ks_ys.npz'))
    # <<<<<<<

    # >>>>>>> MODEL for Generate Images
    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        _not_used = tf.placeholder(tf.float32,[None,24,24,1])
        tau_notused = 0.5
        vq_net = GumbelVAE(tau_notused,None,None,BETA,_not_used,K,D,_mnist_arch,params,'decode')
    # <<<<<<<

    # >>>>>> MODEL for Training Prior
    with tf.variable_scope('pixelcnn'):
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(LEARNING_RATE, global_step, DECAY_STEPS, DECAY_VAL, staircase=DECAY_STAIRCASE)
        tf.summary.scalar('lr',learning_rate)

        #net = PixelCNN(learning_rate,global_step,grad_clip,latent_data.size,vq_net.embeds,K,D,10,num_layers,num_feature_maps)
        net = PixelCNN(learning_rate,global_step,GRAD_CLIP,
                       latent.size,vq_net.embeds,K,D,
                       10,NUM_LAYERS,NUM_FEATURE_MAPS)
    # <<<<<<
    with tf.variable_scope('misc'):
        # Summary Operations
        tf.summary.scalar('loss',net.loss)
        summary_op = tf.summary.merge_all()

        # Initialize op
        init_op = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())
        config_summary = tf.summary.text('TrainConfig', tf.convert_to_tensor(config.as_matrix()), collections=[])

        sample_images = tf.placeholder(tf.float32,[None,24,24,1])
        sample_summary_op = tf.summary.image('samples',sample_images,max_outputs=20)

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    #sess.graph.finalize()
    sess.run(init_op)
    vq_net.load(sess,MODEL)

    summary_writer = tf.summary.FileWriter(LOG_DIR,sess.graph)
    summary_writer.add_summary(config_summary.eval(session=sess))

    for step in tqdm(xrange(TRAIN_NUM),dynamic_ncols=True):
        batch_xs, batch_ys = latent.data.train.next_batch(BATCH_SIZE)
        it,loss,_ = sess.run([global_step,net.loss,net.train_op],feed_dict={net.X:batch_xs,net.h:batch_ys})

        if( it % SAVE_PERIOD == 0 ):
            net.save(sess,LOG_DIR,step=it)

        if( it % SUMMARY_PERIOD == 0 ):
            tqdm.write('[%5d] Loss: %1.3f'%(it,loss))
            summary = sess.run(summary_op,feed_dict={net.X:batch_xs,net.h:batch_ys})
            summary_writer.add_summary(summary,it)

        if( it % (SUMMARY_PERIOD * 2) == 0 ):
            sampled_zs,log_probs = net.sample_from_prior(sess,np.arange(10),2)
            sampled_ims = sess.run(vq_net.gen,feed_dict={vq_net.latent:sampled_zs})
            summary_writer.add_summary(
                sess.run(sample_summary_op,feed_dict={sample_images:sampled_ims}),it)

    net.save(sess,LOG_DIR)
Exemple #17
0
            x = torch.cat((x, x, x), dim=0)
        return x

    # create loaders for MNIST
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        './data', train=True, download=True, transform=preprocess),
                                               batch_size=args.batch_size,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        './data', train=False, download=True, transform=preprocess),
                                              batch_size=args.batch_size,
                                              shuffle=True)

    # load multimodal VAE
    model = GatedPixelCNN(args.data_channels, args.out_dims) \
        if args.gated else PixelCNN(args.data_channels, args.out_dims)
    if args.cuda:
        model.cuda()

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)

    def train(epoch):
        model.train()
        loss_meter = AverageMeter()

        for batch_idx, (data, _) in enumerate(train_loader):
            data = Variable(data)
            target = Variable((data.data * (args.out_dims - 1)).long())

            if args.cuda:
                data = data.cuda()