def train_phase(predictor, train, valid, args): # setup iterators train_iter = iterators.SerialIterator(train, args.batchsize) valid_iter = iterators.SerialIterator(valid, args.batchsize, repeat=False, shuffle=False) # setup a model device = torch.device(args.gpu) model = Classifier(predictor) model.to(device) # setup an optimizer optimizer = torch.optim.Adam(model.parameters(), weight_decay=max(args.decay, 0)) # setup a trainer updater = training.updaters.StandardUpdater(train_iter, optimizer, model, device=device) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) trainer.extend(extensions.Evaluator(valid_iter, model, device=args.gpu)) # trainer.extend(DumpGraph(model, 'main/loss')) frequency = args.epoch if args.frequency == -1 else max(1, args.frequency) trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch')) trainer.extend(extensions.LogReport()) if args.plot and extensions.PlotReport.available(): trainer.extend( extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport( ['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png')) trainer.extend( extensions.PrintReport([ 'epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time' ])) trainer.extend(extensions.ProgressBar()) if args.resume: trainer.load_state_dict(torch.load(args.resume)) trainer.run() torch.save(predictor.state_dict(), os.path.join(args.out, 'predictor.pth'))
def test_iterator_repeat(self): dataset = [1, 2, 3] it = iterators.SerialIterator(dataset, 2, **self.options) for i in range(3): self.assertEqual(it.epoch, i) self.assertAlmostEqual(it.epoch_detail, i + 0 / 6) if i == 0: self.assertIsNone(it.previous_epoch_detail) else: self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6) batch1 = it.next() self.assertEqual(len(batch1), 2) self.assertIsInstance(batch1, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, i + 2 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 0 / 6) batch2 = it.next() self.assertEqual(len(batch2), 2) self.assertIsInstance(batch2, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, i + 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 2 / 6) batch3 = it.next() self.assertEqual(len(batch3), 2) self.assertIsInstance(batch3, list) self.assertTrue(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, i + 6 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 4 / 6) self.assertEqual( sorted(batch1 + batch2 + batch3), [1, 1, 2, 2, 3, 3])
def test_invalid_order_sampler(self): dataset = [1, 2, 3, 4, 5, 6] with self.assertRaises(ValueError): it = iterators.SerialIterator( dataset, 6, order_sampler=InvalidOrderSampler()) it.next()
def test_iterator_repeat(self): dataset = [1, 2, 3, 4, 5, 6] it = iterators.SerialIterator(dataset, 2, shuffle=self.shuffle, order_sampler=self.order_sampler) for i in range(3): self.assertEqual(it.epoch, i) self.assertAlmostEqual(it.epoch_detail, i + 0 / 6) if i == 0: self.assertIsNone(it.previous_epoch_detail) else: self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6) batch1 = it.next() self.assertEqual(len(batch1), 2) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, i + 2 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 0 / 6) batch2 = it.next() self.assertEqual(len(batch2), 2) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, i + 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 2 / 6) batch3 = it.next() self.assertEqual(len(batch3), 2) self.assertTrue(it.is_new_epoch) self.assertEqual(sorted(batch1 + batch2 + batch3), dataset) self.assertAlmostEqual(it.epoch_detail, i + 6 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 4 / 6)
def test_iterator_repeat_not_even(self): dataset = [1, 2, 3, 4, 5] it = iterators.SerialIterator(dataset, 2, shuffle=False) self.assertEqual(it.epoch, 0) self.assertAlmostEqual(it.epoch_detail, 0 / 5) self.assertIsNone(it.previous_epoch_detail) self.assertEqual(it.next(), [1, 2]) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 2 / 5) self.assertAlmostEqual(it.previous_epoch_detail, 0 / 5) self.assertEqual(it.next(), [3, 4]) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 5) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 5) self.assertEqual(it.next(), [5, 1]) self.assertTrue(it.is_new_epoch) self.assertEqual(it.epoch, 1) self.assertAlmostEqual(it.epoch_detail, 6 / 5) self.assertAlmostEqual(it.previous_epoch_detail, 4 / 5) self.assertEqual(it.next(), [2, 3]) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 8 / 5) self.assertAlmostEqual(it.previous_epoch_detail, 6 / 5) self.assertEqual(it.next(), [4, 5]) self.assertTrue(it.is_new_epoch) self.assertEqual(it.epoch, 2) self.assertAlmostEqual(it.epoch_detail, 10 / 5) self.assertAlmostEqual(it.previous_epoch_detail, 8 / 5)
def test_iterator_repeat_not_even(self): dataset = [1, 2, 3, 4, 5] it = iterators.SerialIterator(dataset, 2, shuffle=self.shuffle, order_sampler=self.order_sampler) batches = sum([it.next() for _ in range(5)], []) self.assertEqual(sorted(batches[:5]), dataset) self.assertEqual(sorted(batches[5:]), dataset)
def test_no_same_indices_order_sampler(self): dataset = [1, 2, 3, 4, 5, 6] batchsize = 5 it = iterators.SerialIterator( dataset, batchsize, order_sampler=NoSameIndicesOrderSampler(batchsize)) for _ in range(5): batch = it.next() self.assertEqual(len(numpy.unique(batch)), batchsize)
def test_iterator_not_repeat(self): dataset = [1, 2, 3, 4, 5, 6] it = iterators.SerialIterator(dataset, 2, repeat=False, shuffle=self.shuffle, order_sampler=self.order_sampler) batches = sum([it.next() for _ in range(3)], []) self.assertEqual(sorted(batches), dataset) for _ in range(2): self.assertRaises(StopIteration, it.next)
def setUp(self): self.data = [torch.empty(3, 4).uniform_(-1, 1) for _ in range(2)] self.iterator = iterators.SerialIterator(self.data, 1, repeat=False, shuffle=False) self.target = DummyModel(self) self.evaluator = extensions.Evaluator(self.iterator, {}, eval_func=self.target, progress_bar=True)
def test_phase(predictor, test, args): # setup an iterator test_iter = iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) # setup an inferencer predictor.load_state_dict( torch.load(os.path.join(args.out, 'predictor.pth'))) model = MCSampler(predictor, mc_iteration=args.mc_iteration, activation=partial(torch.softmax, dim=1), reduce_mean=partial(torch.argmax, dim=1), reduce_var=partial(torch.mean, dim=1)) device = torch.device(args.gpu) model.to(device) infer = Inferencer(test_iter, model, device=args.gpu) pred, uncert = infer.run() # evaluate os.makedirs(args.out, exist_ok=True) match = pred == test.labels accuracy = np.sum(match) / len(match) arr = [uncert[match], uncert[np.logical_not(match)]] plt.rcParams['font.size'] = 18 plt.figure(figsize=(13, 5)) ax = sns.violinplot(data=arr, inner='quartile', palette='Blues', orient='h', cut=0) ax.set_xlabel('Predicted variance') ax.set_yticklabels([ 'Correct prediction\n(n=%d)' % len(arr[0]), 'Wrong prediction\n(n=%d)' % len(arr[1]) ]) plt.title('Accuracy=%.3f' % accuracy) plt.tight_layout() plt.savefig(os.path.join(args.out, 'eval.png')) plt.close()
def test_iterator_state_dict_backward_compat(self): dataset = [1, 2, 3, 4, 5, 6] it = iterators.SerialIterator(dataset, 2, shuffle=self.shuffle, order_sampler=self.order_sampler) self.assertEqual(it.epoch, 0) self.assertAlmostEqual(it.epoch_detail, 0 / 6) self.assertIsNone(it.previous_epoch_detail) batch1 = it.next() self.assertEqual(len(batch1), 2) self.assertIsInstance(batch1, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 2 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 0 / 6) batch2 = it.next() self.assertEqual(len(batch2), 2) self.assertIsInstance(batch2, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6) state_dict = copy.deepcopy(it.state_dict()) it = iterators.SerialIterator(dataset, 2) it.load_state_dict(state_dict) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6) batch3 = it.next() self.assertEqual(len(batch3), 2) self.assertIsInstance(batch3, list) self.assertTrue(it.is_new_epoch) self.assertEqual(sorted(batch1 + batch2 + batch3), dataset) self.assertAlmostEqual(it.epoch_detail, 6 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 4 / 6)
def test_iterator_not_repeat_not_even(self): dataset = [1, 2, 3, 4, 5] it = iterators.SerialIterator(dataset, 2, repeat=False, shuffle=False) self.assertAlmostEqual(it.epoch_detail, 0 / 5) self.assertIsNone(it.previous_epoch_detail) self.assertEqual(it.next(), [1, 2]) self.assertAlmostEqual(it.epoch_detail, 2 / 5) self.assertAlmostEqual(it.previous_epoch_detail, 0 / 5) self.assertEqual(it.next(), [3, 4]) self.assertAlmostEqual(it.epoch_detail, 4 / 5) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 5) self.assertEqual(it.next(), [5]) self.assertTrue(it.is_new_epoch) self.assertEqual(it.epoch, 1) self.assertAlmostEqual(it.epoch_detail, 5 / 5) self.assertAlmostEqual(it.previous_epoch_detail, 4 / 5) self.assertRaises(StopIteration, it.next)
def test_iterator_not_repeat(self): dataset = [1, 2, 3, 4, 5, 6] it = iterators.SerialIterator(dataset, 2, repeat=False, shuffle=False) self.assertAlmostEqual(it.epoch_detail, 0 / 6) self.assertIsNone(it.previous_epoch_detail) self.assertEqual(it.next(), [1, 2]) self.assertAlmostEqual(it.epoch_detail, 2 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 0 / 6) self.assertEqual(it.next(), [3, 4]) self.assertAlmostEqual(it.epoch_detail, 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6) self.assertEqual(it.next(), [5, 6]) self.assertTrue(it.is_new_epoch) self.assertEqual(it.epoch, 1) self.assertAlmostEqual(it.epoch_detail, 6 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 4 / 6) for i in range(2): self.assertRaises(StopIteration, it.next)
def test_iterator_not_repeat_not_even(self): dataset = [1, 2, 3, 4, 5] it = iterators.SerialIterator(dataset, 2, repeat=False, shuffle=self.shuffle, order_sampler=self.order_sampler) self.assertAlmostEqual(it.epoch_detail, 0 / 5) self.assertIsNone(it.previous_epoch_detail) batch1 = it.next() self.assertAlmostEqual(it.epoch_detail, 2 / 5) self.assertAlmostEqual(it.previous_epoch_detail, 0 / 5) batch2 = it.next() self.assertAlmostEqual(it.epoch_detail, 4 / 5) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 5) batch3 = it.next() self.assertAlmostEqual(it.epoch_detail, 5 / 5) self.assertAlmostEqual(it.previous_epoch_detail, 4 / 5) self.assertRaises(StopIteration, it.next) self.assertEqual(len(batch3), 1) self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
def test_iterator_repeat(self): dataset = [1, 2, 3, 4, 5, 6] it = iterators.SerialIterator(dataset, 2, shuffle=False) for i in range(3): self.assertEqual(it.epoch, i) self.assertAlmostEqual(it.epoch_detail, i + 0 / 6) if i == 0: self.assertIsNone(it.previous_epoch_detail) else: self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6) self.assertEqual(it.next(), [1, 2]) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, i + 2 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 0 / 6) self.assertEqual(it.next(), [3, 4]) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, i + 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 2 / 6) self.assertEqual(it.next(), [5, 6]) self.assertTrue(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, i + 6 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 4 / 6)
def test_iterator_compatibilty(self): dataset = [1, 2, 3, 4, 5, 6] iters = ( lambda: iterators.SerialIterator(dataset, 2), lambda: iterators.MultiprocessIterator(dataset, 2, **self.options), ) for it_before, it_after in itertools.permutations(iters, 2): it = it_before() self.assertEqual(it.epoch, 0) self.assertAlmostEqual(it.epoch_detail, 0 / 6) batch1 = it.next() self.assertEqual(len(batch1), 2) self.assertIsInstance(batch1, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 2 / 6) batch2 = it.next() self.assertEqual(len(batch2), 2) self.assertIsInstance(batch2, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) state_dict = copy.deepcopy(it.state_dict()) it = it_after() it.load_state_dict(state_dict) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) batch3 = it.next() self.assertEqual(len(batch3), 2) self.assertIsInstance(batch3, list) self.assertTrue(it.is_new_epoch) self.assertEqual(sorted(batch1 + batch2 + batch3), dataset) self.assertAlmostEqual(it.epoch_detail, 6 / 6)
def test_iterator_shuffle_nondivisible(self): dataset = list(range(10)) it = iterators.SerialIterator(dataset, 3) out = sum([it.next() for _ in range(7)], []) self.assertNotEqual(out[0:10], out[10:20])
def test_phase(predictor, test, args): # setup an iterator test_iter = iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) # setup an inferencer predictor.load_state_dict( torch.load(os.path.join(args.out, 'predictor.pth'))) model = MCSampler(predictor, mc_iteration=args.mc_iteration, activation=[lambda x: x, torch.exp], reduce_mean=None, reduce_var=None) device = torch.device(args.gpu) model.to(device) infer = Inferencer(test_iter, model, device=args.gpu) pred, epistemic_uncert, aleatory_uncert, _ = infer.run() # visualize x = test.x.ravel() t = test.t.ravel() pred = pred.ravel() epistemic_uncert = epistemic_uncert.ravel() aleatory_uncert = aleatory_uncert.ravel() plt.rcParams['font.size'] = 18 plt.figure(figsize=(13, 5)) ax = sns.scatterplot(x=x, y=pred, color='blue', s=75) ax.errorbar(x, pred, yerr=epistemic_uncert, fmt='none', capsize=10, ecolor='gray', linewidth=1.5) ax.plot(x, t, color='red', linewidth=1.5) ax.set_xlabel('x') ax.set_ylabel('y') ax.set_xlim(-10, 10) ax.set_ylim(-15, 15) plt.legend(['Ground-truth', 'Prediction', 'Epistemic uncertainty']) plt.title('Result on testing data set') plt.tight_layout() plt.savefig(os.path.join(args.out, 'eval_epistemic.png')) plt.close() plt.rcParams['font.size'] = 18 plt.figure(figsize=(13, 5)) ax = sns.scatterplot(x=x, y=pred, color='blue', s=75) ax.errorbar(x, pred, yerr=aleatory_uncert, fmt='none', capsize=10, ecolor='gray', linewidth=1.5) ax.plot(x, t, color='red', linewidth=1.5) ax.set_xlabel('x') ax.set_ylabel('y') ax.set_xlim(-10, 10) ax.set_ylim(-15, 15) plt.legend(['Ground-truth', 'Prediction', 'Aleatoric uncertainty']) plt.title('Result on testing data set') plt.tight_layout() plt.savefig(os.path.join(args.out, 'eval_aleatoric.png')) plt.close()
def train_phase(predictor, train, valid, args): # visualize plt.rcParams['font.size'] = 18 plt.figure(figsize=(13, 5)) ax = sns.scatterplot(x=train.x.ravel(), y=train.y.ravel(), color='blue', s=55, alpha=0.3) ax.plot(train.x.ravel(), train.t.ravel(), color='red', linewidth=2) ax.set_xlabel('x') ax.set_ylabel('y') ax.set_xlim(-10, 10) ax.set_ylim(-15, 15) plt.legend(['Ground-truth', 'Observation']) plt.title('Training data set') plt.tight_layout() plt.savefig(os.path.join(args.out, 'train_dataset.png')) plt.close() # setup iterators train_iter = iterators.SerialIterator(train, args.batchsize, shuffle=True) valid_iter = iterators.SerialIterator(valid, args.batchsize, repeat=False, shuffle=False) # setup a model device = torch.device(args.gpu) lossfun = noised_mean_squared_error accfun = lambda y, t: F.l1_loss(y[0], t) model = Regressor(predictor, lossfun=lossfun, accfun=accfun) model.to(device) # setup an optimizer optimizer = torch.optim.Adam(model.parameters(), weight_decay=max(args.decay, 0)) # setup a trainer updater = training.updaters.StandardUpdater(train_iter, optimizer, model, device=device) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) trainer.extend(extensions.Evaluator(valid_iter, model, device=args.gpu)) # trainer.extend(DumpGraph(model, 'main/loss')) frequency = args.epoch if args.frequency == -1 else max(1, args.frequency) trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch')) trainer.extend(extensions.LogReport()) if args.plot and extensions.PlotReport.available(): trainer.extend( extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport( ['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png')) trainer.extend( extensions.PlotReport( ['main/predictor/sigma', 'validation/main/predictor/sigma'], 'epoch', file_name='sigma.png')) trainer.extend( extensions.PrintReport([ 'epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'main/predictor/sigma', 'validation/main/predictor/sigma', 'elapsed_time' ])) trainer.extend(extensions.ProgressBar()) if args.resume: trainer.load_state_dict(torch.load(args.resume)) trainer.run() torch.save(predictor.state_dict(), os.path.join(args.out, 'predictor.pth'))
def test_iterator_shuffle_divisible(self): dataset = list(range(10)) it = iterators.SerialIterator(dataset, 10, shuffle=self.shuffle, order_sampler=self.order_sampler) self.assertNotEqual(it.next(), it.next())
def test_phase(predictor, test, args): print('# samples:') print('-- test:', len(test)) test_iter = iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) # setup a inferencer snapshot_file = find_latest_snapshot( 'predictor_iter_{.updater.iteration:08}.pth', args.out) predictor.load_state_dict(torch.load(snapshot_file)) print('Loaded a snapshot:', snapshot_file) model = MCSampler(predictor, mc_iteration=args.mc_iteration, activation=partial(torch.softmax, dim=1), reduce_mean=partial(torch.argmax, dim=1), reduce_var=partial(torch.mean, dim=1)) device = torch.device(args.gpu) model.to(device) infer = Inferencer(test_iter, model, device=args.gpu) pred, uncert = infer.run() # evaluate os.makedirs(os.path.join(args.out, 'test'), exist_ok=True) acc_values = [] uncert_values = [] uncert_clim = (0, np.percentile(uncert, 95)) files = test.files['image'] if isinstance(files, np.ndarray): files = files.tolist() commonpath = os.path.commonpath(files) plt.rcParams['font.size'] = 14 for i, (p, u, imf, lbf) in enumerate( zip(pred, uncert, test.files['image'], test.files['label'])): im, _ = load_image(imf) im = im[:, :, ::-1] lb, _ = load_image(lbf) if lb.ndim == 3: lb = lb[:, :, 0] acc_values.append(eval_metric(p, lb)) uncert_values.append(np.mean(u[p == 1])) # NOTE: instrument class plt.figure(figsize=(20, 4)) for j, (pic, cmap, clim, title) in enumerate( zip([im, p, lb, u, (p != lb).astype(np.uint8)], [None, None, None, 'jet', 'jet'], [None, None, None, uncert_clim, None], [ 'Input image\n%s' % os.path.relpath(imf, commonpath), 'Predicted label\n(DC=%.3f)' % acc_values[-1], 'Ground-truth label', 'Predicted variance\n(PV=%.4f)' % uncert_values[-1], 'Error' ])): plt.subplot(1, 5, j + 1) plt.imshow(pic, cmap=cmap) plt.xticks([], []) plt.yticks([], []) plt.title(title) plt.clim(clim) plt.tight_layout() plt.savefig(os.path.join(args.out, 'test/%03d.png' % i)) plt.close() c = pearsonr(uncert_values, acc_values) plt.figure(figsize=(11, 11)) ax = sns.scatterplot(x=uncert_values, y=acc_values, color='blue', s=50) ax.set_xlabel('Predicted variance') ax.set_ylabel('Dice coefficient') plt.grid() plt.title('r=%.3f' % c[0]) plt.savefig(os.path.join(args.out, 'eval.png')) plt.close()
def test_phase(generator, test, args): print('# samples:') print('-- test:', len(test)) test_iter = iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) # setup a inferencer snapshot_file = find_latest_snapshot('generator_iter_{.updater.iteration:08}.pth', args.out) generator.load_state_dict(torch.load(snapshot_file)) print('Loaded a snapshot:', snapshot_file) model = MCSampler(generator, mc_iteration=args.mc_iteration, activation=torch.tanh, reduce_mean=None, reduce_var=partial(torch.mean, dim=1)) device = torch.device(args.gpu) model.to(device) infer = Inferencer(test_iter, model, device=args.gpu) pred, uncert = infer.run() # evaluate os.makedirs(os.path.join(args.out, 'test'), exist_ok=True) acc_values = [] uncert_values = [] uncert_clim = (0, np.percentile(uncert, 95)) error_clim = (0, 1) files = test.files['image'] if isinstance(files, np.ndarray): files = files.tolist() commonpath = os.path.commonpath(files) plt.rcParams['font.size'] = 14 for i, (p, u, imf, lbf) in enumerate(zip(pred, uncert, test.files['image'], test.files['label'])): im, _ = load_image(imf) lb, _ = load_image(lbf) im = im.astype(np.float32) lb = lb.astype(np.float32) p = p.transpose(1,2,0) im = (im[:,:,::-1] + 1.) / 2. lb = (lb[:,:,::-1] + 1.) / 2. p = (p[:,:,::-1] + 1.) / 2. error = np.mean(np.abs(p-lb), axis=-1) acc_values.append( eval_metric(p,lb) ) uncert_values.append( np.mean(u) ) plt.figure(figsize=(20,4)) for j, (pic, cmap, clim, title) in enumerate(zip( [im, p, lb, u, error], [None, None, None, 'jet', 'jet'], [None, None, None, uncert_clim, error_clim], ['Input image\n%s' % os.path.relpath(imf, commonpath), 'Predicted label\n(MAE=%.3f)' % acc_values[-1], 'Ground-truth label', 'Predicted variance\n(PV=%.4f)' % uncert_values[-1], 'Error'])): plt.subplot(1,5, j+1) plt.imshow(pic, cmap=cmap) plt.xticks([], []) plt.yticks([], []) plt.title(title) plt.clim(clim) plt.tight_layout() plt.savefig(os.path.join(args.out, 'test/%03d.png' % i)) plt.close()
def train_phase(predictor, train, valid, args): print('# classes:', train.n_classes) print('# samples:') print('-- train:', len(train)) print('-- valid:', len(valid)) # setup dataset iterators train_iter = iterators.MultiprocessIterator(train, args.batchsize) valid_iter = iterators.SerialIterator(valid, args.batchsize, repeat=False, shuffle=True) # setup a model class_weight = None # NOTE: please set if you have.. lossfun = partial(softmax_cross_entropy, normalize=False, class_weight=class_weight) device = torch.device(args.gpu) model = Classifier(predictor, lossfun=lossfun) model.to(device) # setup an optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=max(args.decay, 0)) # setup a trainer updater = training.updaters.StandardUpdater(train_iter, optimizer, model, device=device) trainer = training.Trainer(updater, (args.iteration, 'iteration'), out=args.out) frequency = max(args.iteration // 20, 1) if args.frequency == -1 else max(1, args.frequency) stop_trigger = triggers.EarlyStoppingTrigger( monitor='validation/main/loss', max_trigger=(args.iteration, 'iteration'), check_trigger=(frequency, 'iteration'), patients=np.inf if args.pinfall == -1 else max(1, args.pinfall)) trainer = training.Trainer(updater, stop_trigger, out=args.out) # setup a visualizer transforms = { 'x': lambda x: x, 'y': lambda x: np.argmax(x, axis=0), 't': lambda x: x } cmap = np.array([[0, 0, 0], [0, 0, 1]]) cmaps = {'x': None, 'y': cmap, 't': cmap} clims = {'x': 'minmax', 'y': None, 't': None} visualizer = ImageVisualizer(transforms=transforms, cmaps=cmaps, clims=clims) # setup a validator valid_file = os.path.join('validation', 'iter_{.updater.iteration:08}.png') trainer.extend(Validator(valid_iter, model, valid_file, visualizer=visualizer, n_vis=20, device=args.gpu), trigger=(frequency, 'iteration')) # trainer.extend(DumpGraph(model, 'main/loss')) trainer.extend(extensions.snapshot( filename='snapshot_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) trainer.extend(extensions.snapshot_object( predictor, 'predictor_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) log_keys = [ 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy' ] trainer.extend(LogReport(keys=log_keys)) # setup log ploter if extensions.PlotReport.available(): for plot_key in ['loss', 'accuracy']: plot_keys = [ key for key in log_keys if key.split('/')[-1].startswith(plot_key) ] trainer.extend( extensions.PlotReport(plot_keys, 'iteration', file_name=plot_key + '.png', trigger=(frequency, 'iteration'))) trainer.extend( PrintReport(['iteration'] + log_keys + ['elapsed_time'], n_step=100)) trainer.extend(extensions.ProgressBar()) if args.resume: trainer.load_state_dict(torch.load(args.resume)) # train trainer.run()
def train_phase(generator, train, valid, args): print('# samples:') print('-- train:', len(train)) print('-- valid:', len(valid)) # setup dataset iterators train_iter = iterators.SerialIterator(train, args.batchsize) valid_iter = iterators.SerialIterator(valid, args.batchsize, repeat=False, shuffle=True) # setup a model model = Regressor(generator, activation=torch.tanh, lossfun=F.l1_loss, accfun=F.l1_loss) discriminator = build_discriminator() discriminator.save_args(os.path.join(args.out, 'discriminator.json')) device = torch.device(args.gpu) model.to(device) discriminator.to(device) # setup an optimizer optimizer_G = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta, 0.999), weight_decay=max(args.decay, 0)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta, 0.999), weight_decay=max(args.decay, 0)) # setup a trainer updater = DCGANUpdater( iterator=train_iter, optimizer={ 'gen': optimizer_G, 'dis': optimizer_D, }, model={ 'gen': model, 'dis': discriminator, }, alpha=args.alpha, device=args.gpu, ) frequency = max(args.iteration//80, 1) if args.frequency == -1 else max(1, args.frequency) stop_trigger = triggers.EarlyStoppingTrigger(monitor='validation/main/loss', max_trigger=(args.iteration, 'iteration'), check_trigger=(frequency, 'iteration'), patients=np.inf if args.pinfall == -1 else max(1, args.pinfall)) trainer = training.Trainer(updater, stop_trigger, out=args.out) # shift lr trainer.extend( extensions.LinearShift('lr', (args.lr, 0.0), (args.iteration//2, args.iteration), optimizer=optimizer_G)) trainer.extend( extensions.LinearShift('lr', (args.lr, 0.0), (args.iteration//2, args.iteration), optimizer=optimizer_D)) # setup a visualizer transforms = {'x': lambda x: x, 'y': lambda x: x, 't': lambda x: x} clims = {'x': (-1., 1.), 'y': (-1., 1.), 't': (-1., 1.)} visualizer = ImageVisualizer(transforms=transforms, cmaps=None, clims=clims) # setup a validator valid_file = os.path.join('validation', 'iter_{.updater.iteration:08}.png') trainer.extend(Validator(valid_iter, model, valid_file, visualizer=visualizer, n_vis=20, device=args.gpu), trigger=(frequency, 'iteration')) # trainer.extend(DumpGraph('loss_gen', filename='generative_loss.dot')) # trainer.extend(DumpGraph('loss_cond', filename='conditional_loss.dot')) # trainer.extend(DumpGraph('loss_dis', filename='discriminative_loss.dot')) trainer.extend(extensions.snapshot(filename='snapshot_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) trainer.extend(extensions.snapshot_object(generator, 'generator_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) trainer.extend(extensions.snapshot_object(discriminator, 'discriminator_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) log_keys = ['loss_gen', 'loss_cond', 'loss_dis', 'validation/main/accuracy'] trainer.extend(LogReport(keys=log_keys, trigger=(100, 'iteration'))) # setup log ploter if extensions.PlotReport.available(): for plot_key in ['loss', 'accuracy']: plot_keys = [key for key in log_keys if key.split('/')[-1].startswith(plot_key)] trainer.extend( extensions.PlotReport(plot_keys, 'iteration', file_name=plot_key + '.png', trigger=(frequency, 'iteration')) ) trainer.extend(PrintReport(['iteration'] + log_keys + ['elapsed_time'], n_step=1)) trainer.extend(extensions.ProgressBar()) if args.resume: trainer.load_state_dict(torch.load(args.resume)) # train trainer.run()