示例#1
0
    def train(self):
        start_time = time.time()

        # Setup generator
        if self.is_siamese:
            train_generator = generate_batch(dataset_dir=self.train_dir,
                                             batch_size=self.batch_size,
                                             image_size=self.image_size)
            valid_generator = generate_batch(dataset_dir=self.valid_dir,
                                             batch_size=self.batch_size,
                                             image_size=self.image_size)

            self.num_classes = len(utils.get_dir_names(self.train_dir))

        else:
            train_datagen = ImageDataGenerator(
                rescale=1. / 255,
                rotation_range=20,
                width_shift_range=0.2,
                height_shift_range=0.2,
                horizontal_flip=True,
            )

            valid_datagen = ImageDataGenerator(rescale=1. / 255)

            train_generator = train_datagen.flow_from_directory(
                directory=self.train_dir,
                target_size=(self.image_size, self.image_size),
                batch_size=self.batch_size,
            )

            valid_generator = valid_datagen.flow_from_directory(
                directory=self.valid_dir,
                target_size=(self.image_size, self.image_size),
                batch_size=self.batch_size,
            )
            self.num_classes = len(train_generator.class_indices)

        optimizer = Adam
        if self.optimizer == "Adam":
            optimizer = Adam
        elif self.optimizer == "RMSProp":
            optimizer = RMSprop
        model = None
        # Check training from scratch or continue training
        if self.model_path is not None:
            model = load_model(self.model_path)
        else:
            if self.model_name == "VGG16":
                model_base = VGG16(include_top=False,
                                   input_shape=self.input_shape)
            elif self.model_name == "ResNet50":
                model_base = ResNet50(include_top=False,
                                      input_shape=self.input_shape)
            elif self.model_name == "DenseNet121":
                model_base = DenseNet121(include_top=False,
                                         input_shape=self.input_shape)
            elif self.model_name == "InceptionV3":
                model_base = InceptionV3(include_top=False,
                                         input_shape=self.input_shape)
            elif self.model_name == "InceptionResNetV2":
                model_base = InceptionResNetV2(include_top=False,
                                               input_shape=self.input_shape)
            elif self.model_name == "Xception":
                model_base = Xception(include_top=False,
                                      input_shape=self.input_shape)
            elif self.model_name == "Scratch":
                model_base = Sequential()
                model_base.add(
                    Conv2D(32,
                           kernel_size=(3, 3),
                           activation="relu",
                           input_shape=self.input_shape))
                model_base.add(
                    Conv2D(32, kernel_size=(3, 3), activation="relu"))
                model_base.add(MaxPool2D())
                model_base.add(
                    Conv2D(64, kernel_size=(3, 3), activation="relu"))
                model_base.add(
                    Conv2D(64, kernel_size=(3, 3), activation="relu"))
                model_base.add(MaxPool2D())
                model_base.add(
                    Conv2D(128, kernel_size=(3, 3), activation="relu"))
                model_base.add(
                    Conv2D(128, kernel_size=(3, 3), activation="relu"))
                model_base.add(
                    Conv2D(128, kernel_size=(3, 3), activation="relu"))
                model_base.add(MaxPool2D())
                model_base.add(
                    Conv2D(256, kernel_size=(3, 3), activation="relu"))
                model_base.add(
                    Conv2D(256, kernel_size=(3, 3), activation="relu"))
                model_base.add(
                    Conv2D(256, kernel_size=(3, 3), activation="relu"))
                model_base.add(MaxPool2D())
                self.num_trainable_layer = len(model_base.layers)
            else:
                print("Model name {} is not valid ".format(self.model_name))
                return 0

            # Freeze low layer
            for layer in model_base.layers[:-self.num_trainable_layer]:
                layer.trainable = False

            # Show trainable status of each layers
            print("\nAll layers of {} ".format(self.model_name))
            for layer in model_base.layers:
                print("Layer : {} - Trainable : {}".format(
                    layer, layer.trainable))

            model = Sequential()
            model.add(model_base)
            model.add(Flatten())
            # model.add(Dense(50, activation="relu"))
            # model.add(Dropout(0.25))
            model.add(Dense(self.num_classes, activation="softmax"))

            # Compile model
            model.compile(loss="categorical_crossentropy",
                          metrics=["acc"],
                          optimizer=optimizer(lr=self.lr))

        if self.is_siamese:
            model = get_siamese_model(model)
            model.compile(loss=contrastive_loss,
                          metrics=[accuracy],
                          optimizer=optimizer(lr=self.lr))

        print("\nFinal model summary")
        model.summary()

        # classes = [_ for _ in range(self.num_classes)]
        # for c in train_generator.class_indices:
        #     classes[train_generator.class_indices[c]] = c
        #
        # model.classes = classes

        # Define callbacks
        save_model_dir = os.path.join(self.save_dir,
                                      "Model_{}".format(self.model_name))
        utils.make_dirs(save_model_dir)
        # loss_path = os.path.join(save_model_dir, "epochs_{epoch:02d}-val_loss_{val_loss:.2f}.h5")
        # loss_checkpoint = ModelCheckpoint(
        #     filepath=loss_path,
        #     monitor="val_loss",
        #     verbose=1,
        #     save_best_only=True
        # )

        acc_path = os.path.join(
            save_model_dir, "epochs_{epoch:02d}-val_acc_{val_accuracy:.2f}.h5")
        acc_checkpoint = ModelCheckpoint(filepath=acc_path,
                                         monitor="val_accuracy",
                                         verbose=1,
                                         save_best_only=True)
        callbacks = [acc_checkpoint]

        # Train model
        print("Start train model from {} ...".format("{} pretrained".format(
            self.model_name) if self.model_path is None else self.model_path))

        if self.is_siamese:
            history = model.fit_generator(
                generator=train_generator,
                steps_per_epoch=self.num_classes / self.batch_size,
                epochs=self.num_epochs,
                validation_data=valid_generator,
                validation_steps=self.num_classes / self.batch_size,
                callbacks=callbacks)
        else:
            history = model.fit_generator(
                generator=train_generator,
                steps_per_epoch=train_generator.samples /
                train_generator.batch_size,
                epochs=self.num_epochs,
                validation_data=valid_generator,
                validation_steps=valid_generator.samples /
                train_generator.batch_size,
                callbacks=callbacks)

        # Save model
        save_path = os.path.join(save_model_dir, "final_model.h5")
        model.save(save_path)

        # Save history
        acc, val_acc = history.history["acc"], history.history["val_acc"]
        loss, val_loss = history.history["loss"], history.history["val_loss"]
        train_stats = dict(Loss=loss,
                           Valid_Loss=val_loss,
                           Accuracy=acc,
                           Valid_Accuracy=val_acc)
        df = pd.DataFrame(train_stats)
        save_path = os.path.join(self.save_dir, "History.csv")
        utils.save_csv(df, save_path)

        exec_time = time.time() - start_time
        print("\nTrain model {} done. Time : {:.2f} seconds".format(
            "{} pretrained".format(self.model_name)
            if self.model_path is None else self.model_path, exec_time))