def main(): # Check if GPU is available # (ImageNet example does not support CPU execution) if not chainer.cuda.available: raise RuntimeError('ImageNet requires GPU support.') archs = { 'alex': alex.Alex, 'googlenet': googlenet.GoogLeNet, 'googlenetbn': googlenetbn.GoogLeNetBN, 'nin': nin.NIN, 'resnet50': resnet50.ResNet50, } parser = argparse.ArgumentParser( description='Learning convnet from ILSVRC2012 dataset') parser.add_argument('train', help='Path to training image-label list file') parser.add_argument('val', help='Path to validation image-label list file') parser.add_argument('--arch', '-a', choices=archs.keys(), default='nin', help='Convnet architecture') parser.add_argument('--batchsize', '-B', type=int, default=32, help='Learning minibatch size') parser.add_argument('--epoch', '-E', type=int, default=10, help='Number of epochs to train') parser.add_argument('--initmodel', help='Initialize the model from given file') parser.add_argument('--loaderjob', '-j', type=int, help='Number of parallel data loading processes') parser.add_argument('--mean', '-m', default='mean.npy', help='Mean file (computed by compute_mean.py)') parser.add_argument('--resume', '-r', default='', help='Initialize the trainer from given file') parser.add_argument('--out', '-o', default='result', help='Output directory') parser.add_argument('--root', '-R', default='.', help='Root directory path of image files') parser.add_argument('--val_batchsize', '-b', type=int, default=250, help='Validation minibatch size') parser.add_argument('--test', action='store_true') parser.add_argument('--communicator', default='hierarchical') parser.set_defaults(test=False) args = parser.parse_args() # Start method of multiprocessing module need to be changed if we # are using InfiniBand and MultiprocessIterator. This is because # processes often crash when calling fork if they are using # Infiniband. (c.f., # https://www.open-mpi.org/faq/?category=tuning#fork-warning ) # Also, just setting the start method does not seem to be # sufficient to actually launch the forkserver processes, so also # start a dummy process. # See also our document: # https://chainermn.readthedocs.io/en/stable/tutorial/tips_faqs.html#using-multiprocessiterator # This must be done *before* ``chainermn.create_communicator``!!! multiprocessing.set_start_method('forkserver') p = multiprocessing.Process(target=lambda *x: x, args=()) p.start() p.join() # Prepare ChainerMN communicator. comm = chainermn.create_communicator(args.communicator) device = comm.intra_rank if comm.rank == 0: print('==========================================') print('Num process (COMM_WORLD): {}'.format(comm.size)) print('Using {} communicator'.format(args.communicator)) print('Using {} arch'.format(args.arch)) print('Num Minibatch-size: {}'.format(args.batchsize)) print('Num epoch: {}'.format(args.epoch)) print('==========================================') model = archs[args.arch]() if args.initmodel: print('Load model from', args.initmodel) chainer.serializers.load_npz(args.initmodel, model) chainer.cuda.get_device_from_id(device).use() # Make the GPU current model.to_gpu() # Split and distribute the dataset. Only worker 0 loads the whole dataset. # Datasets of worker 0 are evenly split and distributed to all workers. mean = np.load(args.mean) if comm.rank == 0: train = PreprocessedDataset(args.train, args.root, mean, model.insize) val = PreprocessedDataset(args.val, args.root, mean, model.insize, False) else: train = None val = None train = chainermn.scatter_dataset(train, comm, shuffle=True) val = chainermn.scatter_dataset(val, comm) # A workaround for processes crash should be done before making # communicator above, when using fork (e.g. MultiProcessIterator) # along with Infiniband. train_iter = chainer.iterators.MultiprocessIterator( train, args.batchsize, n_processes=args.loaderjob) val_iter = chainer.iterators.MultiprocessIterator( val, args.val_batchsize, repeat=False, n_processes=args.loaderjob) # Create a multi node optimizer from a standard Chainer optimizer. optimizer = chainermn.create_multi_node_optimizer( chainer.optimizers.MomentumSGD(lr=0.01, momentum=0.9), comm) optimizer.setup(model) # Set up a trainer updater = training.StandardUpdater(train_iter, optimizer, device=device) trainer = training.Trainer(updater, (args.epoch, 'epoch'), args.out) checkpoint_interval = (10, 'iteration') if args.test else (1, 'epoch') val_interval = (10, 'iteration') if args.test else (1, 'epoch') log_interval = (10, 'iteration') if args.test else (1, 'epoch') checkpointer = chainermn.create_multi_node_checkpointer( name='imagenet-example', comm=comm) checkpointer.maybe_load(trainer, optimizer) trainer.extend(checkpointer, trigger=checkpoint_interval) # Create a multi node evaluator from an evaluator. evaluator = TestModeEvaluator(val_iter, model, device=device) evaluator = chainermn.create_multi_node_evaluator(evaluator, comm) trainer.extend(evaluator, trigger=val_interval) # Some display and output extensions are necessary only for one worker. # (Otherwise, there would just be repeated outputs.) if comm.rank == 0: trainer.extend(extensions.DumpGraph('main/loss')) trainer.extend(extensions.LogReport(trigger=log_interval)) trainer.extend(extensions.observe_lr(), trigger=log_interval) trainer.extend(extensions.PrintReport([ 'epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'lr' ]), trigger=log_interval) trainer.extend(extensions.ProgressBar(update_interval=10)) if args.resume: chainer.serializers.load_npz(args.resume, trainer) trainer.run()
def main(): # Check if GPU is available # (ImageNet example does not support CPU execution) if not chainer.cuda.available: raise RuntimeError("ImageNet requires GPU support.") archs = { 'alex': alex.Alex, 'googlenet': googlenet.GoogLeNet, 'googlenetbn': googlenetbn.GoogLeNetBN, 'nin': nin.NIN, 'resnet50': resnet50.ResNet50, } parser = argparse.ArgumentParser( description='Learning convnet from ILSVRC2012 dataset') parser.add_argument('train', help='Path to training image-label list file') parser.add_argument('val', help='Path to validation image-label list file') parser.add_argument('--arch', '-a', choices=archs.keys(), default='nin', help='Convnet architecture') parser.add_argument('--batchsize', '-B', type=int, default=32, help='Learning minibatch size') parser.add_argument('--epoch', '-E', type=int, default=10, help='Number of epochs to train') parser.add_argument('--initmodel', help='Initialize the model from given file') parser.add_argument('--loaderjob', '-j', type=int, help='Number of parallel data loading processes') parser.add_argument('--mean', '-m', default='mean.npy', help='Mean file (computed by compute_mean.py)') parser.add_argument('--resume', '-r', default='', help='Initialize the trainer from given file') parser.add_argument('--out', '-o', default='result', help='Output directory') parser.add_argument('--root', '-R', default='.', help='Root directory path of image files') parser.add_argument('--val_batchsize', '-b', type=int, default=250, help='Validation minibatch size') parser.add_argument('--test', action='store_true') parser.add_argument('--communicator', default='hierarchical') parser.set_defaults(test=False) args = parser.parse_args() # Prepare ChainerMN communicator. comm = chainermn.create_communicator(args.communicator) device = comm.intra_rank if comm.rank == 0: print('==========================================') print('Num process (COMM_WORLD): {}'.format(comm.size)) print('Using {} communicator'.format(args.communicator)) print('Using {} arch'.format(args.arch)) print('Num Minibatch-size: {}'.format(args.batchsize)) print('Num epoch: {}'.format(args.epoch)) print('==========================================') model = archs[args.arch]() if args.initmodel: print('Load model from', args.initmodel) chainer.serializers.load_npz(args.initmodel, model) chainer.cuda.get_device_from_id(device).use() # Make the GPU current model.to_gpu() # Split and distribute the dataset. Only worker 0 loads the whole dataset. # Datasets of worker 0 are evenly split and distributed to all workers. mean = np.load(args.mean) if comm.rank == 0: train = PreprocessedDataset(args.train, args.root, mean, model.insize) val = PreprocessedDataset( args.val, args.root, mean, model.insize, False) else: train = None val = None train = chainermn.scatter_dataset(train, comm, shuffle=True) val = chainermn.scatter_dataset(val, comm) # We need to change the start method of multiprocessing module if we are # using InfiniBand and MultiprocessIterator. This is because processes # often crash when calling fork if they are using Infiniband. # (c.f., https://www.open-mpi.org/faq/?category=tuning#fork-warning ) multiprocessing.set_start_method('forkserver') train_iter = chainer.iterators.MultiprocessIterator( train, args.batchsize, n_processes=args.loaderjob) val_iter = chainer.iterators.MultiprocessIterator( val, args.val_batchsize, repeat=False, n_processes=args.loaderjob) # Create a multi node optimizer from a standard Chainer optimizer. optimizer = chainermn.create_multi_node_optimizer( chainer.optimizers.MomentumSGD(lr=0.01, momentum=0.9), comm) optimizer.setup(model) # Set up a trainer updater = training.StandardUpdater(train_iter, optimizer, device=device) trainer = training.Trainer(updater, (args.epoch, 'epoch'), args.out) checkpoint_interval = (10, 'iteration') if args.test else (1, 'epoch') val_interval = (10, 'iteration') if args.test else (1, 'epoch') log_interval = (10, 'iteration') if args.test else (1, 'epoch') checkpointer = chainermn.create_multi_node_checkpointer( name='imagenet-example', comm=comm) checkpointer.maybe_load(trainer, optimizer) trainer.extend(checkpointer, trigger=checkpoint_interval) # Create a multi node evaluator from an evaluator. evaluator = TestModeEvaluator(val_iter, model, device=device) evaluator = chainermn.create_multi_node_evaluator(evaluator, comm) trainer.extend(evaluator, trigger=val_interval) # Some display and output extensions are necessary only for one worker. # (Otherwise, there would just be repeated outputs.) if comm.rank == 0: trainer.extend(extensions.dump_graph('main/loss')) trainer.extend(extensions.LogReport(trigger=log_interval)) trainer.extend(extensions.observe_lr(), trigger=log_interval) trainer.extend(extensions.PrintReport([ 'epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'lr' ]), trigger=log_interval) trainer.extend(extensions.ProgressBar(update_interval=10)) if args.resume: chainer.serializers.load_npz(args.resume, trainer) trainer.run()
def main(): comm = mn.create_communicator("pure_nccl") device = comm.intra_rank config = get_config() print("pid {}: mask loading...".format(comm.rank)) load_mask_module = import_module( config["additional information"]["mask"]["loader"]["module"], config["additional information"]["mask"]["loader"]["package"]) load_mask = getattr( load_mask_module, config["additional information"]["mask"]["loader"]["function"]) mask = load_mask( **config["additional information"]["mask"]["loader"]["params"]) print("pid {}: done.".format(comm.rank)) if comm.rank == 0: print("mask.shape: {}".format(mask.shape)) model_module = import_module(config["model"]["module"], config["model"]["package"]) Model = getattr(model_module, config["model"]["class"]) model = Model(comm=comm, mask=mask, **config["model"]["params"]) optimizer_module = import_module(config["optimizer"]["module"], config["optimizer"]["package"]) Optimizer = getattr(optimizer_module, config["optimizer"]["class"]) optimizer = mn.create_multi_node_optimizer( Optimizer(**config["optimizer"]["params"]), comm) optimizer.setup(model) if device >= 0: chainer.backends.cuda.get_device_from_id(device).use() model.to_gpu() print("pid {}: GPU {} enabled".format(comm.rank, device)) if comm.rank == 0: dataset_module = import_module(config["dataset"]["module"], config["dataset"]["package"]) Dataset = getattr(dataset_module, config["dataset"]["class"]) train_dataset = Dataset(**config["dataset"]["train"]["params"]) valid_dataset = Dataset(**config["dataset"]["valid"]["params"]) else: train_dataset = None valid_dataset = None train_dataset = mn.datasets.scatter_dataset(train_dataset, comm, shuffle=True) valid_dataset = mn.datasets.scatter_dataset(valid_dataset, comm, shuffle=True) train_iterator = Iterator(train_dataset, config["batch"]["train"]) valid_iterator = Iterator(valid_dataset, config["batch"]["valid"], False, False) updater = Updater(train_iterator, optimizer, device=device) trainer = Trainer(updater, **config["trainer"]["params"]) checkpointer = mn.create_multi_node_checkpointer(config["general"]["name"], comm) checkpointer.maybe_load(trainer, optimizer) trainer.extend(checkpointer, trigger=tuple(config["trainer"]["snapshot_interval"])) evaluator = Evaluator(valid_iterator, model, device=device) evaluator = mn.create_multi_node_evaluator(evaluator, comm) trainer.extend(evaluator) trainer.extend(observe_lr(), trigger=config["trainer"]["log_interval"]) if comm.rank == 0: trainer.extend(LogReport(trigger=config["trainer"]["log_interval"])) trainer.extend(PrintReport( ["epoch", "iteration", "main/loss", "validation/main/loss"]), trigger=config["trainer"]["log_interval"]) trainer.extend(ProgressBar(update_interval=1)) trainer.run()
CosineAnnealing('lr', int(args.epoch), len(train) / (args.batchsize // comm.size), eta_min=args.eta_min, init=args.lr)) else: mst_epochs = [30, 60, 90] trainer.extend(extensions.ExponentialShift('lr', 0.1, init=args.lr), trigger=triggers.ManualScheduleTrigger( mst_epochs, 'epoch')) test_interval = 1, 'epoch' snapshot_interval = 10, 'epoch' log_interval = 10, 'iteration' checkpointer = chainermn.create_multi_node_checkpointer(name='pgp-chainer', comm=comm) checkpointer.maybe_load(trainer, optimizer) trainer.extend(checkpointer, trigger=test_interval) evaluater = extensions.Evaluator(test_iter, model, device=device) evaluater = chainermn.create_multi_node_evaluator(evaluater, comm) trainer.extend(evaluater, trigger=test_interval) if comm.rank == 0: trainer.extend(extensions.dump_graph('main/loss')) trainer.extend( extensions.snapshot(filename='snapshot_epoch_{.updater.epoch}'), trigger=snapshot_interval) trainer.extend(extensions.snapshot_object( model, 'model_epoch_{.updater.epoch}'), trigger=snapshot_interval)
def main(): # Check if GPU is available # (ImageNet example does not support CPU execution) if not chainer.cuda.available: raise RuntimeError('ImageNet requires GPU support.') archs = [f'b{i}' for i in range(8)] + ['se'] patchsizes = { 'b0': 224, 'b1': 240, 'b2': 260, 'b3': 300, 'b4': 380, 'b5': 456, 'b6': 528, 'b7': 600, 'se': 224 } parser = argparse.ArgumentParser( description='Learning convnet from ILSVRC2012 dataset') parser.add_argument('--arch', '-a', choices=archs, default='b0') parser.add_argument('--patchsize', default=None, type=int, help='The input size of images. If not specifed,\ architecture-wise default values wil be used.' ) parser.add_argument('--batchsize', '-B', type=int, default=32, help='Learning minibatch size') parser.add_argument('--optimizer', default='RMSProp') parser.add_argument('--lr', default=0.256, type=float) parser.add_argument('--cosine_annealing', action='store_true') parser.add_argument('--exponent', type=float, default=0.97) parser.add_argument('--exponent_trigger', type=float, default=2.6) parser.add_argument('--soft_label', action='store_true') parser.add_argument('--epoch', '-E', type=int, default=350, help='Number of epochs to train') parser.add_argument('--initmodel', help='Initialize the model from given file') parser.add_argument('--loaderjob', '-j', type=int, default=3, help='Number of parallel data loading processes') parser.add_argument('--resume', '-r', default='', help='Initialize the trainer from given file') parser.add_argument('--out', '-o', default='result', help='Output directory') parser.add_argument('--root', '-R', default='../ssd/imagenet', help='Root directory path of image files') parser.add_argument('--val_batchsize', '-b', type=int, default=32, help='Validation minibatch size') parser.add_argument('--workerwisebn', action='store_true') parser.add_argument('--no_dropconnect', action='store_true') parser.add_argument('--test', action='store_true') parser.add_argument('--communicator', default='pure_nccl') parser.add_argument('--no_autoaugment', action='store_true') parser.add_argument('--dtype', default='float32', choices=['mixed16', 'float32'], help='For now do not use mixed16') parser.set_defaults(test=False) args = parser.parse_args() chainer.global_config.dtype = args.dtype comm = chainermn.create_communicator(args.communicator) device = comm.intra_rank if comm.rank == 0: print('==========================================') print('Num process (COMM_WORLD): {}'.format(comm.size)) print('Using {} communicator'.format(args.communicator)) print('Using {} arch'.format(args.arch)) print('Num Minibatch-size: {}'.format(args.batchsize)) print('Num epoch: {}'.format(args.epoch)) mode = 'workerwise' if args.workerwisebn else 'synchronized' print(f'BatchNorm is {mode}') print('==========================================') if args.soft_label: accfun = soft_accuracy lossfun = soft_softmax_cross_entropy else: accfun = F.accuracy lossfun = F.softmax_cross_entropy if args.arch != 'se': model = EfficientNet(args.arch, workerwisebn=args.workerwisebn, no_dropconnect=args.no_dropconnect) else: model = SEResNeXt50() model = L.Classifier(model, lossfun=lossfun, accfun=accfun) if args.initmodel: print('Load model from', args.initmodel) chainer.serializers.load_npz(args.initmodel, model) chainer.cuda.get_device_from_id(device).use() # Make the GPU current model.to_gpu() # Split and distribute the dataset. Only worker 0 loads the whole dataset. # Datasets of worker 0 are evenly split and distributed to all workers. patchsize = patchsizes[ args.arch] if args.patchsize is None else args.patchsize patchsize = (patchsize, patchsize) train_transform, val_transform, _ = get_transforms( patchsize, no_autoaugment=args.no_autoaugment, soft=args.soft_label) if comm.rank == 0: train = ImageNetDataset(args.root, 'train') val = ImageNetDataset(args.root, 'val') else: train = None val = None train = chainermn.scatter_dataset(train, comm, shuffle=True) val = chainermn.scatter_dataset(val, comm) train = chainer.datasets.TransformDataset(train, train_transform) val = chainer.datasets.TransformDataset(val, val_transform) # A workaround for processes crash should be done before making # communicator above, when using fork (e.g. MultiProcessIterator) # along with Infiniband. train_iter = chainer.iterators.MultiprocessIterator( train, args.batchsize, n_processes=args.loaderjob) val_iter = chainer.iterators.MultiprocessIterator( val, args.val_batchsize, repeat=False, n_processes=args.loaderjob) # Create a multi node optimizer from a standard Chainer optimizer. symbol = 'lr' if args.optimizer.lower() == 'rmsprop': optimizer = chainer.optimizers.RMSprop(lr=args.lr, alpha=0.9) elif args.optimizer.lower() == 'momentumsgd': optimizer = chainer.optimizers.MomentumSGD(lr=args.lr) elif args.optimizer.lower() == 'corrected': optimizer = chainer.optimizers.CorrectedMomentumSGD(lr=args.lr) elif args.optimizer.lower() == 'adabound': optimizer = chainer.optimizers.AdaBound(alpha=args.lr, final_lr=0.5) symbol = 'alpha' optimizer = chainermn.create_multi_node_optimizer(optimizer, comm) optimizer.setup(model) optimizer.add_hook(chainer.optimizer.WeightDecay(1e-5)) args.out = f'experiments/{args.arch}' + args.out save_args(args, args.out) # Set up a trainer updater = training.StandardUpdater(train_iter, optimizer, device=device) trainer = training.Trainer(updater, (args.epoch, 'epoch'), args.out) checkpoint_interval = (10, 'iteration') if args.test else (1, 'epoch') val_interval = (10, 'iteration') if args.test else (2, 'epoch') log_interval = (10, 'iteration') if args.test else (2, 'epoch') checkpointer = chainermn.create_multi_node_checkpointer( name='imagenet-example', comm=comm) checkpointer.maybe_load(trainer, optimizer) trainer.extend(checkpointer, trigger=checkpoint_interval) if args.cosine_annealing: schedule = lr_schedules.CosineLRSchedule(args.lr) if args.optimizer in ['MomentumSGD', 'Corrected']: trainer.extend(lr_schedules.LearningRateScheduler(schedule)) else: trainer.extend(extensions.ExponentialShift(symbol, args.exponent), trigger=(args.exponent_trigger, 'epoch')) # Create a multi node evaluator from an evaluator. evaluator = TestModeEvaluator(val_iter, model, device=device) evaluator = chainermn.create_multi_node_evaluator(evaluator, comm) trainer.extend(evaluator, trigger=val_interval) # Some display and output extensions are necessary only for one worker. # (Otherwise, there would just be repeated outputs.) if comm.rank == 0: trainer.extend(extensions.DumpGraph('main/loss')) trainer.extend(extensions.snapshot_object( model, 'model_iter_{.updater.iteration}.npz'), trigger=val_interval) trainer.extend(extensions.LogReport(trigger=log_interval)) trainer.extend(extensions.observe_lr(), trigger=log_interval) trainer.extend(extensions.PrintReport([ 'epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'lr' ]), trigger=log_interval) trainer.extend(extensions.ProgressBar(update_interval=100)) if args.resume: chainer.serializers.load_npz(args.resume, trainer) trainer.run()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--model', choices=('ssd300', 'ssd512'), default='ssd300') parser.add_argument('--batchsize', type=int, default=32) parser.add_argument('--labelnum', type=int, default=50) parser.add_argument('--gpu', type=int, default=-1) parser.add_argument('--out', default='result') parser.add_argument('--resume') parser.add_argument('--image_label', '-il', help='Path to training image-label list file') parser.add_argument('--bbox', help='Path to training bbox list file') parser.add_argument('--image_label_test', '-ilt', help='Path to training image-label list file') parser.add_argument('--bbox_test', help='Path to training bbox list file') parser.add_argument('--image_root', '-TR', default='.', help='Root directory path of image files') args = parser.parse_args() comm = chainermn.create_communicator('naive') if comm.mpi_comm.rank == 0: print('==========================================') print('Num process (COMM_WORLD): {}'.format(MPI.COMM_WORLD.Get_size())) if args.model == 'ssd300': model = SSD300(n_fg_class=args.labelnum, pretrained_model='imagenet') elif args.model == 'ssd512': model = SSD512(n_fg_class=args.labelnum, pretrained_model='imagenet') model.use_preset('evaluate') train_chain = MultiboxTrainChain(model) if args.gpu >= 0: chainer.cuda.get_device_from_id(args.gpu).use() model.to_gpu() from test_datasets import DeepFashionBboxDataset if comm.rank == 0: train = DeepFashionBboxDataset(args.bbox, args.image_label, args.image_root) test = DeepFashionBboxDataset(args.bbox_test, args.image_label_test, args.image_root) train = TransformDataset( train, Transform(model.coder, model.insize, model.mean)) else: train, test = None, None train = chainermn.scatter_dataset(train, comm, shuffle=True) test = chainermn.scatter_dataset(test, comm, shuffle=True) train_iter = chainer.iterators.MultiprocessIterator(train, args.batchsize) test_iter = chainer.iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) # initial lr is set to 1e-3 by ExponentialShift optimizer = chainer.optimizers.MomentumSGD() optimizer = chainermn.create_multi_node_optimizer(optimizer, comm) optimizer.setup(train_chain) for param in train_chain.params(): if param.name == 'b': param.update_rule.add_hook(GradientScaling(2)) else: param.update_rule.add_hook(WeightDecay(0.0005)) updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu) trainer = training.Trainer(updater, (120000, 'iteration'), args.out) checkpoint_interval = (1000, 'iteration') checkpointer = chainermn.create_multi_node_checkpointer( name='imagenet-example', comm=comm) checkpointer.maybe_load(trainer, optimizer) trainer.extend(checkpointer, trigger=checkpoint_interval) trainer.extend(extensions.ExponentialShift('lr', 0.1, init=1e-3), trigger=triggers.ManualScheduleTrigger([80000, 100000], 'iteration')) evaluator = DetectionVOCEvaluator(test_iter, model, use_07_metric=True, label_names=voc_bbox_label_names) evaluator = chainermn.create_multi_node_evaluator(evaluator, comm) trainer.extend(evaluator, trigger=(10000, 'iteration')) if comm.rank == 0: log_interval = 10, 'iteration' trainer.extend(extensions.LogReport(trigger=log_interval)) trainer.extend(extensions.observe_lr(), trigger=log_interval) trainer.extend(extensions.PrintReport([ 'epoch', 'iteration', 'lr', 'main/loss', 'main/loss/loc', 'main/loss/conf', 'validation/main/map' ]), trigger=log_interval) trainer.extend(extensions.ProgressBar(update_interval=10)) trainer.extend(extensions.snapshot(), trigger=(10000, 'iteration')) trainer.extend(extensions.snapshot_object( model, 'model_iter_{.updater.iteration}'), trigger=(120000, 'iteration')) if args.resume: serializers.load_npz(args.resume, trainer) trainer.run()
def train(config): config_backup = copy.deepcopy(config) # Setup device, comm = get_device_communicator(config['gpu'], config['communicator'], config['seed'], config['batchsize']) chainer.config.comm = comm # To use from the inside of models if config.get('seed', None) is not None: random.seed(config['seed']) numpy.random.seed(config['seed']) cuda.cupy.random.seed(config['seed']) # Prepare dataset and models if not config['label']: if comm.mpi_comm.rank == 0: dataset = make_instance(tgan2, config['dataset']) else: dataset = None dataset = chainermn.scatter_dataset(dataset, comm, shuffle=True) # Retrieve property from the original of SubDataset n_channels = dataset._dataset.n_channels gen = make_instance(tgan2, config['gen'], args={'out_channels': n_channels}) dis = make_instance(tgan2, config['dis'], args={'in_channels': n_channels}) else: if comm.mpi_comm.rank == 0: print('## NOTE: Training Conditional TGAN') dataset = make_instance(tgan2, config['dataset'], args={'label': True}) else: dataset = None dataset = chainermn.scatter_dataset(dataset, comm, shuffle=True) # Retrieve property from the original of SubDataset n_channels = dataset._dataset.n_channels n_classes = dataset._dataset.n_classes gen = make_instance(tgan2, config['gen'], args={ 'out_channels': n_channels, 'n_classes': n_classes }) dis = make_instance(tgan2, config['dis'], args={ 'in_channels': n_channels, 'n_classes': n_classes }) if device >= 0: chainer.cuda.get_device(device).use() gen.to_gpu() dis.to_gpu() if comm.mpi_comm.rank == 0: def print_params(link): n_params = sum([p.size for n, p in link.namedparams()]) print('# of params in {}:\t{}'.format(link.__class__.__name__, n_params)) print_params(gen) print_params(dis) # Prepare optimizers gen_optimizer = chainermn.create_multi_node_optimizer( make_instance(chainer.optimizers, config['gen_opt']), comm) dis_optimizer = chainermn.create_multi_node_optimizer( make_instance(chainer.optimizers, config['dis_opt']), comm) gen_optimizer.setup(gen) dis_optimizer.setup(dis) optimizers = { 'generator': gen_optimizer, 'discriminator': dis_optimizer, } iterator = chainer.iterators.MultithreadIterator( dataset, batch_size=config['batchsize']) updater = make_instance(tgan2, config['updater'], args={ 'iterator': iterator, 'optimizer': optimizers, 'device': device }) # Prepare trainer and its extensions trainer = training.Trainer(updater, (config['iteration'], 'iteration'), out=config['out']) snapshot_interval = (config['snapshot_interval'], 'iteration') display_interval = (config['display_interval'], 'iteration') if comm.rank == 0: # Inception score if config.get('inception_score', None) is not None: conf_classifier = config['inception_score']['classifier'] classifier = make_instance(tgan2, conf_classifier) if 'model_path' in conf_classifier: chainer.serializers.load_npz(conf_classifier['model_path'], classifier, path=conf_classifier['npz_path']) if device >= 0: classifier = classifier.to_gpu() is_conf = config['inception_score'] is_args = { 'batchsize': is_conf['batchsize'], 'n_samples': is_conf['n_samples'], 'splits': is_conf['splits'], 'n_frames': is_conf['n_frames'], } trainer.extend(tgan2.make_inception_score_extension( gen, classifier, **is_args), trigger=(is_conf['interval'], 'iteration')) # Snapshot trainer.extend(extensions.snapshot_object( gen, 'generator_iter_{.updater.iteration}.npz'), trigger=snapshot_interval) # Do not save discriminator to save the space # trainer.extend( # extensions.snapshot_object( # dis, 'discriminator_iter_{.updater.iteration}.npz'), # trigger=snapshot_interval) # Save movie if config.get('preview', None) is not None: preview_batchsize = config['preview']['batchsize'] trainer.extend(tgan2.out_generated_movie( gen, dis, rows=config['preview']['rows'], cols=config['preview']['cols'], seed=0, dst=config['out'], batchsize=preview_batchsize), trigger=snapshot_interval) # Log trainer.extend(extensions.LogReport(trigger=display_interval)) report_keys = config['report_keys'] if config.get('inception_score', None) is not None: report_keys.append('IS_mean') trainer.extend(extensions.PrintReport(report_keys), trigger=display_interval) trainer.extend( extensions.ProgressBar(update_interval=display_interval[0])) # Linear decay if ('linear_decay' in config) and (config['linear_decay']['start'] is not None): if comm.rank == 0: print('Use linear decay: {}:{} -> {}:{}'.format( config['linear_decay']['start'], config['iteration'], config['gen_opt']['args']['alpha'], 0.)) trainer.extend( extensions.LinearShift( 'alpha', (config['gen_opt']['args']['alpha'], 0.), (config['linear_decay']['start'], config['iteration']), gen_optimizer)) trainer.extend( extensions.LinearShift( 'alpha', (config['dis_opt']['args']['alpha'], 0.), (config['linear_decay']['start'], config['iteration']), dis_optimizer)) # Checkpointer config_hash = hashlib.sha1() config_hash.update( yaml.dump(config_backup, default_flow_style=False).encode('utf-8')) os.makedirs('snapshots', exist_ok=True) checkpointer = chainermn.create_multi_node_checkpointer( name='tgan2', comm=comm, path=f'snapshots/{config_hash.hexdigest()}') checkpointer.maybe_load(trainer, gen_optimizer) if trainer.updater.epoch > 0: print('Resuming from checkpoints: epoch =', trainer.updater.epoch) trainer.extend(checkpointer, trigger=snapshot_interval) # Copy config to result dir os.makedirs(config['out'], exist_ok=True) config_path = os.path.join(config['out'], 'config.yml') with open(config_path, 'w') as fp: fp.write(yaml.dump(config_backup, default_flow_style=False)) # Run the training trainer.run()