예제 #1
0
def mlp_model(x_train, y_train, x_val, y_val, params):
    model = Sequential()
    model.add(
        Dense(params['layer_size'],
              activation=params['activation'],
              input_dim=x_train.shape[1],
              kernel_regularizer=l2(params['regularization'])))
    model.add(Dropout(params['dropout']))
    for i in range(params['layers'] - 1):
        model.add(
            Dense(params['layer_size'],
                  activation=params['activation'],
                  kernel_regularizer=l2(params['regularization'])))
        model.add(Dropout(params['dropout']))
    model.add(Dense(2, activation='softmax'))
    model.compile(
        optimizer=params['optimizer'](params['lr']),
        loss=params['loss_functions'],
        # loss=params['loss_functions']([params['weights1'], params['weights2']]),
        metrics=['accuracy', Recall(), Precision(), f1])
    history = model.fit(x_train,
                        y_train,
                        batch_size=params['batch_size'],
                        validation_data=(x_val, y_val),
                        epochs=100,
                        callbacks=[
                            EarlyStopping(monitor='val_acc',
                                          patience=5,
                                          min_delta=0.01)
                        ],
                        verbose=0)
    return history, model
def validate_on_batch(model, X, y, accuracy_metric, loss_metric):
    ŷ = model(X, training=False)
    precision_metric = Precision()
    recall_metric = Recall()
    accuracy = accuracy_metric(y, ŷ)
    loss = loss_metric(y, ŷ)
    precision = precision_metric(y, ŷ)
    recall = recall_metric(y, ŷ)
    return accuracy.numpy(), loss.numpy(), precision.numpy(), recall.numpy()
def train_on_batch(model, optimizer, X, y, accuracy_metric, loss_metric):
    apply_gradient_descent(X, model, optimizer, y)
    ŷ = model(X, training=True)
    # Calculate loss after pso weight updating
    precision_metric = Precision()
    recall_metric = Recall()
    accuracy = accuracy_metric(y, ŷ)
    loss = loss_metric(y, ŷ)
    precision = precision_metric(y, ŷ)
    recall = recall_metric(y, ŷ)
    # Update training metric.
    return accuracy.numpy(), loss.numpy(), precision.numpy(), recall.numpy()
예제 #4
0
def create_model(input_dim, l2=1e-3, lr=1e-5):
    model = Sequential([
        Dense(17,
              input_dim=input_dim,
              kernel_regularizer=regularizers.l2(l2),
              activation='relu'),
        Dense(1, kernel_regularizer=regularizers.l2(l2), activation='sigmoid')
    ])

    metrics = ['accuracy', AUC(), Precision(), Recall()]

    model.compile(optimizer=Adam(learning_rate=lr),
                  loss='binary_crossentropy',
                  metrics=metrics)

    return model
예제 #5
0
def train_classifier():
    data = get_data()

    classifier = Sequential()
    classifier.add(
        Dense(100,
              activation=tf.nn.relu,
              input_shape=(FLAGS.sentence_embedding_size, )))
    for i in range(1 - 1):
        classifier.add(
            Dense(100,
                  activation='relu',
                  kernel_regularizer=tf.keras.regularizers.l2(0.3)))
        classifier.add(Dropout(0.5))
    classifier.add(Dense(2, activation='softmax'))
    classifier.compile(optimizer=Adagrad(0.01),
                       loss='categorical_crossentropy',
                       metrics=['accuracy',
                                Recall(),
                                Precision(), f1])

    classifier.summary()

    helper._print_header('Training classifier')

    classifier.fit(data['train'][0],
                   data['train'][1],
                   batch_size=FLAGS.classifier_batch_size,
                   validation_data=(data['val'][0], data['val'][1]),
                   epochs=200,
                   callbacks=[
                       EarlyStopping(monitor='val_accuracy',
                                     patience=25,
                                     min_delta=0.01),
                       SaveBestModelCallback()
                   ],
                   verbose=2)
