def run(args): ms.context.set_context( mode=ms.context.GRAPH_MODE, device_target=args.device, save_graphs=False, ) net = LeNet5( num_class=10, num_channel=3, use_bn=args.use_bn, dbg_log_tensor=args.log_tensor, ) loss = ms.nn.loss.SoftmaxCrossEntropyWithLogits( sparse=True, reduction='mean', ) opt = build_optimizer(args, net) if args.mode == 'init': save_checkpoint( net, ckpt_file_name=os.path.join('seeds', '%d.ckpt' % (time.time())), ) if args.mode == 'train': ds_train = create_dataset( args=args, data_path=os.path.join(args.data_path, 'train'), batch_size=args.device_batch_size, ) if args.init_ckpt: print('using init checkpoint %s' % (args.init_ckpt)) load_ckpt(net, args.init_ckpt) train(args, net, loss, opt, ds_train) if args.mode == 'test': if args.use_kungfu: rank = kfops.kungfu_current_rank() if rank > 0: return ds_test = create_dataset( args=args, data_path=os.path.join(args.data_path, 'test'), batch_size=args.device_batch_size, ) if args.ckpt_files: checkpoints = args.ckpt_files.split(',') else: checkpoint_dir = get_ckpt_dir(args) print('checkpoint_dir: %s' % (checkpoint_dir)) checkpoints = list(sorted(glob.glob(checkpoint_dir + '/*.ckpt'))) print('will test %d checkpoints' % (len(checkpoints))) # for i, n in enumerate(checkpoints): # print('[%d]=%s' % (i, n)) test(args, net, loss, opt, ds_test, checkpoints)
def run(args): ms.context.set_context( mode=ms.context.GRAPH_MODE, device_target=args.device, save_graphs=False, ) net = LeNet5( num_class=10, num_channel=3, use_bn=args.use_bn, ) loss = ms.nn.loss.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') opt = build_optimizer(args, net) if args.mode == 'init': save_checkpoint( net, ckpt_file_name=os.path.join('seeds', '%d.ckpt' % (time.time())), ) if args.mode == 'train': ds_train = create_dataset( data_path=os.path.join(args.data_path, 'train'), batch_size=args.device_batch_size, ) if args.init_ckpt: print('using init checkpoint %s' % (args.init_ckpt)) load_ckpt(net, args.init_ckpt) train(args, net, loss, opt, ds_train) if args.mode == 'test': ds_test = create_dataset( data_path=os.path.join(args.data_path, 'test'), batch_size=args.device_batch_size, ) if args.ckpt_files: checkpoints = args.ckpt_files.split(',') else: steps = [10, 20, 30, 40] checkpoints = [get_ckpt_file_name(args, i) for i in steps] print('will test %d checkpoints' % (len(checkpoints))) # for i, n in enumerate(checkpoints): # print('[%d]=%s' % (i, n)) test(args, net, loss, opt, ds_test, checkpoints)