def test_savefun_and_writer_exclusive(self): # savefun and writer arguments cannot be specified together. def savefun(*args, **kwargs): assert False writer = extensions.snapshot_writers.SimpleWriter() with pytest.raises(TypeError): extensions.snapshot(savefun=savefun, writer=writer) trainer = mock.MagicMock() with pytest.raises(TypeError): extensions.snapshot_object(trainer, savefun=savefun, writer=writer)
def train_phase(predictor, train, valid, args): # setup iterators train_iter = iterators.SerialIterator(train, args.batchsize) valid_iter = iterators.SerialIterator(valid, args.batchsize, repeat=False, shuffle=False) # setup a model device = torch.device(args.gpu) model = Classifier(predictor) model.to(device) # setup an optimizer optimizer = torch.optim.Adam(model.parameters(), weight_decay=max(args.decay, 0)) # setup a trainer updater = training.updaters.StandardUpdater(train_iter, optimizer, model, device=device) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) trainer.extend(extensions.Evaluator(valid_iter, model, device=args.gpu)) # trainer.extend(DumpGraph(model, 'main/loss')) frequency = args.epoch if args.frequency == -1 else max(1, args.frequency) trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch')) trainer.extend(extensions.LogReport()) if args.plot and extensions.PlotReport.available(): trainer.extend( extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport( ['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png')) trainer.extend( extensions.PrintReport([ 'epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time' ])) trainer.extend(extensions.ProgressBar()) if args.resume: trainer.load_state_dict(torch.load(args.resume)) trainer.run() torch.save(predictor.state_dict(), os.path.join(args.out, 'predictor.pth'))
def test_remove_stale_snapshots(self): fmt = 'snapshot_iter_{.updater.iteration}' retain = 3 snapshot = extensions.snapshot(filename=fmt, n_retains=retain, autoload=False) trainer = testing.get_trainer_with_mock_updater() trainer.out = self.path trainer.extend(snapshot, trigger=(1, 'iteration'), priority=2) class TimeStampUpdater(): t = time.time() - 100 name = 'ts_updater' priority = 1 # This must be called after snapshot taken def __call__(self, _trainer): filename = os.path.join(_trainer.out, fmt.format(_trainer)) self.t += 1 # For filesystems that does low timestamp precision os.utime(filename, (self.t, self.t)) trainer.extend(TimeStampUpdater(), trigger=(1, 'iteration')) trainer.run() assert 10 == trainer.updater.iteration assert trainer._done pattern = os.path.join(trainer.out, "snapshot_iter_*") found = [os.path.basename(path) for path in glob.glob(pattern)] assert retain == len(found) found.sort() # snapshot_iter_(8, 9, 10) expected expected = ['snapshot_iter_{}'.format(i) for i in range(8, 11)] expected.sort() assert expected == found trainer2 = testing.get_trainer_with_mock_updater() trainer2.out = self.path assert not trainer2._done snapshot2 = extensions.snapshot(filename=fmt, autoload=True) # Just making sure no error occurs snapshot2.initialize(trainer2)
def test_call(self): t = mock.MagicMock() c = mock.MagicMock(side_effect=[True, False]) w = mock.MagicMock() snapshot = extensions.snapshot(target=t, condition=c, writer=w) trainer = mock.MagicMock() snapshot(trainer) snapshot(trainer) assert c.call_count == 2 assert w.call_count == 1
def train_phase(predictor, train, valid, args): print('# classes:', train.n_classes) print('# samples:') print('-- train:', len(train)) print('-- valid:', len(valid)) # setup dataset iterators train_iter = iterators.MultiprocessIterator(train, args.batchsize) valid_iter = iterators.SerialIterator(valid, args.batchsize, repeat=False, shuffle=True) # setup a model class_weight = None # NOTE: please set if you have.. lossfun = partial(softmax_cross_entropy, normalize=False, class_weight=class_weight) device = torch.device(args.gpu) model = Classifier(predictor, lossfun=lossfun) model.to(device) # setup an optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=max(args.decay, 0)) # setup a trainer updater = training.updaters.StandardUpdater(train_iter, optimizer, model, device=device) trainer = training.Trainer(updater, (args.iteration, 'iteration'), out=args.out) frequency = max(args.iteration // 20, 1) if args.frequency == -1 else max(1, args.frequency) stop_trigger = triggers.EarlyStoppingTrigger( monitor='validation/main/loss', max_trigger=(args.iteration, 'iteration'), check_trigger=(frequency, 'iteration'), patients=np.inf if args.pinfall == -1 else max(1, args.pinfall)) trainer = training.Trainer(updater, stop_trigger, out=args.out) # setup a visualizer transforms = { 'x': lambda x: x, 'y': lambda x: np.argmax(x, axis=0), 't': lambda x: x } cmap = np.array([[0, 0, 0], [0, 0, 1]]) cmaps = {'x': None, 'y': cmap, 't': cmap} clims = {'x': 'minmax', 'y': None, 't': None} visualizer = ImageVisualizer(transforms=transforms, cmaps=cmaps, clims=clims) # setup a validator valid_file = os.path.join('validation', 'iter_{.updater.iteration:08}.png') trainer.extend(Validator(valid_iter, model, valid_file, visualizer=visualizer, n_vis=20, device=args.gpu), trigger=(frequency, 'iteration')) # trainer.extend(DumpGraph(model, 'main/loss')) trainer.extend(extensions.snapshot( filename='snapshot_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) trainer.extend(extensions.snapshot_object( predictor, 'predictor_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) log_keys = [ 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy' ] trainer.extend(LogReport(keys=log_keys)) # setup log ploter if extensions.PlotReport.available(): for plot_key in ['loss', 'accuracy']: plot_keys = [ key for key in log_keys if key.split('/')[-1].startswith(plot_key) ] trainer.extend( extensions.PlotReport(plot_keys, 'iteration', file_name=plot_key + '.png', trigger=(frequency, 'iteration'))) trainer.extend( PrintReport(['iteration'] + log_keys + ['elapsed_time'], n_step=100)) trainer.extend(extensions.ProgressBar()) if args.resume: trainer.load_state_dict(torch.load(args.resume)) # train trainer.run()
def train_phase(predictor, train, valid, args): # visualize plt.rcParams['font.size'] = 18 plt.figure(figsize=(13, 5)) ax = sns.scatterplot(x=train.x.ravel(), y=train.y.ravel(), color='blue', s=55, alpha=0.3) ax.plot(train.x.ravel(), train.t.ravel(), color='red', linewidth=2) ax.set_xlabel('x') ax.set_ylabel('y') ax.set_xlim(-10, 10) ax.set_ylim(-15, 15) plt.legend(['Ground-truth', 'Observation']) plt.title('Training data set') plt.tight_layout() plt.savefig(os.path.join(args.out, 'train_dataset.png')) plt.close() # setup iterators train_iter = iterators.SerialIterator(train, args.batchsize, shuffle=True) valid_iter = iterators.SerialIterator(valid, args.batchsize, repeat=False, shuffle=False) # setup a model device = torch.device(args.gpu) lossfun = noised_mean_squared_error accfun = lambda y, t: F.l1_loss(y[0], t) model = Regressor(predictor, lossfun=lossfun, accfun=accfun) model.to(device) # setup an optimizer optimizer = torch.optim.Adam(model.parameters(), weight_decay=max(args.decay, 0)) # setup a trainer updater = training.updaters.StandardUpdater(train_iter, optimizer, model, device=device) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) trainer.extend(extensions.Evaluator(valid_iter, model, device=args.gpu)) # trainer.extend(DumpGraph(model, 'main/loss')) frequency = args.epoch if args.frequency == -1 else max(1, args.frequency) trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch')) trainer.extend(extensions.LogReport()) if args.plot and extensions.PlotReport.available(): trainer.extend( extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport( ['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png')) trainer.extend( extensions.PlotReport( ['main/predictor/sigma', 'validation/main/predictor/sigma'], 'epoch', file_name='sigma.png')) trainer.extend( extensions.PrintReport([ 'epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'main/predictor/sigma', 'validation/main/predictor/sigma', 'elapsed_time' ])) trainer.extend(extensions.ProgressBar()) if args.resume: trainer.load_state_dict(torch.load(args.resume)) trainer.run() torch.save(predictor.state_dict(), os.path.join(args.out, 'predictor.pth'))
def main(): parser = argparse.ArgumentParser(description='Chainer example: MNIST') parser.add_argument('--batchsize', '-b', type=int, default=100, help='Number of images in each mini-batch') parser.add_argument('--epoch', '-e', type=int, default=20, help='Number of sweeps over the dataset to train') parser.add_argument('--frequency', '-f', type=int, default=-1, help='Frequency of taking a snapshot') parser.add_argument('--device', '-d', type=str, default='-1', help='Device specifier. Either ChainerX device ' 'specifier or an integer. If non-negative integer, ' 'CuPy arrays with specified device id are used. If ' 'negative integer, NumPy arrays are used') parser.add_argument('--out', '-o', default='result', help='Directory to output the result') parser.add_argument('--resume', '-r', type=str, help='Resume the training from snapshot') parser.add_argument('--autoload', action='store_true', help='Automatically load trainer snapshots in case' ' of preemption or other temporary system failure') parser.add_argument('--unit', '-u', type=int, default=1000, help='Number of units') group = parser.add_argument_group('deprecated arguments') group.add_argument('--gpu', '-g', dest='device', type=int, nargs='?', const=0, help='GPU ID (negative value indicates CPU)') args = parser.parse_args() device = torch.device(args.device) print('Device: {}'.format(device)) print('# Minibatch-size: {}'.format(args.batchsize)) print('# epoch: {}'.format(args.epoch)) print('') # Set up a neural network to train # Classifier reports softmax cross entropy loss and accuracy at every # iteration, which will be used by the PrintReport extension below. model = Classifier(MLP(784, args.unit, 10)) model.to(device) # Setup an optimizer optimizer = torch.optim.Adam(model.parameters()) # Load the MNIST dataset transform = transforms.ToTensor() train = datasets.MNIST('data', train=True, download=True, transform=transform) test = datasets.MNIST('data', train=False, transform=transform) train_iter = pytorch_trainer.iterators.SerialIterator( train, args.batchsize) test_iter = pytorch_trainer.iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) # Set up a trainer updater = training.updaters.StandardUpdater(train_iter, optimizer, model, device=device) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) # Evaluate the model with the test dataset for each epoch trainer.extend(extensions.Evaluator(test_iter, model, device=device), call_before_training=True) # Dump a computational graph from 'loss' variable at the first iteration # The "main" refers to the target link of the "main" optimizer. # trainer.extend(extensions.DumpGraph('main/loss')) # Take a snapshot for each specified epoch frequency = args.epoch if args.frequency == -1 else max(1, args.frequency) # Take a snapshot each ``frequency`` epoch, delete old stale # snapshots and automatically load from snapshot files if any # files are already resident at result directory. trainer.extend(extensions.snapshot(n_retains=1, autoload=args.autoload), trigger=(frequency, 'epoch')) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(), call_before_training=True) # Save two plot images to the result dir trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png'), call_before_training=True) trainer.extend(extensions.PlotReport( ['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png'), call_before_training=True) # Print selected entries of the log to stdout # Here "main" refers to the target link of the "main" optimizer again, and # "validation" refers to the default name of the Evaluator extension. # Entries other than 'epoch' are reported by the Classifier link, called by # either the updater or the evaluator. trainer.extend(extensions.PrintReport([ 'epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time' ]), call_before_training=True) # Print a progress bar to stdout trainer.extend(extensions.ProgressBar()) if args.resume is not None: # Resume from a snapshot (Note: this loaded model is to be # overwritten by --autoload option, autoloading snapshots, if # any snapshots exist in output directory) trainer.load_state_dict(torch.load(args.resume)) # Run the training trainer.run()
def train_phase(generator, train, valid, args): print('# samples:') print('-- train:', len(train)) print('-- valid:', len(valid)) # setup dataset iterators train_iter = iterators.SerialIterator(train, args.batchsize) valid_iter = iterators.SerialIterator(valid, args.batchsize, repeat=False, shuffle=True) # setup a model model = Regressor(generator, activation=torch.tanh, lossfun=F.l1_loss, accfun=F.l1_loss) discriminator = build_discriminator() discriminator.save_args(os.path.join(args.out, 'discriminator.json')) device = torch.device(args.gpu) model.to(device) discriminator.to(device) # setup an optimizer optimizer_G = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta, 0.999), weight_decay=max(args.decay, 0)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta, 0.999), weight_decay=max(args.decay, 0)) # setup a trainer updater = DCGANUpdater( iterator=train_iter, optimizer={ 'gen': optimizer_G, 'dis': optimizer_D, }, model={ 'gen': model, 'dis': discriminator, }, alpha=args.alpha, device=args.gpu, ) frequency = max(args.iteration//80, 1) if args.frequency == -1 else max(1, args.frequency) stop_trigger = triggers.EarlyStoppingTrigger(monitor='validation/main/loss', max_trigger=(args.iteration, 'iteration'), check_trigger=(frequency, 'iteration'), patients=np.inf if args.pinfall == -1 else max(1, args.pinfall)) trainer = training.Trainer(updater, stop_trigger, out=args.out) # shift lr trainer.extend( extensions.LinearShift('lr', (args.lr, 0.0), (args.iteration//2, args.iteration), optimizer=optimizer_G)) trainer.extend( extensions.LinearShift('lr', (args.lr, 0.0), (args.iteration//2, args.iteration), optimizer=optimizer_D)) # setup a visualizer transforms = {'x': lambda x: x, 'y': lambda x: x, 't': lambda x: x} clims = {'x': (-1., 1.), 'y': (-1., 1.), 't': (-1., 1.)} visualizer = ImageVisualizer(transforms=transforms, cmaps=None, clims=clims) # setup a validator valid_file = os.path.join('validation', 'iter_{.updater.iteration:08}.png') trainer.extend(Validator(valid_iter, model, valid_file, visualizer=visualizer, n_vis=20, device=args.gpu), trigger=(frequency, 'iteration')) # trainer.extend(DumpGraph('loss_gen', filename='generative_loss.dot')) # trainer.extend(DumpGraph('loss_cond', filename='conditional_loss.dot')) # trainer.extend(DumpGraph('loss_dis', filename='discriminative_loss.dot')) trainer.extend(extensions.snapshot(filename='snapshot_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) trainer.extend(extensions.snapshot_object(generator, 'generator_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) trainer.extend(extensions.snapshot_object(discriminator, 'discriminator_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) log_keys = ['loss_gen', 'loss_cond', 'loss_dis', 'validation/main/accuracy'] trainer.extend(LogReport(keys=log_keys, trigger=(100, 'iteration'))) # setup log ploter if extensions.PlotReport.available(): for plot_key in ['loss', 'accuracy']: plot_keys = [key for key in log_keys if key.split('/')[-1].startswith(plot_key)] trainer.extend( extensions.PlotReport(plot_keys, 'iteration', file_name=plot_key + '.png', trigger=(frequency, 'iteration')) ) trainer.extend(PrintReport(['iteration'] + log_keys + ['elapsed_time'], n_step=1)) trainer.extend(extensions.ProgressBar()) if args.resume: trainer.load_state_dict(torch.load(args.resume)) # train trainer.run()