def train_phase(predictor, train, valid, args):

    # setup iterators
    train_iter = iterators.SerialIterator(train, args.batchsize)
    valid_iter = iterators.SerialIterator(valid,
                                          args.batchsize,
                                          repeat=False,
                                          shuffle=False)

    # setup a model
    device = torch.device(args.gpu)

    model = Classifier(predictor)
    model.to(device)

    # setup an optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 weight_decay=max(args.decay, 0))

    # setup a trainer
    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                model,
                                                device=device)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    trainer.extend(extensions.Evaluator(valid_iter, model, device=args.gpu))

    # trainer.extend(DumpGraph(model, 'main/loss'))

    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
    trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))

    trainer.extend(extensions.LogReport())

    if args.plot and extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                  'epoch',
                                  file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch',
                file_name='accuracy.png'))

    trainer.extend(
        extensions.PrintReport([
            'epoch', 'iteration', 'main/loss', 'validation/main/loss',
            'main/accuracy', 'validation/main/accuracy', 'elapsed_time'
        ]))

    trainer.extend(extensions.ProgressBar())

    if args.resume:
        trainer.load_state_dict(torch.load(args.resume))

    trainer.run()

    torch.save(predictor.state_dict(), os.path.join(args.out, 'predictor.pth'))
示例#2
0
    def test_iterator_repeat(self):
        dataset = [1, 2, 3]
        it = iterators.SerialIterator(dataset, 2, **self.options)
        for i in range(3):
            self.assertEqual(it.epoch, i)
            self.assertAlmostEqual(it.epoch_detail, i + 0 / 6)
            if i == 0:
                self.assertIsNone(it.previous_epoch_detail)
            else:
                self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6)
            batch1 = it.next()
            self.assertEqual(len(batch1), 2)
            self.assertIsInstance(batch1, list)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, i + 2 / 6)
            self.assertAlmostEqual(it.previous_epoch_detail, i + 0 / 6)
            batch2 = it.next()
            self.assertEqual(len(batch2), 2)
            self.assertIsInstance(batch2, list)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, i + 4 / 6)
            self.assertAlmostEqual(it.previous_epoch_detail, i + 2 / 6)
            batch3 = it.next()
            self.assertEqual(len(batch3), 2)
            self.assertIsInstance(batch3, list)
            self.assertTrue(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, i + 6 / 6)
            self.assertAlmostEqual(it.previous_epoch_detail, i + 4 / 6)

            self.assertEqual(
                sorted(batch1 + batch2 + batch3), [1, 1, 2, 2, 3, 3])
示例#3
0
    def test_invalid_order_sampler(self):
        dataset = [1, 2, 3, 4, 5, 6]

        with self.assertRaises(ValueError):
            it = iterators.SerialIterator(
                dataset, 6, order_sampler=InvalidOrderSampler())
            it.next()
示例#4
0
 def test_iterator_repeat(self):
     dataset = [1, 2, 3, 4, 5, 6]
     it = iterators.SerialIterator(dataset, 2, shuffle=self.shuffle,
                                   order_sampler=self.order_sampler)
     for i in range(3):
         self.assertEqual(it.epoch, i)
         self.assertAlmostEqual(it.epoch_detail, i + 0 / 6)
         if i == 0:
             self.assertIsNone(it.previous_epoch_detail)
         else:
             self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6)
         batch1 = it.next()
         self.assertEqual(len(batch1), 2)
         self.assertFalse(it.is_new_epoch)
         self.assertAlmostEqual(it.epoch_detail, i + 2 / 6)
         self.assertAlmostEqual(it.previous_epoch_detail, i + 0 / 6)
         batch2 = it.next()
         self.assertEqual(len(batch2), 2)
         self.assertFalse(it.is_new_epoch)
         self.assertAlmostEqual(it.epoch_detail, i + 4 / 6)
         self.assertAlmostEqual(it.previous_epoch_detail, i + 2 / 6)
         batch3 = it.next()
         self.assertEqual(len(batch3), 2)
         self.assertTrue(it.is_new_epoch)
         self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
         self.assertAlmostEqual(it.epoch_detail, i + 6 / 6)
         self.assertAlmostEqual(it.previous_epoch_detail, i + 4 / 6)
