def __init__(self, batch_size, iterations):
     self.batch_size = batch_size
     self.iter = iterations
     self.len = self.batch_size * self.iter + self.batch_size
     dh = MNISTDataHandler("../DIRNet-Keras/MNIST_data", is_train=True)
     batch_x, batch_y = dh.sample_pair(self.batch_size)
     self.xy_batch = np.concatenate([batch_x, batch_y], axis=1)
     for i in range(self.iter):
         batch_x, batch_y = dh.sample_pair(self.batch_size)
         self.xy_batch = np.concatenate(
             [self.xy_batch,
              np.concatenate([batch_x, batch_y], axis=1)],
             axis=0)
     print("length-->", self.xy_batch.shape[0])
Пример #2
0
def main():
    sess = tf.Session()
    config = get_config(is_train=False)
    mkdir(config.result_dir)

    reg = DIRNet(sess, config, "DIRNet", is_train=False)
    reg.restore(config.ckpt_dir)
    dh = MNISTDataHandler("MNIST_data", is_train=False)

    for i in range(1):  #10
        result_i_dir = config.result_dir + "/{}".format(i)
        mkdir(result_i_dir)

        batch_x, batch_y = dh.sample_pair(config.batch_size, i)
        reg.deploy(result_i_dir, batch_x, batch_y)
Пример #3
0
def main():
    sess = tf.Session()
    config = get_config(is_train=True)
    mkdir(config.tmp_dir)
    mkdir(config.ckpt_dir)
    q = 100
    reg = DIRNet(sess, config, "DIRNet", is_train=True)
    dh = MNISTDataHandler("MNIST_data", is_train=True)

    for i in range(config.iteration):
        batch_x, batch_y = dh.sample_pair(config.batch_size)
        loss = reg.fit(batch_x, batch_y)
        print("iter {:>6d} : {}".format(i + 1, loss))

        if (i + 1) % 1000 == 0:
            reg.deploy(config.tmp_dir, batch_x, batch_y)
            reg.save(config.ckpt_dir)
Пример #4
0
    from utils import get_logger
    logger = get_logger()

    model = DIRNetDeform(logger, device='cpu')
    resultDir = 'imgs'
    os.makedirs(resultDir, exist_ok=True)

    # set your checkpoint here
    checkpointFileName = 'dirConvnet_model-zncc_SSD_unet-2_gpu_deformConv_affine.ckpt'
    #    checkpointFileName = 'dirConvnet_model-zncc_unet_deformConv_affine-mnist_gpu.ckpt'

    #
    checkpoint = torch.load(checkpointFileName, map_location={'cuda:0': 'cpu'})
    model.load_state_dict(checkpoint['model_state_dict'])
    shape = [28, 28]
    dh = MNISTDataHandler("../DIRNet-Keras/MNIST_data", is_train=False)

    #    dh = FashionMNISTDataHandler("../fashionMnist", is_train=False)

    Iwarped_mean = 0
    targetImage_mean = 0
    sourceImage_mean = 0
    length = 25

    for i in range(length):
        batch_x, batch_y = dh.sample_pair(1)
        input2model = torch.FloatTensor(
            np.concatenate([batch_x, batch_y], axis=1))
        model.eval()

        Iwarped = model(input2model)
