Beispiel #1
0
def train(**kwargs):
    save_checkpoint_dir = Path(kwargs['save_dir'])
    save_checkpoint_dir.mkdir(exist_ok=True, parents=True)

    model = efficientdet.models.EfficientDet(
        kwargs['n_classes'],
        D=kwargs['efficientdet'],
        bidirectional=kwargs['bidirectional'],
        freeze_backbone=kwargs['freeze_backbone'],
        weights='imagenet')

    if kwargs['checkpoint'] is not None:
        print('Loading checkpoint from {}...'.format(kwargs['checkpoint']))
        model.load_weights(kwargs['checkpoint'])

    ds, class2idx = efficientdet.data.build_ds(
        format=kwargs['format'],
        annots_path=kwargs['train_dataset'],
        images_path=kwargs['images_path'],
        im_size=(model.config.input_size, ) * 2,
        class_names=kwargs['classes_names'].split(','),
        batch_size=kwargs['batch_size'])

    val_ds = None
    if kwargs['val_dataset']:
        val_ds, _ = efficientdet.data.build_ds(
            format=kwargs['format'],
            annots_path=kwargs['val_dataset'],
            images_path=kwargs['images_path'],
            class_names=kwargs['classes_names'].split(','),
            im_size=(model.config.input_size, ) * 2,
            shuffle=False,
            batch_size=kwargs['batch_size'] // 2)

    anchors = generate_anchors(model.anchors_config, model.config.input_size)

    optimizer = tf.optimizers.Adam(learning_rate=kwargs['learning_rate'],
                                   clipnorm=0.001)

    for epoch in range(kwargs['epochs']):

        engine.train_single_epoch(model=model,
                                  anchors=anchors,
                                  dataset=ds,
                                  optimizer=optimizer,
                                  loss_fn=loss_fn,
                                  num_classes=kwargs['n_classes'],
                                  epoch=epoch)

        if val_ds is not None:
            engine.evaluate(model=model, dataset=val_ds, class2idx=class2idx)

        model_type = 'bifpn' if kwargs['bidirectional'] else 'fpn'
        data_format = kwargs['format']
        fname = f'{model_type}_{data_format}_efficientdet_weights_{epoch}.tf'
        fname = save_checkpoint_dir / fname
        model.save_weights(str(fname))
Beispiel #2
0
def train(**kwargs):
    save_checkpoint_dir = Path(kwargs['save_dir'])
    save_checkpoint_dir.mkdir(exist_ok=True, parents=True)

    if kwargs['checkpoint'] is not None:
        print('Loading checkpoint from {}...'.format(kwargs['checkpoint']))
        model = efficientdet.checkpoint.load(kwargs['checkpoint'])
    elif kwargs['from_pretrained'] is not None:
        model = (efficientdet.EfficientDet.from_pretrained(
            kwargs['from_pretrained'], num_classes=kwargs['n_classes']))
        for l in model.layers:
            l.trainable = True
        model.trainable = True
        print('Training from a pretrained model...')
        print('This will override any configuration related to EfficientNet'
              ' using the defined in the pretrained model.')
    else:
        model = efficientdet.models.EfficientDet(
            kwargs['n_classes'],
            D=kwargs['efficientdet'],
            bidirectional=kwargs['bidirectional'],
            freeze_backbone=kwargs['freeze_backbone'],
            weights='imagenet')

    ds, class2idx = efficientdet.data.build_ds(
        format=kwargs['format'],
        annots_path=kwargs['train_dataset'],
        images_path=kwargs['images_path'],
        im_size=(model.config.input_size, ) * 2,
        class_names=kwargs['classes_names'].split(','),
        batch_size=kwargs['batch_size'],
        data_augmentation=True)

    val_ds = None
    if kwargs['val_dataset']:
        val_ds, _ = efficientdet.data.build_ds(
            format=kwargs['format'],
            annots_path=kwargs['val_dataset'],
            images_path=kwargs['images_path'],
            class_names=kwargs['classes_names'].split(','),
            im_size=(model.config.input_size, ) * 2,
            shuffle=False,
            data_augmentation=False,
            batch_size=kwargs['batch_size'] // 2)

    anchors = generate_anchors(model.anchors_config, model.config.input_size)

    if kwargs['w_scheduler']:
        lr = efficientdet.optim.EfficientDetLRScheduler(
            kwargs['learning_rate'],
            kwargs['epochs'], (ds_len(ds) // kwargs['grad_accum_steps']) + 1,
            alpha=kwargs['alpha'])
    else:
        lr = kwargs['learning_rate']

    optimizer = tf.optimizers.SGD(learning_rate=lr, momentum=0.9)

    for epoch in range(kwargs['epochs']):

        engine.train_single_epoch(model=model,
                                  anchors=anchors,
                                  dataset=ds,
                                  optimizer=optimizer,
                                  grad_accum_steps=kwargs['grad_accum_steps'],
                                  loss_fn=loss_fn,
                                  num_classes=kwargs['n_classes'],
                                  epoch=epoch,
                                  print_every=kwargs['print_freq'])

        if val_ds is not None and (epoch + 1) % kwargs['validate_freq'] == 0:
            engine.evaluate(model=model, dataset=val_ds, class2idx=class2idx)

        model_type = 'bifpn' if kwargs['bidirectional'] else 'fpn'
        data_format = kwargs['format']
        arch = kwargs['efficientdet']
        save_dir = (save_checkpoint_dir /
                    f'{arch}_{model_type}_{data_format}_{epoch}')
        efficientdet.checkpoint.save(model, kwargs, save_dir)