示例#5
0
    def test_iterator_repeat_not_even(self):
        dataset = [1, 2, 3, 4, 5]
        it = iterators.SerialIterator(dataset, 2, shuffle=False)

        self.assertEqual(it.epoch, 0)
        self.assertAlmostEqual(it.epoch_detail, 0 / 5)
        self.assertIsNone(it.previous_epoch_detail)
        self.assertEqual(it.next(), [1, 2])
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 2 / 5)
        self.assertAlmostEqual(it.previous_epoch_detail, 0 / 5)
        self.assertEqual(it.next(), [3, 4])
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 4 / 5)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 5)
        self.assertEqual(it.next(), [5, 1])
        self.assertTrue(it.is_new_epoch)
        self.assertEqual(it.epoch, 1)
        self.assertAlmostEqual(it.epoch_detail, 6 / 5)
        self.assertAlmostEqual(it.previous_epoch_detail, 4 / 5)

        self.assertEqual(it.next(), [2, 3])
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 8 / 5)
        self.assertAlmostEqual(it.previous_epoch_detail, 6 / 5)
        self.assertEqual(it.next(), [4, 5])
        self.assertTrue(it.is_new_epoch)
        self.assertEqual(it.epoch, 2)
        self.assertAlmostEqual(it.epoch_detail, 10 / 5)
        self.assertAlmostEqual(it.previous_epoch_detail, 8 / 5)
示例#6
0
    def test_iterator_repeat_not_even(self):
        dataset = [1, 2, 3, 4, 5]
        it = iterators.SerialIterator(dataset, 2, shuffle=self.shuffle,
                                      order_sampler=self.order_sampler)

        batches = sum([it.next() for _ in range(5)], [])
        self.assertEqual(sorted(batches[:5]), dataset)
        self.assertEqual(sorted(batches[5:]), dataset)
示例#7
0
    def test_no_same_indices_order_sampler(self):
        dataset = [1, 2, 3, 4, 5, 6]
        batchsize = 5

        it = iterators.SerialIterator(
            dataset, batchsize,
            order_sampler=NoSameIndicesOrderSampler(batchsize))
        for _ in range(5):
            batch = it.next()
            self.assertEqual(len(numpy.unique(batch)), batchsize)
示例#8
0
    def test_iterator_not_repeat(self):
        dataset = [1, 2, 3, 4, 5, 6]
        it = iterators.SerialIterator(dataset, 2, repeat=False,
                                      shuffle=self.shuffle,
                                      order_sampler=self.order_sampler)

        batches = sum([it.next() for _ in range(3)], [])
        self.assertEqual(sorted(batches), dataset)
        for _ in range(2):
            self.assertRaises(StopIteration, it.next)
示例#9
0
    def setUp(self):
        self.data = [torch.empty(3, 4).uniform_(-1, 1) for _ in range(2)]

        self.iterator = iterators.SerialIterator(self.data,
                                                 1,
                                                 repeat=False,
                                                 shuffle=False)
        self.target = DummyModel(self)
        self.evaluator = extensions.Evaluator(self.iterator, {},
                                              eval_func=self.target,
                                              progress_bar=True)
def test_phase(predictor, test, args):

    # setup an iterator
    test_iter = iterators.SerialIterator(test,
                                         args.batchsize,
                                         repeat=False,
                                         shuffle=False)

    # setup an inferencer
    predictor.load_state_dict(
        torch.load(os.path.join(args.out, 'predictor.pth')))

    model = MCSampler(predictor,
                      mc_iteration=args.mc_iteration,
                      activation=partial(torch.softmax, dim=1),
                      reduce_mean=partial(torch.argmax, dim=1),
                      reduce_var=partial(torch.mean, dim=1))

    device = torch.device(args.gpu)
    model.to(device)

    infer = Inferencer(test_iter, model, device=args.gpu)

    pred, uncert = infer.run()

    # evaluate
    os.makedirs(args.out, exist_ok=True)

    match = pred == test.labels
    accuracy = np.sum(match) / len(match)

    arr = [uncert[match], uncert[np.logical_not(match)]]

    plt.rcParams['font.size'] = 18
    plt.figure(figsize=(13, 5))
    ax = sns.violinplot(data=arr,
                        inner='quartile',
                        palette='Blues',
                        orient='h',
                        cut=0)
    ax.set_xlabel('Predicted variance')
    ax.set_yticklabels([
        'Correct prediction\n(n=%d)' % len(arr[0]),
        'Wrong prediction\n(n=%d)' % len(arr[1])
    ])
    plt.title('Accuracy=%.3f' % accuracy)
    plt.tight_layout()
    plt.savefig(os.path.join(args.out, 'eval.png'))
    plt.close()