Пример #5
0
def main():
    config = get_config(is_train=True)
    mkdir(config.tmp_dir)
    mkdir(config.ckpt_dir)

    model = DIRNet(config)

    transform = tf.Compose([tf.Resize([16, 16]), tf.ToTensor()])
    train_loader = DataLoader(MNIST('mnist',
                                    train=True,
                                    download=True,
                                    transform=transform),
                              batch_size=train_batch)
    test_loader = DataLoader(MNIST('mnist',
                                   train=False,
                                   download=True,
                                   transform=transform),
                             batch_size=test_batch)
    for batch, (data, label) in enumerate(train_loader):
        if batch == 0:
            num_images = 16000
            # num_images = 300
            digit_0 = data.index_select(0,
                                        label.eq(0).nonzero().squeeze())[:3000]
            digit_1 = data.index_select(0,
                                        label.eq(1).nonzero().squeeze())[:3000]
            digit_2 = data.index_select(0,
                                        label.eq(2).nonzero().squeeze())[:3000]
            digit_3 = data.index_select(0,
                                        label.eq(3).nonzero().squeeze())[:3000]
            digit_4 = data.index_select(0,
                                        label.eq(4).nonzero().squeeze())[:3000]
            digit_5 = data.index_select(0,
                                        label.eq(5).nonzero().squeeze())[:3000]
            digit_6 = data.index_select(0,
                                        label.eq(6).nonzero().squeeze())[:3000]
            digit_7 = data.index_select(0,
                                        label.eq(7).nonzero().squeeze())[:3000]
            digit_8 = data.index_select(0,
                                        label.eq(8).nonzero().squeeze())[:3000]
            digit_9 = data.index_select(0,
                                        label.eq(9).nonzero().squeeze())[:3000]

    digit = torch.stack([
        digit_0, digit_1, digit_2, digit_3, digit_4, digit_5, digit_6, digit_7,
        digit_8, digit_9
    ],
                        dim=0)

    for batch, (data, label) in enumerate(test_loader):
        if batch == 0:
            num_images = 16000
            # num_images = 300
            digit_0_t = data.index_select(
                0,
                label.eq(0).nonzero().squeeze())[:500]
            digit_1_t = data.index_select(
                0,
                label.eq(1).nonzero().squeeze())[:500]
            digit_2_t = data.index_select(
                0,
                label.eq(2).nonzero().squeeze())[:500]
            digit_3_t = data.index_select(
                0,
                label.eq(3).nonzero().squeeze())[:500]
            digit_4_t = data.index_select(
                0,
                label.eq(4).nonzero().squeeze())[:500]
            digit_5_t = data.index_select(
                0,
                label.eq(5).nonzero().squeeze())[:500]
            digit_6_t = data.index_select(
                0,
                label.eq(6).nonzero().squeeze())[:500]
            digit_7_t = data.index_select(
                0,
                label.eq(7).nonzero().squeeze())[:500]
            digit_8_t = data.index_select(
                0,
                label.eq(8).nonzero().squeeze())[:500]
            digit_9_t = data.index_select(
                0,
                label.eq(9).nonzero().squeeze())[:500]

    digit_t = torch.stack([
        digit_0_t, digit_1_t, digit_2_t, digit_3_t, digit_4_t, digit_5_t,
        digit_6_t, digit_7_t, digit_8_t, digit_9_t
    ],
                          dim=0)

    optim = torch.optim.Adam(model.parameters(), lr=config.lr)
    scheduler = StepLR(optim, step_size=200, gamma=0.5)

    train_pr = MNISTDataHandler(digit)
    test_pr = MNISTDataHandler(digit_t)

    total_loss = 0
    for i in range(config.iteration):

        batch_x, batch_y = train_pr.sample_pair(config.batch_size)
        optim.zero_grad()
        _, loss = model(batch_x, batch_y)
        loss.backward()
        optim.step()
        scheduler.step()
        total_loss += loss

        if (i + 1) % 100 == 0:
            print("iter {:>6d} : {}".format(i + 1, total_loss))
            total_loss = 0
            batch_x, batch_y = test_pr.sample_pair(config.batch_size)
            model.deploy(config.tmp_dir, batch_x, batch_y)
Пример #6
0
  ana = BaseStructure()

  #ana.addBranchEntropy([0,1,2,3,4])
  #ana.addBranchEntropy([0,1,2,3,4])
  #ana.addBranchEntropy([0,1,2,3,4])
  #ana.addBranchEntropy([0,1,2,3,4])
  #ana.addBranchEntropy([0,1,2,4,4])



  ana.agCizdir()


  return 0

  

  

if __name__ == "__main__":
  dh = MNISTDataHandler("MNIST_data", is_train=True)
  main()

# TODO
# 1. encode image forinput
# 2. build tree from input
# 3. display tree
# 4. clasify
# 5. deeper thinking
# 6.