示例#1
0
    def __init__(self, solver_prototxt, output_dir, pretrained_model=None):

        self.output_dir = output_dir

        self.solver = caffe.SGDSolver(solver_prototxt)

        if pretrained_model is not None:
            print('Loading pretrained model '
                  'weights from {:s}').format(pretrained_model)
            self.solver.net.copy_from(pretrained_model)

        self.solver_param = caffe_pb2.SolverParameter()
        with open(solver_prototxt, 'rt') as f:
            pb2.text_format.Merge(f.read(), self.solver_param)
示例#2
0
# --------------------------------------------------------
# Seg-FCN for Dragon
# Copyright (c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
""" Test a FCN-8s(PASCAL VOC) network """

import dragon.vm.caffe as caffe
import score
import numpy as np

weights = 'snapshot/train_iter_100000.caffemodel'

if __name__ == '__main__':

    # init
    caffe.set_mode_gpu()
    caffe.set_device(0)

    solver = caffe.SGDSolver('solver.prototxt')
    solver.net.copy_from(weights)

    # scoring
    val = np.loadtxt('../data/seg11valid.txt', dtype=str)
    score.seg_tests(solver, 'D:/seg', val)
示例#3
0
# --------------------------------------------------------
# Cifar-10 for Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
""" Train a cifar-10 net """

import dragon.vm.caffe as caffe

if __name__ == '__main__':

    # init
    caffe.set_mode_gpu()

    # solve
    solver = caffe.SGDSolver('cifar10_full_solver.prototxt')
    solver.step(70000)
    solver.snapshot()
示例#4
0
# --------------------------------------------------------
# Cifar-10 for Dragon
# Copyright(c) 2017 SeetaTech
# Written by Ting Pan
# --------------------------------------------------------
""" Train a cifar-10 net """

import dragon.vm.caffe as caffe

if __name__ == '__main__':

    # init
    caffe.set_mode_gpu()

    # solve
    solver = caffe.SGDSolver('cifar10_quick_solver.prototxt')
    solver.step(5000)
    solver.snapshot()