def test_shift_gpu(self): model = self.create_model() data_parallel_model_utils.ShiftActivationDevices( model, activations=["fc4", "fc5"], shifts={ 0: 4, 1: 4, 2: 5, 3: 5 }, ) for op in model.param_init_net.Proto().op: for outp in op.output: prefix = outp.split("/")[0] if outp.split("/")[-1] in set( ['fc4_w', 'fc5_w', 'fc4_b', 'fc5_b']): if prefix == 'gpu_0' or prefix == 'gpu_1': self.assertEqual(op.device_option.cuda_gpu_id, 4) else: self.assertEqual(op.device_option.cuda_gpu_id, 5) if outp.split("/")[-1] in set( ['fc1_w', 'fc2_w', 'fc3_b', 'fc3_w']): gpu_id = int(prefix.split("_")[-1]) self.assertEqual(gpu_id, op.device_option.cuda_gpu_id) # Test that we can run the net if workspace.NumCudaDevices() >= 6: workspace.RunNetOnce(model.param_init_net) workspace.CreateNet(model.net) workspace.RunNet(model.net.Proto().name)
def Train(args): # Either use specified device list or generate one if args.gpus is not None: gpus = [int(x) for x in args.gpus.split(',')] num_gpus = len(gpus) else: gpus = list(range(args.num_gpus)) num_gpus = args.num_gpus log.info("Running on GPUs: {}".format(gpus)) # Verify valid batch size total_batch_size = args.batch_size batch_per_device = total_batch_size // num_gpus assert \ total_batch_size % num_gpus == 0, \ "Number of GPUs must divide batch size" # Round down epoch size to closest multiple of batch size across machines global_batch_size = total_batch_size * args.num_shards epoch_iters = int(args.epoch_size / global_batch_size) assert \ epoch_iters > 0, \ "Epoch size must be larger than batch size times shard count" args.epoch_size = epoch_iters * global_batch_size log.info("Using epoch size: {}".format(args.epoch_size)) # Create ModelHelper object train_arg_scope = { 'order': 'NCHW', 'use_cudnn': True, 'cudnn_exhaustive_search': True, 'ws_nbytes_limit': (args.cudnn_workspace_limit_mb * 1024 * 1024), } train_model = model_helper.ModelHelper( name="ban-pc-resnet50", arg_scope=train_arg_scope ) num_shards = args.num_shards shard_id = args.shard_id # Expect interfaces to be comma separated. # Use of multiple network interfaces is not yet complete, # so simply use the first one in the list. interfaces = args.distributed_interfaces.split(",") # Rendezvous using MPI when run with mpirun if os.getenv("OMPI_COMM_WORLD_SIZE") is not None: num_shards = int(os.getenv("OMPI_COMM_WORLD_SIZE", 1)) shard_id = int(os.getenv("OMPI_COMM_WORLD_RANK", 0)) if num_shards > 1: rendezvous = dict( kv_handler=None, num_shards=num_shards, shard_id=shard_id, engine="GLOO", transport=args.distributed_transport, interface=interfaces[0], mpi_rendezvous=True, exit_nets=None) elif num_shards > 1: # Create rendezvous for distributed computation store_handler = "store_handler" if args.redis_host is not None: # Use Redis for rendezvous if Redis host is specified workspace.RunOperatorOnce( core.CreateOperator( "RedisStoreHandlerCreate", [], [store_handler], host=args.redis_host, port=args.redis_port, prefix=args.run_id, ) ) else: # Use filesystem for rendezvous otherwise workspace.RunOperatorOnce( core.CreateOperator( "FileStoreHandlerCreate", [], [store_handler], path=args.file_store_path, prefix=args.run_id, ) ) rendezvous = dict( kv_handler=store_handler, shard_id=shard_id, num_shards=num_shards, engine="GLOO", transport=args.distributed_transport, interface=interfaces[0], exit_nets=None) else: rendezvous = None # Model configs for constructing model with open(args.model_config) as f: model_config = yaml.load(f) # Model building functions def create_target_model_ops(model, loss_scale): initializer = (PseudoFP16Initializer if args.dtype == 'float16' else Initializer) with brew.arg_scope([brew.conv, brew.fc], WeightInitializer=initializer, BiasInitializer=initializer, enable_tensor_core=args.enable_tensor_core, float16_compute=args.float16_compute): pred = add_se_model(model, model_config, "data", is_test=False) if args.dtype == 'float16': pred = model.net.HalfToFloat(pred, pred + '_fp32') loss = add_softmax_loss(model, pred, 'label') brew.accuracy(model, ['softmax', 'label'], 'accuracy') return [loss] def add_optimizer(model): ''' stepsz = int(30 * args.epoch_size / total_batch_size / num_shards) optimizer.add_weight_decay(model, args.weight_decay) opt = optimizer.build_multi_precision_sgd( model, args.base_learning_rate, momentum=0.9, nesterov=1, policy="step", stepsize=stepsz, gamma=0.1 ) ''' optimizer.add_weight_decay(model, args.weight_decay) opt = optimizer.build_multi_precision_sgd( model, base_learning_rate = args.base_learning_rate, momentum = model_config['solver']['momentum'], nesterov = model_config['solver']['nesterov'], policy = model_config['solver']['lr_policy'], power = model_config['solver']['power'], max_iter = model_config['solver']['max_iter'], ) return opt # Define add_image_input function. # Depends on the "train_data" argument. # Note that the reader will be shared with between all GPUS. reader = train_model.CreateDB( "reader", db=args.train_data, db_type=args.db_type, num_shards=num_shards, shard_id=shard_id, ) def add_image_input(model): AddImageInput( model, reader, batch_size=batch_per_device, img_size=args.image_size, dtype=args.dtype, is_test=False, ) def add_post_sync_ops(model): """Add ops applied after initial parameter sync.""" for param_info in model.GetOptimizationParamInfo(model.GetParams()): if param_info.blob_copy is not None: model.param_init_net.HalfToFloat( param_info.blob, param_info.blob_copy[core.DataType.FLOAT] ) # Create parallelized model data_parallel_model.Parallelize( train_model, input_builder_fun=add_image_input, forward_pass_builder_fun=create_target_model_ops, optimizer_builder_fun=add_optimizer, post_sync_builder_fun=add_post_sync_ops, devices=gpus, rendezvous=rendezvous, optimize_gradient_memory=False, cpu_device=args.use_cpu, shared_model=args.use_cpu, combine_spatial_bn=args.use_cpu, ) if args.model_parallel: # Shift half of the activations to another GPU assert workspace.NumCudaDevices() >= 2 * args.num_gpus activations = data_parallel_model_utils.GetActivationBlobs(train_model) data_parallel_model_utils.ShiftActivationDevices( train_model, activations=activations[len(activations) // 2:], shifts={g: args.num_gpus + g for g in range(args.num_gpus)}, ) data_parallel_model.OptimizeGradientMemory(train_model, {}, set(), False) workspace.RunNetOnce(train_model.param_init_net) workspace.CreateNet(train_model.net) # Add test model, if specified test_model = None if (args.test_data is not None): log.info("----- Create test net ----") test_arg_scope = { 'order': "NCHW", 'use_cudnn': True, 'cudnn_exhaustive_search': True, } test_model = model_helper.ModelHelper( name="ban-pc-resnet50_test", arg_scope=test_arg_scope, init_params=False ) test_reader = test_model.CreateDB( "test_reader", db=args.test_data, db_type=args.db_type, ) def test_input_fn(model): AddImageInput( model, test_reader, batch_size=batch_per_device, img_size=args.image_size, dtype=args.dtype, is_test=True, ) data_parallel_model.Parallelize( test_model, input_builder_fun=test_input_fn, forward_pass_builder_fun=create_target_model_ops, post_sync_builder_fun=add_post_sync_ops, param_update_builder_fun=None, devices=gpus, cpu_device=args.use_cpu, ) workspace.RunNetOnce(test_model.param_init_net) workspace.CreateNet(test_model.net) epoch = 0 # load the pre-trained model and reset epoch if args.load_model_path is not None: LoadModel(args.load_model_path, train_model) # Sync the model params data_parallel_model.FinalizeAfterCheckpoint(train_model) # reset epoch. load_model_path should end with *_X.mdl, # where X is the epoch number last_str = args.load_model_path.split('_')[-1] if last_str.endswith('.mdl'): epoch = int(last_str[:-4]) log.info("Reset epoch to {}".format(epoch)) else: log.warning("The format of load_model_path doesn't match!") expname = "log/{}/resnet50_gpu{}_b{}_L{}_lr{:.2f}_v2".format( args.dataset_name, args.num_gpus, total_batch_size, args.num_labels, args.base_learning_rate, ) explog = experiment_util.ModelTrainerLog(expname, args) # Load pretrained param_init_net load_init_net_multigpu(args) # Run the training one epoch a time best_accuracy = 0 while epoch < args.num_epochs: epoch, best_accuracy = RunEpoch( args, epoch, train_model, test_model, total_batch_size, num_shards, expname, explog, best_accuracy, ) # Save the model for each epoch SaveModel(args, train_model, epoch) model_path = "%s/%s_" % ( args.file_store_path, args.save_model_name ) # remove the saved model from the previous epoch if it exists if os.path.isfile(model_path + str(epoch - 1) + ".mdl"): os.remove(model_path + str(epoch - 1) + ".mdl")
def Train(args): # Either use specified device list or generate one if args.gpus is not None: gpus = [int(x) for x in args.gpus.split(',')] num_gpus = len(gpus) else: gpus = list(range(args.num_gpus)) num_gpus = args.num_gpus log.info("Running on GPUs: {}".format(gpus)) # Verify valid batch size total_batch_size = args.batch_size batch_per_device = total_batch_size // num_gpus global_batch_size = total_batch_size * args.num_shards epoch_iters = int(args.epoch_size / global_batch_size) args.epoch_size = epoch_iters * global_batch_size log.info("Using epoch size: {}".format(args.epoch_size)) train_arg_scope = { 'order': 'NCHW', 'use_cudnn': True, 'cudnn_exhaustive_search': True, 'ws_nbytes_limit': (args.cudnn_workspace_limit_mb * 1024 * 1024), } train_model = model_helper.ModelHelper(name="resnet101", arg_scope=train_arg_scope) num_shards = args.num_shards shard_id = args.shard_id interfaces = args.distributed_interfaces.split(",") if os.getenv("OMPI_COMM_WORLD_SIZE") is not None: num_shards = int(os.getenv("OMPI_COMM_WORLD_SIZE", 1)) shard_id = int(os.getenv("OMPI_COMM_WORLD_RANK", 0)) if num_shards > 1: rendezvous = dict(kv_handler=None, num_shards=num_shards, shard_id=shard_id, engine="GLOO", transport=args.distributed_transport, interface=interfaces[0], mpi_rendezvous=True, exit_nets=None) elif num_shards > 1: store_handler = "store_handler" if args.redis_host is not None: workspace.RunOperatorOnce( core.CreateOperator( "RedisStoreHandlerCreate", [], [store_handler], host=args.redis_host, port=args.redis_port, prefix=args.run_id, )) else: workspace.RunOperatorOnce( core.CreateOperator( "FileStoreHandlerCreate", [], [store_handler], path=args.file_store_path, prefix=args.run_id, )) rendezvous = dict(kv_handler=store_handler, shard_id=shard_id, num_shards=num_shards, engine="GLOO", transport=args.distributed_transport, interface=interfaces[0], exit_nets=None) else: rendezvous = None def create_resnet101_model_ops(model, loss_scale): initializer = (pFP16Initializer if args.dtype == 'float16' else Initializer) with brew.arg_scope([brew.conv, brew.fc], WeightInitializer=initializer, BiasInitializer=initializer, enable_tensor_core=args.enable_tensor_core, float16_compute=args.float16_compute): pred = resnet.create_resnet101( model, "data", num_input_channels=args.num_channels, num_labels=args.num_labels, no_bias=True, no_loss=True, ) if args.dtype == 'float16': pred = model.net.HalfToFloat(pred, pred + '_fp32') softmax, loss = model.SoftmaxWithLoss([pred, 'label'], ['softmax', 'loss']) loss = model.Scale(loss, scale=loss_scale) brew.accuracy(model, [softmax, "label"], "accuracy") return [loss] def add_optimizer(model): stepsz = int(30 * args.epoch_size / total_batch_size / num_shards) if args.float16_compute: opt = optimizer.build_fp16_sgd(model, args.base_learning_rate, momentum=0.9, nesterov=1, weight_decay=args.weight_decay, policy="step", stepsize=stepsz, gamma=0.1) else: optimizer.add_weight_decay(model, args.weight_decay) opt = optimizer.build_multi_precision_sgd(model, args.base_learning_rate, momentum=0.9, nesterov=1, policy="step", stepsize=stepsz, gamma=0.1) return opt if args.train_data == "null": def add_image_input(model): AddNullInput( model, None, batch_size=batch_per_device, img_size=args.image_size, dtype=args.dtype, ) else: reader = train_model.CreateDB( "reader", db=args.train_data, db_type=args.db_type, num_shards=num_shards, shard_id=shard_id, ) def add_image_input(model): AddImageInput( model, reader, batch_size=batch_per_device, img_size=args.image_size, dtype=args.dtype, is_test=False, ) def add_post_sync_ops(model): for param_info in model.GetOptimizationParamInfo(model.GetParams()): if param_info.blob_copy is not None: model.param_init_net.HalfToFloat( param_info.blob, param_info.blob_copy[core.DataType.FLOAT]) data_parallel_model.Parallelize( train_model, input_builder_fun=add_image_input, forward_pass_builder_fun=create_resnet101_model_ops, optimizer_builder_fun=add_optimizer, post_sync_builder_fun=add_post_sync_ops, devices=gpus, rendezvous=rendezvous, optimize_gradient_memory=False, cpu_device=args.use_cpu, shared_model=args.use_cpu, ) if args.model_parallel: activations = data_parallel_model_utils.GetActivationBlobs(train_model) data_parallel_model_utils.ShiftActivationDevices( train_model, activations=activations[len(activations) // 2:], shifts={g: args.num_gpus + g for g in range(args.num_gpus)}, ) data_parallel_model.OptimizeGradientMemory(train_model, {}, set(), False) workspace.RunNetOnce(train_model.param_init_net) workspace.CreateNet(train_model.net) test_model = None if (args.test_data is not None): log.info("----- Create test net ----") test_arg_scope = { 'order': "NCHW", 'use_cudnn': True, 'cudnn_exhaustive_search': True, } test_model = model_helper.ModelHelper(name="resnet101_test", arg_scope=test_arg_scope, init_params=False) test_reader = test_model.CreateDB( "test_reader", db=args.test_data, db_type=args.db_type, ) def test_input_fn(model): AddImageInput( model, test_reader, batch_size=batch_per_device, img_size=args.image_size, dtype=args.dtype, is_test=True, ) data_parallel_model.Parallelize( test_model, input_builder_fun=test_input_fn, forward_pass_builder_fun=create_resnet101_model_ops, post_sync_builder_fun=add_post_sync_ops, param_update_builder_fun=None, devices=gpus, cpu_device=args.use_cpu, ) workspace.RunNetOnce(test_model.param_init_net) workspace.CreateNet(test_model.net) epoch = 0 if args.load_model_path is not None: LoadModel(args.load_model_path, train_model) data_parallel_model.FinalizeAfterCheckpoint(train_model) last_str = args.load_model_path.split('_')[-1] if last_str.endswith('.mdl'): epoch = int(last_str[:-4]) log.info("Reset epoch to {}".format(epoch)) else: log.warning("The format of load_model_path doesn't match!") expname = "resnet101_gpu%d_b%d_L%d_lr%.2f_v2" % ( args.num_gpus, total_batch_size, args.num_labels, args.base_learning_rate, ) explog = experiment_util.ModelTrainerLog(expname, args) while epoch < args.num_epochs: epoch = RunEpoch(args, epoch, train_model, test_model, total_batch_size, num_shards, expname, explog) # final save SaveModel(workspace, train_model)
def Train(args): # Either use specified device list or generate one if args.gpus is not None: gpus = [int(x) for x in args.gpus.split(',')] num_gpus = len(gpus) else: gpus = list(range(args.num_gpus)) num_gpus = args.num_gpus log.info("Running on GPUs: {}".format(gpus)) # Verify valid batch size total_batch_size = args.batch_size batch_per_device = total_batch_size // num_gpus assert \ total_batch_size % num_gpus == 0, \ "Number of GPUs must divide batch size" # Round down epoch size to closest multiple of batch size across machines global_batch_size = total_batch_size * args.num_shards epoch_iters = int(args.epoch_size / global_batch_size) assert \ epoch_iters > 0, \ "Epoch size must be larger than batch size times shard count" args.epoch_size = epoch_iters * global_batch_size log.info("Using epoch size: {}".format(args.epoch_size)) # Create ModelHelper object train_arg_scope = { 'order': 'NCHW', 'use_cudnn': True, 'cudnn_exhaustive_search': True, 'ws_nbytes_limit': (args.cudnn_workspace_limit_mb * 1024 * 1024), } train_model = model_helper.ModelHelper(name="resnet50", arg_scope=train_arg_scope) num_shards = args.num_shards shard_id = args.shard_id # Expect interfaces to be comma separated. # Use of multiple network interfaces is not yet complete, # so simply use the first one in the list. interfaces = args.distributed_interfaces.split(",") # Rendezvous using MPI when run with mpirun if os.getenv("OMPI_COMM_WORLD_SIZE") is not None: num_shards = int(os.getenv("OMPI_COMM_WORLD_SIZE", 1)) shard_id = int(os.getenv("OMPI_COMM_WORLD_RANK", 0)) if num_shards > 1: rendezvous = dict(kv_handler=None, num_shards=num_shards, shard_id=shard_id, engine="GLOO", transport=args.distributed_transport, interface=interfaces[0], mpi_rendezvous=True, exit_nets=None) elif num_shards > 1: # Create rendezvous for distributed computation store_handler = "store_handler" if args.redis_host is not None: # Use Redis for rendezvous if Redis host is specified workspace.RunOperatorOnce( core.CreateOperator( "RedisStoreHandlerCreate", [], [store_handler], host=args.redis_host, port=args.redis_port, prefix=args.run_id, )) else: # Use filesystem for rendezvous otherwise workspace.RunOperatorOnce( core.CreateOperator( "FileStoreHandlerCreate", [], [store_handler], path=args.file_store_path, prefix=args.run_id, )) rendezvous = dict(kv_handler=store_handler, shard_id=shard_id, num_shards=num_shards, engine="GLOO", transport=args.distributed_transport, interface=interfaces[0], exit_nets=None) else: rendezvous = None # Model building functions # def create_resnet50_model_ops(model, loss_scale): # initializer = (PseudoFP16Initializer if args.dtype == 'float16' # else Initializer) # with brew.arg_scope([brew.conv, brew.fc], # WeightInitializer=initializer, # BiasInitializer=initializer, # enable_tensor_core=args.enable_tensor_core, # float16_compute=args.float16_compute): # pred = resnet.create_resnet50( # #args.layers, # model, # "data", # num_input_channels=args.num_channels, # num_labels=args.num_labels, # no_bias=True, # no_loss=True, # ) # if args.dtype == 'float16': # pred = model.net.HalfToFloat(pred, pred + '_fp32') # softmax, loss = model.SoftmaxWithLoss([pred, 'label'], # ['softmax', 'loss']) # loss = model.Scale(loss, scale=loss_scale) # brew.accuracy(model, [softmax, "label"], "accuracy") # return [loss] def create_model_ops(model, loss_scale): return create_model_ops_testable(model, loss_scale, is_test=False) def create_model_ops_test(model, loss_scale): return create_model_ops_testable(model, loss_scale, is_test=True) # Model building functions def create_model_ops_testable(model, loss_scale, is_test=False): initializer = (PseudoFP16Initializer if args.dtype == 'float16' else Initializer) with brew.arg_scope([brew.conv, brew.fc], WeightInitializer=initializer, BiasInitializer=initializer, enable_tensor_core=args.enable_tensor_core, float16_compute=args.float16_compute): if args.model == "cifar10": if args.image_size != 32: log.warn("Cifar10 expects a 32x32 image.") pred = models.cifar10.create_cifar10( model, "data", image_channels=args.num_channels, num_classes=args.num_labels, image_height=args.image_size, image_width=args.image_size, ) elif args.model == "resnet32x32": if args.image_size != 32: log.warn("ResNet32x32 expects a 32x32 image.") pred = models.resnet.create_resnet32x32( model, "data", num_layers=args.num_layers, num_input_channels=args.num_channels, num_labels=args.num_labels, is_test=is_test) elif args.model == "resnet": if args.image_size != 224: log.warn( "ResNet expects a 224x224 image. input image = %d" % args.image_size) pred = resnet.create_resnet50( #args.layers, model, "data", num_input_channels=args.num_channels, num_labels=args.num_labels, no_bias=True, no_loss=True, ) elif args.model == "vgg": if args.image_size != 224: log.warn("VGG expects a 224x224 image.") pred = vgg.create_vgg(model, "data", num_input_channels=args.num_channels, num_labels=args.num_labels, num_layers=args.num_layers, is_test=is_test) elif args.model == "googlenet": if args.image_size != 224: log.warn("GoogLeNet expects a 224x224 image.") pred = googlenet.create_googlenet( model, "data", num_input_channels=args.num_channels, num_labels=args.num_labels, is_test=is_test) elif args.model == "alexnet": if args.image_size != 224: log.warn("Alexnet expects a 224x224 image.") pred = alexnet.create_alexnet( model, "data", num_input_channels=args.num_channels, num_labels=args.num_labels, is_test=is_test) elif args.model == "alexnetv0": if args.image_size != 224: log.warn("Alexnet v0 expects a 224x224 image.") pred = alexnet.create_alexnetv0( model, "data", num_input_channels=args.num_channels, num_labels=args.num_labels, is_test=is_test) else: raise NotImplementedError("Network {} not found.".format( args.model)) if args.dtype == 'float16': pred = model.net.HalfToFloat(pred, pred + '_fp32') softmax, loss = model.SoftmaxWithLoss([pred, 'label'], ['softmax', 'loss']) loss = model.Scale(loss, scale=loss_scale) brew.accuracy(model, [softmax, "label"], "accuracy") return [loss] def add_optimizer(model): stepsz = int(30 * args.epoch_size / total_batch_size / num_shards) if args.float16_compute: # TODO: merge with multi-prceision optimizer opt = optimizer.build_fp16_sgd( model, args.base_learning_rate, momentum=0.9, nesterov=1, weight_decay=args.weight_decay, # weight decay included policy="step", stepsize=stepsz, gamma=0.1) else: optimizer.add_weight_decay(model, args.weight_decay) opt = optimizer.build_multi_precision_sgd(model, args.base_learning_rate, momentum=0.9, nesterov=1, policy="step", stepsize=stepsz, gamma=0.1) print("info:===============================" + str(opt)) return opt # Define add_image_input function. # Depends on the "train_data" argument. # Note that the reader will be shared with between all GPUS. if args.train_data == "null": def add_image_input(model): AddNullInput( model, None, batch_size=batch_per_device, img_size=args.image_size, dtype=args.dtype, ) else: reader = train_model.CreateDB( "reader", db=args.train_data, db_type=args.db_type, num_shards=num_shards, shard_id=shard_id, ) def add_image_input(model): AddImageInput( model, reader, batch_size=batch_per_device, img_size=args.image_size, dtype=args.dtype, is_test=False, ) def add_post_sync_ops(model): """Add ops applied after initial parameter sync.""" for param_info in model.GetOptimizationParamInfo(model.GetParams()): if param_info.blob_copy is not None: model.param_init_net.HalfToFloat( param_info.blob, param_info.blob_copy[core.DataType.FLOAT]) # Create parallelized model data_parallel_model.Parallelize(train_model, input_builder_fun=add_image_input, forward_pass_builder_fun=create_model_ops, optimizer_builder_fun=add_optimizer, post_sync_builder_fun=add_post_sync_ops, devices=gpus, rendezvous=rendezvous, optimize_gradient_memory=False, cpu_device=args.use_cpu, shared_model=args.use_cpu, combine_spatial_bn=args.use_cpu, use_nccl=args.use_nccl) if args.model_parallel: # Shift half of the activations to another GPU assert workspace.NumCudaDevices() >= 2 * args.num_gpus activations = data_parallel_model_utils.GetActivationBlobs(train_model) data_parallel_model_utils.ShiftActivationDevices( train_model, activations=activations[len(activations) // 2:], shifts={g: args.num_gpus + g for g in range(args.num_gpus)}, ) data_parallel_model.OptimizeGradientMemory(train_model, {}, set(), False) workspace.RunNetOnce(train_model.param_init_net) workspace.CreateNet(train_model.net) if "GLOO_ALGORITHM" in os.environ and os.environ[ "GLOO_ALGORITHM"] == "PHUB": #i need to communicate to PHub about the elements that need aggregation, #as well as their sizes. #at this stage, all i need is the name of keys and my key ID. grad_names = list(reversed(train_model._grad_names)) phubKeyNames = ["allreduce_{}_status".format(x) for x in grad_names] caffe2GradSizes = dict( zip([ data_parallel_model.stripBlobName(name) + "_grad" for name in train_model._parameters_info.keys() ], [x.size for x in train_model._parameters_info.values()])) phubKeySizes = [str(caffe2GradSizes[x]) for x in grad_names] if rendezvous["shard_id"] == 0: #only id 0 needs to send to rendezvous. r = redis.StrictRedis() #foreach key, I need to assign an ID joinedStr = ",".join(phubKeyNames) r.set("[PLink]IntegrationKeys", joinedStr) joinedStr = ",".join(phubKeySizes) r.set("[PLink]IntegrationKeySizes", joinedStr) # Add test model, if specified test_model = None if (args.test_data is not None): log.info("----- Create test net ----") test_arg_scope = { 'order': "NCHW", 'use_cudnn': True, 'cudnn_exhaustive_search': True, } test_model = model_helper.ModelHelper(name="resnet50_test", arg_scope=test_arg_scope, init_params=False) test_reader = test_model.CreateDB( "test_reader", db=args.test_data, db_type=args.db_type, ) def test_input_fn(model): AddImageInput( model, test_reader, batch_size=batch_per_device, img_size=args.image_size, dtype=args.dtype, is_test=True, ) data_parallel_model.Parallelize( test_model, input_builder_fun=test_input_fn, forward_pass_builder_fun=create_model_ops_test, post_sync_builder_fun=add_post_sync_ops, param_update_builder_fun=None, devices=gpus, cpu_device=args.use_cpu, ) workspace.RunNetOnce(test_model.param_init_net) workspace.CreateNet(test_model.net) epoch = 0 # load the pre-trained model and reset epoch if args.load_model_path is not None: LoadModel(args.load_model_path, train_model) # Sync the model params data_parallel_model.FinalizeAfterCheckpoint(train_model) # reset epoch. load_model_path should end with *_X.mdl, # where X is the epoch number last_str = args.load_model_path.split('_')[-1] if last_str.endswith('.mdl'): epoch = int(last_str[:-4]) log.info("Reset epoch to {}".format(epoch)) else: log.warning("The format of load_model_path doesn't match!") expname = "resnet50_gpu%d_b%d_L%d_lr%.2f_v2" % ( args.num_gpus, total_batch_size, args.num_labels, args.base_learning_rate, ) explog = experiment_util.ModelTrainerLog(expname, args) # Run the training one epoch a time while epoch < args.num_epochs: epoch = RunEpoch(args, epoch, train_model, test_model, total_batch_size, num_shards, expname, explog) # Save the model for each epoch SaveModel(args, train_model, epoch) model_path = "%s/%s_" % (args.file_store_path, args.save_model_name) # remove the saved model from the previous epoch if it exists if os.path.isfile(model_path + str(epoch - 1) + ".mdl"): os.remove(model_path + str(epoch - 1) + ".mdl")