Beispiel #1
0
    def build(self):
        self.net = PixelCNN()
        print(self.net, '\n')

        if self.config.mode == 'train':
            self.optimizer = self.config.optimizer(self.net.parameters())
            self.loss_fn = nn.CrossEntropyLoss()
Beispiel #2
0
def train(conf, data):
    X = tf.placeholder(
        tf.float32,
        shape=[None, conf.img_height, conf.img_width, conf.channel])
    model = PixelCNN(X, conf)

    trainer = tf.train.RMSPropOptimizer(1e-3)
    gradients = trainer.compute_gradients(model.loss)

    clipped_gradients = [(tf.clip_by_value(_[0], -conf.grad_clip,
                                           conf.grad_clip), _[1])
                         for _ in gradients]
    optimizer = trainer.apply_gradients(clipped_gradients)

    saver = tf.train.Saver(tf.trainable_variables())

    gpu_options = tf.GPUOptions(allow_growth=True)
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(tf.initialize_all_variables())
        if len(glob.glob(conf.ckpt_file + '*')) != 0:
            saver.restore(sess, conf.ckpt_file)
            print("Model Restored")
        else:
            print("Model reinitialized")
        if conf.epochs > 0:
            print("Started Model Training...")
        pointer = 0
        for i in range(conf.epochs):
            start_time = time.time()
            for j in range(conf.num_batches):
                if conf.data == "mnist":
                    batch_X, batch_y = data.train.next_batch(conf.batch_size)
                    batch_X = binarize(
                        batch_X.reshape([
                            conf.batch_size, conf.img_height, conf.img_width,
                            conf.channel
                        ]))
                    batch_y = one_hot(batch_y, conf.num_classes)
                else:
                    batch_X, pointer = get_batch(data, pointer,
                                                 conf.batch_size)
                data_dict = {X: batch_X}
                if conf.conditional is True:
                    data_dict[model.h] = batch_y
                _, cost = sess.run([optimizer, model.loss],
                                   feed_dict=data_dict)
            print("Epoch: %d, Cost: %f, step time %fs" %
                  (i, cost, time.time() - start_time))
            if (i + 1) % 2 == 0:
                saver.save(sess, conf.ckpt_file)
                generate_samples(sess, X, model.h, model.pred, conf, "")

        generate_samples(sess, X, model.h, model.pred, conf, "")
Beispiel #3
0
def train(conf, data):
    X = tf.placeholder(
        tf.float32,
        shape=[None, conf.img_height, conf.img_width, conf.channel])
    model = PixelCNN(X, conf)

    trainer = tf.train.RMSPropOptimizer(1e-3)
    gradients = trainer.compute_gradients(model.loss)

    clipped_gradients = [(tf.clip_by_value(_[0], -conf.grad_clip,
                                           conf.grad_clip), _[1])
                         for _ in gradients]
    optimizer = trainer.apply_gradients(clipped_gradients)

    saver = tf.train.Saver(tf.trainable_variables())

    with tf.Session() as sess:
        merged = tf.merge_all_summaries()
        writer = tf.train.SummaryWriter(conf.summary_path, sess.graph)
        sess.run(tf.initialize_all_variables())
        if os.path.exists(conf.ckpt_file):
            saver.restore(sess, conf.ckpt_file)
            print("Model Restored")

        if conf.epochs > 0:
            print("Started Model Training...")
        step = 0
        pointer = 0
        for i in range(conf.epochs):
            for j in range(conf.num_batches):
                if conf.data == "mnist":
                    batch_X, batch_y = data.train.next_batch(conf.batch_size)
                    batch_X = binarize(batch_X.reshape([conf.batch_size, \
                            conf.img_height, conf.img_width, conf.channel]))
                    batch_y = one_hot(batch_y, conf.num_classes)
                else:
                    batch_X, pointer = get_batch(data, pointer,
                                                 conf.batch_size)
                data_dict = {X: batch_X}
                if conf.conditional is True:
                    data_dict[model.h] = batch_y
                _, cost, summary = sess.run([optimizer, model.loss, merged],
                                            feed_dict=data_dict)
                writer.add_summary(summary, step)
                step += 1
            print("Epoch: %d, Cost: %f" % (i, cost))
            if (i + 1) % 10 == 0:
                saver.save(sess, conf.ckpt_file)
                generate_samples(sess, X, model.h, model.pred, conf, "")

        generate_samples(sess, X, model.h, model.pred, conf, "")
