def start_new_training(self): try: # 绑定信号,如果是接收到信号,表示用户自己选择退出训练实例 # 训练实例状态为正常结束 SUPPORTED_SIGNALS = ( signal.SIGINT, signal.SIGTERM, ) for signum in SUPPORTED_SIGNALS: try: signal.signal(signum, self.signal_handler) logger.info("Bind signal '%s' success to %s", signum, self.signal_handler) except Exception as identifier: logger.warning("Bind signal '%s' failed, err: %s", signum, identifier) # AVA-SDK 初始化一个训练实例 self.train_ins = train.TrainInstance() logger.info("start new tarining, training_ins_id: %s", self.train_ins.get_training_ins_id()) logger.info("prepare_train_config") self.prepare_train_config() logger.info("prepare_solver_config") self.prepare_solver_config() logger.info("prepare_sampleset_config") self.prepare_sampleset_data() logger.info("prepare_model") #self.prepare_model() self.prepare_model_riheng() opts = self.train_config opts.update(self.solver_config) fit_args = {k: opts.get(k) for k in FIT_KWARGS_KEYS} logger.info("fit args: %s" % fit_args) self.mod.fit(self.train_data, eval_data=self.val_data, initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2), **fit_args) logger.info("training finish") err_msg = "" except Exception as err: err_msg = "training failed, err: %s" % (err) logger.info(err_msg) traceback.print_exc(file=sys.stderr) self.clean_up(err_msg=err_msg)
def start_new_training(): # AVA-SDK training Instance SUPPORTED_SIGNALS = (signal.SIGINT, signal.SIGTERM,) for signum in SUPPORTED_SIGNALS: try: signal.signal(signum, signal_handler) logger.info("Bind signal '%s' success to %s", signum, signal_handler) except Exception as identifier: logger.warning( "Bind signal '%s' failed, err: %s", signum, identifier) try: # parse args train_ins = train.TrainInstance() err_msg='' # add CALLBACK solver_param = caffe_pb2.SolverParameter() with open('/workspace/model/lenet_solver.prototxt','r') as f: text_format.Merge(f.read(),solver_param) solver_param.snapshot_prefix = train_ins.get_snapshot_base_path() logger.info("saving to %s", solver_param.snapshot_prefix) fixed_solver = train_ins.get_base_path() + "/solver.prototxt" with open(fixed_solver, 'w') as f: f.write(str(solver_param)) logger.info("write fixed solver to %s", fixed_solver) # AVA-SDK start caffe process training_cmd = ['caffe','train','-solver',fixed_solver,'-gpu','0'] proc = cmd.startproc(training_cmd) logger.info("Started %s", proc) # AVA-SDK add caffe callback cmd.logproc(proc, [train_ins.get_monitor_callback("caffe")]) exit_code = proc.wait() logger.info("Finished proc with code %s", exit_code) logger.info("Gracefully shutdown after 5s, wait cleaner ...") time.sleep(5) logger.info("Done.") if exit_code != 0: logger.error( "training exit code [%d] != 0, raise Exception", exit_code) raise Exception("training exit code [%d] != 0" % (exit_code)) train_ins.done(err_msg=err_msg) except Exception as err: err_msg = "training failed, err: %s" % (err) logger.info(err_msg) traceback.print_exc(file=sys.stderr) if train_ins == None: return train_ins.done(err_msg=err_msg)
def start_new_training(): # AVA-SDK training Instance SUPPORTED_SIGNALS = (signal.SIGINT, signal.SIGTERM,) for signum in SUPPORTED_SIGNALS: try: signal.signal(signum, signal_handler) logger.info("Bind signal '%s' success to %s", signum, signal_handler) except Exception as identifier: logger.warning( "Bind signal '%s' failed, err: %s", signum, identifier) try: # parse args train_ins = train.TrainInstance() err_msg='' # add CALLBACK # AVA-SDK start caffe process out_dir = train_ins.get_snapshot_base_path() roidb_path = train_ins.get_trainset_base_path() + "/cache/gt_roidb.pkl" training_cmd = ['python', 'detect_py_faster_rcnn.py', '--solver', 'vgg_solver.prototxt', '--gpu', '0', '--output_path', out_dir, '--ava_roidb_path', roidb_path, '--train_base_path', train_ins.get_trainset_base_path()+'/cache'] proc = cmd.startproc(training_cmd) logger.info("Started %s", proc) # AVA-SDK add caffe callback cmd.logproc(proc, [train_ins.get_monitor_callback("caffe")]) exit_code = proc.wait() logger.info("Finished proc with code %s", exit_code) logger.info("Gracefully shutdown after 5s, wait cleaner ...") time.sleep(5) logger.info("Done.") if exit_code != 0: logger.error( "training exit code [%d] != 0, raise Exception", exit_code) raise Exception("training exit code [%d] != 0" % (exit_code)) train_ins.done(err_msg=err_msg) except Exception as err: err_msg = "training failed, err: %s" % (err) logger.info(err_msg) traceback.print_exc(file=sys.stderr) if train_ins == None: return train_ins.done(err_msg=err_msg)
if __name__ == '__main__': # download data #(train_fname, val_fname) = download_cifar10() # parse args parser = argparse.ArgumentParser(description="train cifar10", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( '--data-train', help='training data, recdio file', type=str) parser.add_argument( '--data-val', help='validation data, recdio file', type=str) # 在一个训练任务的训练环境中,每一次训练被称为一个“训练实例” train_ins = train.TrainInstance() # 添加监控 snapshot_prefix = train_ins.get_snapshot_base_path() + "/snapshot" snapshot_interval_epochs = 1 #snapshot_interval_epochs = params.get_value( # "intervals.snapshotIntervalEpochs", default=1) # add CALLBACK batch_end_cb = train_ins.get_monitor_callback( "mxnet", batch_size=128, # args.batch_size batch_freq=10) #args.batch_end_callback = batch_end_cb # 测试 actual_batch_size = 128 * 2
def start_new_training(): # binding signals SUPPORTED_SIGNALS = ( signal.SIGINT, signal.SIGTERM, ) for signum in SUPPORTED_SIGNALS: try: signal.signal(signum, signal_handler) logger.info("Bind signal '%s' success to %s", signum, signal_handler) except Exception as identifier: logger.warning("Bind signal '%s' failed, err: %s", signum, identifier) try: # parse args parser = argparse.ArgumentParser( description="train imagenet-1k", formatter_class=argparse.ArgumentDefaultsHelpFormatter) fit.add_fit_args(parser) data.add_data_args(parser) data.add_data_aug_args(parser) # use a large aug level data.set_data_aug_level(parser, 3) parser.set_defaults( # network network='resnet', num_layers=50, # data num_classes=10, num_examples=60000, image_shape='3,28,28', min_random_scale=1, # if input image has min size k, suggest to use # 256.0/x, e.g. 0.533 for 480 # train num_epochs=80, lr_step_epochs='30,60', dtype='float32', batch_size=32) args = parser.parse_args() # AVA-SDK new an Instance train_ins = train.TrainInstance() # add CALLBACK batch_end_cb = train_ins.get_monitor_callback( "mxnet", batch_size=args.batch_size, batch_freq=10) args.batch_end_callback = batch_end_cb # load network from importlib import import_module net = import_module('symbols.' + args.network) sym = net.get_symbol(**vars(args)) # train fit.fit(args, sym, data.get_rec_iter) logger.info("training finish") err_msg = "" if train_ins == None: return train_ins.done(err_msg=err_msg) except Exception as err: err_msg = "training failed, err: %s" % (err) logger.info(err_msg) traceback.print_exc(file=sys.stderr) if train_ins == None: return train_ins.done(err_msg=err_msg)