Beispiel #1
0
    def test_iterator_state_dict_backward_compat(self):
        dataset = [1, 2, 3, 4, 5, 6]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)

        self.assertEqual(it.epoch, 0)
        self.assertAlmostEqual(it.epoch_detail, 0 / 6)
        self.assertIsNone(it.previous_epoch_detail)
        batch1 = it.next()
        self.assertEqual(len(batch1), 2)
        self.assertIsInstance(batch1, list)
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 2 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 0 / 6)
        batch2 = it.next()
        self.assertEqual(len(batch2), 2)
        self.assertIsInstance(batch2, list)
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 4 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6)

        state_dict = copy.deepcopy(it.state_dict())

        it = iterators.MultiprocessIterator(dataset, 2, **self.options)
        it.load_state_dict(state_dict)
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 4 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6)

        batch3 = it.next()
        self.assertEqual(len(batch3), 2)
        self.assertIsInstance(batch3, list)
        self.assertTrue(it.is_new_epoch)
        self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
        self.assertAlmostEqual(it.epoch_detail, 6 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 4 / 6)
Beispiel #2
0
 def test_reproduce_same_permutation(self):
     dataset = [1, 2, 3, 4, 5, 6]
     order_sampler1 = iterators.ShuffleOrderSampler(
         numpy.random.RandomState(self._seed))
     it1 = iterators.MultiprocessIterator(dataset,
                                          6,
                                          order_sampler=order_sampler1)
     order_sampler2 = iterators.ShuffleOrderSampler(
         numpy.random.RandomState(self._seed))
     it2 = iterators.MultiprocessIterator(dataset,
                                          6,
                                          order_sampler=order_sampler2)
     for _ in range(5):
         self.assertEqual(it1.next(), it2.next())
Beispiel #3
0
    def test_iterator_repeat(self):
        dataset = [1, 2, 3]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)
        for i in range(3):
            self.assertEqual(it.epoch, i)
            self.assertAlmostEqual(it.epoch_detail, i + 0 / 6)
            if i == 0:
                self.assertIsNone(it.previous_epoch_detail)
            else:
                self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6)
            batch1 = it.next()
            self.assertEqual(len(batch1), 2)
            self.assertIsInstance(batch1, list)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, i + 2 / 6)
            self.assertAlmostEqual(it.previous_epoch_detail, i + 0 / 6)
            batch2 = it.next()
            self.assertEqual(len(batch2), 2)
            self.assertIsInstance(batch2, list)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, i + 4 / 6)
            self.assertAlmostEqual(it.previous_epoch_detail, i + 2 / 6)
            batch3 = it.next()
            self.assertEqual(len(batch3), 2)
            self.assertIsInstance(batch3, list)
            self.assertTrue(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, i + 6 / 6)
            self.assertAlmostEqual(it.previous_epoch_detail, i + 4 / 6)

            self.assertEqual(sorted(batch1 + batch2 + batch3),
                             [1, 1, 2, 2, 3, 3])
Beispiel #4
0
    def test_invalid_order_sampler(self):
        dataset = [1, 2, 3, 4, 5, 6]

        with self.assertRaises(ValueError):
            it = iterators.MultiprocessIterator(
                dataset, 6, shuffle=None, order_sampler=_InvalidOrderSampler())
            it.next()
Beispiel #5
0
    def test_iterator_list_type(self):
        dataset = [[i, numpy.zeros((10, )) + i] for i in range(6)]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)
        for i in range(3):
            self.assertEqual(it.epoch, i)
            self.assertAlmostEqual(it.epoch_detail, i)
            if i == 0:
                self.assertIsNone(it.previous_epoch_detail)
            else:
                self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6)
            batches = {}
            for j in range(3):
                batch = it.next()
                self.assertEqual(len(batch), 2)
                if j != 2:
                    self.assertFalse(it.is_new_epoch)
                else:
                    self.assertTrue(it.is_new_epoch)
                self.assertAlmostEqual(it.epoch_detail,
                                       (3 * i + j + 1) * 2 / 6)
                self.assertAlmostEqual(it.previous_epoch_detail,
                                       (3 * i + j) * 2 / 6)
                for x in batch:
                    self.assertIsInstance(x, list)
                    self.assertIsInstance(x[1], numpy.ndarray)
                    batches[x[0]] = x[1]

            self.assertEqual(len(batches), len(dataset))
            for k, v in six.iteritems(batches):
                numpy.testing.assert_allclose(dataset[k][1], v)
Beispiel #6
0
    def test_iterator_pickle_after_init(self):
        dataset = [1, 2, 3, 4, 5, 6]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)

        self.assertEqual(it.epoch, 0)
        self.assertAlmostEqual(it.epoch_detail, 0 / 6)
        self.assertIsNone(it.previous_epoch_detail)
        batch1 = it.next()
        self.assertEqual(len(batch1), 2)
        self.assertIsInstance(batch1, list)
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 2 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 0 / 6)
        batch2 = it.next()
        self.assertEqual(len(batch2), 2)
        self.assertIsInstance(batch2, list)
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 4 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6)

        pickled_it = pickle.dumps(it)
        it = pickle.loads(pickled_it)

        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 4 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6)

        batch3 = it.next()
        self.assertEqual(len(batch3), 2)
        self.assertIsInstance(batch3, list)
        self.assertTrue(it.is_new_epoch)
        self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
        self.assertAlmostEqual(it.epoch_detail, 6 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 4 / 6)