示例#11
0
    def test_iterator_state_dict_backward_compat(self):
        dataset = [1, 2, 3, 4, 5, 6]
        it = iterators.SerialIterator(dataset, 2, shuffle=self.shuffle,
                                      order_sampler=self.order_sampler)

        self.assertEqual(it.epoch, 0)
        self.assertAlmostEqual(it.epoch_detail, 0 / 6)
        self.assertIsNone(it.previous_epoch_detail)
        batch1 = it.next()
        self.assertEqual(len(batch1), 2)
        self.assertIsInstance(batch1, list)
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 2 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 0 / 6)
        batch2 = it.next()
        self.assertEqual(len(batch2), 2)
        self.assertIsInstance(batch2, list)
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 4 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6)

        state_dict = copy.deepcopy(it.state_dict())

        it = iterators.SerialIterator(dataset, 2)
        it.load_state_dict(state_dict)
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 4 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6)

        batch3 = it.next()
        self.assertEqual(len(batch3), 2)
        self.assertIsInstance(batch3, list)
        self.assertTrue(it.is_new_epoch)
        self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
        self.assertAlmostEqual(it.epoch_detail, 6 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 4 / 6)
示例#12
0
    def test_iterator_not_repeat_not_even(self):
        dataset = [1, 2, 3, 4, 5]
        it = iterators.SerialIterator(dataset, 2, repeat=False, shuffle=False)

        self.assertAlmostEqual(it.epoch_detail, 0 / 5)
        self.assertIsNone(it.previous_epoch_detail)
        self.assertEqual(it.next(), [1, 2])
        self.assertAlmostEqual(it.epoch_detail, 2 / 5)
        self.assertAlmostEqual(it.previous_epoch_detail, 0 / 5)
        self.assertEqual(it.next(), [3, 4])
        self.assertAlmostEqual(it.epoch_detail, 4 / 5)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 5)
        self.assertEqual(it.next(), [5])
        self.assertTrue(it.is_new_epoch)
        self.assertEqual(it.epoch, 1)
        self.assertAlmostEqual(it.epoch_detail, 5 / 5)
        self.assertAlmostEqual(it.previous_epoch_detail, 4 / 5)
        self.assertRaises(StopIteration, it.next)
示例#13
0
    def test_iterator_not_repeat(self):
        dataset = [1, 2, 3, 4, 5, 6]
        it = iterators.SerialIterator(dataset, 2, repeat=False, shuffle=False)

        self.assertAlmostEqual(it.epoch_detail, 0 / 6)
        self.assertIsNone(it.previous_epoch_detail)
        self.assertEqual(it.next(), [1, 2])
        self.assertAlmostEqual(it.epoch_detail, 2 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 0 / 6)
        self.assertEqual(it.next(), [3, 4])
        self.assertAlmostEqual(it.epoch_detail, 4 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6)
        self.assertEqual(it.next(), [5, 6])
        self.assertTrue(it.is_new_epoch)
        self.assertEqual(it.epoch, 1)
        self.assertAlmostEqual(it.epoch_detail, 6 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 4 / 6)
        for i in range(2):
            self.assertRaises(StopIteration, it.next)
示例#14
0
    def test_iterator_not_repeat_not_even(self):
        dataset = [1, 2, 3, 4, 5]
        it = iterators.SerialIterator(dataset, 2, repeat=False,
                                      shuffle=self.shuffle,
                                      order_sampler=self.order_sampler)

        self.assertAlmostEqual(it.epoch_detail, 0 / 5)
        self.assertIsNone(it.previous_epoch_detail)
        batch1 = it.next()
        self.assertAlmostEqual(it.epoch_detail, 2 / 5)
        self.assertAlmostEqual(it.previous_epoch_detail, 0 / 5)
        batch2 = it.next()
        self.assertAlmostEqual(it.epoch_detail, 4 / 5)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 5)
        batch3 = it.next()
        self.assertAlmostEqual(it.epoch_detail, 5 / 5)
        self.assertAlmostEqual(it.previous_epoch_detail, 4 / 5)
        self.assertRaises(StopIteration, it.next)

        self.assertEqual(len(batch3), 1)
        self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