예제 #6
0
def train_lstm_model(params: Dict,
                     full_text: str,
                     model_path=None,
                     weights_path=None):
    """
    Train function, builds the model, with metrics and checkpoints, import model
    config and train with number of epochs and parameters provided in config

    Args:
        params: json_config for the model
        full_text: full text of data to process
        model_path: optional path to previously saved model
        weights_path: optional path to previously saved weights

    Returns:
        model history
    """

    # load char2int encoder
    char2int_encoder = load_json_file(params['char2int_encoder_path'])

    # Load model from previous training session
    if model_path and weights_path:
        model = load_model_from_json_and_weights(model_path, weights_path)
    # Create new model if no previous one
    else:
        lstm_model = LSTMModel(sequence_length=params['sequence_length'],
                               step=params['step'],
                               lstm_units=params['lstm_units'],
                               char_encoder=char2int_encoder)
        model = lstm_model.build_model()

    # Set optimizer
    optimizer = RMSprop()

    # Metrics
    precision = Precision()
    recall = Recall()
    categorical_accuracy = CategoricalAccuracy()
    metrics = [precision, recall, categorical_accuracy]

    model.compile(optimizer=optimizer,
                  loss=params['loss'],
                  metrics=metrics,
                  run_eagerly=False)

    # Define callbacks
    if weights_path:
        last_epoch = max([
            int(
                re.search(r"weights\.0?(?P<epoch>\d\d?)-",
                          filename).group("epoch"))
            for filename in os.listdir(params['model_path'])
            if filename.endswith("hdf5")
        ])
        file_path = params["model_path"] + '/weights.' + str(
            last_epoch) + '-{epoch:02d}-{val_loss:.2f}.hdf5'
    else:
        file_path = params[
            "model_path"] + '/weights.{epoch:02d}-{val_loss:.2f}.hdf5'
    checkpoint = ModelCheckpoint(monitor='val_loss',
                                 filepath=file_path,
                                 verbose=1,
                                 save_freq='epoch')
    reduce_lr = ReduceLROnPlateau(monitor='val_loss',
                                  factor=0.5,
                                  patience=1,
                                  verbose=1,
                                  mode='auto',
                                  epsilon=0.0001,
                                  cooldown=0,
                                  min_lr=0)
    callbacks_fit = [checkpoint, reduce_lr]

    # Save model json
    if not model_path:
        with open(params["model_path"] + '/model_result.json',
                  'w') as json_file:
            json_file.write(model.to_json())

    x, y = extract_data_with_labels(full_text, params, char2int_encoder)

    # Fit model
    logging.info('Start training')
    history = model.fit(x,
                        y,
                        batch_size=params['batch_size'],
                        epochs=params['epochs'],
                        verbose=1,
                        callbacks=callbacks_fit,
                        validation_split=0.2)

    # Print results
    history = history.history
    logging.info(history)
    model.save_weights(params["model_path"] + '/model_result_weights.h5')

    return history['val_categorical_accuracy'], history['val_loss']
