def test_iterator_state_dict_backward_compat(self): dataset = [1, 2, 3, 4, 5, 6] it = iterators.MultiprocessIterator(dataset, 2, **self.options) self.assertEqual(it.epoch, 0) self.assertAlmostEqual(it.epoch_detail, 0 / 6) self.assertIsNone(it.previous_epoch_detail) batch1 = it.next() self.assertEqual(len(batch1), 2) self.assertIsInstance(batch1, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 2 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 0 / 6) batch2 = it.next() self.assertEqual(len(batch2), 2) self.assertIsInstance(batch2, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6) state_dict = copy.deepcopy(it.state_dict()) it = iterators.MultiprocessIterator(dataset, 2, **self.options) it.load_state_dict(state_dict) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6) batch3 = it.next() self.assertEqual(len(batch3), 2) self.assertIsInstance(batch3, list) self.assertTrue(it.is_new_epoch) self.assertEqual(sorted(batch1 + batch2 + batch3), dataset) self.assertAlmostEqual(it.epoch_detail, 6 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 4 / 6)
def test_reproduce_same_permutation(self): dataset = [1, 2, 3, 4, 5, 6] order_sampler1 = iterators.ShuffleOrderSampler( numpy.random.RandomState(self._seed)) it1 = iterators.MultiprocessIterator(dataset, 6, order_sampler=order_sampler1) order_sampler2 = iterators.ShuffleOrderSampler( numpy.random.RandomState(self._seed)) it2 = iterators.MultiprocessIterator(dataset, 6, order_sampler=order_sampler2) for _ in range(5): self.assertEqual(it1.next(), it2.next())
def test_iterator_repeat(self): dataset = [1, 2, 3] it = iterators.MultiprocessIterator(dataset, 2, **self.options) for i in range(3): self.assertEqual(it.epoch, i) self.assertAlmostEqual(it.epoch_detail, i + 0 / 6) if i == 0: self.assertIsNone(it.previous_epoch_detail) else: self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6) batch1 = it.next() self.assertEqual(len(batch1), 2) self.assertIsInstance(batch1, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, i + 2 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 0 / 6) batch2 = it.next() self.assertEqual(len(batch2), 2) self.assertIsInstance(batch2, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, i + 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 2 / 6) batch3 = it.next() self.assertEqual(len(batch3), 2) self.assertIsInstance(batch3, list) self.assertTrue(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, i + 6 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 4 / 6) self.assertEqual(sorted(batch1 + batch2 + batch3), [1, 1, 2, 2, 3, 3])
def test_invalid_order_sampler(self): dataset = [1, 2, 3, 4, 5, 6] with self.assertRaises(ValueError): it = iterators.MultiprocessIterator( dataset, 6, shuffle=None, order_sampler=_InvalidOrderSampler()) it.next()
def test_iterator_list_type(self): dataset = [[i, numpy.zeros((10, )) + i] for i in range(6)] it = iterators.MultiprocessIterator(dataset, 2, **self.options) for i in range(3): self.assertEqual(it.epoch, i) self.assertAlmostEqual(it.epoch_detail, i) if i == 0: self.assertIsNone(it.previous_epoch_detail) else: self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6) batches = {} for j in range(3): batch = it.next() self.assertEqual(len(batch), 2) if j != 2: self.assertFalse(it.is_new_epoch) else: self.assertTrue(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, (3 * i + j + 1) * 2 / 6) self.assertAlmostEqual(it.previous_epoch_detail, (3 * i + j) * 2 / 6) for x in batch: self.assertIsInstance(x, list) self.assertIsInstance(x[1], numpy.ndarray) batches[x[0]] = x[1] self.assertEqual(len(batches), len(dataset)) for k, v in six.iteritems(batches): numpy.testing.assert_allclose(dataset[k][1], v)
def test_iterator_pickle_after_init(self): dataset = [1, 2, 3, 4, 5, 6] it = iterators.MultiprocessIterator(dataset, 2, **self.options) self.assertEqual(it.epoch, 0) self.assertAlmostEqual(it.epoch_detail, 0 / 6) self.assertIsNone(it.previous_epoch_detail) batch1 = it.next() self.assertEqual(len(batch1), 2) self.assertIsInstance(batch1, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 2 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 0 / 6) batch2 = it.next() self.assertEqual(len(batch2), 2) self.assertIsInstance(batch2, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6) pickled_it = pickle.dumps(it) it = pickle.loads(pickled_it) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6) batch3 = it.next() self.assertEqual(len(batch3), 2) self.assertIsInstance(batch3, list) self.assertTrue(it.is_new_epoch) self.assertEqual(sorted(batch1 + batch2 + batch3), dataset) self.assertAlmostEqual(it.epoch_detail, 6 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 4 / 6)
def test_stalled_getitem(self): nth = self.nth batch_size = 2 sleep = 0.5 timeout = 0.1 dataset = StallingDataset(nth, sleep) it = iterators.MultiprocessIterator(dataset, batch_size=batch_size, shuffle=False, dataset_timeout=timeout, repeat=False) # TimeoutWarning should be issued. warning_cls = iterators.MultiprocessIterator.TimeoutWarning data = [] # No warning until the stalling batch for i in range(nth // batch_size): data.append(it.next()) # Warning on the stalling batch with testing.assert_warns(warning_cls): data.append(it.next()) # Retrieve data until the end while True: try: data.append(it.next()) except StopIteration: break # All data must be retrieved assert data == [ dataset.data[i * batch_size:(i + 1) * batch_size] for i in range((len(dataset) + batch_size - 1) // batch_size) ]
def test_iterator_pickle_new(self): dataset = [1, 2, 3, 4, 5, 6] it = iterators.MultiprocessIterator(dataset, 2, **self.options) self.assertEqual(it.epoch, 0) self.assertAlmostEqual(it.epoch_detail, 0 / 6) self.assertIsNone(it.previous_epoch_detail) pickled_it = pickle.dumps(it) it = pickle.loads(pickled_it)
def test_unsupported_reset_finalized(self): dataset = [1, 2, 3, 4] it = iterators.MultiprocessIterator(dataset, 2, repeat=False, **self.options) it.next() it.next() it.finalize() self.assertRaises(NotImplementedError, it.reset)
def test_iterator_not_repeat(self): dataset = [1, 2, 3, 4, 5] it = iterators.MultiprocessIterator(dataset, 2, repeat=False, **self.options) batches = sum([it.next() for _ in range(3)], []) self.assertEqual(sorted(batches), dataset) for _ in range(2): self.assertRaises(StopIteration, it.next)
def test_no_same_indices_order_sampler(self): dataset = [1, 2, 3, 4, 5, 6] batchsize = 5 it = iterators.MultiprocessIterator( dataset, batchsize, order_sampler=_NoSameIndicesOrderSampler(batchsize)) for _ in range(5): batch = it.next() self.assertEqual(len(numpy.unique(batch)), batchsize)
def test_reset_repeat(self): dataset = [1, 2, 3, 4] it = iterators.MultiprocessIterator(dataset, 2, repeat=True, **self.options) for trial in range(4): batches = sum([it.next() for _ in range(4)], []) self.assertEqual(sorted(batches), sorted(2 * dataset)) it.reset()
def test_finalize_not_deadlock(self): dataset = numpy.ones((1000, 1000)) it = iterators.MultiprocessIterator(dataset, 10, n_processes=4) for _ in range(10): it.next() t = threading.Thread(target=lambda: it.finalize()) t.daemon = True t.start() t.join(5) deadlock = t.is_alive() self.assertFalse(deadlock)
def test_iterator_not_repeat_not_even(self): dataset = [1, 2, 3, 4, 5] it = iterators.MultiprocessIterator(dataset, 2, repeat=False, **self.options) self.assertAlmostEqual(it.epoch_detail, 0 / 5) self.assertIsNone(it.previous_epoch_detail) batch1 = it.next() self.assertAlmostEqual(it.epoch_detail, 2 / 5) self.assertAlmostEqual(it.previous_epoch_detail, 0 / 5) batch2 = it.next() self.assertAlmostEqual(it.epoch_detail, 4 / 5) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 5) batch3 = it.next() self.assertAlmostEqual(it.epoch_detail, 5 / 5) self.assertAlmostEqual(it.previous_epoch_detail, 4 / 5) self.assertRaises(StopIteration, it.next) self.assertEqual(len(batch3), 1) self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
def test_iterator_compatibilty(self): dataset = [1, 2, 3, 4, 5, 6] iters = ( lambda: iterators.SerialIterator(dataset, 2), lambda: iterators.MultiprocessIterator(dataset, 2, **self.options), ) for it_before, it_after in itertools.permutations(iters, 2): it = it_before() self.assertEqual(it.epoch, 0) self.assertAlmostEqual(it.epoch_detail, 0 / 6) batch1 = it.next() self.assertEqual(len(batch1), 2) self.assertIsInstance(batch1, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 2 / 6) batch2 = it.next() self.assertEqual(len(batch2), 2) self.assertIsInstance(batch2, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) state_dict = copy.deepcopy(it.state_dict()) it = it_after() it.load_state_dict(state_dict) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) batch3 = it.next() self.assertEqual(len(batch3), 2) self.assertIsInstance(batch3, list) self.assertTrue(it.is_new_epoch) self.assertEqual(sorted(batch1 + batch2 + batch3), dataset) self.assertAlmostEqual(it.epoch_detail, 6 / 6)
def test_iterator_repeat_not_even(self): dataset = [1, 2, 3, 4, 5] it = iterators.MultiprocessIterator(dataset, 2, **self.options) batches = sum([it.next() for _ in range(5)], []) self.assertEqual(sorted(batches), sorted(dataset * 2))
def test_iterator_shuffle_divisible(self): dataset = list(range(10)) it = iterators.MultiprocessIterator(dataset, 10, **self.options) self.assertNotEqual(it.next(), it.next())
def test_iterator_shuffle_nondivisible(self): dataset = list(range(10)) it = iterators.MultiprocessIterator(dataset, 3, **self.options) out = sum([it.next() for _ in range(7)], []) self.assertNotEqual(out[0:10], out[10:20])
def train_phase(predictor, train, valid, args): print('# classes:', train.n_classes) print('# samples:') print('-- train:', len(train)) print('-- valid:', len(valid)) # setup dataset iterators train_iter = iterators.MultiprocessIterator(train, args.batchsize) valid_iter = iterators.SerialIterator(valid, args.batchsize, repeat=False, shuffle=True) # setup a model class_weight = None # NOTE: please set if you have.. lossfun = partial(softmax_cross_entropy, normalize=False, class_weight=class_weight) device = torch.device(args.gpu) model = Classifier(predictor, lossfun=lossfun) model.to(device) # setup an optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=max(args.decay, 0)) # setup a trainer updater = training.updaters.StandardUpdater(train_iter, optimizer, model, device=device) trainer = training.Trainer(updater, (args.iteration, 'iteration'), out=args.out) frequency = max(args.iteration // 20, 1) if args.frequency == -1 else max(1, args.frequency) stop_trigger = triggers.EarlyStoppingTrigger( monitor='validation/main/loss', max_trigger=(args.iteration, 'iteration'), check_trigger=(frequency, 'iteration'), patients=np.inf if args.pinfall == -1 else max(1, args.pinfall)) trainer = training.Trainer(updater, stop_trigger, out=args.out) # setup a visualizer transforms = { 'x': lambda x: x, 'y': lambda x: np.argmax(x, axis=0), 't': lambda x: x } cmap = np.array([[0, 0, 0], [0, 0, 1]]) cmaps = {'x': None, 'y': cmap, 't': cmap} clims = {'x': 'minmax', 'y': None, 't': None} visualizer = ImageVisualizer(transforms=transforms, cmaps=cmaps, clims=clims) # setup a validator valid_file = os.path.join('validation', 'iter_{.updater.iteration:08}.png') trainer.extend(Validator(valid_iter, model, valid_file, visualizer=visualizer, n_vis=20, device=args.gpu), trigger=(frequency, 'iteration')) # trainer.extend(DumpGraph(model, 'main/loss')) trainer.extend(extensions.snapshot( filename='snapshot_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) trainer.extend(extensions.snapshot_object( predictor, 'predictor_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) log_keys = [ 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy' ] trainer.extend(LogReport(keys=log_keys)) # setup log ploter if extensions.PlotReport.available(): for plot_key in ['loss', 'accuracy']: plot_keys = [ key for key in log_keys if key.split('/')[-1].startswith(plot_key) ] trainer.extend( extensions.PlotReport(plot_keys, 'iteration', file_name=plot_key + '.png', trigger=(frequency, 'iteration'))) trainer.extend( PrintReport(['iteration'] + log_keys + ['elapsed_time'], n_step=100)) trainer.extend(extensions.ProgressBar()) if args.resume: trainer.load_state_dict(torch.load(args.resume)) # train trainer.run()