示例#15
0
 def test_iterator_repeat(self):
     dataset = [1, 2, 3, 4, 5, 6]
     it = iterators.SerialIterator(dataset, 2, shuffle=False)
     for i in range(3):
         self.assertEqual(it.epoch, i)
         self.assertAlmostEqual(it.epoch_detail, i + 0 / 6)
         if i == 0:
             self.assertIsNone(it.previous_epoch_detail)
         else:
             self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6)
         self.assertEqual(it.next(), [1, 2])
         self.assertFalse(it.is_new_epoch)
         self.assertAlmostEqual(it.epoch_detail, i + 2 / 6)
         self.assertAlmostEqual(it.previous_epoch_detail, i + 0 / 6)
         self.assertEqual(it.next(), [3, 4])
         self.assertFalse(it.is_new_epoch)
         self.assertAlmostEqual(it.epoch_detail, i + 4 / 6)
         self.assertAlmostEqual(it.previous_epoch_detail, i + 2 / 6)
         self.assertEqual(it.next(), [5, 6])
         self.assertTrue(it.is_new_epoch)
         self.assertAlmostEqual(it.epoch_detail, i + 6 / 6)
         self.assertAlmostEqual(it.previous_epoch_detail, i + 4 / 6)
    def test_iterator_compatibilty(self):
        dataset = [1, 2, 3, 4, 5, 6]

        iters = (
            lambda: iterators.SerialIterator(dataset, 2),
            lambda: iterators.MultiprocessIterator(dataset, 2, **self.options),
        )

        for it_before, it_after in itertools.permutations(iters, 2):
            it = it_before()

            self.assertEqual(it.epoch, 0)
            self.assertAlmostEqual(it.epoch_detail, 0 / 6)
            batch1 = it.next()
            self.assertEqual(len(batch1), 2)
            self.assertIsInstance(batch1, list)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, 2 / 6)
            batch2 = it.next()
            self.assertEqual(len(batch2), 2)
            self.assertIsInstance(batch2, list)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, 4 / 6)

            state_dict = copy.deepcopy(it.state_dict())

            it = it_after()
            it.load_state_dict(state_dict)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, 4 / 6)

            batch3 = it.next()
            self.assertEqual(len(batch3), 2)
            self.assertIsInstance(batch3, list)
            self.assertTrue(it.is_new_epoch)
            self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
            self.assertAlmostEqual(it.epoch_detail, 6 / 6)
示例#17
0
 def test_iterator_shuffle_nondivisible(self):
     dataset = list(range(10))
     it = iterators.SerialIterator(dataset, 3)
     out = sum([it.next() for _ in range(7)], [])
     self.assertNotEqual(out[0:10], out[10:20])
示例#18
0
def test_phase(predictor, test, args):

    # setup an iterator
    test_iter = iterators.SerialIterator(test,
                                         args.batchsize,
                                         repeat=False,
                                         shuffle=False)

    # setup an inferencer
    predictor.load_state_dict(
        torch.load(os.path.join(args.out, 'predictor.pth')))

    model = MCSampler(predictor,
                      mc_iteration=args.mc_iteration,
                      activation=[lambda x: x, torch.exp],
                      reduce_mean=None,
                      reduce_var=None)

    device = torch.device(args.gpu)
    model.to(device)

    infer = Inferencer(test_iter, model, device=args.gpu)

    pred, epistemic_uncert, aleatory_uncert, _ = infer.run()

    # visualize
    x = test.x.ravel()
    t = test.t.ravel()
    pred = pred.ravel()
    epistemic_uncert = epistemic_uncert.ravel()
    aleatory_uncert = aleatory_uncert.ravel()

    plt.rcParams['font.size'] = 18
    plt.figure(figsize=(13, 5))
    ax = sns.scatterplot(x=x, y=pred, color='blue', s=75)
    ax.errorbar(x,
                pred,
                yerr=epistemic_uncert,
                fmt='none',
                capsize=10,
                ecolor='gray',
                linewidth=1.5)
    ax.plot(x, t, color='red', linewidth=1.5)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_xlim(-10, 10)
    ax.set_ylim(-15, 15)
    plt.legend(['Ground-truth', 'Prediction', 'Epistemic uncertainty'])
    plt.title('Result on testing data set')
    plt.tight_layout()
    plt.savefig(os.path.join(args.out, 'eval_epistemic.png'))
    plt.close()

    plt.rcParams['font.size'] = 18
    plt.figure(figsize=(13, 5))
    ax = sns.scatterplot(x=x, y=pred, color='blue', s=75)
    ax.errorbar(x,
                pred,
                yerr=aleatory_uncert,
                fmt='none',
                capsize=10,
                ecolor='gray',
                linewidth=1.5)
    ax.plot(x, t, color='red', linewidth=1.5)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_xlim(-10, 10)
    ax.set_ylim(-15, 15)
    plt.legend(['Ground-truth', 'Prediction', 'Aleatoric uncertainty'])
    plt.title('Result on testing data set')
    plt.tight_layout()
    plt.savefig(os.path.join(args.out, 'eval_aleatoric.png'))
    plt.close()