예제 #7
0
    def compile(self,
                embedding_dim=4,
                u_hidden_units=[128, 64, 16],
                i_hidden_units=[64, 32, 16],
                activation='relu',
                dropout=0.3,
                loss=tf.keras.losses.BinaryCrossentropy(),
                optimizer=tf.keras.optimizers.Adam(1e-4),
                metrics=[Precision(), AUC()],
                summary=True,
                **kwargs):
        """
                Args:
                    embedding_dim: 静态离散特征embedding参数,embedding[0]表示输入值离散空间上限,embedding[1]表示输出向量维度
                    u_hidden_units: 用户塔MLP层神经元参数,最少2层
                    i_hidden_units: 物品塔MLP层神经元参数,最少2层
                    activation: MLP层激活函数
                    dropout: dropout系数
                    loss: 损失函数
                    optimizer: 优化器
                    metrics: 效果度量函数
                    summary: 是否输出summary信息
        """
        self.embedding_dim = embedding_dim
        self.u_hidden_units = u_hidden_units
        self.i_hidden_units = i_hidden_units
        self.activation = activation
        self.dropout = dropout

        # 定义输入格式
        user_input_features = OrderedDict()
        item_input_features = OrderedDict()
        user_input_features['u_continue_cols'] = tf.keras.layers.Input(
            shape=len(self.u_continue_cols),
            name='u_continue_cols_input')  # 用户数值特征
        for col in self.u_discrete_cols:
            user_input_features[col] = tf.keras.layers.Input(
                shape=1, name=col + '_input')  # 用户离散特征
        item_input_features['i_continue_cols'] = tf.keras.layers.Input(
            shape=len(self.i_continue_cols),
            name='i_continue_cols_input')  # 物品数值特征
        for col in self.i_discrete_cols:
            item_input_features[col] = tf.keras.layers.Input(
                shape=1, name=col + '_input')  # 物品离散特征
        for col in self.u_history_cols:
            user_input_features[col] = tf.keras.layers.Input(
                shape=self.u_history_col_ts_step[col],
                name=col + '_input')  # 用户关于物品的历史序列特征

        # 构造双塔结构
        user_vector_list = []
        item_vector_list = []

        # u_dense = tf.keras.layers.BatchNormalization()(user_input_features['u_continue_cols'])
        u_dense = user_input_features['u_continue_cols']
        user_vector_list.append(u_dense)

        # i_dense = tf.keras.layers.BatchNormalization()(item_input_features['i_continue_cols'])
        i_dense = item_input_features['i_continue_cols']
        item_vector_list.append(i_dense)

        for col in self.u_discrete_cols:
            le = self.le_dict[col]
            user_vector_list.append(
                tf.reshape(
                    tf.keras.layers.Embedding(len(le.classes_) + 10,
                                              self.embedding_dim,
                                              name=col + '_embedding')(
                                                  user_input_features[col]),
                    [-1, self.embedding_dim]))
        share_embedding_dict = {}
        for col in self.i_discrete_cols:
            le = self.le_dict[col]
            if col not in self.u_history_col_names:  # 该特征不在用户历史记录中
                item_vector_list.append(
                    tf.reshape(
                        tf.keras.layers.Embedding(
                            len(le.classes_) + 10,
                            self.embedding_dim,
                            name=col + '_embedding')(item_input_features[col]),
                        [-1, self.embedding_dim]))
            else:
                embedding_dim = int(len(le.classes_)**0.25) + 1  # 动态确定维度
                embedding_layer = tf.keras.layers.Embedding(
                    len(le.classes_) + 10,
                    embedding_dim,
                    name=col + '_embedding')  # ItemId的embedding层用户和物品塔共用
                share_embedding_dict[col] = embedding_layer
                item_vector_list.append(
                    tf.reshape(embedding_layer(item_input_features[col]),
                               [-1, embedding_dim]))

        for i in range(len(self.u_history_cols)):
            item_col_name = self.u_history_col_names[i]
            le = self.le_dict[item_col_name]
            embedding_dim = int(len(le.classes_)**0.25) + 1  # 动态确定维度
            lstm_out_dim = int(
                (embedding_dim *
                 self.u_history_col_ts_step[self.u_history_cols[i]]) / 2)

            embedding_series = share_embedding_dict[item_col_name](
                user_input_features[self.u_history_cols[i]])
            # embedding_series = tf.keras.layers.BatchNormalization()(embedding_series)
            user_vector_list.append(
                tf.keras.layers.LSTM(units=lstm_out_dim)(embedding_series))

        user_embedding = tf.keras.layers.concatenate(user_vector_list, axis=1)
        item_embedding = tf.keras.layers.concatenate(item_vector_list, axis=1)

        for i in range(len(self.u_hidden_units[:-1])):
            # user_embedding = tf.keras.layers.BatchNormalization()(user_embedding)
            user_embedding = tf.keras.layers.Dense(
                self.u_hidden_units[i],
                activation=self.activation)(user_embedding)
            user_embedding = tf.keras.layers.Dropout(
                self.dropout)(user_embedding)
        user_embedding = tf.keras.layers.Dense(
            self.u_hidden_units[-1], name='user_embedding',
            activation='tanh')(user_embedding)
        # user_embedding = tf.keras.layers.BatchNormalization()(user_embedding)

        for i in range(len(self.i_hidden_units[:-1])):
            # item_embedding = tf.keras.layers.BatchNormalization()(item_embedding)
            item_embedding = tf.keras.layers.Dense(
                self.i_hidden_units[i],
                activation=self.activation)(item_embedding)
            item_embedding = tf.keras.layers.Dropout(
                self.dropout)(item_embedding)
        item_embedding = tf.keras.layers.Dense(
            self.i_hidden_units[-1], name='item_embedding',
            activation='tanh')(item_embedding)
        # item_embedding = tf.keras.layers.BatchNormalization()(item_embedding)

        # 双塔向量做内积输出
        output = tf.expand_dims(
            tf.reduce_sum(user_embedding * item_embedding, axis=1), 1)
        # output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(output)
        output = tf.sigmoid(output)

        user_input = list(user_input_features.values())
        self.user_input_len = len(user_input)
        item_input = list(item_input_features.values())
        self.item_input_len = len(item_input)
        inputs_list = user_input + item_input
        self.model = tf.keras.models.Model(inputs=inputs_list, outputs=output)

        self.model.compile(loss=loss,
                           optimizer=optimizer,
                           metrics=metrics,
                           **kwargs)
        if summary:
            self.model.summary()
        log_dir = "logs/fit/" + datetime.datetime.now().strftime(
            "%Y%m%d-%H%M%S")
        self.tensorboard_callback = tf.keras.callbacks.TensorBoard(
            log_dir=log_dir, histogram_freq=1)
