def prepare(self, dirname='test', device=None): outdir = os.path.join(self.temp_dir, dirname) self.updater = training.updaters.StandardUpdater( self.iterator, self.optimizer, self.model, device=device) self.trainer = training.Trainer( self.updater, (self.n_epochs, 'epoch'), out=outdir) self.trainer.extend(training.extensions.FailOnNonNumber())
def train_phase(predictor, train, valid, args): # setup iterators train_iter = iterators.SerialIterator(train, args.batchsize) valid_iter = iterators.SerialIterator(valid, args.batchsize, repeat=False, shuffle=False) # setup a model device = torch.device(args.gpu) model = Classifier(predictor) model.to(device) # setup an optimizer optimizer = torch.optim.Adam(model.parameters(), weight_decay=max(args.decay, 0)) # setup a trainer updater = training.updaters.StandardUpdater(train_iter, optimizer, model, device=device) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) trainer.extend(extensions.Evaluator(valid_iter, model, device=args.gpu)) # trainer.extend(DumpGraph(model, 'main/loss')) frequency = args.epoch if args.frequency == -1 else max(1, args.frequency) trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch')) trainer.extend(extensions.LogReport()) if args.plot and extensions.PlotReport.available(): trainer.extend( extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport( ['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png')) trainer.extend( extensions.PrintReport([ 'epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time' ])) trainer.extend(extensions.ProgressBar()) if args.resume: trainer.load_state_dict(torch.load(args.resume)) trainer.run() torch.save(predictor.state_dict(), os.path.join(args.out, 'predictor.pth'))
def get_trainer_with_mock_updater( stop_trigger=(10, 'iteration'), iter_per_epoch=10, extensions=None): """Returns a :class:`~pytorch_trainer.training.Trainer` object with mock updater. The returned trainer can be used for testing the trainer itself and the extensions. A mock object is used as its updater. The update function set to the mock correctly increments the iteration counts ( ``updater.iteration``), and thus you can write a test relying on it. Args: stop_trigger: Stop trigger of the trainer. iter_per_epoch: The number of iterations per epoch. extensions: Extensions registered to the trainer. Returns: Trainer object with a mock updater. """ if extensions is None: extensions = [] check_available() updater = mock.Mock() updater.get_all_models.return_value = {} updater.iteration = 0 updater.epoch = 0 updater.epoch_detail = 0 updater.is_new_epoch = True updater.previous_epoch_detail = None updater.state_dict.return_value = {} # dummy state dict def update(): updater.update_core() updater.iteration += 1 updater.epoch = updater.iteration // iter_per_epoch updater.epoch_detail = updater.iteration / iter_per_epoch updater.is_new_epoch = (updater.iteration - 1) // \ iter_per_epoch != updater.epoch updater.previous_epoch_detail = (updater.iteration - 1) \ / iter_per_epoch updater.update = update trainer = training.Trainer(updater, stop_trigger, extensions=extensions) return trainer
def train_phase(predictor, train, valid, args): print('# classes:', train.n_classes) print('# samples:') print('-- train:', len(train)) print('-- valid:', len(valid)) # setup dataset iterators train_iter = iterators.MultiprocessIterator(train, args.batchsize) valid_iter = iterators.SerialIterator(valid, args.batchsize, repeat=False, shuffle=True) # setup a model class_weight = None # NOTE: please set if you have.. lossfun = partial(softmax_cross_entropy, normalize=False, class_weight=class_weight) device = torch.device(args.gpu) model = Classifier(predictor, lossfun=lossfun) model.to(device) # setup an optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=max(args.decay, 0)) # setup a trainer updater = training.updaters.StandardUpdater(train_iter, optimizer, model, device=device) trainer = training.Trainer(updater, (args.iteration, 'iteration'), out=args.out) frequency = max(args.iteration // 20, 1) if args.frequency == -1 else max(1, args.frequency) stop_trigger = triggers.EarlyStoppingTrigger( monitor='validation/main/loss', max_trigger=(args.iteration, 'iteration'), check_trigger=(frequency, 'iteration'), patients=np.inf if args.pinfall == -1 else max(1, args.pinfall)) trainer = training.Trainer(updater, stop_trigger, out=args.out) # setup a visualizer transforms = { 'x': lambda x: x, 'y': lambda x: np.argmax(x, axis=0), 't': lambda x: x } cmap = np.array([[0, 0, 0], [0, 0, 1]]) cmaps = {'x': None, 'y': cmap, 't': cmap} clims = {'x': 'minmax', 'y': None, 't': None} visualizer = ImageVisualizer(transforms=transforms, cmaps=cmaps, clims=clims) # setup a validator valid_file = os.path.join('validation', 'iter_{.updater.iteration:08}.png') trainer.extend(Validator(valid_iter, model, valid_file, visualizer=visualizer, n_vis=20, device=args.gpu), trigger=(frequency, 'iteration')) # trainer.extend(DumpGraph(model, 'main/loss')) trainer.extend(extensions.snapshot( filename='snapshot_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) trainer.extend(extensions.snapshot_object( predictor, 'predictor_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) log_keys = [ 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy' ] trainer.extend(LogReport(keys=log_keys)) # setup log ploter if extensions.PlotReport.available(): for plot_key in ['loss', 'accuracy']: plot_keys = [ key for key in log_keys if key.split('/')[-1].startswith(plot_key) ] trainer.extend( extensions.PlotReport(plot_keys, 'iteration', file_name=plot_key + '.png', trigger=(frequency, 'iteration'))) trainer.extend( PrintReport(['iteration'] + log_keys + ['elapsed_time'], n_step=100)) trainer.extend(extensions.ProgressBar()) if args.resume: trainer.load_state_dict(torch.load(args.resume)) # train trainer.run()
def train_phase(predictor, train, valid, args): # visualize plt.rcParams['font.size'] = 18 plt.figure(figsize=(13, 5)) ax = sns.scatterplot(x=train.x.ravel(), y=train.y.ravel(), color='blue', s=55, alpha=0.3) ax.plot(train.x.ravel(), train.t.ravel(), color='red', linewidth=2) ax.set_xlabel('x') ax.set_ylabel('y') ax.set_xlim(-10, 10) ax.set_ylim(-15, 15) plt.legend(['Ground-truth', 'Observation']) plt.title('Training data set') plt.tight_layout() plt.savefig(os.path.join(args.out, 'train_dataset.png')) plt.close() # setup iterators train_iter = iterators.SerialIterator(train, args.batchsize, shuffle=True) valid_iter = iterators.SerialIterator(valid, args.batchsize, repeat=False, shuffle=False) # setup a model device = torch.device(args.gpu) lossfun = noised_mean_squared_error accfun = lambda y, t: F.l1_loss(y[0], t) model = Regressor(predictor, lossfun=lossfun, accfun=accfun) model.to(device) # setup an optimizer optimizer = torch.optim.Adam(model.parameters(), weight_decay=max(args.decay, 0)) # setup a trainer updater = training.updaters.StandardUpdater(train_iter, optimizer, model, device=device) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) trainer.extend(extensions.Evaluator(valid_iter, model, device=args.gpu)) # trainer.extend(DumpGraph(model, 'main/loss')) frequency = args.epoch if args.frequency == -1 else max(1, args.frequency) trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch')) trainer.extend(extensions.LogReport()) if args.plot and extensions.PlotReport.available(): trainer.extend( extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport( ['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png')) trainer.extend( extensions.PlotReport( ['main/predictor/sigma', 'validation/main/predictor/sigma'], 'epoch', file_name='sigma.png')) trainer.extend( extensions.PrintReport([ 'epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'main/predictor/sigma', 'validation/main/predictor/sigma', 'elapsed_time' ])) trainer.extend(extensions.ProgressBar()) if args.resume: trainer.load_state_dict(torch.load(args.resume)) trainer.run() torch.save(predictor.state_dict(), os.path.join(args.out, 'predictor.pth'))
def main(): parser = argparse.ArgumentParser(description='Chainer example: MNIST') parser.add_argument('--batchsize', '-b', type=int, default=100, help='Number of images in each mini-batch') parser.add_argument('--epoch', '-e', type=int, default=20, help='Number of sweeps over the dataset to train') parser.add_argument('--frequency', '-f', type=int, default=-1, help='Frequency of taking a snapshot') parser.add_argument('--device', '-d', type=str, default='-1', help='Device specifier. Either ChainerX device ' 'specifier or an integer. If non-negative integer, ' 'CuPy arrays with specified device id are used. If ' 'negative integer, NumPy arrays are used') parser.add_argument('--out', '-o', default='result', help='Directory to output the result') parser.add_argument('--resume', '-r', type=str, help='Resume the training from snapshot') parser.add_argument('--autoload', action='store_true', help='Automatically load trainer snapshots in case' ' of preemption or other temporary system failure') parser.add_argument('--unit', '-u', type=int, default=1000, help='Number of units') group = parser.add_argument_group('deprecated arguments') group.add_argument('--gpu', '-g', dest='device', type=int, nargs='?', const=0, help='GPU ID (negative value indicates CPU)') args = parser.parse_args() device = torch.device(args.device) print('Device: {}'.format(device)) print('# Minibatch-size: {}'.format(args.batchsize)) print('# epoch: {}'.format(args.epoch)) print('') # Set up a neural network to train # Classifier reports softmax cross entropy loss and accuracy at every # iteration, which will be used by the PrintReport extension below. model = Classifier(MLP(784, args.unit, 10)) model.to(device) # Setup an optimizer optimizer = torch.optim.Adam(model.parameters()) # Load the MNIST dataset transform = transforms.ToTensor() train = datasets.MNIST('data', train=True, download=True, transform=transform) test = datasets.MNIST('data', train=False, transform=transform) train_iter = pytorch_trainer.iterators.SerialIterator( train, args.batchsize) test_iter = pytorch_trainer.iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) # Set up a trainer updater = training.updaters.StandardUpdater(train_iter, optimizer, model, device=device) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) # Evaluate the model with the test dataset for each epoch trainer.extend(extensions.Evaluator(test_iter, model, device=device), call_before_training=True) # Dump a computational graph from 'loss' variable at the first iteration # The "main" refers to the target link of the "main" optimizer. # trainer.extend(extensions.DumpGraph('main/loss')) # Take a snapshot for each specified epoch frequency = args.epoch if args.frequency == -1 else max(1, args.frequency) # Take a snapshot each ``frequency`` epoch, delete old stale # snapshots and automatically load from snapshot files if any # files are already resident at result directory. trainer.extend(extensions.snapshot(n_retains=1, autoload=args.autoload), trigger=(frequency, 'epoch')) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(), call_before_training=True) # Save two plot images to the result dir trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png'), call_before_training=True) trainer.extend(extensions.PlotReport( ['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png'), call_before_training=True) # Print selected entries of the log to stdout # Here "main" refers to the target link of the "main" optimizer again, and # "validation" refers to the default name of the Evaluator extension. # Entries other than 'epoch' are reported by the Classifier link, called by # either the updater or the evaluator. trainer.extend(extensions.PrintReport([ 'epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time' ]), call_before_training=True) # Print a progress bar to stdout trainer.extend(extensions.ProgressBar()) if args.resume is not None: # Resume from a snapshot (Note: this loaded model is to be # overwritten by --autoload option, autoloading snapshots, if # any snapshots exist in output directory) trainer.load_state_dict(torch.load(args.resume)) # Run the training trainer.run()
def train_phase(generator, train, valid, args): print('# samples:') print('-- train:', len(train)) print('-- valid:', len(valid)) # setup dataset iterators train_iter = iterators.SerialIterator(train, args.batchsize) valid_iter = iterators.SerialIterator(valid, args.batchsize, repeat=False, shuffle=True) # setup a model model = Regressor(generator, activation=torch.tanh, lossfun=F.l1_loss, accfun=F.l1_loss) discriminator = build_discriminator() discriminator.save_args(os.path.join(args.out, 'discriminator.json')) device = torch.device(args.gpu) model.to(device) discriminator.to(device) # setup an optimizer optimizer_G = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta, 0.999), weight_decay=max(args.decay, 0)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta, 0.999), weight_decay=max(args.decay, 0)) # setup a trainer updater = DCGANUpdater( iterator=train_iter, optimizer={ 'gen': optimizer_G, 'dis': optimizer_D, }, model={ 'gen': model, 'dis': discriminator, }, alpha=args.alpha, device=args.gpu, ) frequency = max(args.iteration//80, 1) if args.frequency == -1 else max(1, args.frequency) stop_trigger = triggers.EarlyStoppingTrigger(monitor='validation/main/loss', max_trigger=(args.iteration, 'iteration'), check_trigger=(frequency, 'iteration'), patients=np.inf if args.pinfall == -1 else max(1, args.pinfall)) trainer = training.Trainer(updater, stop_trigger, out=args.out) # shift lr trainer.extend( extensions.LinearShift('lr', (args.lr, 0.0), (args.iteration//2, args.iteration), optimizer=optimizer_G)) trainer.extend( extensions.LinearShift('lr', (args.lr, 0.0), (args.iteration//2, args.iteration), optimizer=optimizer_D)) # setup a visualizer transforms = {'x': lambda x: x, 'y': lambda x: x, 't': lambda x: x} clims = {'x': (-1., 1.), 'y': (-1., 1.), 't': (-1., 1.)} visualizer = ImageVisualizer(transforms=transforms, cmaps=None, clims=clims) # setup a validator valid_file = os.path.join('validation', 'iter_{.updater.iteration:08}.png') trainer.extend(Validator(valid_iter, model, valid_file, visualizer=visualizer, n_vis=20, device=args.gpu), trigger=(frequency, 'iteration')) # trainer.extend(DumpGraph('loss_gen', filename='generative_loss.dot')) # trainer.extend(DumpGraph('loss_cond', filename='conditional_loss.dot')) # trainer.extend(DumpGraph('loss_dis', filename='discriminative_loss.dot')) trainer.extend(extensions.snapshot(filename='snapshot_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) trainer.extend(extensions.snapshot_object(generator, 'generator_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) trainer.extend(extensions.snapshot_object(discriminator, 'discriminator_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) log_keys = ['loss_gen', 'loss_cond', 'loss_dis', 'validation/main/accuracy'] trainer.extend(LogReport(keys=log_keys, trigger=(100, 'iteration'))) # setup log ploter if extensions.PlotReport.available(): for plot_key in ['loss', 'accuracy']: plot_keys = [key for key in log_keys if key.split('/')[-1].startswith(plot_key)] trainer.extend( extensions.PlotReport(plot_keys, 'iteration', file_name=plot_key + '.png', trigger=(frequency, 'iteration')) ) trainer.extend(PrintReport(['iteration'] + log_keys + ['elapsed_time'], n_step=1)) trainer.extend(extensions.ProgressBar()) if args.resume: trainer.load_state_dict(torch.load(args.resume)) # train trainer.run()