def go(arg):
    tbw = SummaryWriter(log_dir=arg.tb_dir)
    if arg.task == 'mnist':
        transform = Compose([Pad(padding=2), ToTensor()])

        trainset = torchvision.datasets.MNIST(root=arg.data_dir,
                                              train=True,
                                              download=True,
                                              transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=arg.batch_size,
                                                  shuffle=True,
                                                  num_workers=2)
        testset = torchvision.datasets.MNIST(root=arg.data_dir,
                                             train=False,
                                             download=True,
                                             transform=transform)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=arg.batch_size,
                                                 shuffle=False,
                                                 num_workers=2)
        C, H, W = 1, 32, 32

    elif arg.task == 'imagenet64':
        transform = Compose([ToTensor()])
        trainset = torchvision.datasets.ImageFolder(root=arg.data_dir +
                                                    os.sep + 'train',
                                                    transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=arg.batch_size,
                                                  shuffle=True,
                                                  num_workers=2)
        testset = torchvision.datasets.ImageFolder(root=arg.data_dir + os.sep +
                                                   'valid',
                                                   transform=transform)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=arg.batch_size,
                                                 shuffle=False,
                                                 num_workers=2)
        C, H, W = 3, 64, 64
    else:
        raise Exception('Task not recognized.')

    krn = arg.kernel_size
    pad = krn // 2
    OUTCN = 64
    encoder = models.ImEncoder(in_size=(H, W),
                               zsize=arg.zsize,
                               depth=arg.vae_depth,
                               colors=C)
    # decoder = util.Lambda(lambda x: x)  # identity
    decoder = models.ImDecoder(in_size=(H, W),
                               zsize=arg.zsize,
                               depth=arg.vae_depth,
                               out_channels=OUTCN)
    pixcnn = models.CGated((C, H, W), (arg.zsize, ),
                           arg.channels,
                           num_layers=arg.num_layers,
                           k=krn,
                           padding=pad)
    #######################
    if options.loadPreModel:
        encoder, decoder, pixcnn = loadModel(encoder, decoder, pixcnn)
    ########################
    mods = [encoder, decoder, pixcnn]
    if torch.cuda.is_available():
        for m in mods:
            m.cuda()
    print('Constructed network', encoder, decoder, pixcnn)

    sample_zs = torch.randn(12, arg.zsize)
    sample_zs = sample_zs.unsqueeze(1).expand(12, 6, -1).contiguous().view(
        72, 1, -1).squeeze(1)
    # A sample of 144 square images with 3 channels, of the chosen resolution
    # (144 so we can arrange them in a 12 by 12 grid)
    sample_init_zeros = torch.zeros(72, C, H, W)
    sample_init_seeds = torch.zeros(72, C, H, W)
    sh, sw = H // SEEDFRAC, W // SEEDFRAC
    # Init second half of sample with patches from test set, to seed the sampling
    testbatch = util.readn(testloader, n=12)
    testbatch = testbatch.unsqueeze(1).expand(12, 6, C, H,
                                              W).contiguous().view(
                                                  72, 1, C, H, W).squeeze(1)
    sample_init_seeds[:, :, :sh, :] = testbatch[:, :, :sh, :]

    params = []
    for m in mods:
        params.extend(m.parameters())
    optimizer = Adam(params, lr=arg.lr)
    instances_seen = 0
    for epoch in range(arg.epochs):
        # Train
        err_tr = []
        for m in mods:
            m.train(True)
        for i, (input, _) in enumerate(tqdm.tqdm(trainloader)):
            if arg.limit is not None and i * arg.batch_size > arg.limit:
                break
            # Prepare the input
            b, c, w, h = input.size()
            if torch.cuda.is_available():
                input = input.cuda()
            target = (input.data * 255).long()
            input, target = Variable(input), Variable(target)
            # Forward pass
            zs = encoder(input)
            kl_loss = util.kl_loss(*zs)
            z = util.sample(*zs)
            out = decoder(z)
            rec = pixcnn(input, out)
            rec_loss = cross_entropy(rec, target,
                                     reduce=False).view(b, -1).sum(dim=1)
            loss = (rec_loss + kl_loss).mean()
            instances_seen += input.size(0)
            tbw.add_scalar('pixel-models/vae/training/kl-loss',
                           kl_loss.mean().data.item(), instances_seen)
            tbw.add_scalar('pixel-models/vae/training/rec-loss',
                           rec_loss.mean().data.item(), instances_seen)
            err_tr.append(loss.data.item())
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        # Evaluate
        # - we evaluate on the test set, since this is only a simpe reproduction experiment
        #   make sure to split off a validation set if you want to tune hyperparameters for something important
        err_te = []

        for m in mods:
            m.train(False)
        for i, (input, _) in enumerate(tqdm.tqdm(testloader)):
            if arg.limit is not None and i * arg.batch_size > arg.limit:
                break
            b, c, w, h = input.size()
            if torch.cuda.is_available():
                input = input.cuda()
            target = (input.data * 255).long()
            input, target = Variable(input), Variable(target)

            zs = encoder(input)
            kl_loss = util.kl_loss(*zs)
            z = util.sample(*zs)
            out = decoder(z)
            rec = pixcnn(input, out)
            rec_loss = cross_entropy(rec, target,
                                     reduce=False).view(b, -1).sum(dim=1)
            loss = (rec_loss + kl_loss).mean()
            err_te.append(loss.data.item())

        tbw.add_scalar('pixel-models/test-loss',
                       sum(err_te) / len(err_te), epoch)
        print('epoch={:02}; training loss: {:.3f}; test loss: {:.3f}'.format(
            epoch,
            sum(err_tr) / len(err_tr),
            sum(err_te) / len(err_te)))

        for m in mods:
            m.train(False)
        sample_zeros = draw_sample(sample_init_zeros,
                                   decoder,
                                   pixcnn,
                                   sample_zs,
                                   seedsize=(0, 0))
        sample_seeds = draw_sample(sample_init_seeds,
                                   decoder,
                                   pixcnn,
                                   sample_zs,
                                   seedsize=(sh, W))
        sample = torch.cat([sample_zeros, sample_seeds], dim=0)

        torchvision.utils.save_image(
            sample,
            'myResults/sample_{:02d}.png'.format(epoch),
            nrow=12,
            padding=0)

    saveModel(encoder, decoder, pixcnn)