示例#19
0
def train_phase(predictor, train, valid, args):

    # visualize
    plt.rcParams['font.size'] = 18
    plt.figure(figsize=(13, 5))
    ax = sns.scatterplot(x=train.x.ravel(),
                         y=train.y.ravel(),
                         color='blue',
                         s=55,
                         alpha=0.3)
    ax.plot(train.x.ravel(), train.t.ravel(), color='red', linewidth=2)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_xlim(-10, 10)
    ax.set_ylim(-15, 15)
    plt.legend(['Ground-truth', 'Observation'])
    plt.title('Training data set')
    plt.tight_layout()
    plt.savefig(os.path.join(args.out, 'train_dataset.png'))
    plt.close()

    # setup iterators
    train_iter = iterators.SerialIterator(train, args.batchsize, shuffle=True)
    valid_iter = iterators.SerialIterator(valid,
                                          args.batchsize,
                                          repeat=False,
                                          shuffle=False)

    # setup a model
    device = torch.device(args.gpu)

    lossfun = noised_mean_squared_error
    accfun = lambda y, t: F.l1_loss(y[0], t)

    model = Regressor(predictor, lossfun=lossfun, accfun=accfun)
    model.to(device)

    # setup an optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 weight_decay=max(args.decay, 0))

    # setup a trainer
    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                model,
                                                device=device)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    trainer.extend(extensions.Evaluator(valid_iter, model, device=args.gpu))

    # trainer.extend(DumpGraph(model, 'main/loss'))

    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
    trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))

    trainer.extend(extensions.LogReport())

    if args.plot and extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                  'epoch',
                                  file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch',
                file_name='accuracy.png'))

        trainer.extend(
            extensions.PlotReport(
                ['main/predictor/sigma', 'validation/main/predictor/sigma'],
                'epoch',
                file_name='sigma.png'))

    trainer.extend(
        extensions.PrintReport([
            'epoch', 'iteration', 'main/loss', 'validation/main/loss',
            'main/accuracy', 'validation/main/accuracy',
            'main/predictor/sigma', 'validation/main/predictor/sigma',
            'elapsed_time'
        ]))

    trainer.extend(extensions.ProgressBar())

    if args.resume:
        trainer.load_state_dict(torch.load(args.resume))

    trainer.run()

    torch.save(predictor.state_dict(), os.path.join(args.out, 'predictor.pth'))
示例#20
0
 def test_iterator_shuffle_divisible(self):
     dataset = list(range(10))
     it = iterators.SerialIterator(dataset, 10, shuffle=self.shuffle,
                                   order_sampler=self.order_sampler)
     self.assertNotEqual(it.next(), it.next())
