Ejemplo n.º 1
0
def train(config):

    input_width = config['model']['input_width']
    input_height = config['model']['input_height']
    label_file = config['model']['labels']
    model_name = config['model']['name']

    train_data_dir = config['train']['data_dir']
    train_file_list = config['train']['file_list']
    pretrained_weights = config['train']['pretrained_weights']
    batch_size = config['train']['batch_size']
    learning_rate = config['train']['learning_rate']
    nb_epochs = config['train']['nb_epochs']
    start_epoch = config['train']['start_epoch']
    train_base = config['train']['train_base']

    valid_data_dir = config['valid']['data_dir']
    valid_file_list = config['valid']['file_list']

    builder = ModelBuilder(config)

    filepath = os.path.join('', train_file_list)
    train_gen = builder.build_datagen(filepath)
    #     train_gen.save_labels(label_file)
    #     trainDataGen, train_steps_per_epoch = train_gen.from_frame(directory=train_data_dir)

    #     filepath = os.path.join(valid_data_dir, valid_file_list)
    #     valid_gen = builder.build_datagen(filepath, with_aug=False)
    #     validDataGen, valid_steps_per_epoch = valid_gen.from_frame(directory=valid_data_dir)

    # define checkpoint
    dataset_name = model_name
    dirname = 'ckpt-' + dataset_name
    if not os.path.exists(dirname):
        os.makedirs(dirname)

    timestr = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    filepath = os.path.join(
        dirname,
        'weights-%s-%s-{epoch:02d}-{loss:.5f}.hdf5' % (model_name, timestr))
    checkpoint = ModelCheckpoint(
        filepath=filepath,
        monitor='loss',  # acc outperforms loss
        verbose=1,
        save_best_only=True,
        save_weights_only=True,
        period=1)

    # define logs for tensorboard
    tensorboard = TensorBoard(log_dir='logs', histogram_freq=0)

    wgtdir = 'weights'
    if not os.path.exists(wgtdir):
        os.makedirs(wgtdir)

    # train
    train_graph = tf.Graph()
    train_sess = tf.Session(graph=train_graph, config=tf_config)

    tf.keras.backend.set_session(train_sess)
    with train_graph.as_default():
        model = builder.build_model()
        model.compile(
            optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate),
            loss=triple_loss,
            metrics=[p_loss, n_loss])
        model.summary()

        # Load weight of unfinish training model(optional)
        if pretrained_weights != '':
            model.load_weights(pretrained_weights)

        model.fit_generator(
            generator=train_gen,
            #                           validation_data = validDataGen,
            initial_epoch=start_epoch,
            epochs=nb_epochs,
            callbacks=[checkpoint, tensorboard],
            use_multiprocessing=False,
            workers=16)
        model_file = '%s_%s.h5' % (model_name, timestr)
        model.save(model_file)
        print('save model to %s' % (model_file))
Ejemplo n.º 2
0
def train(config):
    
    input_width        = config['model']['input_width']
    input_height       = config['model']['input_height']
    label_file         = config['model']['labels']
    model_name         = config['model']['name']
    class_num          = config['model']['class_num']
    
    train_data_dir     = config['train']['data_dir']
    train_file_list    = config['train']['file_list']
    pretrained_weights = config['train']['pretrained_weights']
    batch_size         = config['train']['batch_size']
    learning_rate      = config['train']['learning_rate']
    nb_epochs          = config['train']['nb_epochs']
    start_epoch        = config['train']['start_epoch']
    train_base         = config['train']['train_base']
    
    valid_data_dir     = config['valid']['data_dir']
    valid_file_list    = config['valid']['file_list']
    
    builder = ModelBuilder(config)

    filepath = train_file_list
    train_gen = builder.build_datagen(filepath)
    train_gen.save_labels(label_file)
    trainDataGen, train_steps_per_epoch = train_gen.from_frame(directory=train_data_dir)
    trainDs = tf.data.Dataset.from_generator(
        lambda: trainDataGen, 
        output_types=(tf.float32, tf.float32), 
        output_shapes=([batch_size,input_width,input_height,3], [batch_size,class_num])
    )
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    trainDs = trainDs.with_options(options)

    
    filepath = valid_file_list
    valid_gen = builder.build_datagen(filepath, with_aug=False)
    validDataGen, valid_steps_per_epoch = valid_gen.from_frame(directory=valid_data_dir)
    validDs = tf.data.Dataset.from_generator(
        lambda: validDataGen, 
        output_types=(tf.float32, tf.float32), 
        output_shapes=([batch_size,input_width,input_height,3], [batch_size,class_num])
    )    
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    validDs = validDs.with_options(options)    
    
    # define checkpoint
    dataset_name = model_name
    dirname = 'ckpt-' + dataset_name
    if not os.path.exists(dirname):
        os.makedirs(dirname)

    timestr = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    filepath = os.path.join(dirname, 'weights-%s-%s-{epoch:02d}-{val_accuracy:.2f}.hdf5' %(model_name, timestr))
    checkpoint = ModelCheckpoint(filepath=filepath, 
                             monitor='val_accuracy',    # acc outperforms loss
                             verbose=1, 
                             save_best_only=True, 
                             save_weights_only=True, 
                             period=5)

    # define logs for tensorboard
    tensorboard = TensorBoard(log_dir='logs', histogram_freq=0)

    wgtdir = 'weights'
    if not os.path.exists(wgtdir):
        os.makedirs(wgtdir)

    # train
    # tf2.5
    strategy = tf.distribute.MirroredStrategy()
    print("Number of devices: {}".format(strategy.num_replicas_in_sync))

    # Open a strategy scope.
    with strategy.scope():
        model = builder.build_model()

        # tf2.5
        if class_num == 2:
            model.compile(optimizer=tf.optimizers.Adam(learning_rate=learning_rate), 
                          loss='categorical_crossentropy',metrics=['accuracy'])
        else:
            model.compile(optimizer=tf.optimizers.Adam(learning_rate=learning_rate), 
                          loss='sparse_categorical_crossentropy',metrics=['accuracy'])
        model.summary()

        # Load weight of unfinish training model(optional)
        if pretrained_weights != '':
            model.load_weights(pretrained_weights)

        model.fit(trainDs,
                  batch_size = batch_size,
                  steps_per_epoch=train_steps_per_epoch,
                  validation_data = validDs,
                  validation_steps=valid_steps_per_epoch,
                  initial_epoch=start_epoch, 
                  epochs=nb_epochs, 
                  callbacks=[checkpoint,tensorboard], 
                  use_multiprocessing=True, 
                  workers=16)
        model_file = '%s_%s.h5' % (model_name,timestr)
        model.save(model_file)
        print('save model to %s' % (model_file))