net, updater, log_dir, fields_to_print, epochs=args.epochs, snapshot_interval=args.snapshot_interval, print_interval=args.log_interval, extra_extensions=( evaluator, epoch_evaluator, model_snapshotter, bbox_plotter, (curriculum, (args.test_interval, 'iteration')), ), postprocess=log_postprocess, do_logging=args.no_log, model_files=[ get_definition_filepath(localization_net), get_definition_filepath(recognition_net), get_definition_filepath(net), ] ) # create interactive prompt that can be used to issue commands while the training is in progress open_interactive_prompt( bbox_plotter=bbox_plotter[0], curriculum=curriculum, ) trainer.run()
def main(): parser = argparse.ArgumentParser( description="Train a KISS model", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("log_name", help="name of log") parser.add_argument("-c", "--config", default="config.cfg", help="path to config file to use") parser.add_argument("-g", "--gpu", nargs='+', default=["-1"], help="gpu if to use (-1 means cpu)") parser.add_argument("-l", "--log-dir", default='tests', help="path to log dir") parser.add_argument( "--snapshot-interval", type=int, default=10000, help="number of iterations after which a snapshot will be taken") parser.add_argument("--log-interval", type=int, default=100, help="log interval") parser.add_argument( "--port", type=int, default=1337, help= "port that is used by bbox plotter to send predictions on test image") parser.add_argument( "--rl", dest="resume_localizer", help= "path to snapshot that is to be used to resume training of localizer") parser.add_argument( "--rr", dest="resume_recognizer", help="path to snapshot that us to be used to pre-initialize recognizer" ) parser.add_argument("--num-layers", type=int, default=18, help="Resnet Variant to use") parser.add_argument( "--no-imgaug", action='store_false', dest='use_imgaug', default=True, help= "disable image augmentation with `imgaug`, but use naive image augmentation instead" ) parser.add_argument( "--rdr", "--rotation-dropout-ratio", dest="rotation_dropout_ratio", type=float, default=0, help="ratio for dropping rotation params in text localization network") parser.add_argument("--save-gradient-information", action='store_true', default=False, help="enable tensorboard gradient plotter") parser.add_argument("--dump-graph", action='store_true', default=False, help="dump computational graph to file") parser.add_argument("--image-mode", default="RGB", choices=["RGB", "L"], help="mode in which images are to be loaded") parser.add_argument("--resume", help="path to logdir from which training shall resume") args = parser.parse_args() args = parse_config(args.config, args) # comm = chainermn.create_communicator(communicator_name='flat') comm = chainermn.create_communicator() args.gpu = comm.intra_rank print(args.gpu) if args.resume is not None: log_dir = os.path.relpath(args.resume) else: log_dir = os.path.join( "logs", args.log_dir, "{}_{}".format(datetime.datetime.now().isoformat(), args.log_name)) args.log_dir = log_dir # set dtype chainer.global_config.dtype = 'float32' if comm.rank == 0: # create log dir if not os.path.exists(log_dir): os.makedirs(log_dir, exist_ok=True) report_keys = ["epoch", "iteration", "loss/localizer/loss"] if args.use_memory_manager: memory_manager = DatasetClient() memory_manager.connect() train_kwargs = { "memory_manager": memory_manager, "base_name": "train_file" } # recognition_kwargs = {"memory_manager": memory_manager, "base_name": "text_recognition_file"} validation_kwargs = { "memory_manager": memory_manager, "base_name": "val_file" } else: train_kwargs = {"npz_file": args.train_file} # recognition_kwargs = {"npz_file": args.text_recognition_file} validation_kwargs = {"npz_file": args.val_file} if comm.rank == 0: train_dataset = TextRecognitionImageDataset( char_map=args.char_map, image_size=args.image_size, root=os.path.dirname(args.train_file), dtype=chainer.get_dtype(), use_imgaug=args.use_imgaug, transform_probability=0.4, keep_aspect_ratio=True, image_mode=args.image_mode, **train_kwargs, ) validation_dataset = TextRecognitionImageDataset( char_map=args.char_map, image_size=args.image_size, root=os.path.dirname(args.val_file), dtype=chainer.get_dtype(), transform_probability=0, keep_aspect_ratio=True, image_mode=args.image_mode, **validation_kwargs, ) else: train_dataset, validation_dataset = None, None train_dataset = scatter_dataset(train_dataset, comm) validation_dataset = scatter_dataset(validation_dataset, comm) # uncomment all commented parts of the code to train the model with extra recognizer training # text_recognition_dataset = TextRecognitionImageCharCropDataset( # char_map=args.char_map, # image_size=args.target_size, # root=os.path.dirname(args.text_recognition_file), # dtype=chainer.get_dtype(), # transform_probability=0, # image_mode=args.image_mode, # gpu_id=args.gpu, # reverse=False, # resize_after_load=False, # **recognition_kwargs, # ) data_iter = chainer.iterators.MultithreadIterator(train_dataset, args.batch_size) validation_iter = chainer.iterators.MultithreadIterator(validation_dataset, args.batch_size, repeat=False) # text_recognition_iter = chainer.iterators.MultithreadIterator(text_recognition_dataset, max(args.batch_size, 32)) localizer = LSTMTextLocalizer( Size(*args.target_size), num_bboxes_to_localize=train_dataset.num_chars_per_word, num_layers=args.num_layers, dropout_ratio=args.rotation_dropout_ratio, ) if args.resume_localizer is not None: load_pretrained_model(args.resume_localizer, localizer) recognizer = TransformerTextRecognizer( train_dataset.num_chars_per_word, train_dataset.num_words_per_image, train_dataset.num_classes, train_dataset.bos_token, num_layers=args.num_layers, ) if args.resume_recognizer is not None: load_pretrained_model(args.resume_recognizer, recognizer) models = [localizer, recognizer] if comm.rank == 0: tensorboard_handle = SummaryWriter(log_dir=args.log_dir, graph=None) else: tensorboard_handle = None localizer_optimizer = RAdam(alpha=args.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9) localizer_optimizer = chainermn.create_multi_node_optimizer( localizer_optimizer, comm) localizer_optimizer.setup(localizer) localizer_optimizer.add_hook(chainer.optimizer_hooks.GradientClipping(2)) if args.save_gradient_information: localizer_optimizer.add_hook( TensorboardGradientPlotter(tensorboard_handle, args.log_interval), ) recognizer_optimizer = RAdam(alpha=args.learning_rate) recognizer_optimizer = chainermn.create_multi_node_optimizer( recognizer_optimizer, comm) recognizer_optimizer.setup(recognizer) optimizers = [localizer_optimizer, recognizer_optimizer] # log train information everytime we encouter a new epoch or args.log_interval iterations have been done log_interval_trigger = ( lambda trainer: (trainer.updater.is_new_epoch or trainer.updater.iteration % args. log_interval == 0) and trainer.updater.iteration > 0) updater_args = { "iterator": { 'main': data_iter, # 'rec': text_recognition_iter, }, "optimizer": { "opt_gen": localizer_optimizer, "opt_rec": recognizer_optimizer, }, "tensorboard_handle": tensorboard_handle, "tensorboard_log_interval": log_interval_trigger, "recognizer_update_interval": 1, "device": args.gpu, } updater = TransformerTextRecognitionUpdater(models=[localizer, recognizer], **updater_args) trainer = chainer.training.Trainer(updater, (args.num_epoch, 'epoch'), out=args.log_dir) data_to_log = { 'log_dir': args.log_dir, 'image_size': args.image_size, 'num_layers': args.num_layers, 'num_chars': train_dataset.num_chars_per_word, 'num_words': train_dataset.num_words_per_image, 'num_classes': train_dataset.num_classes, 'keep_aspect_ratio': train_dataset.keep_aspect_ratio, 'localizer': get_import_info(localizer), 'recognizer': get_import_info(recognizer), 'bos_token': train_dataset.bos_token, } for argument in filter(lambda x: not x.startswith('_'), dir(args)): data_to_log[argument] = getattr(args, argument) def backup_train_config(stats_cpu): if stats_cpu['iteration'] == args.log_interval: stats_cpu.update(data_to_log) if comm.rank == 0: for model in models: trainer.extend( extensions.snapshot_object( model, model.__class__.__name__ + '_{.updater.iteration}.npz'), trigger=lambda trainer: trainer.updater.is_new_epoch or trainer .updater.iteration % args.snapshot_interval == 0, ) trainer.extend(extensions.snapshot(filename='trainer_snapshot', autoload=args.resume is not None), trigger=(args.snapshot_interval, 'iteration')) evaluation_function = TextRecognitionEvaluatorFunction( localizer, recognizer, args.gpu, train_dataset.blank_label, train_dataset.char_map) trainer.extend( TextRecognitionTensorboardEvaluator( validation_iter, localizer, device=args.gpu, eval_func=evaluation_function, tensorboard_handle=tensorboard_handle, num_iterations=200, ), trigger=(args.test_interval, 'iteration'), ) # every epoch run the model on test datasets test_dataset_prefix = "test_dataset_" test_datasets = [ arg for arg in dir(args) if arg.startswith(test_dataset_prefix) ] for test_dataset_name in test_datasets: print( f"setting up testing for {test_dataset_name[len(test_dataset_prefix):]} dataset" ) dataset_path = getattr(args, test_dataset_name) if args.use_memory_manager: test_kwargs = { "memory_manager": memory_manager, "base_name": test_dataset_name } else: test_kwargs = {"npz_file": dataset_path} test_dataset = TextRecognitionImageDataset( char_map=args.char_map, image_size=args.image_size, root=os.path.dirname(dataset_path), dtype=chainer.get_dtype(), transform_probability=0, keep_aspect_ratio=True, image_mode=args.image_mode, **test_kwargs, ) test_iter = chainer.iterators.MultithreadIterator(test_dataset, args.batch_size, repeat=False) trainer.extend(TextRecognitionTensorboardEvaluator( test_iter, localizer, device=args.gpu, eval_func=evaluation_function, tensorboard_handle=tensorboard_handle, base_key=test_dataset_name[len(test_dataset_prefix):]), trigger=(args.snapshot_interval, 'iteration')) models.append(updater) logger = Logger( os.path.dirname(os.path.realpath(__file__)), args.log_dir, postprocess=backup_train_config, trigger=log_interval_trigger, exclusion_filters=['*logs*', '*.pyc', '__pycache__', '.git*'], resume=args.resume is not None, ) if args.test_image is not None: plot_image = train_dataset.load_image(args.test_image) gt_bbox = None else: plot_image = validation_dataset.get_example(0)['image'] gt_bbox = None bbox_plotter = TextRecognitionBBoxPlotter( plot_image, os.path.join(args.log_dir, 'bboxes'), args.target_size, send_bboxes=True, upstream_port=args.port, visualization_anchors=[ ["visual_backprop_anchors"], ], device=args.gpu, render_extracted_rois=True, num_rois_to_render=4, sort_rois=False, show_visual_backprop_overlay=True, visual_backprop_index=0, show_backprop_and_feature_vis=True, gt_bbox=gt_bbox, render_pca=False, log_name=args.log_name, char_map=train_dataset.char_map, blank_label=train_dataset.blank_label, predictors={ "localizer": localizer, "recognizer": recognizer, }, ) trainer.extend(bbox_plotter, trigger=(10, 'iteration')) trainer.extend(logger, trigger=log_interval_trigger) trainer.extend(extensions.PrintReport(report_keys, log_report='Logger'), trigger=log_interval_trigger) # learning rate shift after each epoch trainer.extend(extensions.ExponentialShift( "alpha", 0.1, optimizer=localizer_optimizer), trigger=(1, 'epoch')) trainer.extend(extensions.ProgressBar(update_interval=10)) if args.dump_graph: trainer.extend( extensions.dump_graph('loss/localizer/loss', out_name='model.dot')) open_interactive_prompt( bbox_plotter=bbox_plotter, optimizer=optimizers, ) trainer.run()
def main(): parser = argparse.ArgumentParser(description="Train a sheep localizer") parser.add_argument("train_file", help="path to train csv") parser.add_argument("val_file", help="path to validation file (if you do not want to do validation just enter gibberish here") parser.add_argument("reference_file", help="path to reference images with different zoom levels") parser.add_argument("--no-validation", dest='validation', action='store_false', default=True, help="don't do validation") parser.add_argument("--image-size", type=int, nargs=2, default=(224, 224), help="input size for localizer") parser.add_argument("--target-size", type=int, nargs=2, default=(75, 75), help="crop size for each image") parser.add_argument("-b", "--batch-size", type=int, default=16, help="batch size for training") parser.add_argument("-g", "--gpu", type=int, default=-1, help="gpu if to use (-1 means cpu)") parser.add_argument("--lr", "--learning-rate", dest="learning_rate", type=float, default=0.001, help="learning rate") parser.add_argument("-l", "--log-dir", default='sheep_logs', help="path to log dir") parser.add_argument("--ln", "--log-name", default="test", help="name of log") parser.add_argument("--num-epoch", type=int, default=100, help="number of epochs to train") parser.add_argument("--snapshot-interval", type=int, default=1000, help="number of iterations after which a snapshot will be taken") parser.add_argument("--no-snapshot-every-epoch", dest="snapshot_every_epoch", action='store_false', default=True, help="Do not take a snapshot on every epoch") parser.add_argument("--log-interval", type=int, default=100, help="log interval") parser.add_argument("--port", type=int, default=1337, help="port that is used by bbox plotter to send predictions on test image") parser.add_argument("--test-image", help="path to test image that is to be used with bbox plotter") parser.add_argument("--anchor-image", help="path to anchor image used for metric learning") parser.add_argument("--rl", dest="resume_localizer", help="path to snapshot that is to be used to resume training of localizer") parser.add_argument("--rd", dest="resume_discriminator", help="path to snapshot that is to be used to pre-initialize discriminator") parser.add_argument("--use-resnet-18", action='store_true', default=False, help="Use Resnet-18 for localization") parser.add_argument("--localizer-target", type=float, default=1.0, help="target iou for localizer to reach in the interval [0,1]") parser.add_argument("--no-imgaug", action='store_false', dest='use_imgaug', default=True, help="disable image augmentation with `imgaug`, but use naive image augmentation instead") args = parser.parse_args() report_keys = ["epoch", "iteration", "loss_localizer", "loss_dis", "map", "mean_iou"] if args.train_file.endswith('.json'): train_image_paths = load_train_paths(args.train_file) else: train_image_paths = args.train_file train_dataset = ImageDataset( train_image_paths, os.path.dirname(args.train_file), image_size=args.image_size, dtype=np.float32, use_imgaug=args.use_imgaug, transform_probability=0.5, ) if args.reference_file == 'mnist': reference_dataset = get_mnist(withlabel=False, ndim=3, rgb_format=True)[0] args.target_size = (28, 28) else: reference_dataset = LabeledImageDataset( args.reference_file, os.path.dirname(args.reference_file), image_size=args.target_size, dtype=np.float32, label_dtype=np.float32, ) if args.validation: if args.val_file.endswith('.json'): validation_data = load_train_paths(args.val_file, with_label=True) else: validation_data = args.val_file validation_dataset = LabeledImageDataset(validation_data, os.path.dirname(args.val_file), image_size=args.image_size) validation_iter = chainer.iterators.MultithreadIterator(validation_dataset, args.batch_size, repeat=False) data_iter = chainer.iterators.MultithreadIterator(train_dataset, args.batch_size) reference_iter = chainer.iterators.MultithreadIterator(reference_dataset, args.batch_size) localizer_class = SheepLocalizer if args.use_resnet_18 else Resnet50SheepLocalizer localizer = localizer_class(args.target_size) if args.resume_localizer is not None: load_pretrained_model(args.resume_localizer, localizer) discriminator_output_dim = 1 discriminator = ResnetAssessor(output_dim=discriminator_output_dim) if args.resume_discriminator is not None: load_pretrained_model(args.resume_discriminator, discriminator) models = [localizer, discriminator] localizer_optimizer = chainer.optimizers.Adam(alpha=args.learning_rate, amsgrad=True) localizer_optimizer.setup(localizer) discriminator_optimizer = chainer.optimizers.Adam(alpha=args.learning_rate, amsgrad=True) discriminator_optimizer.setup(discriminator) optimizers = [localizer_optimizer, discriminator_optimizer] updater_args = { "iterator": { 'main': data_iter, 'real': reference_iter, }, "device": args.gpu, "optimizer": { "opt_gen": localizer_optimizer, "opt_dis": discriminator_optimizer, }, "create_pca": False, "resume_discriminator": args.resume_discriminator, "localizer_target": args.localizer_target, } updater = SheepAssessor( models=[localizer, discriminator], **updater_args ) log_dir = os.path.join(args.log_dir, "{}_{}".format(datetime.datetime.now().isoformat(), args.ln)) args.log_dir = log_dir # create log dir if not os.path.exists(log_dir): os.makedirs(log_dir, exist_ok=True) trainer = chainer.training.Trainer(updater, (args.num_epoch, 'epoch'), out=args.log_dir) data_to_log = { 'log_dir': args.log_dir, 'image_size': args.image_size, 'updater': [updater.__class__.__name__, 'updater.py'], 'discriminator': [discriminator.__class__.__name__, 'discriminator.py'], 'discriminator_output_dim': discriminator_output_dim, 'localizer': [localizer.__class__.__name__, 'localizer.py'] } for argument in filter(lambda x: not x.startswith('_'), dir(args)): data_to_log[argument] = getattr(args, argument) def backup_train_config(stats_cpu): if stats_cpu['iteration'] == args.log_interval: stats_cpu.update(data_to_log) for model in models: trainer.extend( extensions.snapshot_object(model, model.__class__.__name__ + '_{.updater.iteration}.npz'), trigger=lambda trainer: trainer.updater.is_new_epoch if args.snapshot_every_epoch else trainer.updater.iteration % args.snapshot_interval == 0, ) # log train information everytime we encouter a new epoch or args.log_interval iterations have been done log_interval_trigger = (lambda trainer: trainer.updater.is_new_epoch or trainer.updater.iteration % args.log_interval == 0) sheep_evaluator = SheepMAPEvaluator(localizer, args.gpu) if args.validation: trainer.extend( Evaluator(validation_iter, localizer, device=args.gpu, eval_func=sheep_evaluator), trigger=log_interval_trigger, ) models.append(updater) logger = Logger( [get_definition_filepath(model) for model in models], args.log_dir, postprocess=backup_train_config, trigger=log_interval_trigger, dest_file_names=['localizer.py', 'discriminator.py', 'updater.py'], ) if args.test_image is not None: plot_image = load_image(args.test_image, args.image_size) gt_bbox = None else: if args.validation: plot_image, gt_bbox, _ = validation_dataset.get_example(0) else: plot_image = train_dataset.get_example(0) gt_bbox = None bbox_plotter = BBOXPlotter( plot_image, os.path.join(args.log_dir, 'bboxes'), args.target_size, send_bboxes=True, upstream_port=args.port, visualization_anchors=[ ["visual_backprop_anchors"], ], device=args.gpu, render_extracted_rois=True, num_rois_to_render=4, show_visual_backprop_overlay=False, show_backprop_and_feature_vis=True, gt_bbox=gt_bbox, render_pca=True, log_name=args.ln, ) trainer.extend(bbox_plotter, trigger=(1, 'iteration')) trainer.extend( logger, trigger=log_interval_trigger ) trainer.extend( extensions.PrintReport(report_keys, log_report='Logger'), trigger=log_interval_trigger ) trainer.extend(extensions.ProgressBar(update_interval=10)) trainer.extend(extensions.dump_graph('loss_localizer', out_name='model.dot')) open_interactive_prompt( bbox_plotter=bbox_plotter, optimizer=optimizers, ) trainer.run()