def test_phase(predictor, test, args):

    print('# samples:')
    print('-- test:', len(test))

    test_iter = iterators.SerialIterator(test,
                                         args.batchsize,
                                         repeat=False,
                                         shuffle=False)

    # setup a inferencer
    snapshot_file = find_latest_snapshot(
        'predictor_iter_{.updater.iteration:08}.pth', args.out)
    predictor.load_state_dict(torch.load(snapshot_file))

    print('Loaded a snapshot:', snapshot_file)

    model = MCSampler(predictor,
                      mc_iteration=args.mc_iteration,
                      activation=partial(torch.softmax, dim=1),
                      reduce_mean=partial(torch.argmax, dim=1),
                      reduce_var=partial(torch.mean, dim=1))

    device = torch.device(args.gpu)
    model.to(device)

    infer = Inferencer(test_iter, model, device=args.gpu)

    pred, uncert = infer.run()

    # evaluate
    os.makedirs(os.path.join(args.out, 'test'), exist_ok=True)

    acc_values = []
    uncert_values = []

    uncert_clim = (0, np.percentile(uncert, 95))

    files = test.files['image']
    if isinstance(files, np.ndarray): files = files.tolist()
    commonpath = os.path.commonpath(files)

    plt.rcParams['font.size'] = 14

    for i, (p, u, imf, lbf) in enumerate(
            zip(pred, uncert, test.files['image'], test.files['label'])):
        im, _ = load_image(imf)
        im = im[:, :, ::-1]
        lb, _ = load_image(lbf)
        if lb.ndim == 3: lb = lb[:, :, 0]

        acc_values.append(eval_metric(p, lb))
        uncert_values.append(np.mean(u[p == 1]))  # NOTE: instrument class

        plt.figure(figsize=(20, 4))

        for j, (pic, cmap, clim, title) in enumerate(
                zip([im, p, lb, u, (p != lb).astype(np.uint8)],
                    [None, None, None, 'jet', 'jet'],
                    [None, None, None, uncert_clim, None], [
                        'Input image\n%s' % os.path.relpath(imf, commonpath),
                        'Predicted label\n(DC=%.3f)' % acc_values[-1],
                        'Ground-truth label',
                        'Predicted variance\n(PV=%.4f)' % uncert_values[-1],
                        'Error'
                    ])):
            plt.subplot(1, 5, j + 1)
            plt.imshow(pic, cmap=cmap)
            plt.xticks([], [])
            plt.yticks([], [])
            plt.title(title)
            plt.clim(clim)

        plt.tight_layout()
        plt.savefig(os.path.join(args.out, 'test/%03d.png' % i))
        plt.close()

    c = pearsonr(uncert_values, acc_values)

    plt.figure(figsize=(11, 11))
    ax = sns.scatterplot(x=uncert_values, y=acc_values, color='blue', s=50)
    ax.set_xlabel('Predicted variance')
    ax.set_ylabel('Dice coefficient')
    plt.grid()
    plt.title('r=%.3f' % c[0])
    plt.savefig(os.path.join(args.out, 'eval.png'))
    plt.close()
def test_phase(generator, test, args):

    print('# samples:')
    print('-- test:', len(test))

    test_iter = iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False)

    # setup a inferencer
    snapshot_file = find_latest_snapshot('generator_iter_{.updater.iteration:08}.pth', args.out)
    generator.load_state_dict(torch.load(snapshot_file))
    print('Loaded a snapshot:', snapshot_file)

    model = MCSampler(generator,
                      mc_iteration=args.mc_iteration,
                      activation=torch.tanh,
                      reduce_mean=None,
                      reduce_var=partial(torch.mean, dim=1))

    device = torch.device(args.gpu)
    model.to(device)

    infer = Inferencer(test_iter, model, device=args.gpu)

    pred, uncert = infer.run()


    # evaluate
    os.makedirs(os.path.join(args.out, 'test'), exist_ok=True)

    acc_values = []
    uncert_values = []

    uncert_clim = (0, np.percentile(uncert, 95))
    error_clim = (0, 1)

    files = test.files['image']
    if isinstance(files, np.ndarray): files = files.tolist()
    commonpath = os.path.commonpath(files)

    plt.rcParams['font.size'] = 14

    for i, (p, u, imf, lbf) in enumerate(zip(pred, uncert,
                                             test.files['image'],
                                             test.files['label'])):
        im, _ = load_image(imf)
        lb, _ = load_image(lbf)
        im = im.astype(np.float32)
        lb = lb.astype(np.float32)

        p = p.transpose(1,2,0)

        im = (im[:,:,::-1] + 1.) / 2.
        lb = (lb[:,:,::-1] + 1.) / 2.
        p  = (p[:,:,::-1] + 1.) / 2.

        error = np.mean(np.abs(p-lb), axis=-1)

        acc_values.append( eval_metric(p,lb) )
        uncert_values.append( np.mean(u) )


        plt.figure(figsize=(20,4))

        for j, (pic, cmap, clim, title) in enumerate(zip(
                                        [im, p, lb, u, error],
                                        [None, None, None, 'jet', 'jet'],
                                        [None, None, None, uncert_clim, error_clim],
                                        ['Input image\n%s' % os.path.relpath(imf, commonpath),
                                             'Predicted label\n(MAE=%.3f)' % acc_values[-1],
                                             'Ground-truth label',
                                             'Predicted variance\n(PV=%.4f)' % uncert_values[-1],
                                             'Error'])):
            plt.subplot(1,5, j+1)
            plt.imshow(pic, cmap=cmap)
            plt.xticks([], [])
            plt.yticks([], [])
            plt.title(title)
            plt.clim(clim)

        plt.tight_layout()
        plt.savefig(os.path.join(args.out, 'test/%03d.png' % i))
        plt.close()