Beispiel #7
0
    def test_stalled_getitem(self):
        nth = self.nth
        batch_size = 2
        sleep = 0.5
        timeout = 0.1

        dataset = StallingDataset(nth, sleep)
        it = iterators.MultiprocessIterator(dataset,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            dataset_timeout=timeout,
                                            repeat=False)

        # TimeoutWarning should be issued.
        warning_cls = iterators.MultiprocessIterator.TimeoutWarning
        data = []
        # No warning until the stalling batch
        for i in range(nth // batch_size):
            data.append(it.next())
        # Warning on the stalling batch
        with testing.assert_warns(warning_cls):
            data.append(it.next())
        # Retrieve data until the end
        while True:
            try:
                data.append(it.next())
            except StopIteration:
                break

        # All data must be retrieved
        assert data == [
            dataset.data[i * batch_size:(i + 1) * batch_size]
            for i in range((len(dataset) + batch_size - 1) // batch_size)
        ]
Beispiel #8
0
    def test_iterator_pickle_new(self):
        dataset = [1, 2, 3, 4, 5, 6]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)

        self.assertEqual(it.epoch, 0)
        self.assertAlmostEqual(it.epoch_detail, 0 / 6)
        self.assertIsNone(it.previous_epoch_detail)
        pickled_it = pickle.dumps(it)
        it = pickle.loads(pickled_it)
Beispiel #9
0
 def test_unsupported_reset_finalized(self):
     dataset = [1, 2, 3, 4]
     it = iterators.MultiprocessIterator(dataset,
                                         2,
                                         repeat=False,
                                         **self.options)
     it.next()
     it.next()
     it.finalize()
     self.assertRaises(NotImplementedError, it.reset)
Beispiel #10
0
    def test_iterator_not_repeat(self):
        dataset = [1, 2, 3, 4, 5]
        it = iterators.MultiprocessIterator(dataset,
                                            2,
                                            repeat=False,
                                            **self.options)

        batches = sum([it.next() for _ in range(3)], [])
        self.assertEqual(sorted(batches), dataset)
        for _ in range(2):
            self.assertRaises(StopIteration, it.next)
Beispiel #11
0
    def test_no_same_indices_order_sampler(self):
        dataset = [1, 2, 3, 4, 5, 6]
        batchsize = 5

        it = iterators.MultiprocessIterator(
            dataset,
            batchsize,
            order_sampler=_NoSameIndicesOrderSampler(batchsize))
        for _ in range(5):
            batch = it.next()
            self.assertEqual(len(numpy.unique(batch)), batchsize)
Beispiel #12
0
    def test_reset_repeat(self):
        dataset = [1, 2, 3, 4]
        it = iterators.MultiprocessIterator(dataset,
                                            2,
                                            repeat=True,
                                            **self.options)

        for trial in range(4):
            batches = sum([it.next() for _ in range(4)], [])
            self.assertEqual(sorted(batches), sorted(2 * dataset))
            it.reset()
Beispiel #13
0
    def test_finalize_not_deadlock(self):
        dataset = numpy.ones((1000, 1000))
        it = iterators.MultiprocessIterator(dataset, 10, n_processes=4)
        for _ in range(10):
            it.next()

        t = threading.Thread(target=lambda: it.finalize())
        t.daemon = True
        t.start()
        t.join(5)
        deadlock = t.is_alive()

        self.assertFalse(deadlock)
Beispiel #14
0
    def test_iterator_not_repeat_not_even(self):
        dataset = [1, 2, 3, 4, 5]
        it = iterators.MultiprocessIterator(dataset,
                                            2,
                                            repeat=False,
                                            **self.options)

        self.assertAlmostEqual(it.epoch_detail, 0 / 5)
        self.assertIsNone(it.previous_epoch_detail)
        batch1 = it.next()
        self.assertAlmostEqual(it.epoch_detail, 2 / 5)
        self.assertAlmostEqual(it.previous_epoch_detail, 0 / 5)
        batch2 = it.next()
        self.assertAlmostEqual(it.epoch_detail, 4 / 5)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 5)
        batch3 = it.next()
        self.assertAlmostEqual(it.epoch_detail, 5 / 5)
        self.assertAlmostEqual(it.previous_epoch_detail, 4 / 5)
        self.assertRaises(StopIteration, it.next)

        self.assertEqual(len(batch3), 1)
        self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
    def test_iterator_compatibilty(self):
        dataset = [1, 2, 3, 4, 5, 6]

        iters = (
            lambda: iterators.SerialIterator(dataset, 2),
            lambda: iterators.MultiprocessIterator(dataset, 2, **self.options),
        )

        for it_before, it_after in itertools.permutations(iters, 2):
            it = it_before()

            self.assertEqual(it.epoch, 0)
            self.assertAlmostEqual(it.epoch_detail, 0 / 6)
            batch1 = it.next()
            self.assertEqual(len(batch1), 2)
            self.assertIsInstance(batch1, list)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, 2 / 6)
            batch2 = it.next()
            self.assertEqual(len(batch2), 2)
            self.assertIsInstance(batch2, list)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, 4 / 6)

            state_dict = copy.deepcopy(it.state_dict())

            it = it_after()
            it.load_state_dict(state_dict)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, 4 / 6)

            batch3 = it.next()
            self.assertEqual(len(batch3), 2)
            self.assertIsInstance(batch3, list)
            self.assertTrue(it.is_new_epoch)
            self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
            self.assertAlmostEqual(it.epoch_detail, 6 / 6)
