示例#1
0
    def get_model(self):
        # get symbol from symbol.py
        self.symbol = symbol.get_symbol(self, self.label_num, self.ignore_label, self.aspp, self.aspp_stride,
                                        self.atrous_type, self.bn_use_global_stats, self.relu_type)

        # load model
        if self.load_model_prefix is not None and self.load_epoch > 0:
            self.symbol, self.arg_params, self.aux_params = \
                mx.model.load_checkpoint(os.path.join(self.load_model_dir, self.load_model_prefix), self.load_epoch)
示例#2
0
parser.add_argument('--num-examples', type=int, default=int(2140 * 0.8),
                    help='the number of training examples')
# need log to file?
parser.add_argument('--log-dir', type=str, default="/tmp/",
                    help='directory of the log file')

parser.add_argument('--load-epoch', type=int,
                    help="load the model on an epoch using the model-prefix")
parser.add_argument('--save-model-prefix', type=str,
                    help='the prefix of the model to save')
# todo statistic about mean data
args = parser.parse_args()

import symbol
net = symbol.get_symbol(output_dim = 30)

from data import FileIter

train = FileIter(
         eval_ratio = 0.2, 
         is_val = False,
         data_name = "data",
         batch_size = args.batch_size,
         label_name = "lr_label"
        )

val = FileIter(
     eval_ratio = 0.2, 
     is_val = True,
     data_name = "data",
import mxnet as mx
import dataLoader, deploy, symbol, debug

network = symbol.get_symbol(num_classes = 10)
# print network.list_arguments()

net = mx.mod.Module(symbol = network, context = mx.gpu(0), fixed_param_names = [ "fc1_weight", "fc1_bias", "fc2_weight", "fc2_bias"])

trainDataIter, valDataIter = dataLoader.get_data_iter()

num_epoch = 120

try:
	net.fit(train_data = trainDataIter,
			eval_data = valDataIter,
			epoch_end_callback = debug.epoch_end_callback,
			eval_end_callback = debug.eval_end_callback,
			eval_metric = deploy.get_eval_metric(),
			optimizer = "sgd",
			initializer = deploy.get_initializer(),
			num_epoch = num_epoch,
			begin_epoch = deploy.get_begin_epoch()
			)
			
	results = net.score(valDataIter, deploy.get_eval_metric(), reset = True)
	print "[Info] Validation Result:", results
	results = net.score(trainDataIter, deploy.get_eval_metric(), reset = True)
	print "[Info] Training Result:", results
	
	print "[Info] Saving Parameters..."
	net.save_params('models/test2_{}.params'.format(str(num_epoch)))
示例#4
0
 def actualize(self, ev, env):
     """
     Return the constant value. The dispatched event is simply ignored.
     """
     ll = [s for s in symbol.get_symbol(self.name)]
     return ll
示例#5
0
 def __function_name(self):
     s = symbol.get_symbol(self.event.function)
     if s: return s[0]
     else: return None
示例#6
0
 def actualize(self, ev, env):
     """
     Return the constant value. The dispatched event is simply ignored.
     """
     ll = [s for s in symbol.get_symbol(self.name)]
     return ll
示例#7
0
def resolve(s):
    r = symbol.get_symbol(s)
    return set(r)