예제 #1
0
        net,
        updater,
        log_dir,
        fields_to_print,
        epochs=args.epochs,
        snapshot_interval=args.snapshot_interval,
        print_interval=args.log_interval,
        extra_extensions=(
            evaluator,
            epoch_evaluator,
            model_snapshotter,
            bbox_plotter,
            (curriculum, (args.test_interval, 'iteration')),
        ),
        postprocess=log_postprocess,
        do_logging=args.no_log,
        model_files=[
            get_definition_filepath(localization_net),
            get_definition_filepath(recognition_net),
            get_definition_filepath(net),
        ]
    )

    # create interactive prompt that can be used to issue commands while the training is in progress
    open_interactive_prompt(
        bbox_plotter=bbox_plotter[0],
        curriculum=curriculum,
    )

    trainer.run()
예제 #2
0
def main():
    parser = argparse.ArgumentParser(
        description="Train a KISS model",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("log_name", help="name of log")
    parser.add_argument("-c",
                        "--config",
                        default="config.cfg",
                        help="path to config file to use")
    parser.add_argument("-g",
                        "--gpu",
                        nargs='+',
                        default=["-1"],
                        help="gpu if to use (-1 means cpu)")
    parser.add_argument("-l",
                        "--log-dir",
                        default='tests',
                        help="path to log dir")
    parser.add_argument(
        "--snapshot-interval",
        type=int,
        default=10000,
        help="number of iterations after which a snapshot will be taken")
    parser.add_argument("--log-interval",
                        type=int,
                        default=100,
                        help="log interval")
    parser.add_argument(
        "--port",
        type=int,
        default=1337,
        help=
        "port that is used by bbox plotter to send predictions on test image")
    parser.add_argument(
        "--rl",
        dest="resume_localizer",
        help=
        "path to snapshot that is to be used to resume training of localizer")
    parser.add_argument(
        "--rr",
        dest="resume_recognizer",
        help="path to snapshot that us to be used to pre-initialize recognizer"
    )
    parser.add_argument("--num-layers",
                        type=int,
                        default=18,
                        help="Resnet Variant to use")
    parser.add_argument(
        "--no-imgaug",
        action='store_false',
        dest='use_imgaug',
        default=True,
        help=
        "disable image augmentation with `imgaug`, but use naive image augmentation instead"
    )
    parser.add_argument(
        "--rdr",
        "--rotation-dropout-ratio",
        dest="rotation_dropout_ratio",
        type=float,
        default=0,
        help="ratio for dropping rotation params in text localization network")
    parser.add_argument("--save-gradient-information",
                        action='store_true',
                        default=False,
                        help="enable tensorboard gradient plotter")
    parser.add_argument("--dump-graph",
                        action='store_true',
                        default=False,
                        help="dump computational graph to file")
    parser.add_argument("--image-mode",
                        default="RGB",
                        choices=["RGB", "L"],
                        help="mode in which images are to be loaded")
    parser.add_argument("--resume",
                        help="path to logdir from which training shall resume")

    args = parser.parse_args()
    args = parse_config(args.config, args)

    # comm = chainermn.create_communicator(communicator_name='flat')
    comm = chainermn.create_communicator()
    args.gpu = comm.intra_rank
    print(args.gpu)

    if args.resume is not None:
        log_dir = os.path.relpath(args.resume)
    else:
        log_dir = os.path.join(
            "logs", args.log_dir,
            "{}_{}".format(datetime.datetime.now().isoformat(), args.log_name))
    args.log_dir = log_dir

    # set dtype
    chainer.global_config.dtype = 'float32'

    if comm.rank == 0:
        # create log dir
        if not os.path.exists(log_dir):
            os.makedirs(log_dir, exist_ok=True)

    report_keys = ["epoch", "iteration", "loss/localizer/loss"]

    if args.use_memory_manager:
        memory_manager = DatasetClient()
        memory_manager.connect()

        train_kwargs = {
            "memory_manager": memory_manager,
            "base_name": "train_file"
        }
        # recognition_kwargs = {"memory_manager": memory_manager, "base_name": "text_recognition_file"}
        validation_kwargs = {
            "memory_manager": memory_manager,
            "base_name": "val_file"
        }
    else:
        train_kwargs = {"npz_file": args.train_file}
        # recognition_kwargs = {"npz_file": args.text_recognition_file}
        validation_kwargs = {"npz_file": args.val_file}

    if comm.rank == 0:
        train_dataset = TextRecognitionImageDataset(
            char_map=args.char_map,
            image_size=args.image_size,
            root=os.path.dirname(args.train_file),
            dtype=chainer.get_dtype(),
            use_imgaug=args.use_imgaug,
            transform_probability=0.4,
            keep_aspect_ratio=True,
            image_mode=args.image_mode,
            **train_kwargs,
        )

        validation_dataset = TextRecognitionImageDataset(
            char_map=args.char_map,
            image_size=args.image_size,
            root=os.path.dirname(args.val_file),
            dtype=chainer.get_dtype(),
            transform_probability=0,
            keep_aspect_ratio=True,
            image_mode=args.image_mode,
            **validation_kwargs,
        )
    else:
        train_dataset, validation_dataset = None, None

    train_dataset = scatter_dataset(train_dataset, comm)
    validation_dataset = scatter_dataset(validation_dataset, comm)

    # uncomment all commented parts of the code to train the model with extra recognizer training
    # text_recognition_dataset = TextRecognitionImageCharCropDataset(
    #     char_map=args.char_map,
    #     image_size=args.target_size,
    #     root=os.path.dirname(args.text_recognition_file),
    #     dtype=chainer.get_dtype(),
    #     transform_probability=0,
    #     image_mode=args.image_mode,
    #     gpu_id=args.gpu,
    #     reverse=False,
    #     resize_after_load=False,
    #     **recognition_kwargs,
    # )

    data_iter = chainer.iterators.MultithreadIterator(train_dataset,
                                                      args.batch_size)
    validation_iter = chainer.iterators.MultithreadIterator(validation_dataset,
                                                            args.batch_size,
                                                            repeat=False)
    # text_recognition_iter = chainer.iterators.MultithreadIterator(text_recognition_dataset, max(args.batch_size, 32))

    localizer = LSTMTextLocalizer(
        Size(*args.target_size),
        num_bboxes_to_localize=train_dataset.num_chars_per_word,
        num_layers=args.num_layers,
        dropout_ratio=args.rotation_dropout_ratio,
    )
    if args.resume_localizer is not None:
        load_pretrained_model(args.resume_localizer, localizer)

    recognizer = TransformerTextRecognizer(
        train_dataset.num_chars_per_word,
        train_dataset.num_words_per_image,
        train_dataset.num_classes,
        train_dataset.bos_token,
        num_layers=args.num_layers,
    )

    if args.resume_recognizer is not None:
        load_pretrained_model(args.resume_recognizer, recognizer)

    models = [localizer, recognizer]

    if comm.rank == 0:
        tensorboard_handle = SummaryWriter(log_dir=args.log_dir, graph=None)
    else:
        tensorboard_handle = None

    localizer_optimizer = RAdam(alpha=args.learning_rate,
                                beta1=0.9,
                                beta2=0.98,
                                eps=1e-9)
    localizer_optimizer = chainermn.create_multi_node_optimizer(
        localizer_optimizer, comm)
    localizer_optimizer.setup(localizer)
    localizer_optimizer.add_hook(chainer.optimizer_hooks.GradientClipping(2))

    if args.save_gradient_information:
        localizer_optimizer.add_hook(
            TensorboardGradientPlotter(tensorboard_handle,
                                       args.log_interval), )

    recognizer_optimizer = RAdam(alpha=args.learning_rate)
    recognizer_optimizer = chainermn.create_multi_node_optimizer(
        recognizer_optimizer, comm)
    recognizer_optimizer.setup(recognizer)

    optimizers = [localizer_optimizer, recognizer_optimizer]

    # log train information everytime we encouter a new epoch or args.log_interval iterations have been done
    log_interval_trigger = (
        lambda trainer:
        (trainer.updater.is_new_epoch or trainer.updater.iteration % args.
         log_interval == 0) and trainer.updater.iteration > 0)

    updater_args = {
        "iterator": {
            'main': data_iter,
            # 'rec': text_recognition_iter,
        },
        "optimizer": {
            "opt_gen": localizer_optimizer,
            "opt_rec": recognizer_optimizer,
        },
        "tensorboard_handle": tensorboard_handle,
        "tensorboard_log_interval": log_interval_trigger,
        "recognizer_update_interval": 1,
        "device": args.gpu,
    }

    updater = TransformerTextRecognitionUpdater(models=[localizer, recognizer],
                                                **updater_args)

    trainer = chainer.training.Trainer(updater, (args.num_epoch, 'epoch'),
                                       out=args.log_dir)

    data_to_log = {
        'log_dir': args.log_dir,
        'image_size': args.image_size,
        'num_layers': args.num_layers,
        'num_chars': train_dataset.num_chars_per_word,
        'num_words': train_dataset.num_words_per_image,
        'num_classes': train_dataset.num_classes,
        'keep_aspect_ratio': train_dataset.keep_aspect_ratio,
        'localizer': get_import_info(localizer),
        'recognizer': get_import_info(recognizer),
        'bos_token': train_dataset.bos_token,
    }

    for argument in filter(lambda x: not x.startswith('_'), dir(args)):
        data_to_log[argument] = getattr(args, argument)

    def backup_train_config(stats_cpu):
        if stats_cpu['iteration'] == args.log_interval:
            stats_cpu.update(data_to_log)

    if comm.rank == 0:
        for model in models:
            trainer.extend(
                extensions.snapshot_object(
                    model,
                    model.__class__.__name__ + '_{.updater.iteration}.npz'),
                trigger=lambda trainer: trainer.updater.is_new_epoch or trainer
                .updater.iteration % args.snapshot_interval == 0,
            )

        trainer.extend(extensions.snapshot(filename='trainer_snapshot',
                                           autoload=args.resume is not None),
                       trigger=(args.snapshot_interval, 'iteration'))

        evaluation_function = TextRecognitionEvaluatorFunction(
            localizer, recognizer, args.gpu, train_dataset.blank_label,
            train_dataset.char_map)

        trainer.extend(
            TextRecognitionTensorboardEvaluator(
                validation_iter,
                localizer,
                device=args.gpu,
                eval_func=evaluation_function,
                tensorboard_handle=tensorboard_handle,
                num_iterations=200,
            ),
            trigger=(args.test_interval, 'iteration'),
        )

        # every epoch run the model on test datasets
        test_dataset_prefix = "test_dataset_"
        test_datasets = [
            arg for arg in dir(args) if arg.startswith(test_dataset_prefix)
        ]
        for test_dataset_name in test_datasets:
            print(
                f"setting up testing for {test_dataset_name[len(test_dataset_prefix):]} dataset"
            )

            dataset_path = getattr(args, test_dataset_name)
            if args.use_memory_manager:
                test_kwargs = {
                    "memory_manager": memory_manager,
                    "base_name": test_dataset_name
                }
            else:
                test_kwargs = {"npz_file": dataset_path}

            test_dataset = TextRecognitionImageDataset(
                char_map=args.char_map,
                image_size=args.image_size,
                root=os.path.dirname(dataset_path),
                dtype=chainer.get_dtype(),
                transform_probability=0,
                keep_aspect_ratio=True,
                image_mode=args.image_mode,
                **test_kwargs,
            )
            test_iter = chainer.iterators.MultithreadIterator(test_dataset,
                                                              args.batch_size,
                                                              repeat=False)
            trainer.extend(TextRecognitionTensorboardEvaluator(
                test_iter,
                localizer,
                device=args.gpu,
                eval_func=evaluation_function,
                tensorboard_handle=tensorboard_handle,
                base_key=test_dataset_name[len(test_dataset_prefix):]),
                           trigger=(args.snapshot_interval, 'iteration'))

        models.append(updater)
        logger = Logger(
            os.path.dirname(os.path.realpath(__file__)),
            args.log_dir,
            postprocess=backup_train_config,
            trigger=log_interval_trigger,
            exclusion_filters=['*logs*', '*.pyc', '__pycache__', '.git*'],
            resume=args.resume is not None,
        )

        if args.test_image is not None:
            plot_image = train_dataset.load_image(args.test_image)
            gt_bbox = None
        else:
            plot_image = validation_dataset.get_example(0)['image']
            gt_bbox = None

        bbox_plotter = TextRecognitionBBoxPlotter(
            plot_image,
            os.path.join(args.log_dir, 'bboxes'),
            args.target_size,
            send_bboxes=True,
            upstream_port=args.port,
            visualization_anchors=[
                ["visual_backprop_anchors"],
            ],
            device=args.gpu,
            render_extracted_rois=True,
            num_rois_to_render=4,
            sort_rois=False,
            show_visual_backprop_overlay=True,
            visual_backprop_index=0,
            show_backprop_and_feature_vis=True,
            gt_bbox=gt_bbox,
            render_pca=False,
            log_name=args.log_name,
            char_map=train_dataset.char_map,
            blank_label=train_dataset.blank_label,
            predictors={
                "localizer": localizer,
                "recognizer": recognizer,
            },
        )
        trainer.extend(bbox_plotter, trigger=(10, 'iteration'))

        trainer.extend(logger, trigger=log_interval_trigger)
        trainer.extend(extensions.PrintReport(report_keys,
                                              log_report='Logger'),
                       trigger=log_interval_trigger)

        # learning rate shift after each epoch
        trainer.extend(extensions.ExponentialShift(
            "alpha", 0.1, optimizer=localizer_optimizer),
                       trigger=(1, 'epoch'))

        trainer.extend(extensions.ProgressBar(update_interval=10))

        if args.dump_graph:
            trainer.extend(
                extensions.dump_graph('loss/localizer/loss',
                                      out_name='model.dot'))

        open_interactive_prompt(
            bbox_plotter=bbox_plotter,
            optimizer=optimizers,
        )

    trainer.run()
def main():
    parser = argparse.ArgumentParser(description="Train a sheep localizer")
    parser.add_argument("train_file", help="path to train csv")
    parser.add_argument("val_file", help="path to validation file (if you do not want to do validation just enter gibberish here")
    parser.add_argument("reference_file", help="path to reference images with different zoom levels")
    parser.add_argument("--no-validation", dest='validation', action='store_false', default=True, help="don't do validation")
    parser.add_argument("--image-size", type=int, nargs=2, default=(224, 224), help="input size for localizer")
    parser.add_argument("--target-size", type=int, nargs=2, default=(75, 75), help="crop size for each image")
    parser.add_argument("-b", "--batch-size", type=int, default=16, help="batch size for training")
    parser.add_argument("-g", "--gpu", type=int, default=-1, help="gpu if to use (-1 means cpu)")
    parser.add_argument("--lr", "--learning-rate", dest="learning_rate", type=float, default=0.001, help="learning rate")
    parser.add_argument("-l", "--log-dir", default='sheep_logs', help="path to log dir")
    parser.add_argument("--ln", "--log-name", default="test", help="name of log")
    parser.add_argument("--num-epoch", type=int, default=100, help="number of epochs to train")
    parser.add_argument("--snapshot-interval", type=int, default=1000, help="number of iterations after which a snapshot will be taken")
    parser.add_argument("--no-snapshot-every-epoch", dest="snapshot_every_epoch", action='store_false', default=True, help="Do not take a snapshot on every epoch")
    parser.add_argument("--log-interval", type=int, default=100, help="log interval")
    parser.add_argument("--port", type=int, default=1337, help="port that is used by bbox plotter to send predictions on test image")
    parser.add_argument("--test-image", help="path to test image that is to be used with bbox plotter")
    parser.add_argument("--anchor-image", help="path to anchor image used for metric learning")
    parser.add_argument("--rl", dest="resume_localizer", help="path to snapshot that is to be used to resume training of localizer")
    parser.add_argument("--rd", dest="resume_discriminator", help="path to snapshot that is to be used to pre-initialize discriminator")
    parser.add_argument("--use-resnet-18", action='store_true', default=False, help="Use Resnet-18 for localization")
    parser.add_argument("--localizer-target", type=float, default=1.0, help="target iou for localizer to reach in the interval [0,1]")
    parser.add_argument("--no-imgaug", action='store_false', dest='use_imgaug', default=True, help="disable image augmentation with `imgaug`, but use naive image augmentation instead")

    args = parser.parse_args()

    report_keys = ["epoch", "iteration", "loss_localizer", "loss_dis", "map", "mean_iou"]

    if args.train_file.endswith('.json'):
        train_image_paths = load_train_paths(args.train_file)
    else:
        train_image_paths = args.train_file

    train_dataset = ImageDataset(
        train_image_paths,
        os.path.dirname(args.train_file),
        image_size=args.image_size,
        dtype=np.float32,
        use_imgaug=args.use_imgaug,
        transform_probability=0.5,
    )

    if args.reference_file == 'mnist':
        reference_dataset = get_mnist(withlabel=False, ndim=3, rgb_format=True)[0]
        args.target_size = (28, 28)
    else:
        reference_dataset = LabeledImageDataset(
            args.reference_file,
            os.path.dirname(args.reference_file),
            image_size=args.target_size,
            dtype=np.float32,
            label_dtype=np.float32,
        )

    if args.validation:
        if args.val_file.endswith('.json'):
            validation_data = load_train_paths(args.val_file, with_label=True)
        else:
            validation_data = args.val_file

        validation_dataset = LabeledImageDataset(validation_data, os.path.dirname(args.val_file), image_size=args.image_size)
        validation_iter = chainer.iterators.MultithreadIterator(validation_dataset, args.batch_size, repeat=False)

    data_iter = chainer.iterators.MultithreadIterator(train_dataset, args.batch_size)
    reference_iter = chainer.iterators.MultithreadIterator(reference_dataset, args.batch_size)

    localizer_class = SheepLocalizer if args.use_resnet_18 else Resnet50SheepLocalizer
    localizer = localizer_class(args.target_size)

    if args.resume_localizer is not None:
        load_pretrained_model(args.resume_localizer, localizer)

    discriminator_output_dim = 1
    discriminator = ResnetAssessor(output_dim=discriminator_output_dim)
    if args.resume_discriminator is not None:
        load_pretrained_model(args.resume_discriminator, discriminator)
    models = [localizer, discriminator]

    localizer_optimizer = chainer.optimizers.Adam(alpha=args.learning_rate, amsgrad=True)
    localizer_optimizer.setup(localizer)

    discriminator_optimizer = chainer.optimizers.Adam(alpha=args.learning_rate, amsgrad=True)
    discriminator_optimizer.setup(discriminator)

    optimizers = [localizer_optimizer, discriminator_optimizer]

    updater_args = {
        "iterator": {
            'main': data_iter,
            'real': reference_iter,
        },
        "device": args.gpu,
        "optimizer": {
            "opt_gen": localizer_optimizer,
            "opt_dis": discriminator_optimizer,
        },
        "create_pca": False,
        "resume_discriminator": args.resume_discriminator,
        "localizer_target": args.localizer_target,
    }

    updater = SheepAssessor(
        models=[localizer, discriminator],
        **updater_args
    )

    log_dir = os.path.join(args.log_dir, "{}_{}".format(datetime.datetime.now().isoformat(), args.ln))
    args.log_dir = log_dir
    # create log dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir, exist_ok=True)

    trainer = chainer.training.Trainer(updater, (args.num_epoch, 'epoch'), out=args.log_dir)

    data_to_log = {
        'log_dir': args.log_dir,
        'image_size': args.image_size,
        'updater': [updater.__class__.__name__, 'updater.py'],
        'discriminator': [discriminator.__class__.__name__, 'discriminator.py'],
        'discriminator_output_dim': discriminator_output_dim,
        'localizer': [localizer.__class__.__name__, 'localizer.py']
    }

    for argument in filter(lambda x: not x.startswith('_'), dir(args)):
        data_to_log[argument] = getattr(args, argument)

    def backup_train_config(stats_cpu):
        if stats_cpu['iteration'] == args.log_interval:
            stats_cpu.update(data_to_log)

    for model in models:
        trainer.extend(
            extensions.snapshot_object(model, model.__class__.__name__ + '_{.updater.iteration}.npz'),
            trigger=lambda trainer: trainer.updater.is_new_epoch if args.snapshot_every_epoch else trainer.updater.iteration % args.snapshot_interval == 0,
        )

    # log train information everytime we encouter a new epoch or args.log_interval iterations have been done
    log_interval_trigger = (lambda trainer:
                            trainer.updater.is_new_epoch or trainer.updater.iteration % args.log_interval == 0)

    sheep_evaluator = SheepMAPEvaluator(localizer, args.gpu)
    if args.validation:
        trainer.extend(
            Evaluator(validation_iter, localizer, device=args.gpu, eval_func=sheep_evaluator),
            trigger=log_interval_trigger,
        )

    models.append(updater)
    logger = Logger(
        [get_definition_filepath(model) for model in models],
        args.log_dir,
        postprocess=backup_train_config,
        trigger=log_interval_trigger,
        dest_file_names=['localizer.py', 'discriminator.py', 'updater.py'],
    )

    if args.test_image is not None:
        plot_image = load_image(args.test_image, args.image_size)
        gt_bbox = None
    else:
        if args.validation:
            plot_image, gt_bbox, _ = validation_dataset.get_example(0)
        else:
            plot_image = train_dataset.get_example(0)
            gt_bbox = None

    bbox_plotter = BBOXPlotter(
        plot_image,
        os.path.join(args.log_dir, 'bboxes'),
        args.target_size,
        send_bboxes=True,
        upstream_port=args.port,
        visualization_anchors=[
            ["visual_backprop_anchors"],
        ],
        device=args.gpu,
        render_extracted_rois=True,
        num_rois_to_render=4,
        show_visual_backprop_overlay=False,
        show_backprop_and_feature_vis=True,
        gt_bbox=gt_bbox,
        render_pca=True,
        log_name=args.ln,
    )
    trainer.extend(bbox_plotter, trigger=(1, 'iteration'))

    trainer.extend(
        logger,
        trigger=log_interval_trigger
    )
    trainer.extend(
        extensions.PrintReport(report_keys, log_report='Logger'),
        trigger=log_interval_trigger
    )

    trainer.extend(extensions.ProgressBar(update_interval=10))
    trainer.extend(extensions.dump_graph('loss_localizer', out_name='model.dot'))

    open_interactive_prompt(
        bbox_plotter=bbox_plotter,
        optimizer=optimizers,
    )

    trainer.run()