def main(): args = parse_args() ins_id = args.id instance = getTrainingInstance(ins_id) assert instance != None generateCFG(instance) cfg_from_file(instance.cfg_file) if instance.set_cfgs is not None: cfg_from_list(instance.set_cfgs) #generate model if it not exists yet if instance.net == '': instance.setDefaultNet() models_path = os.path.join(instance.devkit, 'models') if os.path.exists(models_path): os.popen('rm -rf {}'.format(models_path)) print 'Generating model ' + models_path generate_custom_net.main(instance.cls_num,models_path,instance.steps,instance.lr) #if not os.path.exists(instance.net): # generate_custom_net.main(instance.cls_num,instance.netdir,instance.steps,instance.lr) #copy net_def to devkit net_src = os.path.join(instance.devkit,instance.net, "faster_rcnn_alt_opt", \ "faster_rcnn_test.pt") net_dst = os.path.join(instance.devkit, "results") if (not os.path.exists(net_dst)): os.makedirs(net_dst) print 'Copying {} to {}'.format(net_src, net_dst) shutil.copy(net_src, net_dst) #generate factory.py generateFactory(instance) #generate train.sh import generate_train_sh generate_train_sh.main(instance) #make symbolic link to VOCCode generateVOCCode(instance.devkit) if instance.validate() == False: print 'Error in training instance.' exit(1) dbconn = sql.MySQLConnection('192.168.1.90','test','test','zb_label') dbconn.connect() sqlstr = 'update zb_train set status = 2 where id = {}'.format(ins_id) dbconn.query(sqlstr) dbconn.commit() dbconn.close() #start training try: os.system('experiments/scripts/train.sh') acc_rate = getAccuracy(ins_id) acc_str = json.dumps(acc_rate) #sqlstr = 'update zb_train set status = 3 , accuracy = {} where id = {}'.format(json.dumps(acc_str), ins_id) #prev dbconn may be time-out and closed by the server. sqlstr = 'update zb_train set status = 3 where id = {}'.format( ins_id) dbconn = sql.MySQLConnection('192.168.1.90','test','test','zb_label') dbconn.connect() dbconn.query(sqlstr) dbconn.commit() dbconn.close() except Exception,e: sqlerrstr = 'update zb_train set status = -1 where id = {}'.format(ins_id) dbconn = sql.MySQLConnection('192.168.1.90','test','test','zb_label') dbconn.connect() dbconn.query(sqlstr) dbconn.commit() dbconn.close() print e