def get_resnet(args): resnet_path = os.path.join(curr_path, "./ResNet") if not os.path.isdir(resnet_path): os.system("git clone https://github.com/tornadomeet/ResNet") sys.path.insert(0, resnet_path) from symbol_resnet import resnet if args.depth == 18: units = [2, 2, 2, 2] elif args.depth == 34: units = [3, 4, 6, 3] elif args.depth == 50: units = [3, 4, 6, 3] elif args.depth == 101: units = [3, 4, 23, 3] elif args.depth == 152: units = [3, 8, 36, 3] elif args.depth == 200: units = [3, 24, 36, 3] else: raise ValueError("no experiments done on detph {}, you can do it youself".format(args.depth)) filter_list=[64, 256, 512, 1024, 2048] if args.depth >=50 else [64, 64, 128, 256, 512] bottle_neck = True if args.depth >= 50 else False symbol = resnet(units=units, num_stage=4, filter_list=filter_list, num_class=args.num_classes, data_type="imagenet", bottle_neck=bottle_neck, bn_mom=.9, workspace=512) return symbol
def get_resnet(args): resnet_path = os.path.join(curr_path, "./ResNet") if not os.path.isdir(resnet_path): os.system("git clone https://github.com/tornadomeet/ResNet") sys.path.insert(0, resnet_path) from symbol_resnet import resnet if args.depth == 18: units = [2, 2, 2, 2] elif args.depth == 34: units = [3, 4, 6, 3] elif args.depth == 50: units = [3, 4, 6, 3] elif args.depth == 101: units = [3, 4, 23, 3] elif args.depth == 152: units = [3, 8, 36, 3] elif args.depth == 200: units = [3, 24, 36, 3] else: raise ValueError( "no experiments done on detph {}, you can do it youself".format( args.depth)) filter_list = [64, 256, 512, 1024, 2048 ] if args.depth >= 50 else [64, 64, 128, 256, 512] bottle_neck = True if args.depth >= 50 else False symbol = resnet(units=units, num_stage=4, filter_list=filter_list, num_class=args.num_classes, data_type="imagenet", bottle_neck=bottle_neck, bn_mom=.9, workspace=512) return symbol
def get_model(cfg): labels = [] # for i in range(cfg.MAX_LABELS + 1): # label = mx.sym.var('softmax{}_label'.format(i)) # labels.append(label) for i in range(cfg.MAX_LABELS): label = mx.sym.var('softmax{}_label'.format(i)) labels.append(label) depth = 28 k = 8 per_unit = [(depth - 4) // 6] filter_list = [16, 16 * k, 32 * k, 64 * k] bottle_neck = False units = per_unit * 3 net = resnet(units=units, num_stage=3, filter_list=filter_list, bottle_neck=bottle_neck) top = get_output(net, labels, cfg) return top
def get_predict_net(self): data = mx.symbol.Variable("data") network = resnet(data=data, units=[2, 2, 2, 2], num_stage=4, filter_list=[64, 64, 128, 256, 512], num_class=128, data_type="imagenet", bottle_neck=False, bn_mom=0.9, workspace=512) return network
def get_network(batch_size): anchor = mx.symbol.Variable("anchor") positive = mx.symbol.Variable("positive") negative = mx.symbol.Variable("negative") concat = mx.symbol.Concat(*[anchor, positive, negative], dim=0, name="concat") share_net = resnet(data=concat, units=[2, 2, 2, 2], num_stage=4, filter_list=[64, 64, 128, 256, 512], num_class=128, data_type="imagenet", bottle_neck=False, bn_mom=0.9, workspace=512) one = mx.symbol.Variable("one") one = mx.symbol.Reshape(data=one, shape=(-1, 1)) fa = mx.symbol.slice_axis(share_net, axis=0, begin=0, end=batch_size) fp = mx.symbol.slice_axis(share_net, axis=0, begin=batch_size, end=2 * batch_size) fn = mx.symbol.slice_axis(share_net, axis=0, begin=2 * batch_size, end=3 * batch_size) fs = fa - fp fd = fa - fn fs = fs * fs fd = fd * fd fs = mx.symbol.sum(fs, axis=1, keepdims=1) fd = mx.symbol.sum(fd, axis=1, keepdims=1) loss = fd - fs loss = one - loss loss = mx.symbol.Activation(data=loss, act_type='relu') return mx.symbol.MakeLoss(loss)
def main(): if args.data_type == "cifar10": args.aug_level = 1 args.num_classes = 10 # depth should be one of 110, 164, 1001,...,which is should fit (args.depth-2)%9 == 0 if ((args.depth - 2) % 9 == 0 and args.depth >= 164): per_unit = [(args.depth - 2) / 9] filter_list = [16, 64, 128, 256] bottle_neck = True elif ((args.depth - 2) % 6 == 0 and args.depth < 164): per_unit = [(args.depth - 2) / 6] filter_list = [16, 16, 32, 64] bottle_neck = False else: raise ValueError( "no experiments done on detph {}, you can do it youself". format(args.depth)) units = per_unit * 3 symbol = resnet(units=units, num_stage=3, filter_list=filter_list, num_class=args.num_classes, data_type="cifar10", bottle_neck=bottle_neck, bn_mom=args.bn_mom, workspace=args.workspace, memonger=args.memonger) elif args.data_type == "imagenet": args.num_classes = 1000 if args.depth == 18: units = [2, 2, 2, 2] elif args.depth == 34: units = [3, 4, 6, 3] elif args.depth == 50: units = [3, 4, 6, 3] elif args.depth == 101: units = [3, 4, 23, 3] elif args.depth == 152: units = [3, 8, 36, 3] elif args.depth == 200: units = [3, 24, 36, 3] elif args.depth == 269: units = [3, 30, 48, 8] else: raise ValueError( "no experiments done on detph {}, you can do it youself". format(args.depth)) symbol = resnet(units=units, num_stage=4, filter_list=[64, 256, 512, 1024, 2048] if args.depth >= 50 else [64, 64, 128, 256, 512], num_class=args.num_classes, data_type="imagenet", bottle_neck=True if args.depth >= 50 else False, bn_mom=args.bn_mom, workspace=args.workspace, memonger=args.memonger) else: raise ValueError("do not support {} yet".format(args.data_type)) kv = mx.kvstore.create(args.kv_store) devs = mx.cpu() if args.gpus is None else [ mx.gpu(int(i)) for i in args.gpus.split(',') ] epoch_size = max(int(args.num_examples / args.batch_size / kv.num_workers), 1) begin_epoch = args.model_load_epoch if args.model_load_epoch else 0 if not os.path.exists("./model"): os.mkdir("./model") model_prefix = "model/resnet-{}-{}-{}".format(args.data_type, args.depth, kv.rank) checkpoint = mx.callback.do_checkpoint(model_prefix) arg_params = None aux_params = None if args.retrain: _, arg_params, aux_params = mx.model.load_checkpoint( model_prefix, args.model_load_epoch) if args.memonger: import memonger symbol = memonger.search_plan( symbol, data=(args.batch_size, 3, 32, 32) if args.data_type == "cifar10" else (args.batch_size, 3, 224, 224)) train = mx.io.ImageRecordIter( path_imgrec=os.path.join(args.data_dir, "cifar10_train.rec") if args.data_type == 'cifar10' else os.path.join(args.data_dir, "train_256_q90.rec") if args.aug_level == 1 else os.path.join(args.data_dir, "train_480_q90.rec"), label_width=1, data_name='data', label_name='softmax_label', data_shape=(3, 32, 32) if args.data_type == "cifar10" else (3, 224, 224), batch_size=args.batch_size, pad=4 if args.data_type == "cifar10" else 0, fill_value=127, # only used when pad is valid rand_crop=True, max_random_scale=1.0, # 480 with imagnet, 32 with cifar10 min_random_scale=1.0 if args.data_type == "cifar10" else 1.0 if args.aug_level == 1 else 0.533, # 256.0/480.0 max_aspect_ratio=0 if args.data_type == "cifar10" else 0 if args.aug_level == 1 else 0.25, random_h=0 if args.data_type == "cifar10" else 0 if args.aug_level == 1 else 36, # 0.4*90 random_s=0 if args.data_type == "cifar10" else 0 if args.aug_level == 1 else 50, # 0.4*127 random_l=0 if args.data_type == "cifar10" else 0 if args.aug_level == 1 else 50, # 0.4*127 max_rotate_angle=0 if args.aug_level <= 2 else 10, max_shear_ratio=0 if args.aug_level <= 2 else 0.1, rand_mirror=True, shuffle=True, num_parts=kv.num_workers, part_index=kv.rank) val = mx.io.ImageRecordIter( path_imgrec=os.path.join(args.data_dir, "cifar10_val.rec") if args.data_type == 'cifar10' else os.path.join( args.data_dir, "val_256_q90.rec"), label_width=1, data_name='data', label_name='softmax_label', batch_size=args.batch_size, data_shape=(3, 32, 32) if args.data_type == "cifar10" else (3, 224, 224), rand_crop=False, rand_mirror=False, num_parts=kv.num_workers, part_index=kv.rank) lr_scheduler = multi_factor_scheduler(begin_epoch, epoch_size, step=[120, 160], factor=0.1) \ if args.data_type=='cifar10' else \ multi_factor_scheduler(begin_epoch, epoch_size, step=[30, 60, 90], factor=0.1) mod = mx.mod.Module(symbol=symbol, context=devs) optimizer_params = { 'learning_rate': args.lr, 'wd': args.wd, 'lr_scheduler': lr_scheduler, 'momentum ': args.mom } mod.fit( train, eval_data=val, eval_metric=['acc', 'ce'] if args.data_type == 'cifar10' else ['acc', mx.metric.create('top_k_accuracy', top_k=5)], arg_params=arg_params, aux_params=aux_params, num_epoch=200 if args.data_type == "cifar10" else 120, begin_epoch=begin_epoch, optimizer='nag', # optimizer = 'sgd', initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2), kvstore=kv, batch_end_callback=mx.callback.Speedometer(args.batch_size, args.frequent), epoch_end_callback=checkpoint)
def main(): if args.data_type == "cifar10": # depth should be one of 110, 164, 1001,...,which is should fit (args.depth-2)%9 == 0 if ((args.depth - 2) % 9 == 0): per_unit = [(args.depth - 2) / 9] units = per_unit * 3 symbol = resnet(units=units, num_stage=3, filter_list=[16, 64, 128, 256], num_class=10, data_type="cifar10", bottle_neck=True if args.depth >= 164 else False, bn_mom=args.bn_mom, workspace=512) else: raise ValueError( "no experiments done on detph {}, you can do it youself". format(args.depth)) elif args.data_type == "imagenet": if args.depth == 18: units = [2, 2, 2, 2] elif args.depth == 34: units = [3, 4, 6, 3] elif args.depth == 50: units = [3, 4, 6, 3] elif args.depth == 101: units = [3, 4, 23, 3] elif args.depth == 152: units = [3, 8, 36, 3] elif args.depth == 200: units = [3, 24, 36, 3] else: raise ValueError( "no experiments done on detph {}, you can do it youself". format(args.depth)) symbol = resnet(units=units, num_stage=4, filter_list=[64, 256, 512, 1024, 2048] if args.depth >= 50 else [64, 64, 128, 256, 512], num_class=1000, data_type="imagenet", bottle_neck=True if args.depth >= 50 else False, bn_mom=args.bn_mom, workspace=512) else: raise ValueError("do not support {} yet".format(args.data_type)) devs = mx.cpu() if args.gpus is None else [ mx.gpu(int(i)) for i in args.gpus.split(',') ] epoch_size = max(int(args.num_examples / args.batch_size), 1) if not os.path.exists("./model"): os.mkdir("./model") checkpoint = mx.callback.do_checkpoint("model/resnet-{}-{}".format( args.data_type, args.depth)) kv = mx.kvstore.create(args.kv_store) arg_params = None aux_params = None if args.retrain: _, arg_params, aux_params = mx.model.load_checkpoint( "model/resnet-{}-{}".format(args.data_type, args.depth), args.model_load_epoch) train = mx.io.ImageRecordIter( path_imgrec=os.path.join(args.data_dir, "train_480_q90.rec"), # path_imgrec = os.path.join(args.data_dir, "train_256_q90.rec"), label_width=1, data_name='data', label_name='softmax_label', data_shape=(3, 32, 32) if args.data_type == "cifar10" else (3, 224, 224), batch_size=args.batch_size, pad=4 if args.data_type == "cifar10" else 0, fill_value=127, # only used when pad is valid rand_crop=True, max_random_scale=1.0 if args.data_type == "cifar10" else 1.0, # 480 min_random_scale=1.0 if args.data_type == "cifar10" else 0.533, # 256.0/480.0 max_aspect_ratio=0 if args.data_type == "cifar10" else 0.25, random_h=0 if args.data_type == "cifar10" else 36, # 0.4*90 random_s=0 if args.data_type == "cifar10" else 50, # 0.4*127 random_l=0 if args.data_type == "cifar10" else 50, # 0.4*127 rand_mirror=True, shuffle=True, num_parts=kv.num_workers, part_index=kv.rank) val = mx.io.ImageRecordIter( path_imgrec=os.path.join(args.data_dir, "val_256_q90.rec"), label_width=1, data_name='data', label_name='softmax_label', batch_size=args.batch_size, data_shape=(3, 32, 32) if args.data_type == "cifar10" else (3, 224, 224), rand_crop=False, rand_mirror=False, num_parts=kv.num_workers, part_index=kv.rank) model = mx.model.FeedForward( ctx=devs, symbol=symbol, arg_params=arg_params, aux_params=aux_params, num_epoch=200 if args.data_type == "cifar10" else 120, begin_epoch=args.model_load_epoch if args.model_load_epoch else 0, learning_rate=args.lr, momentum=args.mom, wd=args.wd, optimizer='nag', # optimizer = 'sgd', initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2), lr_scheduler=mx.lr_scheduler.MultiFactorScheduler( step=[120 * epoch_size, 160 * epoch_size], factor=0.1) if args.data_type == 'cifar10' else mx.lr_scheduler.MultiFactorScheduler( step=[30 * epoch_size, 60 * epoch_size, 90 * epoch_size], factor=0.1), ) model.fit(X=train, eval_data=val, eval_metric=['acc'] if args.data_type == 'cifar10' else ['acc', mx.metric.create('top_k_accuracy', top_k=5)], kvstore=kv, batch_end_callback=mx.callback.Speedometer(args.batch_size, 50), epoch_end_callback=checkpoint)
def main(): if args.data_type == "cifar10": args.aug_level = 1 args.num_classes = 10 # depth should be one of 110, 164, 1001,...,which is should fit (args.depth-2)%9 == 0 if ((args.depth - 2) % 9 == 0 and args.depth >= 164): per_unit = [(args.depth - 2) / 9] filter_list = [16, 64, 128, 256] bottle_neck = True elif ((args.depth - 2) % 6 == 0 and args.depth < 164): per_unit = [(args.depth - 2) / 6] filter_list = [16, 16, 32, 64] bottle_neck = False else: raise ValueError( "no experiments done on detph {}, you can do it youself". format(args.depth)) units = per_unit * 3 symbol = resnet(units=units, num_stage=3, filter_list=filter_list, num_class=args.num_classes, data_type="cifar10", bottle_neck=bottle_neck, bn_mom=args.bn_mom, workspace=args.workspace, memonger=args.memonger) elif args.data_type == "imagenet": args.num_classes = 1000 if args.depth == 18: units = [2, 2, 2, 2] elif args.depth == 34: units = [3, 4, 6, 3] elif args.depth == 50: units = [3, 4, 6, 3] elif args.depth == 101: units = [3, 4, 23, 3] elif args.depth == 152: units = [3, 8, 36, 3] elif args.depth == 200: units = [3, 24, 36, 3] elif args.depth == 269: units = [3, 30, 48, 8] else: raise ValueError( "no experiments done on detph {}, you can do it youself". format(args.depth)) symbol = resnet(units=units, num_stage=4, filter_list=[64, 256, 512, 1024, 2048] if args.depth >= 50 else [64, 64, 128, 256, 512], num_class=args.num_classes, data_type="imagenet", bottle_neck=True if args.depth >= 50 else False, bn_mom=args.bn_mom, workspace=args.workspace, memonger=args.memonger) else: raise ValueError("do not support {} yet".format(args.data_type)) kv = mx.kvstore.create(args.kv_store) devs = mx.cpu() if args.gpus is None else [ mx.gpu(int(i)) for i in args.gpus.split(',') ] # logging head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s' if 'log_file' in args and args.log_file is not None: log_file = args.log_file log_dir = args.log_dir log_file_full_name = os.path.join(log_dir, log_file) if not os.path.exists(log_dir): os.mkdir(log_dir) logger = logging.getLogger() handler = logging.FileHandler(log_file_full_name) formatter = logging.Formatter(head) handler.setFormatter(formatter) logger.addHandler(handler) logger.setLevel(logging.DEBUG) logger.info('start with arguments %s', args) else: logging.basicConfig(level=logging.DEBUG, format=head) logging.info('start with arguments %s', args) kv_store_type = "" if args.kv_store == "dist_sync": kv_store_type = "bsp" elif args.kv_store == "dist_async": kv_store_type = "asp" elif args.kv_store == "dist_gsync": kv_store_type = "gsp" elif args.kv_store == "dist_ssync": kv_store_type = "ssp" begin_epoch = args.model_load_epoch if args.model_load_epoch else 0 user = getpass.getuser() if not os.path.exists("/home/{}/mxnet_model/model/{}/resnet{}/{}".format( user, args.data_type, args.depth, kv_store_type)): os.makedirs("/home/{}/mxnet_model/model/{}/resnet{}/{}".format( user, args.data_type, args.depth, kv_store_type)) model_prefix = "/home/{}/mxnet_model/model/{}/resnet{}/{}/{}-{}-resnet{}-{}".format( user, args.data_type, args.depth, kv_store_type, kv_store_type, args.data_type, args.depth, kv.rank) checkpoint = None if not args.savemodel else mx.callback.do_checkpoint( model_prefix) arg_params = None aux_params = None if args.retrain: _, arg_params, aux_params = mx.model.load_checkpoint( model_prefix, args.model_load_epoch) if args.memonger: import memonger symbol = memonger.search_plan( symbol, data=(args.batch_size, 3, 32, 32) if args.data_type == "cifar10" else (args.batch_size, 3, 224, 224)) splits = 1 part = 0 val_splits = kv.num_workers val_part = kv.rank '''yegeyan 2016.10.6''' if args.kv_store == "dist_sync" or args.kv_store == "dist_async" or args.kv_store == "dist_ssync": #if args.kv_store == "dist_sync": splits = kv.num_workers part = kv.rank if args.kv_store == "dist_gsync": if args.data_allocator == 1: if args.hostname == "gpu-cluster-1": part = args.cluster1_begin splits = args.cluster1_end elif args.hostname == "gpu-cluster-2": part = args.cluster2_begin splits = args.cluster2_end elif args.hostname == "gpu-cluster-3": part = args.cluster3_begin splits = args.cluster3_end elif args.hostname == "gpu-cluster-4": part = args.cluster4_begin splits = args.cluster4_end else: part = args.cluster5_begin splits = args.cluster5_end args.data_proportion = splits - part else: splits = kv.num_workers part = kv.rank # yegeyan 2017.1.15 epoch_size = args.num_examples / args.batch_size model_args = {} if args.kv_store == 'dist_sync' or args.kv_store == 'dist_async' or args.kv_store == 'dist_ssync': #if args.kv_store == 'dist_sync': epoch_size /= kv.num_workers model_args['epoch_size'] = epoch_size '''yegeyan 2016.12.13''' if args.kv_store == 'dist_gsync': if args.data_allocator == 1: epoch_size *= args.data_proportion model_args['epoch_size'] = epoch_size else: epoch_size /= kv.num_workers model_args['epoch_size'] = epoch_size if 'lr_factor' in args and args.lr_factor < 1: model_args['lr_scheduler'] = mx.lr_scheduler.FactorScheduler( step=max(int(batch_num * args.lr_factor_epoch), 1), # yegeyan 2016.12.13 factor=args.lr_factor) if 'clip_gradient' in args and args.clip_gradient is not None: model_args['clip_gradient'] = args.clip_gradient eval_metrics = ['accuracy'] ## TopKAccuracy only allows top_k > 1 for top_k in [5, 10, 20]: eval_metrics.append(mx.metric.create('top_k_accuracy', top_k=top_k)) # yegeyan 2017.1.4 val_eval_metrics = ['accuracy'] ## TopKAccuracy only allows top_k > 1 for top_k in [5, 10, 20]: val_eval_metrics.append(mx.metric.create('top_k_accuracy', top_k=top_k)) train = mx.io.ImageRecordIter( path_imgrec=os.path.join(args.data_dir, "train.rec") if args.data_type == 'cifar10' else os.path.join(args.data_dir, "train_480.rec") if args.aug_level == 1 else os.path.join(args.data_dir, "train_480.rec"), label_width=1, data_name='data', label_name='softmax_label', data_shape=(3, 32, 32) if args.data_type == "cifar10" else (3, 224, 224), batch_size=args.batch_size, pad=4 if args.data_type == "cifar10" else 0, fill_value=127, # only used when pad is valid rand_crop=True, max_random_scale=1.0, # 480 with imagnet, 32 with cifar10 min_random_scale=1.0 if args.data_type == "cifar10" else 1.0 if args.aug_level == 1 else 0.533, # 256.0/480.0 max_aspect_ratio=0 if args.data_type == "cifar10" else 0 if args.aug_level == 1 else 0.25, random_h=0 if args.data_type == "cifar10" else 0 if args.aug_level == 1 else 36, # 0.4*90 random_s=0 if args.data_type == "cifar10" else 0 if args.aug_level == 1 else 50, # 0.4*127 random_l=0 if args.data_type == "cifar10" else 0 if args.aug_level == 1 else 50, # 0.4*127 max_rotate_angle=0 if args.aug_level <= 2 else 10, max_shear_ratio=0 if args.aug_level <= 2 else 0.1, rand_mirror=True, shuffle=True, preprocess_threads=4, num_parts=splits, part_index=part) val = mx.io.ImageRecordIter( path_imgrec=os.path.join(args.data_dir, "test.rec") if args.data_type == 'cifar10' else os.path.join(args.data_dir, "val_480.rec"), label_width=1, data_name='data', label_name='softmax_label', batch_size=args.batch_size, data_shape=(3, 32, 32) if args.data_type == "cifar10" else (3, 224, 224), rand_crop=False, rand_mirror=False, preprocess_threads=4, num_parts=val_splits, part_index=val_part) model = mx.model.FeedForward( ctx=devs, symbol=symbol, arg_params=arg_params, aux_params=aux_params, num_epoch=args.num_epochs, begin_epoch=begin_epoch, learning_rate=args.lr, momentum=args.mom, wd=args.wd, #optimizer = 'nag', optimizer='sgd', initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2), lr_scheduler=multi_factor_scheduler( begin_epoch, epoch_size, step=[220, 260, 280], factor=0.1) if args.data_type == 'cifar10' else multi_factor_scheduler( begin_epoch, epoch_size, step=[30, 60, 90], factor=0.1), **model_args) model.fit(X=train, eval_data=val, eval_metric=eval_metrics, val_eval_metric=val_eval_metrics, kvstore=kv, batch_end_callback=mx.callback.Speedometer(args.batch_size, 50), epoch_end_callback=checkpoint, hostname=socket.gethostbyname_ex(socket.gethostname())[0], dataset=args.data_type, staleness=args.staleness, network_name="resnet_" + str(args.depth), lr=args.lr) #yegeyan 2017.5.15
def main(): if (args.depth-2) % 9 == 0: # and args.depth >= 164: per_unit = [(args.depth-2) / 9] filter_list = [16, 64, 128, 256] bottle_neck = True # elif (args.depth-2) % 6 == 0 and args.depth < 164: # per_unit = [(args.depth-2) / 6] # filter_list = [16, 16, 32, 64] # bottle_neck = False else: raise ValueError( "no experiments done on detph {}, you can do it youself".format(args.depth)) units = per_unit*3 symbol = resnet(units=units, num_stage=3, filter_list=filter_list, num_class=args.num_classes, bottle_neck=bottle_neck, bn_mom=args.bn_mom, workspace=args.workspace, memonger=args.memonger) kv = mx.kvstore.create(args.kv_store) devs = mx.cpu() if args.gpus is None else [ mx.gpu(int(i)) for i in args.gpus.split(',')] epoch_size = max( int(args.num_examples / args.batch_size / kv.num_workers), 1) begin_epoch = args.model_load_epoch if args.model_load_epoch else 0 if not os.path.exists("./model"): os.mkdir("./model") model_prefix = "model/resnet-{}-{}-{}".format( data_type, args.depth, kv.rank) checkpoint = mx.callback.do_checkpoint(model_prefix) arg_params = None aux_params = None if args.retrain: _, arg_params, aux_params = mx.model.load_checkpoint( model_prefix, args.model_load_epoch) if args.memonger: import memonger symbol = memonger.search_plan( symbol, data=(args.batch_size, 3, 32, 32)) train = mx.io.ImageRecordIter( path_imgrec = os.path.join(args.data_dir, "cifar10_train.rec"), label_width = 1, data_shape = (3, 32, 32), num_parts = kv.num_workers, part_index = kv.rank, shuffle = True, batch_size = args.batch_size, rand_crop = True, fill_value = 127, # only used when pad is valid pad = 4, rand_mirror = True, ) val = mx.io.ImageRecordIter( path_imgrec = os.path.join(args.data_dir, "cifar10_val.rec"), label_width = 1, data_shape = (3, 32, 32), num_parts = kv.num_workers, part_index = kv.rank, batch_size = args.batch_size, ) model = mx.mod.Module( symbol = symbol, data_names = ('data', ), label_names = ('softmax_label', ), context = devs, ) model.fit( train_data = train, eval_data = val, eval_metric = ['acc'], epoch_end_callback = checkpoint, batch_end_callback = mx.callback.Speedometer(args.batch_size, args.frequent), kvstore = kv, optimizer = 'nag', optimizer_params = (('learning_rate', args.lr), ('momentum', args.mom), ('wd', args.wd), ( 'lr_scheduler', multi_factor_scheduler(begin_epoch, epoch_size, step=[80], factor=0.1))), initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2), arg_params = arg_params, aux_params = aux_params, allow_missing = True, begin_epoch = begin_epoch, num_epoch = args.end_epoch, )