예제 #8
0
#     # define subplot
#     pyplot.subplot(330 + 1 + i)
#     # plot raw pixel data
#     pyplot.imshow(train_x[i], cmap=pyplot.get_cmap('gray'))
# # show the figure
# pyplot.show()

loss, train_y, test_y = data_set.get_dataset_labels(train_y, test_y)

model = build_unet([IMAGE_DIMS[0], IMAGE_DIMS[1], 1],
                   len(data_set.class_names))

opt = Adam(learning_rate=hyperparameters.init_lr)
model.compile(loss='binary_crossentropy',
              optimizer=opt,
              metrics=['accuracy', Precision(),
                       Recall()])

start_time = time.time()
H = training_loop(model,
                  opt,
                  hyperparameters,
                  train_x,
                  train_y,
                  test_x,
                  test_y,
                  meta_heuristic=hyperparameters.meta_heuristic,
                  meta_heuristic_order=hyperparameters.meta_heuristic_order)
time_taken = timedelta(seconds=(time.time() - start_time))

# H =model.fit(x=aug.flow(train_x, train_y, batch_size=hyperparameters.batch_size), validation_data=(test_x, test_y),
예제 #9
0
    def compile(self,
                embedding_dim=4,
                rnn_cells=64,
                hidden_units=[64, 16],
                activation='relu',
                dropout=0.3,
                loss=tf.keras.losses.BinaryCrossentropy(),
                optimizer=tf.keras.optimizers.Adam(1e-4),
                metrics=[Precision(), AUC()],
                summary=True,
                **kwargs):
        """
                Args:
                    embedding_dim: 静态离散特征embedding参数,表示输出向量维度,输入词典维度由训练数据自动计算产生
                    rnn_cells: 时序连续特征输出神经元个数
                    hidden_units: MLP层神经元参数,最少2层
                    activation: MLP层激活函数
                    dropout: dropout系数
                    loss: 损失函数
                    optimizer: 优化器
                    metrics: 效果度量函数
                    summary: 是否输出summary信息
                """
        self.embedding_dim = embedding_dim
        self.rnn_cells = rnn_cells
        self.hidden_units = hidden_units
        self.activation = activation
        self.dropout = dropout

        if not self.scalar_rnn:
            print("数据Scalar尚未初始化,请先调用pre_processing方法进行数据预处理,然后才能编译模型!")
            return

        # 定义输入格式
        input_features = OrderedDict()

        input_features['input_rnn_continue'] = tf.keras.layers.Input(
            shape=(self.ts_step, len(self.rnn_continue_X_cols)),
            name='input_rnn_continue')  # 连续时间序列数据
        if self.static_continue_X_cols:
            input_features['input_static_continue'] = tf.keras.layers.Input(
                shape=len(self.static_continue_X_cols),
                name='input_static_continue')  # 连续静态数据
        for col in self.static_discrete_X_cols:
            input_features[col] = tf.keras.layers.Input(shape=1,
                                                        name=col)  # 静态离散特征

        # 构造网络结构
        rnn_dense = [
            tf.keras.layers.LSTM(units=self.rnn_cells)(
                input_features['input_rnn_continue'])
        ]
        static_dense = []
        if self.static_continue_X_cols:
            static_dense = [input_features['input_static_continue']]
        static_discrete = []
        for col in self.static_discrete_X_cols:
            vol_size = len(self.le_dict[col].classes_) + 1
            vec = tf.keras.layers.Embedding(vol_size, self.embedding_dim)(
                input_features[col])
            static_discrete.append(tf.reshape(vec, [-1, self.embedding_dim]))

        concated_vec = rnn_dense + static_dense + static_discrete
        if len(concated_vec) == 1:
            x = concated_vec[0]
        else:
            x = tf.keras.layers.concatenate(concated_vec, axis=1)

        # 特征拼接后加入全连接层
        for i in range(len(hidden_units)):
            x = tf.keras.layers.Dense(hidden_units[i],
                                      activation=activation)(x)
            x = tf.keras.layers.Dropout(dropout)(x)

        output = tf.keras.layers.Dense(1, activation='sigmoid',
                                       name='action')(x)

        inputs_list = list(input_features.values())
        self.model = tf.keras.models.Model(inputs=inputs_list, outputs=output)

        self.model.compile(loss=loss,
                           optimizer=optimizer,
                           metrics=metrics,
                           **kwargs)
        if summary:
            self.model.summary()
