def train_directory(args):
    """指定ディレクトリについてgenerator作ってモデル学習"""
    print("train_directory")
    # ### train validation data load ### #
    d_cls = get_train_valid_test.LabeledDataset(
        [args["img_rows"], args["img_cols"], args["channels"]],
        args["batch_size"],
        valid_batch_size=args["batch_size"],
        train_samples=len(util.find_img_files(args["train_data_dir"])),
        valid_samples=len(util.find_img_files(args["validation_data_dir"])),
    )
    if args["is_flow"]:
        # 指定ディレクトリの前処理済み画像、ラベル、ファイルパスロード
        d_cls.X_train, d_cls.y_train, train_paths = base_dataset.load_my_data(
            args["train_data_dir"],
            classes=args["classes"],
            img_height=args["img_rows"],
            img_width=args["img_cols"],
            channel=args["channels"],
            is_pytorch=False,
        )
        d_cls.X_valid, d_cls.y_valid, valid_paths = base_dataset.load_my_data(
            args["validation_data_dir"],
            classes=args["classes"],
            img_height=args["img_rows"],
            img_width=args["img_cols"],
            channel=args["channels"],
            is_pytorch=False,
        )
        d_cls.X_train, d_cls.X_valid = d_cls.X_train * 255.0, d_cls.X_valid * 255.0
        d_cls.create_my_generator_flow(my_IDG_options=args["my_IDG_options"])

    elif args["is_flow_from_directory"]:
        d_cls.create_my_generator_flow_from_directory(
            args["train_data_dir"],
            args["classes"],
            valid_data_dir=args["validation_data_dir"],
            color_mode=args["color_mode"],
            class_mode=args["class_mode"],
            my_IDG_options=args["my_IDG_options"],
        )
        # d_cls.train_gen_augmentor = d_cls.create_augmentor_util_from_directory(args['train_data_dir']
        #                                                                       , args['batch_size']
        #                                                                       , augmentor_options=args['train_augmentor_options'])

    # binaryラベルのgeneratorをマルチタスクgeneratorに変換するラッパー
    if args["n_multitask"] > 1 and args["multitask_pred_n_node"] == 1:
        d_cls.train_gen = get_train_valid_test.binary_generator_multi_output_wrapper(
            d_cls.train_gen)
        d_cls.valid_gen = get_train_valid_test.binary_generator_multi_output_wrapper(
            d_cls.valid_gen)

    # ### model ### #
    os.makedirs(args["output_dir"], exist_ok=True)
    if args["choice_model"] == "model_paper":
        model = model_paper.create_paper_cnn(
            input_shape=(args["img_cols"], args["img_rows"], args["channels"]),
            num_classes=args["num_classes"],
            activation=args["activation"],
        )
    else:
        model, orig_model = define_model.get_fine_tuning_model(
            args["output_dir"],
            args["img_rows"],
            args["img_cols"],
            args["channels"],
            args["num_classes"],
            args["choice_model"],
            trainable=args["trainable"],
            fcpool=args["fcpool"],
            fcs=args["fcs"],
            drop=args["drop"],
            activation=args["activation"],
            weights=args["weights"],
        )
    optim = define_model.get_optimizers(choice_optim=args["choice_optim"],
                                        lr=args["lr"],
                                        decay=args["decay"])
    model.compile(loss=args["loss"], optimizer=optim, metrics=args["metrics"])

    cb = my_callback.get_base_cb(
        args["output_dir"],
        args["num_epoch"],
        early_stopping=args["num_epoch"] // 4,
        monitor="val_" + args["metrics"][0],
        metric=args["metrics"][0],
    )

    # lr_finder
    if args["is_lr_finder"] == True:
        # 最適な学習率確認して関数抜ける
        lr_finder.run(
            model,
            d_cls.train_gen,
            args["batch_size"],
            d_cls.init_train_steps_per_epoch,
            output_dir=args["output_dir"],
        )
        return

    # ### train ### #
    start_time = time.time()
    hist = model.fit(
        d_cls.train_gen,
        steps_per_epoch=d_cls.init_train_steps_per_epoch,
        epochs=args["num_epoch"],
        validation_data=d_cls.valid_gen,
        validation_steps=d_cls.init_valid_steps_per_epoch,
        verbose=2,  # 1:ログをプログレスバーで標準出力 2:最低限の情報のみ出す
        callbacks=cb,
    )
    end_time = time.time()
    print("Elapsed Time : {:.2f}sec".format(end_time - start_time))

    model.save(os.path.join(args["output_dir"], "model_last_epoch.h5"))

    plot_log.plot_results(
        args["output_dir"],
        os.path.join(args["output_dir"], "tsv_logger.tsv"),
        acc_metric=args["metrics"][0],
    )

    return hist
def train_directory(args):
    """指定ディレクトリについてモデル学習"""
    print('train_directory')
    # ### train validation data load ### #
    X_train, y_train, X_valid, y_valid, _, _ = get_dataset(
        classes=args['classes'], dataset_dir=args['data_dir'])

    # 4次元テンソルじゃないとImageDataGenerator使えない?
    #train_datagen = my_generator.MyImageDataGenerator(**args['my_IDG_options'])
    #train_gen = train_datagen.flow(X_train, y_train, batch_size=args['batch_size'])
    #valid_datagen = ImageDataGenerator()
    #valid_gen = valid_datagen.flow(X_valid, y_valid, batch_size=1)

    # ### model ### #
    if args['choice_model'] == 'resnet_2d':
        model = model_2d.create_resnet_2d(input_shape=(args['img_cols'],
                                                       args['channels']),
                                          num_classes=args['num_classes'],
                                          activation=args['activation'])
    else:
        model = model_2d.create_vgg_2d(input_shape=(args['img_cols'],
                                                    args['channels']),
                                       num_classes=args['num_classes'],
                                       activation=args['activation'])

    optim = define_model.get_optimizers(choice_optim=args['choice_optim'],
                                        lr=args['lr'],
                                        decay=args['decay'])
    model.compile(loss=args['loss'], optimizer=optim, metrics=args['metrics'])

    os.makedirs(args['output_dir'], exist_ok=True)
    cb = my_callback.get_base_cb(
        args['output_dir'],
        args['num_epoch'],
        early_stopping=args['num_epoch'] // 3,
        monitor='val_' + args['metrics'][0],
        metric=args['metrics'][0],
    )

    # ### train ### #
    start_time = time.time()
    hist = model.fit(
        #train_gen,
        X_train,
        y_train,
        #steps_per_epoch=X_train.shape[0] // args['batch_size'],
        batch_size=args['batch_size'],
        epochs=args['num_epoch'],
        #validation_data=valid_gen, validation_steps=X_valid.shape[0] // 1,
        validation_data=(X_valid, y_valid),
        verbose=2,  # 1:ログをプログレスバーで標準出力 2:最低限の情報のみ出す
        callbacks=cb)
    end_time = time.time()
    print("Elapsed Time : {:.2f}sec".format(end_time - start_time))

    model.save(os.path.join(args['output_dir'], 'model_last_epoch.h5'))

    plot_log.plot_results(args['output_dir'],
                          os.path.join(args['output_dir'], 'tsv_logger.tsv'),
                          acc_metric=args['metrics'][0])

    return hist