Esempio n. 2
0
def go(arg):

    tbw = SummaryWriter(log_dir=arg.tb_dir)

    ## Load the data
    if arg.task == 'mnist':
        trainset = torchvision.datasets.MNIST(root=arg.data_dir,
                                              train=True,
                                              download=True,
                                              transform=ToTensor())
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=arg.batch_size,
                                                  shuffle=True,
                                                  num_workers=2)

        testset = torchvision.datasets.MNIST(root=arg.data_dir,
                                             train=False,
                                             download=True,
                                             transform=ToTensor())
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=arg.batch_size,
                                                 shuffle=False,
                                                 num_workers=2)
        C, H, W = 1, 28, 28
        CLS = 10

    elif arg.task == 'cifar10':
        trainset = torchvision.datasets.CIFAR10(root=arg.data_dir,
                                                train=True,
                                                download=True,
                                                transform=ToTensor())
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=arg.batch_size,
                                                  shuffle=True,
                                                  num_workers=2)

        testset = torchvision.datasets.CIFAR10(root=arg.data_dir,
                                               train=False,
                                               download=True,
                                               transform=ToTensor())
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=arg.batch_size,
                                                 shuffle=False,
                                                 num_workers=2)
        C, H, W = 3, 32, 32
        CLS = 10
    else:
        raise Exception('Task {} not recognized.'.format(arg.task))

    ## Set up the model

    if arg.model == 'gated':

        model = models.CGated((C, H, W), (CLS, ),
                              arg.channels,
                              num_layers=arg.num_layers,
                              k=arg.kernel_size,
                              padding=arg.kernel_size // 2)

    else:
        raise Exception('model "{}" not recognized'.format(arg.model))

    print('Constructed network', model)

    # A sample of 144 square images with 3 channels, of the chosen resolution
    # (144 so we can arrange them in a 12 by 12 grid)
    sample_init_zeros = torch.zeros(72, C, H, W)
    sample_init_seeds = torch.zeros(72, C, H, W)

    sh, sw = H // SEEDFRAC, W // SEEDFRAC

    # Init second half of sample with patches from test set, to seed the sampling
    testbatch = util.readn(testloader, n=12)
    testcls_seeds = util.readn(testloader, n=12, cls=True, maxval=CLS)

    testbatch = testbatch.unsqueeze(1).expand(12, 6, C, H,
                                              W).contiguous().view(
                                                  72, 1, C, H, W).squeeze(1)
    sample_init_seeds[:, :, :sh, :] = testbatch[:, :, :sh, :]
    testcls_seeds = testcls_seeds.unsqueeze(1).expand(
        12, 6, CLS).contiguous().view(72, 1, CLS).squeeze(1)

    # Get classes for the unseeded part
    testcls_zeros = util.readn(testloader, n=24, cls=True, maxval=CLS)[12:]
    testcls_zeros = testcls_zeros.unsqueeze(1).expand(
        12, 6, CLS).contiguous().view(72, 1, CLS).squeeze(1)

    optimizer = Adam(model.parameters(), lr=arg.lr)

    if torch.cuda.is_available():
        model.cuda()

    model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    instances_seen = 0
    for epoch in range(arg.epochs):

        # Train
        err_tr = []
        model.train(True)

        for i, (input, classes) in enumerate(tqdm.tqdm(trainloader)):
            if arg.limit is not None and i * arg.batch_size > arg.limit:
                break

            # Prepare the input
            b, c, w, h = input.size()

            classes = util.one_hot(classes, CLS)

            if torch.cuda.is_available():
                input, classes = input.cuda(), classes.cuda()

            target = (input.data * 255).long()

            input, classes, target = Variable(input), Variable(
                classes), Variable(target)

            # Forward pass
            result = model(input, classes)
            loss = cross_entropy(result, target)
            loss = loss * util.LOG2E  # Convert from nats to bits

            instances_seen += input.size(0)
            tbw.add_scalar('pixel-models/training-loss', loss.data.item(),
                           instances_seen)
            err_tr.append(loss.data.item())

            # Backward pass
            optimizer.zero_grad()

            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()

            optimizer.step()

        if epoch % arg.eval_every == 0 and epoch != 0:
            with torch.no_grad():

                # Evaluate
                # - we evaluate on the test set, since this is only a simpe reproduction experiment
                #   make sure to split off a validation set if you want to tune hyperparameters for something important

                err_te = []
                model.train(False)

                for i, (input, classes) in enumerate(tqdm.tqdm(testloader)):
                    if arg.limit is not None and i * arg.batch_size > arg.limit:
                        break

                    classes = util.one_hot(classes, CLS)

                    if torch.cuda.is_available():
                        input, classes = input.cuda(), classes.cuda()

                    target = (input.data * 255).long()

                    input, classes, target = Variable(input), Variable(
                        classes), Variable(target)

                    result = model(input, classes)
                    loss = cross_entropy(result, target)
                    loss = loss * util.LOG2E  # Convert from nats to bits

                    err_te.append(loss.data.item())

                tbw.add_scalar('pixel-models/test-loss',
                               sum(err_te) / len(err_te), epoch)
                print('epoch={:02}; training loss: {:.3f}; test loss: {:.3f}'.
                      format(epoch,
                             sum(err_tr) / len(err_tr),
                             sum(err_te) / len(err_te)))

                model.train(False)
                sample_zeros = draw_sample(sample_init_zeros,
                                           testcls_zeros,
                                           model,
                                           seedsize=(0, 0),
                                           batch_size=arg.batch_size)
                sample_seeds = draw_sample(sample_init_seeds,
                                           testcls_seeds,
                                           model,
                                           seedsize=(sh, W),
                                           batch_size=arg.batch_size)
                sample = torch.cat([sample_zeros, sample_seeds], dim=0)

                utils.save_image(sample,
                                 'sample_{:02d}.png'.format(epoch),
                                 nrow=12,
                                 padding=0)