예제 #10
0
파일: sbcnm.py 프로젝트: yhangang/lapras
    def compile(self,
                embedding_dim=4,
                u_hidden_units=[128, 64, 16],
                i_hidden_units=[64, 32, 16],
                activation='relu',
                dropout=0.3,
                loss=tf.keras.losses.BinaryCrossentropy(),
                optimizer=tf.keras.optimizers.Adam(1e-4),
                metrics=[Precision(), AUC()],
                summary=True,
                **kwargs):
        """
                Args:
                    embedding_dim: 静态离散特征embedding参数,embedding[0]表示输入值离散空间上限,embedding[1]表示输出向量维度
                    u_hidden_units: 用户塔MLP层神经元参数,最少2层
                    i_hidden_units: 物品塔MLP层神经元参数,最少2层
                    activation: MLP层激活函数
                    dropout: dropout系数
                    loss: 损失函数
                    optimizer: 优化器
                    metrics: 效果度量函数
                    summary: 是否输出summary信息
        """
        self.embedding_dim = embedding_dim
        self.u_hidden_units = u_hidden_units
        self.i_hidden_units = i_hidden_units
        self.activation = activation
        self.dropout = dropout

        # 定义输入格式
        user_input_features = OrderedDict()
        item_input_features = OrderedDict()
        user_input_features['u_continue_cols'] = tf.keras.layers.Input(
            shape=len(self.u_continue_cols),
            name='u_continue_cols_input')  # 用户数值特征
        for col in self.u_discrete_cols:
            user_input_features[col] = tf.keras.layers.Input(
                shape=1, name=col + '_input')  # 用户离散特征
        item_input_features['i_continue_cols'] = tf.keras.layers.Input(
            shape=len(self.i_continue_cols),
            name='i_continue_cols_input')  # 物品数值特征
        for col in self.i_discrete_cols:
            item_input_features[col] = tf.keras.layers.Input(
                shape=1, name=col + '_input')  # 物品离散特征

        for col in self.u_history_cols:
            user_input_features[col] = tf.keras.layers.Input(
                shape=self.u_history_col_ts_step[col],
                name=col + '_input')  # 用户关于物品的历史序列特征
        for col in self.u_lstm_cols:
            user_input_features[col] = tf.keras.layers.Input(
                shape=self.u_lstm_col_ts_step[col],
                name=col + '_input')  # 用户LSTM特征

        # 构造双塔结构
        user_vector_list = []
        item_vector_list = []

        # u_dense = tf.keras.layers.BatchNormalization()(user_input_features['u_continue_cols'])
        u_dense = user_input_features['u_continue_cols']
        user_vector_list.append(u_dense)

        # i_dense = tf.keras.layers.BatchNormalization()(item_input_features['i_continue_cols'])
        i_dense = item_input_features['i_continue_cols']
        item_vector_list.append(i_dense)

        for col in self.u_discrete_cols:
            user_vector_list.append(
                tf.reshape(
                    tf.keras.layers.Embedding(
                        self.u_discrete_col_input_dim[col],
                        self.embedding_dim,
                        name=col + '_embedding')(user_input_features[col]),
                    [-1, self.embedding_dim]))

        for col in self.i_discrete_cols:
            item_vector_list.append(
                tf.reshape(
                    tf.keras.layers.Embedding(
                        self.i_discrete_col_input_dim[col],
                        self.embedding_dim,
                        name=col + '_embedding')(item_input_features[col]),
                    [-1, self.embedding_dim]))

        for col in self.u_history_cols:
            pooling_embedding = tf.keras.layers.GlobalAveragePooling1D(
                data_format='channels_last',
                name=col + '_pooling')(user_input_features[col])
            user_vector_list.append(pooling_embedding)

        for col in self.u_lstm_cols:
            # 动态计算LSTM输出维度
            lstm_out_dim = int(self.u_lstm_col_ts_step[col][0] *
                               self.u_lstm_col_ts_step[col][1] / 4)
            lstm_vector = tf.keras.layers.LSTM(units=lstm_out_dim,
                                               name=col + '_LSTM')(
                                                   user_input_features[col])
            user_vector_list.append(lstm_vector)

        user_embedding = tf.keras.layers.concatenate(user_vector_list, axis=1)
        item_embedding = tf.keras.layers.concatenate(item_vector_list, axis=1)

        for i in range(len(self.u_hidden_units[:-1])):
            # user_embedding = tf.keras.layers.BatchNormalization()(user_embedding)
            user_embedding = tf.keras.layers.Dense(
                self.u_hidden_units[i],
                activation=self.activation)(user_embedding)
            user_embedding = tf.keras.layers.Dropout(
                self.dropout)(user_embedding)
        user_embedding = tf.keras.layers.Dense(
            self.u_hidden_units[-1], name='user_embedding',
            activation='tanh')(user_embedding)

        for i in range(len(self.i_hidden_units[:-1])):
            # item_embedding = tf.keras.layers.BatchNormalization()(item_embedding)
            item_embedding = tf.keras.layers.Dense(
                self.i_hidden_units[i],
                activation=self.activation)(item_embedding)
            item_embedding = tf.keras.layers.Dropout(
                self.dropout)(item_embedding)
        item_embedding = tf.keras.layers.Dense(
            self.i_hidden_units[-1], name='item_embedding',
            activation='tanh')(item_embedding)

        # 双塔向量做内积输出
        output = tf.expand_dims(
            tf.reduce_sum(user_embedding * item_embedding, axis=1), 1)
        output = tf.sigmoid(output)

        user_input = list(user_input_features.values())
        self.user_input_len = len(user_input)
        item_input = list(item_input_features.values())
        self.item_input_len = len(item_input)
        inputs_list = user_input + item_input
        self.model = tf.keras.models.Model(inputs=inputs_list, outputs=output)

        self.model.compile(loss=loss,
                           optimizer=optimizer,
                           metrics=metrics,
                           **kwargs)
        if summary:
            self.model.summary()