Beispiel #4
0
class Solver(object):
    def __init__(self, config, train_loader, test_loader):
        self.config = config
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.n_batches_in_epoch = len(self.train_loader)
        self.total_data_size = len(self.train_loader.dataset)
        self.is_train = self.config.isTrain

    def build(self):
        self.net = PixelCNN()
        print(self.net, '\n')

        if self.config.mode == 'train':
            self.optimizer = self.config.optimizer(self.net.parameters())
            self.loss_fn = nn.CrossEntropyLoss()

    def train(self):

        for epoch_i in trange(self.config.n_epochs, desc='Epoch', ncols=80):
            epoch_i += 1

            # For debugging
            if epoch_i == 1:
                #     self.test(epoch_i)
                self.sample(epoch_i)

            self.net.train()
            self.batch_loss_history = []

            for batch_i, (image, label) in enumerate(
                    tqdm(self.train_loader,
                         desc='Batch',
                         ncols=80,
                         leave=False)):

                batch_i += 1
                # [batch_size, 3, 32, 32]
                image = Variable(image)

                # [batch_size, 3, 32, 32, 256]
                logit = self.net(image)
                logit = logit.contiguous()
                logit = logit.view(-1, 256)

                target = Variable(image.data.view(-1) * 255).long()

                batch_loss = self.loss_fn(logit, target)

                self.optimizer.zero_grad()
                batch_loss.backward()
                self.optimizer.step()

                batch_loss = float(batch_loss.data)
                self.batch_loss_history.append(batch_loss)

                if batch_i > 1 and batch_i % self.config.log_interval == 0:
                    log_string = f'Epoch: {epoch_i} | Batch: ({batch_i}/{self.n_batches_in_epoch}) | '
                    log_string += f'Loss: {batch_loss:.3f}'
                    tqdm.write(log_string)

            epoch_loss = np.mean(self.batch_loss_history)
            tqdm.write(f'Epoch Loss: {epoch_loss:.2f}')

            self.test(epoch_i)
            self.sample(epoch_i)

    def test(self, epoch_i):
        """Compute error on test set"""

        test_errors = []
        # cuda.synchronize()
        start = time.time()

        self.net.eval()

        for image, label in self.test_loader:

            # [batch_size, channel, height, width]
            image = Variable(image.cuda(), volatile=True)

            # [batch_size, channel, height, width, 256]
            logit = self.net(image).contiguous()

            # [batch_size x channel x height x width, 256]
            logit = logit.view(-1, 256)

            # [batch_size x channel x height x width]
            target = Variable((image.data.view(-1) * 255).long())

            loss = F.cross_entropy(logit, target)

            test_error = float(loss.data)
            test_errors.append(test_error)

        # cuda.synchronize()
        time_test = time.time() - start
        log_string = f'Test done! | It took {time_test:.1f}s | '
        log_string += f'Test Loss: {np.mean(test_errors):.2f}'
        tqdm.write(log_string)

    def sample(self, epoch_i):
        """Sampling Images"""

        image_path = str(self.config.ckpt_dir.joinpath(f'epoch-{epoch_i}.png'))
        tqdm.write(f'Saved sampled images at f{image_path})')
        self.net.eval()

        sample = torch.zeros(self.config.batch_size, 3, 32, 32)

        for i in range(32):
            for j in range(32):

                # [batch_size, channel, height, width, 256]
                out = self.net(Variable(sample, volatile=True))

                # out[:, :, i, j]
                # => [batch_size, channel, 256]
                probs = F.softmax(out[:, :, i, j], dim=2).data

                # Sample single pixel (each channel independently)
                for k in range(3):
                    # 0 ~ 255 => 0 ~ 1
                    pixel = torch.multinomial(probs[:, k], 1).float() / 255.
                    ##                    print(pixel.view(16).shape)
                    curr_shape = sample[:, k, i, j].shape
                    ##                    print(k, i,j)
                    sample[:, k, i, j] = pixel.view(curr_shape)

        import ipdb
        ipdb.set_trace()

        save_image(sample, image_path)
