def pred_directory(args):
    """指定ディレクトリについてモデル予測"""
    # ### test data load ### #
    d_cls = get_train_valid_test.LabeledDataset(
        [args["img_rows"], args["img_cols"], args["channels"]],
        args["batch_size"],
        valid_batch_size=args["batch_size"],
    )
    if args["is_flow"]:
        # 指定ディレクトリの前処理済み画像、ラベル、ファイルパスロード
        d_cls.X_test, d_cls.y_test, test_paths = base_dataset.load_my_data(
            args["test_data_dir"],
            classes=args["classes"],
            img_height=args["img_rows"],
            img_width=args["img_cols"],
            channel=args["channels"],
            is_pytorch=False,
        )
        d_cls.create_test_generator()

    elif args["is_flow_from_directory"]:
        d_cls.create_my_generator_flow_from_directory(
            args["train_data_dir"],
            args["classes"],
            test_data_dir=args["test_data_dir"],
            color_mode=args["color_mode"],
            class_mode=args["class_mode"],
            my_IDG_options={"rescale": 1 / 255.0},
        )

    # binaryラベルのgeneratorをマルチタスクgeneratorに変換するラッパー
    if args["n_multitask"] > 1 and args["multitask_pred_n_node"] == 1:
        d_cls.test_gen = get_train_valid_test.binary_generator_multi_output_wrapper(
            d_cls.test_gen
        )

    # generator predict TTA
    # load_model = keras.models.load_model(os.path.join(args['output_dir'], 'best_val_loss.h5'))
    load_model = keras.models.load_model(
        os.path.join(args["output_dir"], "best_val_accuracy.h5")
    )
    pred_tta = base_predict.predict_tta_generator(
        load_model,
        d_cls.test_gen,
        TTA=args["TTA"],
        TTA_rotate_deg=args["TTA_rotate_deg"],
        TTA_crop_num=args["TTA_crop_num"],
        TTA_crop_size=args["TTA_crop_size"],
        resize_size=[args["img_rows"], args["img_cols"]],
    )
    pred_tta_df = base_predict.get_predict_generator_results(
        pred_tta, d_cls.test_gen, classes_list=args["classes"]
    )
    # 混同行列作成
    base_predict.conf_matrix_from_pred_classes_generator(
        pred_tta_df, args["classes"], args["output_dir"]
    )
    def trial_train_directory(self, trial, args):
        keras.backend.clear_session()
        # ### train validation data load ### #
        d_cls = get_train_valid_test.LabeledDataset(
            [args["img_rows"], args["img_cols"], args["channels"]],
            args["batch_size"],
            valid_batch_size=args["batch_size"],
            train_samples=len(util.find_img_files(args["train_data_dir"])),
            valid_samples=len(util.find_img_files(
                args["validation_data_dir"])),
        )
        if args["is_flow"]:
            # 指定ディレクトリの前処理済み画像、ラベル、ファイルパスロード
            d_cls.X_train, d_cls.y_train, train_paths = base_dataset.load_my_data(
                args["train_data_dir"],
                classes=args["classes"],
                img_height=args["img_rows"],
                img_width=args["img_cols"],
                channel=args["channels"],
                is_pytorch=False,
            )
            d_cls.X_valid, d_cls.y_valid, valid_paths = base_dataset.load_my_data(
                args["validation_data_dir"],
                classes=args["classes"],
                img_height=args["img_rows"],
                img_width=args["img_cols"],
                channel=args["channels"],
                is_pytorch=False,
            )
            d_cls.X_train, d_cls.X_valid = d_cls.X_train * 255.0, d_cls.X_valid * 255.0
            d_cls.create_my_generator_flow(
                my_IDG_options=args["my_IDG_options"])

        elif args["is_flow_from_directory"]:
            d_cls.create_my_generator_flow_from_directory(
                args["train_data_dir"],
                args["classes"],
                valid_data_dir=args["validation_data_dir"],
                color_mode=args["color_mode"],
                class_mode=args["class_mode"],
                my_IDG_options=args["my_IDG_options"],
            )
            # d_cls.train_gen_augmentor = d_cls.create_augmentor_util_from_directory(args['train_data_dir']
            #                                                                       , args['batch_size']
            #                                                                       , augmentor_options=args['train_augmentor_options'])

        # binaryラベルのgeneratorをマルチタスクgeneratorに変換するラッパー
        if args["n_multitask"] > 1 and args["multitask_pred_n_node"] == 1:
            d_cls.train_gen = get_train_valid_test.binary_generator_multi_output_wrapper(
                d_cls.train_gen)
            d_cls.valid_gen = get_train_valid_test.binary_generator_multi_output_wrapper(
                d_cls.valid_gen)

        # ### model ### #
        os.makedirs(args["output_dir"], exist_ok=True)
        if args["choice_model"] == "model_paper":
            model = model_paper.create_paper_cnn(
                input_shape=(args["img_cols"], args["img_rows"],
                             args["channels"]),
                num_classes=args["num_classes"],
                activation=args["activation"],
            )
        else:
            model, orig_model = define_model.get_fine_tuning_model(
                args["output_dir"],
                args["img_rows"],
                args["img_cols"],
                args["channels"],
                args["num_classes"],
                args["choice_model"],
                trainable=args["trainable"],
                fcpool=args["fcpool"],
                fcs=args["fcs"],
                drop=args["drop"],
                activation=args["activation"],
                weights=args["weights"],
            )
        optim = define_model.get_optimizers(choice_optim=args["choice_optim"],
                                            lr=args["lr"],
                                            decay=args["decay"])
        model.compile(loss=args["loss"],
                      optimizer=optim,
                      metrics=args["metrics"])

        cb = my_callback.get_base_cb(
            args["output_dir"],
            args["num_epoch"],
            early_stopping=20,
            monitor="val_" + args["metrics"][0],
            metric=args["metrics"][0],
        )  # args['num_epoch']//3
        cb.append(OptunaCallback(trial, True))

        # ### train ### #
        hist = model.fit(
            d_cls.train_gen,
            steps_per_epoch=d_cls.init_train_steps_per_epoch,
            epochs=args["num_epoch"],
            validation_data=d_cls.valid_gen,
            validation_steps=d_cls.init_valid_steps_per_epoch,
            verbose=2,  # 1:ログをプログレスバーで標準出力 2:最低限の情報のみ出す
            callbacks=cb,
        )

        return hist