def train_phase(predictor, train, valid, args):

    print('# classes:', train.n_classes)
    print('# samples:')
    print('-- train:', len(train))
    print('-- valid:', len(valid))

    # setup dataset iterators
    train_iter = iterators.MultiprocessIterator(train, args.batchsize)
    valid_iter = iterators.SerialIterator(valid,
                                          args.batchsize,
                                          repeat=False,
                                          shuffle=True)

    # setup a model
    class_weight = None  # NOTE: please set if you have..

    lossfun = partial(softmax_cross_entropy,
                      normalize=False,
                      class_weight=class_weight)

    device = torch.device(args.gpu)

    model = Classifier(predictor, lossfun=lossfun)
    model.to(device)

    # setup an optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=max(args.decay, 0))

    # setup a trainer
    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                model,
                                                device=device)
    trainer = training.Trainer(updater, (args.iteration, 'iteration'),
                               out=args.out)

    frequency = max(args.iteration //
                    20, 1) if args.frequency == -1 else max(1, args.frequency)

    stop_trigger = triggers.EarlyStoppingTrigger(
        monitor='validation/main/loss',
        max_trigger=(args.iteration, 'iteration'),
        check_trigger=(frequency, 'iteration'),
        patients=np.inf if args.pinfall == -1 else max(1, args.pinfall))

    trainer = training.Trainer(updater, stop_trigger, out=args.out)

    # setup a visualizer
    transforms = {
        'x': lambda x: x,
        'y': lambda x: np.argmax(x, axis=0),
        't': lambda x: x
    }

    cmap = np.array([[0, 0, 0], [0, 0, 1]])
    cmaps = {'x': None, 'y': cmap, 't': cmap}

    clims = {'x': 'minmax', 'y': None, 't': None}

    visualizer = ImageVisualizer(transforms=transforms,
                                 cmaps=cmaps,
                                 clims=clims)

    # setup a validator
    valid_file = os.path.join('validation', 'iter_{.updater.iteration:08}.png')
    trainer.extend(Validator(valid_iter,
                             model,
                             valid_file,
                             visualizer=visualizer,
                             n_vis=20,
                             device=args.gpu),
                   trigger=(frequency, 'iteration'))

    # trainer.extend(DumpGraph(model, 'main/loss'))

    trainer.extend(extensions.snapshot(
        filename='snapshot_iter_{.updater.iteration:08}.pth'),
                   trigger=(frequency, 'iteration'))
    trainer.extend(extensions.snapshot_object(
        predictor, 'predictor_iter_{.updater.iteration:08}.pth'),
                   trigger=(frequency, 'iteration'))

    log_keys = [
        'main/loss', 'validation/main/loss', 'main/accuracy',
        'validation/main/accuracy'
    ]

    trainer.extend(LogReport(keys=log_keys))

    # setup log ploter
    if extensions.PlotReport.available():
        for plot_key in ['loss', 'accuracy']:
            plot_keys = [
                key for key in log_keys
                if key.split('/')[-1].startswith(plot_key)
            ]
            trainer.extend(
                extensions.PlotReport(plot_keys,
                                      'iteration',
                                      file_name=plot_key + '.png',
                                      trigger=(frequency, 'iteration')))

    trainer.extend(
        PrintReport(['iteration'] + log_keys + ['elapsed_time'], n_step=100))

    trainer.extend(extensions.ProgressBar())

    if args.resume:
        trainer.load_state_dict(torch.load(args.resume))

    # train
    trainer.run()
