示例#1
0
def train(NET, TRAIN, VAL):

    # Random Seed
    random = cfg.getRandomState()
    image.resetRandomState()

    # Load pretrained model
    if cfg.PRETRAINED_MODEL_NAME:
        snapshot = io.loadModel(cfg.PRETRAINED_MODEL_NAME)
        NET = io.loadParams(NET, snapshot['params'])

    # Load teacher models
    teach_funcs = []
    if len(cfg.TEACHER) > 0:
        for t in cfg.TEACHER:
            snapshot = io.loadModel(t)
            TEACHER = snapshot['net']
            teach_funcs.append(birdnet.test_function(TEACHER, hasTargets=False))

    # Compile Theano functions
    train_net = birdnet.train_function(NET)
    test_net = birdnet.test_function(NET)

    # Status
    log.i("START TRAINING...")

    # Train for some epochs...
    for epoch in range(cfg.EPOCH_START, cfg.EPOCHS + 1):

        try:

            # Stop?
            if cfg.DO_BREAK:
                break

            # Clear stats for every epoch
            stats.clearStats()
            stats.setValue('sample_count', len(TRAIN) + len(VAL))

            # Start timer
            stats.tic('epoch_time')
            
            # Shuffle dataset (this way we get "new" batches every epoch)
            TRAIN = shuffle(TRAIN, random_state=random)

            # Iterate over TRAIN batches of images
            for image_batch, target_batch in bg.nextBatch(TRAIN):

                # Show progress
                stats.showProgress(epoch)

                # If we have a teacher, we use that model to get new targets
                if len(teach_funcs) > 0:
                    target_batch = np.zeros((len(teach_funcs), target_batch.shape[0], target_batch.shape[1]), dtype='float32')
                    for i in range(len(teach_funcs)):
                        target_batch[i] = teach_funcs[i](image_batch)
                    target_batch = np.mean(target_batch, axis=0)
                
                # Calling the training functions returns the current loss
                loss = train_net(image_batch, target_batch, lr.dynamicLearningRate(cfg.LR_SCHEDULE, epoch))
                stats.setValue('train loss', loss, 'append')
                stats.setValue('batch_count', 1, 'add')

                # Stop?
                if cfg.DO_BREAK:
                    break

            # Iterate over VAL batches of images
            for image_batch, target_batch in bg.nextBatch(VAL, False, True):

                # Calling the test function returns the net output, loss and accuracy
                prediction_batch, loss, acc = test_net(image_batch, target_batch)
                stats.setValue('val loss', loss, 'append')
                stats.setValue('val acc', acc, 'append')
                stats.setValue('batch_count', 1, 'add')
                stats.setValue('lrap', [metrics.lrap(prediction_batch, target_batch)], 'add')

                # Show progress
                stats.showProgress(epoch)

                # Stop?
                if cfg.DO_BREAK:
                    break

            # Show stats for epoch
            stats.showProgress(epoch, done=True)
            stats.toc('epoch_time')
            log.r(('TRAIN LOSS:', np.mean(stats.getValue('train loss'))), new_line=False)
            log.r(('VAL LOSS:', np.mean(stats.getValue('val loss'))), new_line=False)
            log.r(('VAL ACC:', int(np.mean(stats.getValue('val acc')) * 10000) / 100.0, '%'), new_line=False)          
            log.r(('MLRAP:', int(np.mean(stats.getValue('lrap')) * 1000) / 1000.0), new_line=False)
            log.r(('TIME:', stats.getValue('epoch_time'), 's'))

            # Save snapshot?
            if not epoch % cfg.SNAPSHOT_EPOCHS:
                io.saveModel(NET, cfg.CLASSES, epoch)
                print('vish')
                io.saveParams(NET, cfg.CLASSES, epoch)

            # New best net?
            if np.mean(stats.getValue('lrap')) > stats.getValue('best_mlrap'):
                stats.setValue('best_net', NET, static=True)
                stats.setValue('best_epoch', epoch, static=True)
                stats.setValue('best_mlrap', np.mean(stats.getValue('lrap')), static=True)

            # Early stopping?
            if epoch - stats.getValue('best_epoch') >= cfg.EARLY_STOPPING_WAIT:
                log.i('EARLY STOPPING!')
                break

            # Stop?
            if cfg.DO_BREAK:
                break

        except KeyboardInterrupt:
            log.i('KeyboardInterrupt')
            cfg.DO_BREAK = True
            break

    # Status
    log.i('TRAINING DONE!')
    log.r(('BEST MLRAP:', stats.getValue('best_mlrap'), 'EPOCH:', stats.getValue('best_epoch')))

    # Save best model and return
    io.saveParams(stats.getValue('best_net'), cfg.CLASSES, stats.getValue('best_epoch'))
    print('in training vish')
    return io.saveModel(stats.getValue('best_net'), cfg.CLASSES, stats.getValue('best_epoch'))