def train_directory(args):
    """指定ディレクトリについてgenerator作ってモデル学習"""
    print("train_directory")
    # ### train validation data load ### #
    d_cls = get_train_valid_test.LabeledDataset(
        [args["img_rows"], args["img_cols"], args["channels"]],
        args["batch_size"],
        valid_batch_size=args["batch_size"],
        train_samples=len(util.find_img_files(args["train_data_dir"])),
        valid_samples=len(util.find_img_files(args["validation_data_dir"])),
    )
    if args["is_flow"]:
        # 指定ディレクトリの前処理済み画像、ラベル、ファイルパスロード
        d_cls.X_train, d_cls.y_train, train_paths = base_dataset.load_my_data(
            args["train_data_dir"],
            classes=args["classes"],
            img_height=args["img_rows"],
            img_width=args["img_cols"],
            channel=args["channels"],
            is_pytorch=False,
        )
        d_cls.X_valid, d_cls.y_valid, valid_paths = base_dataset.load_my_data(
            args["validation_data_dir"],
            classes=args["classes"],
            img_height=args["img_rows"],
            img_width=args["img_cols"],
            channel=args["channels"],
            is_pytorch=False,
        )
        d_cls.X_train, d_cls.X_valid = d_cls.X_train * 255.0, d_cls.X_valid * 255.0
        d_cls.create_my_generator_flow(my_IDG_options=args["my_IDG_options"])

    elif args["is_flow_from_directory"]:
        d_cls.create_my_generator_flow_from_directory(
            args["train_data_dir"],
            args["classes"],
            valid_data_dir=args["validation_data_dir"],
            color_mode=args["color_mode"],
            class_mode=args["class_mode"],
            my_IDG_options=args["my_IDG_options"],
        )
        # d_cls.train_gen_augmentor = d_cls.create_augmentor_util_from_directory(args['train_data_dir']
        #                                                                       , args['batch_size']
        #                                                                       , augmentor_options=args['train_augmentor_options'])

    # binaryラベルのgeneratorをマルチタスクgeneratorに変換するラッパー
    if args["n_multitask"] > 1 and args["multitask_pred_n_node"] == 1:
        d_cls.train_gen = get_train_valid_test.binary_generator_multi_output_wrapper(
            d_cls.train_gen)
        d_cls.valid_gen = get_train_valid_test.binary_generator_multi_output_wrapper(
            d_cls.valid_gen)

    # ### model ### #
    os.makedirs(args["output_dir"], exist_ok=True)
    if args["choice_model"] == "model_paper":
        model = model_paper.create_paper_cnn(
            input_shape=(args["img_cols"], args["img_rows"], args["channels"]),
            num_classes=args["num_classes"],
            activation=args["activation"],
        )
    else:
        model, orig_model = define_model.get_fine_tuning_model(
            args["output_dir"],
            args["img_rows"],
            args["img_cols"],
            args["channels"],
            args["num_classes"],
            args["choice_model"],
            trainable=args["trainable"],
            fcpool=args["fcpool"],
            fcs=args["fcs"],
            drop=args["drop"],
            activation=args["activation"],
            weights=args["weights"],
        )
    optim = define_model.get_optimizers(choice_optim=args["choice_optim"],
                                        lr=args["lr"],
                                        decay=args["decay"])
    model.compile(loss=args["loss"], optimizer=optim, metrics=args["metrics"])

    cb = my_callback.get_base_cb(
        args["output_dir"],
        args["num_epoch"],
        early_stopping=args["num_epoch"] // 4,
        monitor="val_" + args["metrics"][0],
        metric=args["metrics"][0],
    )

    # lr_finder
    if args["is_lr_finder"] == True:
        # 最適な学習率確認して関数抜ける
        lr_finder.run(
            model,
            d_cls.train_gen,
            args["batch_size"],
            d_cls.init_train_steps_per_epoch,
            output_dir=args["output_dir"],
        )
        return

    # ### train ### #
    start_time = time.time()
    hist = model.fit(
        d_cls.train_gen,
        steps_per_epoch=d_cls.init_train_steps_per_epoch,
        epochs=args["num_epoch"],
        validation_data=d_cls.valid_gen,
        validation_steps=d_cls.init_valid_steps_per_epoch,
        verbose=2,  # 1:ログをプログレスバーで標準出力 2:最低限の情報のみ出す
        callbacks=cb,
    )
    end_time = time.time()
    print("Elapsed Time : {:.2f}sec".format(end_time - start_time))

    model.save(os.path.join(args["output_dir"], "model_last_epoch.h5"))

    plot_log.plot_results(
        args["output_dir"],
        os.path.join(args["output_dir"], "tsv_logger.tsv"),
        acc_metric=args["metrics"][0],
    )

    return hist