Beispiel #16
0
    def test_iterator_repeat_not_even(self):
        dataset = [1, 2, 3, 4, 5]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)

        batches = sum([it.next() for _ in range(5)], [])
        self.assertEqual(sorted(batches), sorted(dataset * 2))
Beispiel #17
0
 def test_iterator_shuffle_divisible(self):
     dataset = list(range(10))
     it = iterators.MultiprocessIterator(dataset, 10, **self.options)
     self.assertNotEqual(it.next(), it.next())
Beispiel #18
0
 def test_iterator_shuffle_nondivisible(self):
     dataset = list(range(10))
     it = iterators.MultiprocessIterator(dataset, 3, **self.options)
     out = sum([it.next() for _ in range(7)], [])
     self.assertNotEqual(out[0:10], out[10:20])
def train_phase(predictor, train, valid, args):

    print('# classes:', train.n_classes)
    print('# samples:')
    print('-- train:', len(train))
    print('-- valid:', len(valid))

    # setup dataset iterators
    train_iter = iterators.MultiprocessIterator(train, args.batchsize)
    valid_iter = iterators.SerialIterator(valid,
                                          args.batchsize,
                                          repeat=False,
                                          shuffle=True)

    # setup a model
    class_weight = None  # NOTE: please set if you have..

    lossfun = partial(softmax_cross_entropy,
                      normalize=False,
                      class_weight=class_weight)

    device = torch.device(args.gpu)

    model = Classifier(predictor, lossfun=lossfun)
    model.to(device)

    # setup an optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=max(args.decay, 0))

    # setup a trainer
    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                model,
                                                device=device)
    trainer = training.Trainer(updater, (args.iteration, 'iteration'),
                               out=args.out)

    frequency = max(args.iteration //
                    20, 1) if args.frequency == -1 else max(1, args.frequency)

    stop_trigger = triggers.EarlyStoppingTrigger(
        monitor='validation/main/loss',
        max_trigger=(args.iteration, 'iteration'),
        check_trigger=(frequency, 'iteration'),
        patients=np.inf if args.pinfall == -1 else max(1, args.pinfall))

    trainer = training.Trainer(updater, stop_trigger, out=args.out)

    # setup a visualizer
    transforms = {
        'x': lambda x: x,
        'y': lambda x: np.argmax(x, axis=0),
        't': lambda x: x
    }

    cmap = np.array([[0, 0, 0], [0, 0, 1]])
    cmaps = {'x': None, 'y': cmap, 't': cmap}

    clims = {'x': 'minmax', 'y': None, 't': None}

    visualizer = ImageVisualizer(transforms=transforms,
                                 cmaps=cmaps,
                                 clims=clims)

    # setup a validator
    valid_file = os.path.join('validation', 'iter_{.updater.iteration:08}.png')
    trainer.extend(Validator(valid_iter,
                             model,
                             valid_file,
                             visualizer=visualizer,
                             n_vis=20,
                             device=args.gpu),
                   trigger=(frequency, 'iteration'))

    # trainer.extend(DumpGraph(model, 'main/loss'))

    trainer.extend(extensions.snapshot(
        filename='snapshot_iter_{.updater.iteration:08}.pth'),
                   trigger=(frequency, 'iteration'))
    trainer.extend(extensions.snapshot_object(
        predictor, 'predictor_iter_{.updater.iteration:08}.pth'),
                   trigger=(frequency, 'iteration'))

    log_keys = [
        'main/loss', 'validation/main/loss', 'main/accuracy',
        'validation/main/accuracy'
    ]

    trainer.extend(LogReport(keys=log_keys))

    # setup log ploter
    if extensions.PlotReport.available():
        for plot_key in ['loss', 'accuracy']:
            plot_keys = [
                key for key in log_keys
                if key.split('/')[-1].startswith(plot_key)
            ]
            trainer.extend(
                extensions.PlotReport(plot_keys,
                                      'iteration',
                                      file_name=plot_key + '.png',
                                      trigger=(frequency, 'iteration')))

    trainer.extend(
        PrintReport(['iteration'] + log_keys + ['elapsed_time'], n_step=100))

    trainer.extend(extensions.ProgressBar())

    if args.resume:
        trainer.load_state_dict(torch.load(args.resume))

    # train
    trainer.run()