예제 #11
0
def main(argv):
    del argv
    # path
    data_dir = os.path.join(BASE_DIR, 'dataset', FLAGS.dataset)
    exp_dir = os.path.join(data_dir, 'exp', FLAGS.exp_name)
    model_dir = os.path.join(exp_dir, 'ckpt')
    log_dir = exp_dir
    os.makedirs(model_dir, exist_ok=True)
    # os.makedirs(log_dir, exist_ok=True)
    model_path = os.path.join(model_dir, 'model-{epoch:04d}.ckpt.h5')

    # logging
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.DEBUG,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(log_dir, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.info(
        '------------------------------experiment start------------------------------------'
    )

    for i in (
            'exp_name',
            'dataset',
            'model',
            'mode',
            'lr',
    ):
        logging.info(
            '%s: %s' %
            (i, FLAGS.get_flag_value(i, '########VALUE MISSED#########')))
    logging.info(FLAGS.flag_values_dict())

    # resume from checkpoint
    largest_epoch = 0
    if FLAGS.resume == 'ckpt':
        chkpts = tf.io.gfile.glob(model_dir + '/*.ckpt.h5')
        if len(chkpts):
            largest_epoch = sorted([int(i[-12:-8]) for i in chkpts],
                                   reverse=True)[0]
            print('resume from epoch', largest_epoch)
            weight_path = model_path.format(epoch=largest_epoch)
        else:
            weight_path = None
    elif len(FLAGS.resume):
        assert os.path.isfile(FLAGS.resume)
        weight_path = FLAGS.resume
    else:
        weight_path = None

    dataset = importlib.import_module(
        'dataset.%s.data_loader' %
        FLAGS.dataset).DataLoader(**FLAGS.flag_values_dict())
    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():
        model = globals()[FLAGS.model](**FLAGS.flag_values_dict())
        # model = alexnet()
        if FLAGS.resume and weight_path:
            logging.info('resume from previous ckp: %s' % largest_epoch)
            model.load_weights(weight_path)
        # model.layers[1].trainable = False
        loss = globals()[FLAGS.loss]
        model.compile(
            optimizer=SGDW(momentum=0.9,
                           learning_rate=FLAGS.lr,
                           weight_decay=FLAGS.weight_decay),
            loss=loss,
            metrics=[
                "accuracy",
                Recall(),
                Precision(),
                MeanIoU(num_classes=FLAGS.classes)
            ],
        )
        # if 'train' in FLAGS.mode:
        #     model.summary()
        logging.info('There are %s layers in model' % len(model.layers))
        if FLAGS.freeze_layers > 0:
            logging.info('Freeze first %s layers' % FLAGS.freeze_layers)
            for i in model.layers[:FLAGS.freeze_layers]:
                i.trainable = False
        verbose = 1 if FLAGS.debug is True else 2
        if 'train' in FLAGS.mode:
            callbacks = [
                model_checkpoint(filepath=model_path,
                                 monitor=FLAGS.model_checkpoint_monitor),
                tensorboard(log_dir=os.path.join(exp_dir, 'tb-logs')),
                early_stopping(monitor=FLAGS.model_checkpoint_monitor,
                               patience=FLAGS.early_stopping_patience),
                lr_schedule(name=FLAGS.lr_schedule, epochs=FLAGS.epoch)
            ]
            file_writer = tf.summary.create_file_writer(
                os.path.join(exp_dir, 'tb-logs', "metrics"))
            file_writer.set_as_default()
            train_ds = dataset.get(
                'train')  # get first to calculate train size
            model.fit(
                train_ds,
                epochs=FLAGS.epoch,
                validation_data=dataset.get('valid'),
                callbacks=callbacks,
                initial_epoch=largest_epoch,
                verbose=verbose,
            )

            # evaluate before train on valid
            # result = model.evaluate(
            #     dataset.get('test'),
            # )
            # logging.info('evaluate before train on valid result:')
            # for i in range(len(result)):
            #     logging.info('%s:\t\t%s' % (model.metrics_names[i], result[i]))
        if 'test' in FLAGS.mode:
            # 学习valid
            # model.fit(
            #     dataset.get('valid'),
            #     epochs=3,
            #     # callbacks=callbacks,
            #     verbose=verbose
            # )
            # model.save_weights(os.path.join(model_dir, 'model.h5'))
            # 测试test
            result = model.evaluate(dataset.get('test'), )
            logging.info('evaluate result:')
            for i in range(len(result)):
                logging.info('%s:\t\t%s' % (model.metrics_names[i], result[i]))
            # TODO: remove previous checkpoint
        if 'predict' in FLAGS.mode:
            files = read_txt(
                os.path.join(BASE_DIR,
                             'dataset/%s/predict.txt' % FLAGS.dataset))
            output_dir = FLAGS.predict_output_dir
            os.makedirs(output_dir, exist_ok=True)
            i = 0
            ds = dataset.get('predict')
            for batch in ds:
                predict = model.predict(batch)
                for p in predict:
                    if i % 1000 == 0:
                        logging.info('curr: %s/%s' % (i, len(files)))
                    p_r = tf.squeeze(tf.argmax(
                        p, axis=-1)).numpy().astype('int16')
                    p_r = (p_r + 1) * 100
                    p_im = Image.fromarray(p_r)
                    im_path = os.path.join(
                        output_dir, '%s.png' % files[i].split('/')[-1][:-4])
                    p_im.save(im_path)
                    i += 1
        if FLAGS.task == 'visualize_result':
            dataset.visualize_evaluate(model, FLAGS.mode)
def define_model(input=(28, 28, 1), classes=10):
    input = Input(shape=input)
    x = Conv2D(8, (3, 3), activation='relu', kernel_initializer='he_uniform')(input)
    x = MaxPooling2D((2, 2))(x)
    x = Flatten()(x)
    x = Dense(256, activation='relu', kernel_initializer='he_uniform')(x)
    output = Dense(classes, activation='softmax')(x)
    model = Model(inputs=input, outputs=output)

    opt = SGD(lr=hyperparameters.init_lr, momentum=0.9)

    model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy', Precision(), Recall()])
    return model, opt