示例#1
0
    def __init__(self, in_size, zsize=32, use_res=False, use_bn=False, depth=0):
        super().__init__()
        self.zsize = zsize

        # - Encoder
        modules = [
            util.Block(3, a, use_res=use_res, batch_norm=use_bn),
            MaxPool2d((p, p)),
            util.Block(a, b, use_res=use_res, batch_norm=use_bn),
            MaxPool2d((q, q)),
            util.Block(b, c, use_res=use_res, batch_norm=use_bn),
            MaxPool2d((r, r)),
        ]

        for i in range(depth):
            modules.append( util.Block(c, c, use_res=use_res, batch_norm=use_bn))

        modules.extend([
            util.Flatten(),
            Linear((in_size[0] // (p*q*r)) * (in_size[1] //  (p*q*r)) * c, zsize * 2)
        ])

        self.encoder = Sequential(*modules)
示例#2
0
def go(args,
       batch=64,
       epochs=350,
       k=3,
       modelname='baseline',
       cuda=False,
       seed=1,
       lr=0.001,
       subsample=None,
       num_values=-1,
       min_sigma=0.0,
       tb_dir=None,
       data='./data',
       hidden=32,
       task='mnist',
       final=False,
       dropout=0.0,
       plot_every=1):

    if seed < 0:
        seed = random.randint(0, 1000000)
        print('random seed: ', seed)
    else:
        torch.manual_seed(seed)

    tbw = SummaryWriter(log_dir=tb_dir)

    normalize = transforms.Compose([transforms.ToTensor()])

    if (task == 'mnist'):
        data = data + os.sep + task

        if final:
            train = torchvision.datasets.MNIST(root=data,
                                               train=True,
                                               download=True,
                                               transform=normalize)
            trainloader = torch.utils.data.DataLoader(train,
                                                      batch_size=batch,
                                                      shuffle=True,
                                                      num_workers=2)

            test = torchvision.datasets.MNIST(root=data,
                                              train=False,
                                              download=True,
                                              transform=normalize)
            testloader = torch.utils.data.DataLoader(test,
                                                     batch_size=batch,
                                                     shuffle=False,
                                                     num_workers=2)

        else:
            NUM_TRAIN = 45000
            NUM_VAL = 5000
            total = NUM_TRAIN + NUM_VAL

            train = torchvision.datasets.MNIST(root=data,
                                               train=True,
                                               download=True,
                                               transform=normalize)

            trainloader = DataLoader(train,
                                     batch_size=batch,
                                     sampler=util.ChunkSampler(
                                         0, NUM_TRAIN, total))
            testloader = DataLoader(train,
                                    batch_size=batch,
                                    sampler=util.ChunkSampler(
                                        NUM_TRAIN, NUM_VAL, total))

        shape = (1, 28, 28)
        num_classes = 10

    elif (task == 'image-folder-bw'):

        tr = transforms.Compose(
            [transforms.Grayscale(),
             transforms.ToTensor()])

        if final:
            train = torchvision.datasets.ImageFolder(root=data + '/train/',
                                                     transform=tr)
            test = torchvision.datasets.ImageFolder(root=data + '/test/',
                                                    transform=tr)

            trainloader = DataLoader(train, batch_size=batch, shuffle=True)
            testloader = DataLoader(train, batch_size=batch, shuffle=True)

        else:

            NUM_TRAIN = 45000
            NUM_VAL = 5000
            total = NUM_TRAIN + NUM_VAL

            train = torchvision.datasets.ImageFolder(root=data, transform=tr)

            trainloader = DataLoader(train,
                                     batch_size=batch,
                                     sampler=util.ChunkSampler(
                                         0, NUM_TRAIN, total))
            testloader = DataLoader(train,
                                    batch_size=batch,
                                    sampler=util.ChunkSampler(
                                        NUM_TRAIN, NUM_VAL, total))

        shape = (1, 100, 100)
        num_classes = 10

    elif (task == 'cifar10'):
        data = data + os.sep + task

        if final:
            train = torchvision.datasets.CIFAR10(root=data,
                                                 train=True,
                                                 download=True,
                                                 transform=normalize)
            trainloader = torch.utils.data.DataLoader(train,
                                                      batch_size=batch,
                                                      shuffle=True,
                                                      num_workers=2)
            test = torchvision.datasets.CIFAR10(root=data,
                                                train=False,
                                                download=True,
                                                transform=normalize)
            testloader = torch.utils.data.DataLoader(test,
                                                     batch_size=batch,
                                                     shuffle=False,
                                                     num_workers=2)

        else:
            NUM_TRAIN = 45000
            NUM_VAL = 5000
            total = NUM_TRAIN + NUM_VAL

            train = torchvision.datasets.CIFAR10(root=data,
                                                 train=True,
                                                 download=True,
                                                 transform=normalize)

            trainloader = DataLoader(train,
                                     batch_size=batch,
                                     sampler=util.ChunkSampler(
                                         0, NUM_TRAIN, total))
            testloader = DataLoader(train,
                                    batch_size=batch,
                                    sampler=util.ChunkSampler(
                                        NUM_TRAIN, NUM_VAL, total))

        shape = (3, 32, 32)
        num_classes = 10

    elif (task == 'cifar100'):

        data = data + os.sep + task

        if final:
            train = torchvision.datasets.CIFAR100(root=data,
                                                  train=True,
                                                  download=True,
                                                  transform=normalize)
            trainloader = torch.utils.data.DataLoader(train,
                                                      batch_size=batch,
                                                      shuffle=True,
                                                      num_workers=2)
            test = torchvision.datasets.CIFAR100(root=data,
                                                 train=False,
                                                 download=True,
                                                 transform=normalize)
            testloader = torch.utils.data.DataLoader(test,
                                                     batch_size=batch,
                                                     shuffle=False,
                                                     num_workers=2)

        else:
            NUM_TRAIN = 45000
            NUM_VAL = 5000
            total = NUM_TRAIN + NUM_VAL

            train = torchvision.datasets.CIFAR100(root=data,
                                                  train=True,
                                                  download=True,
                                                  transform=normalize)

            trainloader = DataLoader(train,
                                     batch_size=batch,
                                     sampler=util.ChunkSampler(
                                         0, NUM_TRAIN, total))
            testloader = DataLoader(train,
                                    batch_size=batch,
                                    sampler=util.ChunkSampler(
                                        NUM_TRAIN, NUM_VAL, total))

        shape = (3, 32, 32)
        num_classes = 100

    else:
        raise Exception('Task name {} not recognized'.format(task))

    activation = nn.ReLU()

    hyperlayer = None

    reinforce = False

    if modelname == 'baseline':

        model = nn.Sequential(util.Flatten(), nn.Linear(prod(shape), hidden),
                              activation, nn.Linear(hidden, num_classes),
                              nn.Softmax())

    elif modelname == 'baseline-conv':

        c, w, h = shape
        hid = floor(floor(w / 8) / 4) * floor(floor(h / 8) / 4) * 32

        model = nn.Sequential(nn.Conv2d(c, 4, kernel_size=5,
                                        padding=2), activation,
                              nn.Conv2d(4, 4, kernel_size=5,
                                        padding=2), activation,
                              nn.MaxPool2d(kernel_size=8),
                              nn.Conv2d(4, 16, kernel_size=5,
                                        padding=2), activation,
                              nn.Conv2d(16, 16, kernel_size=5,
                                        padding=2), activation,
                              nn.MaxPool2d(kernel_size=4),
                              nn.Conv2d(16, 32, kernel_size=5,
                                        padding=2), activation,
                              nn.Conv2d(32, 32, kernel_size=5, padding=2),
                              activation, util.Flatten(), nn.Linear(hid, 128),
                              nn.Dropout(dropout), nn.Linear(128, num_classes),
                              nn.Softmax())

    elif modelname == 'ash':

        model = ASHModel(shape=shape,
                         k=k,
                         glimpses=args.num_glimpses,
                         num_values=num_values,
                         min_sigma=min_sigma,
                         subsample=subsample,
                         hidden=hidden,
                         num_classes=num_classes,
                         gadditional=args.gadditional,
                         radditional=args.radditional,
                         region=(args.chunk, args.chunk),
                         reinforce=False)

        # model = nn.Sequential(
        #     hyperlayer,
        #     util.Flatten(),
        #     nn.Linear(k*k*C, hidden),
        #     activation,
        #     nn.Linear(hidden, num_classes),
        #     nn.Softmax())

    elif modelname == 'ash-reinforce':

        model = ASHModel(shape=shape,
                         k=k,
                         glimpses=args.num_glimpses,
                         num_values=num_values,
                         min_sigma=min_sigma,
                         subsample=subsample,
                         hidden=hidden,
                         num_classes=num_classes,
                         reinforce=True,
                         rfboost=args.rfboost)
        reinforce = True

    # elif modelname == 'nas':
    #     C = 1
    #     hyperlayer = SimpleImageLayer(shape, out_channels=C, k=k, adaptive=False, additional=additional, num_values=num_values,
    #                             min_sigma=min_sigma, subsample=subsample)
    #     #
    #     # if rec_lambda is not None:
    #     #     reconstruction = ToImageLayer((C, k, k), out_size=shape, k=k, adaptive=False, additional=additional, num_values=num_values,
    #     #                         min_sigma=min_sigma, subsample=subsample, pre=pre)
    #
    #     model = nn.Sequential(
    #         hyperlayer,
    #         util.Flatten(),
    #         nn.Linear(k*k*C, hidden),
    #         activation,
    #         nn.Linear(hidden, num_classes),
    #         nn.Softmax())

    # elif modelname == 'ash-conv':
    #     C = 1
    #     hyperlayer = SimpleImageLayer(shape, out_channels=C, k=k, adaptive=True, additional=additional, num_values=num_values,
    #                             min_sigma=min_sigma, subsample=subsample, big=not small)
    #
    #     model = nn.Sequential(
    #         hyperlayer,
    #         activation,
    #         nn.Conv2d(C, 16, kernel_size=5, padding=2), activation,
    #         nn.Conv2d(16, 16, kernel_size=5, padding=2), activation,
    #         nn.Conv2d(16, 16, kernel_size=5, padding=2), activation,
    #         nn.MaxPool2d(kernel_size=2),
    #         nn.Conv2d(16, 32, kernel_size=5, padding=2), activation,
    #         nn.Conv2d(32, 32, kernel_size=5, padding=2), activation,
    #         nn.Conv2d(32, 32, kernel_size=5, padding=2), activation,
    #         nn.MaxPool2d(kernel_size=2),
    #         nn.Conv2d(32, 64, kernel_size=5, padding=2), activation,
    #         nn.Conv2d(64, 64, kernel_size=5, padding=2), activation,
    #         nn.Conv2d(64, 64, kernel_size=5, padding=2), activation,
    #         nn.MaxPool2d(kernel_size=2),
    #         nn.Conv2d(64, 128, kernel_size=5, padding=2), activation,
    #         nn.MaxPool2d(kernel_size=2),
    #         util.Flatten(),
    #         nn.Linear(128, num_classes),
    #         nn.Softmax())

    else:
        raise Exception('Model name {} not recognized'.format(modelname))

    if cuda:
        model.cuda()
        if hyperlayer is not None:
            hyperlayer.apply(lambda t: t.cuda())
        # if rec_lambda is not None:
        #     reconstruction.apply(lambda t: t.cuda())

    # if rec_lambda is None:
    #     optimizer = optim.Adam(model.parameters(), lr=lr)
    # else:
    #     optimizer = optim.Adam(list(model.parameters()) + list(reconstruction.parameters()), lr=lr)

    optimizer = optim.Adam(model.parameters(), lr=lr)

    xent = nn.CrossEntropyLoss()
    mse = nn.MSELoss()

    step = 0

    sigs, vals = [], []

    util.makedirs('./mnist/')

    for epoch in range(epochs):

        model.train()

        for i, data in tqdm(enumerate(trainloader, 0)):

            # get the inputs
            inputs, labels = data

            if cuda:
                inputs, labels = inputs.cuda(), labels.cuda()

            # wrap them in Variables
            inputs, labels = Variable(inputs), Variable(labels)

            optimizer.zero_grad()

            if not reinforce:
                outputs = model(inputs)
            else:
                outputs, stoch_nodes, actions = model(inputs)

            mloss = F.cross_entropy(outputs, labels, reduce=False)

            if reinforce:

                rloss = 0.0

                for node, action in zip(stoch_nodes, actions):
                    rloss = rloss - node.log_prob(action) * -mloss.detach(
                    ).unsqueeze(1).expand_as(action)

                # print(mloss.size(), rloss.size())

                loss = rloss.sum(dim=1) + mloss

                tbw.add_scalar('mnist/train-loss', float(loss.mean().item()),
                               step)
                tbw.add_scalar('mnist/model-loss',
                               float(rloss.sum(dim=1).mean().item()), step)
                tbw.add_scalar('mnist/reinf-loss', float(mloss.mean().item()),
                               step)

            else:
                loss = mloss

                tbw.add_scalar('mnist/train-loss',
                               float(loss.data.sum().item()), step)

            loss = loss.sum()
            loss.backward()  # compute the gradients

            # model.debug()

            # print(hyperlayer.values, hyperlayer.values.grad)

            optimizer.step()

            step += inputs.size(0)

            if epoch % plot_every == 0 and i == 0 and hyperlayer is not None:

                sigmas = list(hyperlayer.last_sigmas[0, :])
                values = list(hyperlayer.last_values[0, :])

                sigs.append(sigmas)
                vals.append(values)

                ax = plt.figure().add_subplot(111)

                for j, (s, v) in enumerate(zip(sigs, vals)):
                    s = [si.item() for si in s]
                    ax.scatter([j] * len(s),
                               s,
                               c=v,
                               linewidth=0,
                               alpha=0.2,
                               cmap='RdYlBu',
                               vmin=-1.0,
                               vmax=1.0)

                ax.set_aspect('auto')
                plt.ylim(ymin=0)
                util.clean()

                plt.savefig('sigmas.pdf')
                plt.savefig('sigmas.png')

                hyperlayer.plot(inputs[:10, ...])
                plt.savefig('mnist/attention.{:03}.pdf'.format(epoch))

            if epoch % plot_every == 0 and i == 0 and type(model) is ASHModel:
                #
                # print('post', model.lin1.weight.grad.data.mean())
                # print('pre', list(model.preprocess.modules())[1].weight.grad.data.mean())

                model.plot(inputs[:10, ...])
                plt.savefig('mnist/attention.glimpses.{:03}.pdf'.format(epoch))

        total = 0.0
        correct = 0.0

        model.eval()

        for i, data in enumerate(testloader, 0):

            # get the inputs
            inputs, labels = data

            if cuda:
                inputs, labels = inputs.cuda(), labels.cuda()

            # wrap them in Variables
            inputs, labels = Variable(inputs), Variable(labels)

            if not reinforce:
                outputs = model(inputs)
            else:
                outputs, _, _ = model(inputs)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = correct / total

        tbw.add_scalar('mnist1d/per-epoch-test-acc', accuracy, epoch)
        print('EPOCH {}: {} accuracy '.format(epoch, accuracy))

    LOG.info('Finished Training.')
示例#3
0
    def __init__(self,
                 in_size,
                 k,
                 adaptive=True,
                 gadditional=0,
                 radditional=0,
                 region=None,
                 sigma_scale=0.1,
                 num_values=-1,
                 min_sigma=0.0,
                 subsample=None,
                 preprocess=None):

        ci, hi, wi = in_size
        out_size = co, ho, wo = ci, k, k

        num_indices = k * k * ci * co

        indices = torch.LongTensor(list(np.ndindex((k, k, co))))[:, (2, 0, 1)]

        pixel_indices = indices[:, 1:3].clone()

        indices = torch.cat(
            [indices, indices[:, 0:1], indices[:, 1:3].clone().fill_(0.0)],
            dim=1)

        self.lc = [4, 5]
        self.lc_sizes = [(out_size + in_size)[i] for i in self.lc]

        super().__init__(in_rank=3,
                         out_size=(co, ho, wo),
                         temp_indices=indices,
                         learn_cols=self.lc,
                         gadditional=gadditional,
                         radditional=radditional,
                         region=region,
                         subsample=subsample)

        # scale to [0,1] in each dim
        pixel_indices = pixel_indices.float() / torch.FloatTensor(
            [k, k]).unsqueeze(0).expand_as(pixel_indices)

        self.register_buffer('pixel_indices', pixel_indices)

        assert (len(in_size) == 3)

        self.in_size = in_size
        self.k = k
        self.sigma_scale = sigma_scale
        self.num_values = num_values
        self.min_sigma = min_sigma
        self.out_size = out_size
        self.adaptive = adaptive

        if self.adaptive:
            activation = nn.ReLU()

            p1 = 4
            p2 = 2

            c, w, h = in_size

            if preprocess is not None:
                self.preprocess = preprocess
            else:
                # default preprocess
                hid = max(
                    1,
                    floor(floor(w / p1) / p2) * floor(floor(h / p1) / p2)) * 32

                self.preprocess = nn.Sequential(
                    #nn.MaxPool2d(kernel_size=4),
                    # util.Debug(lambda x: print(x.size())),
                    nn.Conv2d(c, 4, kernel_size=5, padding=2),
                    activation,
                    nn.Conv2d(4, 4, kernel_size=5, padding=2),
                    activation,
                    nn.MaxPool2d(kernel_size=p1),
                    nn.Conv2d(4, 16, kernel_size=5, padding=2),
                    activation,
                    nn.Conv2d(16, 16, kernel_size=5, padding=2),
                    activation,
                    nn.MaxPool2d(kernel_size=p2),
                    nn.Conv2d(16, 32, kernel_size=5, padding=2),
                    activation,
                    nn.Conv2d(32, 32, kernel_size=5, padding=2),
                    activation,
                    # util.Debug(lambda x : print(x.size())),
                    util.Flatten(),
                    nn.Linear(hid, 64),
                    nn.Dropout(DROPOUT),
                    activation,
                    nn.Linear(64, 64),
                    nn.Dropout(DROPOUT),
                    activation,
                    nn.Linear(64, 4),
                )

            # else:  # Use a small convnet to select the bounding box
            #     hid = max(1, floor(w/5) * floor(h/5) * c)
            #     self.preprocess = nn.Sequential(
            #         nn.Conv2d(c, c, kernel_size=5, padding=2),
            #         activation,
            #         nn.Conv2d(c, c, kernel_size=5, padding=2),
            #         activation,
            #         nn.Conv2d(c, c, kernel_size=5, padding=2),
            #         activation,
            #         nn.MaxPool2d(kernel_size=5),
            #         util.Flatten(),
            #         nn.Linear(hid, 16),
            #         activation,
            #         nn.Linear(16, 4)
            #     )

            self.register_buffer('bbox_offset',
                                 torch.FloatTensor([-1, 1, -1, 1]))
            # boost the size of the glimpse window for the reinforce option (so it has something to start with)

        else:  # if not adaptive
            self.bound = Parameter(torch.FloatTensor([-1, 1, -1, 1]))

        self.sigmas = Parameter(torch.randn((indices.size(0), )))

        if num_values > 0:
            self.values = Parameter(torch.randn((num_values, )))
        else:
            self.values = Parameter(torch.randn((indices.size(0), )))

        self.bias = Parameter(torch.zeros(*self.out_size))
示例#4
0
    def __init__(self,
                 shape,
                 k,
                 glimpses,
                 num_values,
                 min_sigma,
                 subsample,
                 hidden,
                 num_classes,
                 reinforce=False,
                 gadditional=None,
                 radditional=None,
                 region=None,
                 rfboost=2.0):
        super().__init__()

        self.reinforce = reinforce
        self.rfboost = rfboost
        self.num_glimpses = glimpses

        activation = nn.ReLU()

        p1 = 4
        p2 = 2

        ch1, ch2, ch3 = 32, 64, 128

        c, h, w = shape
        hid = max(1,
                  floor(floor(w / p1) / p2) * floor(floor(h / p1) / p2)) * ch3
        hidlin = 512

        self.preprocess = nn.Sequential(
            # nn.MaxPool2d(kernel_size=4),
            # util.Debug(lambda x: print(x.size())),
            nn.Conv2d(c, ch1, kernel_size=3, padding=1),
            activation,
            nn.Conv2d(ch1, ch1, kernel_size=3, padding=1),
            activation,
            nn.MaxPool2d(kernel_size=p1),
            nn.Conv2d(ch1, ch2, kernel_size=3, padding=1),
            activation,
            nn.Conv2d(ch2, ch2, kernel_size=3, padding=1),
            activation,
            nn.MaxPool2d(kernel_size=p2),
            nn.Conv2d(ch2, ch3, kernel_size=3, padding=1),
            activation,
            nn.Conv2d(ch3, ch3, kernel_size=3, padding=1),
            activation,
            #util.Debug(lambda x : print(x.size())),
            util.Flatten(),
            nn.Linear(hid, hidlin),
            nn.Dropout(DROPOUT),
            activation,
            nn.Linear(hidlin, hidlin),
            nn.Dropout(DROPOUT),
            activation,
            nn.Linear(hidlin, (4 if not self.reinforce else 12) * glimpses))

        # hid = max(1, floor(w / 5) * floor(h / 5) * c)
        # self.preprocess = nn.Sequential(
        #     nn.Conv2d(c, c, kernel_size=5, padding=2),
        #     activation,
        #     nn.Conv2d(c, c, kernel_size=5, padding=2),
        #     activation,
        #     nn.Conv2d(c, c, kernel_size=5, padding=2),
        #     activation,
        #     nn.MaxPool2d(kernel_size=5),
        #     util.Flatten(),
        #     nn.Linear(hid, 16),
        #     activation,
        #     nn.Linear(16, (4 if not self.reinforce else 8) * glimpses)
        # )

        self.hyperlayers = []

        for _ in range(glimpses):
            self.hyperlayers.append(
                SimpleImageLayer(shape,
                                 k=k,
                                 adaptive=True,
                                 gadditional=gadditional,
                                 radditional=radditional,
                                 region=region,
                                 num_values=num_values,
                                 min_sigma=min_sigma,
                                 subsample=subsample))

        self.lin1 = nn.Linear(k * k * shape[0] * glimpses, hidden)
        self.lin2 = nn.Linear(hidden, num_classes)

        self.k = k
        self.is_cuda = False

        # if self.reinforce:
        #     b = rfboost
        #     self.register_buffer('bbox_offset', torch.FloatTensor([-b, b, -b, b]))

        self.bbox_offset = Parameter(
            torch.FloatTensor([-1, 1, -1, 1]) * self.rfboost)
def go(arg):
    """

    :param arg:
    :return:
    """

    torch.set_printoptions(precision=10)

    """
    Load and organize the data
    """
    trans = torchvision.transforms.ToTensor()
    if arg.final:
        train = torchvision.datasets.MNIST(root=arg.data, train=True, download=True, transform=trans)
        trainloader = torch.utils.data.DataLoader(train, batch_size=arg.batch, shuffle=True, num_workers=2)

        test = torchvision.datasets.MNIST(root=arg.data, train=False, download=True, transform=trans)
        testloader = torch.utils.data.DataLoader(test, batch_size=arg.batch, shuffle=False, num_workers=2)

    else:
        NUM_TRAIN = 45000
        NUM_VAL = 5000
        total = NUM_TRAIN + NUM_VAL

        train = torchvision.datasets.MNIST(root=arg.data, train=True, download=True, transform=trans)

        trainloader = DataLoader(train, batch_size=arg.batch, sampler=util.ChunkSampler(0, NUM_TRAIN, total))
        testloader = DataLoader(train, batch_size=arg.batch, sampler=util.ChunkSampler(NUM_TRAIN, NUM_VAL, total))

    shape = (1, 28, 28)
    num_classes = 10

    train = {label: [] for label in range(10)}

    for inps, labels in trainloader:
        b, c, h, w = inps.size()
        for i in range(b):
            image = inps[i:i+1, :, :, :]
            label = labels[i].item()
            train[label].append(image)

    if arg.limit is not None:
        train = {label: imgs[:arg.limit] for label, imgs in train.items()}

    # train = {label: torch.cat(imgs, dim=0) for label, imgs in train}

    test = {label: [] for label in range(10)}
    for inps, labels in trainloader:
        b, c, h, w = inps.size()
        for i in range(b):
            image = inps[i:i+1, :, :, :]
            label = labels[i].item()
            test[label].append(image)

    # train = {label: torch.cat(imgs, dim=0) for label, imgs in train}
    del b, c, h, w

    torch.manual_seed(arg.seed)
    np.random.seed(arg.seed)
    random.seed(arg.seed)

    ndots = arg.iterations // arg.dot_every

    results = np.zeros((arg.reps, ndots))

    for r in range(arg.reps):
        print('starting {} out of {} repetitions'.format(r, arg.reps))
        util.makedirs('./mnist-sort/{}'.format( r))

        model = sort.SortLayer(arg.size, additional=arg.additional, sigma_scale=arg.sigma_scale,
                               sigma_floor=arg.min_sigma, certainty=arg.certainty)

        # bottom = nn.Linear(28*28, 32, bias=False)
        # bottom.weight.retain_grad()

        # top = nn.Linear(32, 1)
        # top.weight.retain_grad()

        # tokeys = nn.Sequential(
        #     util.Flatten(),
        #     bottom, nn.ReLU(),
        #     nn.Linear(32, 1)# , nn.BatchNorm1d(1)
        # )

        # - channel sizes
        c1, c2, c3 = 16, 64, 128
        h1, h2 = 256, 128

        tokeys = nn.Sequential(
            nn.Conv2d(1, c1, (3, 3), padding=1), nn.ReLU(),
            nn.Conv2d(c1, c1, (3, 3), padding=1), nn.ReLU(),
            nn.Conv2d(c1, c1, (3, 3), padding=1), nn.ReLU(),
            nn.BatchNorm2d(c1),
            nn.MaxPool2d((2, 2)),
            nn.Conv2d(c1, c2, (3, 3), padding=1), nn.ReLU(),
            nn.Conv2d(c2, c2, (3, 3), padding=1), nn.ReLU(),
            nn.Conv2d(c2, c2, (3, 3), padding=1), nn.ReLU(),
            nn.BatchNorm2d(c2),
            nn.MaxPool2d((2, 2)),
            nn.Conv2d(c2, c3, (3, 3), padding=1), nn.ReLU(),
            nn.Conv2d(c3, c3, (3, 3), padding=1), nn.ReLU(),
            nn.Conv2d(c3, c3, (3, 3), padding=1), nn.ReLU(),
            nn.BatchNorm2d(c3),
            nn.MaxPool2d((2, 2)),
            util.Flatten(),
            nn.Linear(9 * c3, h1), nn.ReLU(),
            nn.Linear(h1, h2), nn.ReLU(),
            nn.Linear(h2, 1)# , nn.BatchNorm1d(1),
        )

        if arg.cuda:
            model.cuda()
            tokeys.cuda()

        optimizer = optim.Adam(list(model.parameters()) + list(tokeys.parameters()), lr=arg.lr)

        for i in trange(arg.iterations):

            x, t, l = gen(arg.batch, train, arg.size)

            if arg.cuda:
                x, t = x.cuda(), t.cuda()

            x, t = Variable(x), Variable(t)

            optimizer.zero_grad()

            keys = tokeys(x.view(arg.batch * arg.size, 1, 28, 28))
            keys = keys.view(arg.batch, arg.size)

            # keys = keys * 0.0 + l

            keys.retain_grad()

            x = x.view(arg.batch, arg.size, -1)
            t = t.view(arg.batch, arg.size, -1)

            ys, ts, keys = model(x, keys=keys, target=t)

            if arg.loss == 'plain':
                # just compare the output to the target
                # loss = F.mse_loss(ys[-1], t) # compute the loss
                # loss = F.binary_cross_entropy(ys[-1].clamp(0, 1), t.clamp(0, 1))
                loss = util.xent(ys[-1], t).mean()
            elif arg.loss == 'means':
                # compare the output to the back-sorted target at each step
                loss = 0.0
                loss = loss + util.xent(ys[0], ts[0]).mean()
                loss = loss + util.xent(ts[-1], ts[-1]).mean()

                for d in range(1, len(ys)-1):
                    numbuckets = 2 ** d
                    bucketsize = arg.size // numbuckets

                    xb = ys[d][:, None, :, :].view(arg.batch, numbuckets, bucketsize, -1)
                    tb = ts[d][:, None, :, :].view(arg.batch, numbuckets, bucketsize, -1)

                    xb = xb.mean(dim=2)
                    tb = tb.mean(dim=2)

                    loss = loss + util.xent(xb, tb).mean() * bucketsize

            elif 'separate':
                # compare the output to the back-sorted target at each step
                loss = 0.0
                loss = loss + util.xent(ts[-1], ts[-1]).mean()

                for d in range(0, len(ys)):
                    loss = loss + util.xent(ys[d], ts[d]).mean()

            else:
                raise Exception('Loss {} not recognized.'.format(arg.loss))

            loss.backward()

            optimizer.step()

            tbw.add_scalar('mnist-sort/loss/{}/{}'.format(arg.size, r), loss.data.item(), i*arg.batch)

            # Plot intermediates, and targets
            if i % arg.plot_every == 0:

                optimizer.zero_grad()

                x, t, l = gen(arg.batch, train, arg.size)

                if arg.cuda:
                    x, t = x.cuda(), t.cuda()

                x, t = Variable(x), Variable(t)

                keys = tokeys(x.view(arg.batch * arg.size, 1, 28, 28))
                keys = keys.view(arg.batch, arg.size)

                # keys = keys * 0.0 + l

                x = x.view(arg.batch, arg.size, -1)
                t = t.view(arg.batch, arg.size, -1)

                ys, ts, _ = model(x, keys=keys, target=t)

                b, n, s = ys[0].size()

                if arg.loss == 'means':

                    for d in range(1, len(ys) - 1):
                        numbuckets = 2 ** d
                        bucketsize = arg.size // numbuckets

                        xb = ys[d][:, None, :, :].view(arg.batch, numbuckets, bucketsize, s)
                        tb = ts[d][:, None, :, :].view(arg.batch, numbuckets, bucketsize, s)

                        xb = xb.mean(dim=2, keepdim=True)\
                            .expand(arg.batch, numbuckets, bucketsize, s)\
                            .contiguous().view(arg.batch, n, s)
                        tb = tb.mean(dim=2, keepdim=True)\
                            .expand(arg.batch, numbuckets, bucketsize, s)\
                            .contiguous().view(arg.batch, n, s)

                        ys[d] = xb
                        ts[d] = tb

                md = int(np.log2(arg.size))
                plt.figure(figsize=(arg.size*2, md+1))

                c = 1
                for row in range(md + 1):
                    for col in range(arg.size*2):
                        ax = plt.subplot(md+1, arg.size*2, c)

                        images = ys[row] if col < arg.size else ts[row]
                        im = images[0].view(arg.size, 28, 28)[col%arg.size].data.cpu().numpy()

                        ax.imshow(im, cmap= 'bone_r' if col < arg.size else 'pink_r')

                        clean(ax)

                        c += 1

                plt.figtext(0.3, 0.95, "input", va="center", ha="center", size=15)
                plt.figtext(0.7, 0.95, "target", va="center", ha="center", size=15)

                plt.savefig('./mnist-sort/{}/intermediates.{:04}.pdf'.format(r, i))

            # Plot the progress
            if i % arg.plot_every == 0:

                optimizer.zero_grad()

                x, t, l = gen(arg.batch, train, arg.size)

                if arg.cuda:
                    x, t = x.cuda(), t.cuda()

                x, t = Variable(x), Variable(t)

                keys = tokeys(x.view(arg.batch * arg.size, 1, 28, 28))
                keys = keys.view(arg.batch, arg.size)
                # keys = keys * 0.01 + l
                keys.retain_grad()

                x = x.view(arg.batch, arg.size, -1)
                t = t.view(arg.batch, arg.size, -1)

                yt, _ = model(x, keys=keys, train=True)

                loss = F.mse_loss(yt, t)  # compute the loss

                loss.backward()

                yi, _ = model(x, keys=keys, train=False)

                input  = x[0].view(arg.size, 28, 28).data.cpu().numpy()
                target = t[0].view(arg.size, 28, 28).data.cpu().numpy()
                output_inf   = yi[0].view(arg.size, 28, 28).data.cpu().numpy()
                output_train = yt[0].view(arg.size, 28, 28).data.cpu().numpy()

                plt.figure(figsize=(arg.size*3, 4*3))
                for col in range(arg.size):

                    ax = plt.subplot(4, arg.size, col + 1)
                    ax.imshow(target[col], cmap='gray_r')
                    clean(ax)

                    if col == 0:
                        ax.set_ylabel('target')

                    ax = plt.subplot(4, arg.size, col + arg.size + 1)
                    ax.imshow(input[col], cmap='gray_r')
                    clean(ax)
                    ax.set_xlabel( '{:.2}, {:.2}'.format(keys[0, col], - keys.grad[0, col] ) )

                    if col == 0:
                        ax.set_ylabel('input')

                    ax = plt.subplot(4, arg.size, col + arg.size * 2 + 1)
                    ax.imshow(output_inf[col], cmap='gray_r')
                    clean(ax)

                    if col == 0:
                        ax.set_ylabel('inference')

                    ax = plt.subplot(4, arg.size, col + arg.size * 3 + 1)
                    ax.imshow(output_train[col], cmap='gray_r')
                    clean(ax)

                    if col == 0:
                        ax.set_ylabel('training')

                plt.savefig('./mnist-sort/{}/mnist.{:04}.pdf'.format(r, i))

                # plt.figure(figsize=(6, 2))
                # ax = plt.subplot(121)
                # ax.imshow(bottom.weight.data.view(28, 28), cmap='RdYlBu')
                # # ax.colorbar()
                # ax = plt.subplot(122)
                # ax.imshow(bottom.weight.grad.data.view(28, 28), cmap='RdYlBu')
                # # ax.title('{:.2}-{:.2}'.format(bottom.weight.grad.data.min(), bottom.weight.grad.data.max()))
                # plt.tight_layout()
                # plt.savefig('./mnist-sort/{}/weights.{:04}.pdf'.format(r, i))

                # sys.exit()

            if i % arg.dot_every == 0:
                """
                Compute the accuracy
                """
                NUM = 10_000
                tot = 0.0
                correct = 0.0
                with torch.no_grad():

                    losses = []
                    for ii in range(NUM//arg.batch):
                        x, t, l = gen(arg.batch, test, arg.size)

                        if arg.cuda:
                            x, t, l = x.cuda(), t.cuda(), l.cuda()

                        x, t, l = Variable(x), Variable(t), Variable(l)

                        keys = tokeys(x.view(arg.batch * arg.size, 1, 28, 28))
                        keys = keys.view(arg.batch, arg.size)

                        # Sort the keys, and sort the labels, and see if the resulting indices match
                        _, gold = torch.sort(l, dim=1)
                        _, mine = torch.sort(keys, dim=1)

                        tot += x.size(0)
                        correct += ((gold != mine).sum(dim=1) == 0).sum().item()

                    print('acc', correct/tot)

                    results[r, i//arg.dot_every] = np.mean(correct/tot)

                    tbw.add_scalar('mnist-sort/testloss/{}/{}'.format(arg.size, r), correct/tot, i * arg.batch)

    np.save('results.{}.np'.format(arg.size), results)
    print('experiments finished')

    plt.figure(figsize=(10, 5))
    ax = plt.gca()

    if results.shape[0] > 1:
        ax.errorbar(x=np.arange(ndots) * arg.dot_every, y=np.mean(results[:, :], axis=0),
                        yerr=np.std(results[:, :], axis=0),
                        label='size {0}x{0}, r={1}'.format(arg.size, arg.reps))
    else:
        ax.plot(np.arange(ndots) * arg.dot_every, np.mean(results[:, :], axis=0),
                        label='size {0}x{0}'.format(arg.size))

    ax.legend()

    util.basic(ax)

    ax.spines['bottom'].set_position('zero')
    ax.set_ylim(0.0, 1.0)
#    ax.set_xlim(0.0, 100.0)

    plt.xlabel('iterations')
    plt.ylabel('error')

    plt.savefig('./quicksort/result.png')
    plt.savefig('./quicksort/result.pdf')
    def __init__(self,
                 data_size,
                 k,
                 emb_size=16,
                 radd=32,
                 gadd=32,
                 range=128,
                 min_sigma=0.0,
                 directed=True,
                 fix_value=False,
                 encoder=False):
        super().__init__()

        self.data_shape = data_size
        n, c, h, w = data_size

        # - channel sizes
        c1, c2, c3 = 16, 32, 64
        h1, h2, h3 = 256, 128, 64

        # upmode = 'bilinear'
        # self.decoder_conv = nn.Sequential(
        #     nn.Linear(h3, 4 * 4 * c3), nn.ReLU(),
        #     util.Reshape((c3, 4, 4)),
        #     nn.ConvTranspose2d(c3, c3, (3, 3), padding=1), nn.ReLU(),
        #     nn.ConvTranspose2d(c3, c3, (3, 3), padding=1), nn.ReLU(),
        #     nn.ConvTranspose2d(c3, c2, (3, 3), padding=1), nn.ReLU(),
        #     nn.Upsample(scale_factor=3, mode=upmode),
        #     nn.ConvTranspose2d(c2, c2, (3, 3), padding=1), nn.ReLU(),
        #     nn.ConvTranspose2d(c2, c2, (3, 3), padding=1), nn.ReLU(),
        #     nn.ConvTranspose2d(c2, c1, (3, 3), padding=1), nn.ReLU(),
        #     nn.Upsample(scale_factor=2, mode=upmode),
        #     nn.ConvTranspose2d(c1, c1, (5, 5), padding=0), nn.ReLU(),
        #     nn.ConvTranspose2d(c1, c1, (3, 3), padding=1), nn.ReLU(),
        #     nn.ConvTranspose2d(c1, 1,  (3, 3), padding=1), nn.Sigmoid(),
        #     # util.Debug(lambda x : print(x.size()))
        # )
        #
        # self.decoder_lin = nn.Sequential(
        #     nn.Linear(emb_size, h3), nn.ReLU(),
        #     nn.Linear(h3, h2), nn.ReLU(),
        #     nn.Linear(h2, h3),
        # )
        #
        # self.decoder = nn.Sequential(
        #     self.decoder_lin,
        #     self.decoder_conv
        # )

        # Encoder is only used during pretraining
        self.encoder = nn.Sequential(util.Flatten(), nn.Linear(28 * 28, h2),
                                     nn.ReLU(), nn.Linear(h2, h3), nn.ReLU(),
                                     nn.Linear(h3, emb_size * 2))

        self.decoder = nn.Sequential(nn.Linear(emb_size, h3), nn.ReLU(),
                                     nn.Linear(h3, h2), nn.ReLU(),
                                     nn.Linear(h2, h3), nn.ReLU(),
                                     nn.Linear(h3, 28 * 28), nn.Sigmoid(),
                                     util.Reshape((1, 28, 28)))

        # self.encoder = None
        # if encoder:
        #     self.encoder_conv = nn.Sequential(
        #         nn.Conv2d(1, c1, (3, 3), padding=1), nn.ReLU(),
        #         nn.Conv2d(c1, c1, (3, 3), padding=1), nn.ReLU(),
        #         nn.Conv2d(c1, c1, (3, 3), padding=1), nn.ReLU(),
        #         nn.MaxPool2d((2, 2)),
        #         nn.Conv2d(c1, c2, (3, 3), padding=1), nn.ReLU(),
        #         nn.Conv2d(c2, c2, (3, 3), padding=1), nn.ReLU(),
        #         nn.Conv2d(c2, c2, (3, 3), padding=1), nn.ReLU(),
        #         nn.MaxPool2d((2, 2)),
        #         nn.Conv2d(c2, c3, (3, 3), padding=1), nn.ReLU(),
        #         nn.Conv2d(c3, c3, (3, 3), padding=1), nn.ReLU(),
        #         nn.Conv2d(c3, c3, (3, 3), padding=1), nn.ReLU(),
        #         nn.MaxPool2d((2, 2)),
        #         util.Flatten(),
        #         nn.Linear(9 * c3, h1)
        #     )
        #
        #     self.encoder_lin = nn.Sequential(
        #         util.Flatten(),
        #         nn.Linear(h1, h2), nn.ReLU(),
        #         nn.Linear(h2, h3), nn.ReLU(),
        #         nn.Linear(h3, emb_size * 2),
        #     )
        #
        #     self.encoder = nn.Sequential(
        #         self.encoder_conv,
        #         self.encoder_lin
        #     )
        #

        self.adj = MatrixHyperlayer(n,
                                    n,
                                    k,
                                    radditional=radd,
                                    gadditional=gadd,
                                    region=(range, ),
                                    min_sigma=min_sigma,
                                    fix_value=fix_value)

        self.embedding = Parameter(torch.randn(n, emb_size))

        self.emb_size = emb_size
示例#7
0
def load_model(name, big=True):

    activation = None

    if name == 'relu':
        activation = Relu
    elif name == 'sigmoid':
        activation = Sigmoid
    elif name == 'relu-lambda':
        activation = Relu
    elif name == 'sigmoid-lambda':
        activation = Sigmoid
    elif name == 'relu-sigloss':
        activation = Relu
    elif name == 'sigmoid-sigloss':
        activation = Sigmoid
    elif name == 'bn-relu':
        activation = BNRelu
    elif name == 'relu-bn':
        activation = ReluBN
    elif name == 'bn-sigmoid':
        activation = BNSigmoid
    elif name == 'sigmoid-bn':
        activation = SigmoidBN
    else:
        raise Exception('Model "{}" not recognized.'.format(name))

    if big:
        model = Sequential(
            nn.Conv2d(in_channels=3,
                      out_channels=16,
                      kernel_size=5,
                      stride=1,
                      padding=2),  # 0
            activation(16),
            nn.Conv2d(in_channels=16,
                      out_channels=16,
                      kernel_size=5,
                      stride=1,
                      padding=2),  # 2
            activation(16),
            nn.Conv2d(in_channels=16,
                      out_channels=16,
                      kernel_size=5,
                      stride=1,
                      padding=2),  # 4
            activation(16),  # 5
            nn.MaxPool2d(stride=2, kernel_size=2),  # 6
            nn.Conv2d(in_channels=16,
                      out_channels=32,
                      kernel_size=5,
                      stride=1,
                      padding=2),  # 7
            activation(32),  # 8
            nn.Conv2d(in_channels=32,
                      out_channels=32,
                      kernel_size=5,
                      stride=1,
                      padding=2),  # 9
            activation(32),
            nn.Conv2d(in_channels=32,
                      out_channels=32,
                      kernel_size=5,
                      stride=1,
                      padding=2),  # 11
            activation(32),
            nn.MaxPool2d(stride=2, kernel_size=2),
            nn.Conv2d(in_channels=32,
                      out_channels=64,
                      kernel_size=5,
                      stride=1,
                      padding=2),  # 14
            activation(64),
            nn.Conv2d(in_channels=64,
                      out_channels=64,
                      kernel_size=5,
                      stride=1,
                      padding=2),  # 16
            activation(64),
            nn.Conv2d(in_channels=64,
                      out_channels=64,
                      kernel_size=5,
                      stride=1,
                      padding=2),  # 18
            activation(64),
            nn.MaxPool2d(stride=2, kernel_size=2),
            util.Flatten(),
            nn.Linear(4 * 4 * 64, 10),  # 22
            nn.Softmax())
    else:
        model = Sequential(
            nn.Conv2d(in_channels=3,
                      out_channels=8,
                      kernel_size=3,
                      stride=1,
                      padding=1),  # 0
            activation(8),
            nn.Conv2d(in_channels=8,
                      out_channels=8,
                      kernel_size=3,
                      stride=1,
                      padding=1),  # 2
            activation(8),
            nn.Conv2d(in_channels=8,
                      out_channels=8,
                      kernel_size=3,
                      stride=1,
                      padding=1),  # 2
            activation(8),
            nn.MaxPool2d(stride=2, kernel_size=2),  # 6
            nn.Conv2d(in_channels=8,
                      out_channels=16,
                      kernel_size=3,
                      stride=1,
                      padding=1),  # 7
            activation(16),  # 8
            nn.Conv2d(in_channels=16,
                      out_channels=16,
                      kernel_size=3,
                      stride=1,
                      padding=1),  # 9
            activation(16),
            nn.Conv2d(in_channels=16,
                      out_channels=16,
                      kernel_size=3,
                      stride=1,
                      padding=1),  # 9
            activation(16),
            nn.MaxPool2d(stride=2, kernel_size=2),
            nn.Conv2d(in_channels=16,
                      out_channels=32,
                      kernel_size=3,
                      stride=1,
                      padding=1),  # 14
            activation(32),
            nn.Conv2d(in_channels=32,
                      out_channels=32,
                      kernel_size=3,
                      stride=1,
                      padding=1),  # 16
            activation(32),
            nn.Conv2d(in_channels=32,
                      out_channels=32,
                      kernel_size=3,
                      stride=1,
                      padding=1),  # 16
            activation(32),
            nn.MaxPool2d(stride=2, kernel_size=2),
            # util.Debug(lambda x: print(x.size())),
            util.Flatten(),
            nn.Linear(4 * 4 * 32, 10),  # 22
            nn.Softmax())

    return model