Beispiel #5
0
def train(conf, data):
    X = tf.placeholder(
        tf.float32,
        shape=[None, conf.img_height, conf.img_width, conf.channels])
    model = PixelCNN(X, conf)

    trainer = tf.train.RMSPropOptimizer(1e-3)
    gradients = trainer.compute_gradients(model.loss)
    clipped_gradients = [(tf.clip_by_value(_[0], -conf.grad_clip,
                                           conf.grad_clip), _[1])
                         for _ in gradients]
    optimizer = trainer.apply_gradients(clipped_gradients)

    saver = tf.train.Saver(tf.trainable_variables())

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        if os.path.exists(conf.ckpt_file):
            saver.restore(sess, conf.ckpt_file)
            print("Model Restored")
        if conf.epochs > 0:
            print("Started Model Training...")

        pointer = 0
        for i in range(conf.epochs):
            epoch_loss = 0.
            for j in range(conf.num_batches):
                if conf.data == "mnist_bw":
                    batch_X, batch_y = data.train.next_batch(
                        conf.batch_size)  # batch_X is N,HW; batch_y is N
                    batch_X = batch_X.reshape([
                        conf.batch_size, conf.img_height, conf.img_width,
                        conf.channels
                    ])  # N,H,W,C
                    batch_X = binarize(batch_X)
                    batch_y = one_hot(batch_y, conf.num_classes)  # N,10
                elif conf.data == "mnist":
                    batch_X, batch_y, pointer = get_batch(
                        data, pointer, conf.batch_size)
                    batch_X = batch_X.reshape([
                        conf.batch_size, conf.img_height, conf.img_width,
                        conf.channels
                    ])  # N,H,W,C
                    batch_y = one_hot(batch_y, conf.num_classes)  # N,10
                elif conf.data == "fashion":
                    batch_X, batch_y, pointer = get_batch(
                        data, pointer, conf.batch_size)
                    batch_X = batch_X.reshape([
                        conf.batch_size, conf.img_height, conf.img_width,
                        conf.channels
                    ])  # N,H,W,C
                    batch_y = one_hot(batch_y, conf.num_classes)  # N,10
                if i == 0 and j == 0:
                    save_images(batch_X, conf.batch_size, 1, "batch.png", conf)
                data_dict = {X: batch_X}
                if conf.conditional is True:
                    data_dict[model.h] = batch_y
                # print(sess.run([model.out, model.pred], feed_dict=data_dict))
                _, cost = sess.run([optimizer, model.loss],
                                   feed_dict=data_dict)
                epoch_loss += cost
            print("Epoch: %d, Cost: %f" % (i, epoch_loss))
            if (i + 1) % 1 == 0:
                saver.save(sess, conf.ckpt_file)
                generate_samples(sess, X, model.h, model.pred, i + ".png",
                                 conf)

        generate_samples(sess, X, model.h, model.pred, "final.png", conf)
Beispiel #6
0
                               shuffle=True,
                               num_workers=4)
    test_dataset = DataLoader(datasets.MNIST('./data',
                                             train=False,
                                             transform=base_tr),
                              batch_size=config.batch,
                              shuffle=False,
                              num_workers=4)

    if config.gate:
        model = GatedPixelCNN(config.size,
                              config.layer,
                              conditional=config.conditional,
                              num_classes=num_classes)
    else:
        model = PixelCNN(config.size, config.layer)
    if is_gpu:
        model.cuda()
    optimizer = optim.Adam([*model.parameters()], lr=config.lr)
    os.makedirs(config.model, exist_ok=True)

    if config.mode == 'train':
        for epoch in range(config.epochs):
            model.train()
            err_tr, err_te = 0, 0
            batch_iter_tr, batch_iter_te = 0, 0
            batch_size = len(train_dataset)
            for batch_idx, train in enumerate(train_dataset):
                train_label = one_hot_encoder(train[1], num_classes)
                train_data = train[0]
                if is_gpu:
Beispiel #7
0
from models import LabelNet, PixelCNN


n_classes = 10 # number of classes
n_epochs = 25 # number of epochs to train
n_layers = 7 # number of convolutional layers
n_channels = 16 # number of channels
device = 'cuda:0'

def to_one_hot(y, k=10):
    y = y.view(-1, 1)
    y_one_hot = torch.zeros(y.numel(), k)
    y_one_hot.scatter_(1, y, 1)
    return y_one_hot.float()

pixel_cnn = PixelCNN(n_channels, n_layers).to(device)
label_net = LabelNet().to(device)

trainloader = data.DataLoader(datasets.MNIST('data', train=True,
                                             download=True,
                                             transform=transforms.ToTensor()),
                              batch_size=128, shuffle=True,
                              num_workers=1, pin_memory=True)

testloader = data.DataLoader(datasets.MNIST('data', train=False,
                                            download=True,
                                            transform=transforms.ToTensor()),
                             batch_size=128, shuffle=False,
                             num_workers=1, pin_memory=True)

sample = torch.Tensor(120, 1, 28, 28).to(device)