Exemplo n.º 1
0
    def test_metrics(self):
        Y = np.random.randint(0, 2, size=(2, 5, 5))
        Yhat = np.random.randint(0, 2, size=(2, 5, 5))

        C, acc, prec, recall, f1 = emlib.metrics(Y, Yhat, display=False)
        prec2, recall2, f12, supp = smetrics(np.reshape(Y, (Y.size, )),
                                             np.reshape(Yhat, (Yhat.size, )))

        self.assertAlmostEqual(prec, prec2[1])
        self.assertAlmostEqual(recall, recall2[1])
        self.assertAlmostEqual(f1, f12[1])
Exemplo n.º 2
0
    def test_metrics(self):
        Y = np.random.randint(0,2,size=(2,5,5))
        Yhat = np.random.randint(0,2,size=(2,5,5))

        C,acc,prec,recall,f1 = emlib.metrics(Y, Yhat, display=False)
        prec2, recall2, f12, supp = smetrics(np.reshape(Y, (Y.size,)), 
                np.reshape(Yhat, (Yhat.size,)))

        self.assertAlmostEqual(prec, prec2[1])
        self.assertAlmostEqual(recall, recall2[1])
        self.assertAlmostEqual(f1, f12[1])
Exemplo n.º 3
0
def _train_network(args):
    """ Main CNN training loop.

    Creates PyCaffe objects and calls train_one_epoch until done.
    """
    #----------------------------------------
    # parse information from the prototxt files
    #----------------------------------------
    solverParam = caffe_pb2.SolverParameter()
    text_format.Merge(open(args.solver).read(), solverParam)

    netFn = solverParam.net
    netParam = caffe_pb2.NetParameter()
    text_format.Merge(open(netFn).read(), netParam)

    batchDim = emlib.infer_data_dimensions(netFn)
    assert(batchDim[2] == batchDim[3])  # tiles must be square
    print('[emCNN]: batch shape: %s' % str(batchDim))
   
    if args.outDir:
        outDir = args.outDir # overrides snapshot prefix
    else: 
        outDir = str(solverParam.snapshot_prefix)   # unicode -> str
    if not os.path.isdir(outDir):
        os.mkdir(outDir)

    # choose a synthetic data generating function
    if args.rotateData:
        syn_func = lambda V: _xform_minibatch(V, arbitraryRotation=True)
        print('[emCNN]:   WARNING: applying arbitrary rotations to data.  This may degrade performance in some cases...\n')
    else:
        syn_func = lambda V: _xform_minibatch(V, arbitraryRotation=False)


    #----------------------------------------
    # Create the Caffe solver
    # Note this assumes a relatively recent PyCaffe
    #----------------------------------------
    solver = caffe.SGDSolver(args.solver)
    solverMD = SGDSolverMemoryData(solver, solverParam)
    solverMD.print_network()

    #----------------------------------------
    # Load data
    #----------------------------------------
    bs = border_size(batchDim)
    print "[emCNN]: tile radius is: %d" % bs
    
    print "[emCNN]: loading training data..."
    Xtrain, Ytrain = _load_data(args.emTrainFile,
            args.labelsTrainFile,
            tileRadius=bs,
            onlySlices=args.trainSlices,
            omitLabels=args.omitLabels)
    
    print "[emCNN]: loading validation data..."
    Xvalid, Yvalid = _load_data(args.emValidFile,
            args.labelsValidFile,
            tileRadius=bs,
            onlySlices=args.validSlices,
            omitLabels=args.omitLabels)

    #----------------------------------------
    # Do training; save results
    #----------------------------------------
    omitLabels = set(args.omitLabels).union([-1,])   # always omit -1
    currEpoch = 1
    sys.stdout.flush()

    while not solverMD.is_training_complete():
        print "[emCNN]: Starting epoch %d" % currEpoch

        train_one_epoch(solverMD, Xtrain, Ytrain, 
            batchDim, outDir, 
            omitLabels=omitLabels,
            data_augment=syn_func)

        currEpoch += 1

        print "[emCNN]: Making predictions on validation data..."
        Mask = np.ones(Xvalid.shape, dtype=np.bool)
        Mask[Yvalid<0] = False
        Prob = predict(solver.net, Xvalid, Mask, batchDim)

        # discard mirrored edges and form class estimates
        Yhat = np.argmax(Prob, 0) 
        Yhat[Mask==False] = -1;
        Prob = prune_border_4d(Prob, bs)
        Yhat = prune_border_3d(Yhat, bs)

        # compute some metrics
        print('[emCNN]: Validation set performance:')
        emlib.metrics(prune_border_3d(Yvalid, bs), Yhat, display=True)

 
    solver.net.save(str(os.path.join(outDir, 'final.caffemodel')))
    np.save(os.path.join(outDir, 'Yhat.npz'), Prob)
    scipy.io.savemat(os.path.join(outDir, 'Yhat.mat'), {'Yhat' : Prob})
    print('[emCNN]: training complete.')