def train_phase(generator, train, valid, args):

    print('# samples:')
    print('-- train:', len(train))
    print('-- valid:', len(valid))

    # setup dataset iterators
    train_iter = iterators.SerialIterator(train, args.batchsize)
    valid_iter = iterators.SerialIterator(valid, args.batchsize,
                                                repeat=False, shuffle=True)

    # setup a model
    model = Regressor(generator,
                      activation=torch.tanh,
                      lossfun=F.l1_loss,
                      accfun=F.l1_loss)

    discriminator = build_discriminator()
    discriminator.save_args(os.path.join(args.out, 'discriminator.json'))

    device = torch.device(args.gpu)

    model.to(device)
    discriminator.to(device)

    # setup an optimizer
    optimizer_G = torch.optim.Adam(model.parameters(),
                                   lr=args.lr,
                                   betas=(args.beta, 0.999),
                                   weight_decay=max(args.decay, 0))

    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=args.lr,
                                   betas=(args.beta, 0.999),
                                   weight_decay=max(args.decay, 0))

    # setup a trainer
    updater = DCGANUpdater(
        iterator=train_iter,
        optimizer={
            'gen': optimizer_G,
            'dis': optimizer_D,
        },
        model={
            'gen': model,
            'dis': discriminator,
        },
        alpha=args.alpha,
        device=args.gpu,
    )

    frequency = max(args.iteration//80, 1) if args.frequency == -1 else max(1, args.frequency)

    stop_trigger = triggers.EarlyStoppingTrigger(monitor='validation/main/loss',
                        max_trigger=(args.iteration, 'iteration'),
                        check_trigger=(frequency, 'iteration'),
                        patients=np.inf if args.pinfall == -1 else max(1, args.pinfall))

    trainer = training.Trainer(updater, stop_trigger, out=args.out)

    # shift lr
    trainer.extend(
        extensions.LinearShift('lr', (args.lr, 0.0),
                        (args.iteration//2, args.iteration),
                        optimizer=optimizer_G))
    trainer.extend(
        extensions.LinearShift('lr', (args.lr, 0.0),
                        (args.iteration//2, args.iteration),
                        optimizer=optimizer_D))

    # setup a visualizer

    transforms = {'x': lambda x: x, 'y': lambda x: x, 't': lambda x: x}
    clims = {'x': (-1., 1.), 'y': (-1., 1.), 't': (-1., 1.)}

    visualizer = ImageVisualizer(transforms=transforms,
                                 cmaps=None,
                                 clims=clims)

    # setup a validator
    valid_file = os.path.join('validation', 'iter_{.updater.iteration:08}.png')
    trainer.extend(Validator(valid_iter, model, valid_file,
                             visualizer=visualizer, n_vis=20,
                             device=args.gpu),
                             trigger=(frequency, 'iteration'))

    # trainer.extend(DumpGraph('loss_gen', filename='generative_loss.dot'))
    # trainer.extend(DumpGraph('loss_cond', filename='conditional_loss.dot'))
    # trainer.extend(DumpGraph('loss_dis', filename='discriminative_loss.dot'))

    trainer.extend(extensions.snapshot(filename='snapshot_iter_{.updater.iteration:08}.pth'),
                                       trigger=(frequency, 'iteration'))
    trainer.extend(extensions.snapshot_object(generator, 'generator_iter_{.updater.iteration:08}.pth'),
                                              trigger=(frequency, 'iteration'))
    trainer.extend(extensions.snapshot_object(discriminator, 'discriminator_iter_{.updater.iteration:08}.pth'),
                                              trigger=(frequency, 'iteration'))

    log_keys = ['loss_gen', 'loss_cond', 'loss_dis',
                'validation/main/accuracy']

    trainer.extend(LogReport(keys=log_keys, trigger=(100, 'iteration')))

    # setup log ploter
    if extensions.PlotReport.available():
        for plot_key in ['loss', 'accuracy']:
            plot_keys = [key for key in log_keys if key.split('/')[-1].startswith(plot_key)]
            trainer.extend(
                extensions.PlotReport(plot_keys,
                                     'iteration', file_name=plot_key + '.png',
                                     trigger=(frequency, 'iteration')) )

    trainer.extend(PrintReport(['iteration'] + log_keys + ['elapsed_time'], n_step=1))

    trainer.extend(extensions.ProgressBar())

    if args.resume:
        trainer.load_state_dict(torch.load(args.resume))


    # train
    trainer.run()