示例#2
0
def test(SNAPSHOTS):

    # Do we have more than one snapshot?
    if not isinstance(SNAPSHOTS, (list, tuple)):
        SNAPSHOTS = [SNAPSHOTS]

    # Load snapshots
    test_functions = []
    for s in SNAPSHOTS:

        # Settings
        NET = s['net']
        cfg.CLASSES = s['classes']
        cfg.IM_DIM = s['im_dim']
        cfg.IM_SIZE = s['im_size']

        # Compile test function
        test_net = birdnet.test_function(NET, hasTargets=False, layer_index=-1)
        test_functions.append(test_net)

    # Parse Testset
    TEST = parseTestSet()

    # Status
    log.i('START TESTING...')
    stats.clearStats()
    stats.tic('test_time')

    # Make predictions
    for spec_batch, labels, filename in bg.threadedGenerator(
            getSpecBatches(TEST)):

        try:

            # Status
            stats.tic('pred_time')

            # Prediction
            prediction_batch = []
            for test_func in test_functions:
                if len(prediction_batch) == 0:
                    prediction_batch = test_func(spec_batch)
                else:
                    prediction_batch += test_func(spec_batch)
            prediction_batch /= len(test_functions)

            # Eliminate the scores for 'Noise'
            if 'Noise' in cfg.CLASSES:
                prediction_batch[:, cfg.CLASSES.index('Noise')] = np.min(
                    prediction_batch)

            # Prediction pooling
            p_pool = predictionPooling(prediction_batch)

            # Get class labels
            p_labels = {}
            for i in range(p_pool.shape[0]):
                p_labels[cfg.CLASSES[i]] = p_pool[i]

            # Sort by score
            p_sorted = sorted(p_labels.items(),
                              key=operator.itemgetter(1),
                              reverse=True)

            # Calculate MLRAP (MRR for single labels)
            targets = np.zeros(p_pool.shape[0], dtype='float32')
            for label in labels:
                if label in cfg.CLASSES:
                    targets[cfg.CLASSES.index(label)] = 1.0
            lrap = metrics.lrap(np.expand_dims(p_pool, 0),
                                np.expand_dims(targets, 0))
            stats.setValue('lrap', lrap, mode='append')

            # Show sample stats
            log.i((filename), new_line=True)
            log.i(('\tLABELS:', labels), new_line=True)
            log.i(('\tTOP PREDICTION:', p_sorted[0][0],
                   int(p_sorted[0][1] * 1000) / 10.0, '%'),
                  new_line=True)
            log.i(('\tLRAP:', int(lrap * 1000) / 1000.0), new_line=False)
            log.i(('\tMLRAP:',
                   int(np.mean(stats.getValue('lrap')) * 1000) / 1000.0),
                  new_line=True)

            # Save some stats
            if p_sorted[0][0] == labels[0]:
                stats.setValue('top1_correct', 1, 'add')
                stats.setValue('top1_confidence', p_sorted[0][1], 'append')
            else:
                stats.setValue('top1_incorrect', 1, 'add')
            stats.toc('pred_time')
            stats.setValue('time_per_batch', stats.getValue('pred_time'),
                           'append')

        except KeyboardInterrupt:
            cfg.DO_BREAK = True
            break
        except:
            log.e('ERROR WHILE TRAINING!')
            continue

    # Stats
    stats.toc('test_time')
    log.i(('TESTING DONE!', 'TIME:', stats.getValue('test_time'), 's'))
    log.r(('FINAL MLRAP:', np.mean(stats.getValue('lrap'))))
    log.r(('TOP 1 ACCURACY:',
           max(
               0,
               float(stats.getValue('top1_correct')) /
               (stats.getValue('top1_correct') +
                stats.getValue('top1_incorrect')))))
    log.r(('TOP 1 MEAN CONFIDENCE:',
           max(0, np.mean(stats.getValue('top1_confidence')))))
    log.r(('TIME PER BATCH:',
           int(np.mean(stats.getValue('time_per_batch')) * 1000), 'ms'))

    return np.mean(stats.getValue('lrap')), int(
        np.mean(stats.getValue('time_per_file')) * 1000)