def test(num_versions, assertion_ground_truth): # N is batch size; D_in is input dimension; # D_out is output dimension. N, D_in, D_H, D_out = 4, 4, 4, 4 # Create random input and output tensors. x1 = torch.randn(N, D_in).cuda().requires_grad_(True) y1 = torch.randn(N, D_out).cuda() # Create copies of these tensors. x2 = x1.clone().detach().requires_grad_(True) y2 = y1.clone() x3 = x1.clone().detach().requires_grad_(True) y3 = y1.clone() model = torch.nn.Sequential(torch.nn.Linear(D_in, D_H), torch.nn.ReLU(), torch.nn.Linear(D_H, D_out)).cuda() loss_fn = torch.nn.MSELoss() optimizer = sgd.SGDWithWeightStashing([model], model.parameters(), num_versions=num_versions, lr=1e-1) inputs = [x1, x2, x3] # Compute the prediction and loss function using the same weights # and inputs. y1_pred = model(x1) y2_pred = model(x2) y3_pred = model(x3) assert torch.equal(y1_pred, y2_pred) assert torch.equal(y1_pred, y3_pred) losses = [loss_fn(y1_pred, y1), loss_fn(y2_pred, y2), loss_fn(y3_pred, y3)] x_grads = [] for loss, x in zip(losses, inputs): optimizer.zero_grad() optimizer.load_old_params() loss.backward() x_grads.append(x.grad.clone().detach()) optimizer.load_new_params() optimizer.step() # Assert that the right weight versions are used to compute the # gradients. assert (torch.equal(x_grads[0], x_grads[1]) == assertion_ground_truth[0]) assert (torch.equal(x_grads[0], x_grads[2]) == assertion_ground_truth[1]) assert not torch.equal(y1_pred, model(x1))
def main(): global args, best_prec1 args = parser.parse_args() if int(args.rank) == int(args.world_size) - 1: log_level = logging.INFO else: log_level = logging.WARNING # log_level = logging.INFO logging.basicConfig( level=log_level, format="[%(asctime)s] %(name)s:%(levelname)s: %(message)s") logging.info(f'Find median: {args.find_median}') logging.warning(f'rank:{args.rank}, local_rank:{args.local_rank}') torch.cuda.set_device(args.local_rank) # define loss function (criterion) criterion = nn.CrossEntropyLoss() # create stages of the model module = importlib.import_module(args.module) args.arch = module.arch() model = module.model(criterion) # determine shapes of all tensors in passed-in model if args.arch == 'inception_v3': input_size = [args.batch_size, 3, 299, 299] else: #input_size = [args.batch_size, 3, 224, 224] # input_size = [args.batch_size, 3, 32, 32] input_size = [args.batch_size, 200] training_tensor_shapes = { "input0": input_size, "target": [args.batch_size] } dtypes = {"input0": torch.int64, "target": torch.int64} inputs_module_destinations = {"input": 0} target_tensor_names = {"target"} for i, (stage, inputs, outputs) in enumerate(model[:-1]): # Skip last layer (loss). input_tensors = [] for input in inputs: if i == 0: input_tensor = torch.zeros(tuple( training_tensor_shapes[input]), dtype=torch.int64) else: input_tensor = torch.zeros(tuple( training_tensor_shapes[input]), dtype=torch.float32) input_tensors.append(input_tensor) with torch.no_grad(): logging.debug( f'[{i}] input tensor shape: {input_tensors[0].shape}') output_tensors = stage(*tuple(input_tensors)) if not type(output_tensors) is tuple: output_tensors = [output_tensors] for output, output_tensor in zip(outputs, list(output_tensors)): training_tensor_shapes[output] = list(output_tensor.size()) dtypes[output] = output_tensor.dtype eval_tensor_shapes = {} for key in training_tensor_shapes: eval_tensor_shapes[key] = tuple([args.eval_batch_size] + training_tensor_shapes[key][1:]) training_tensor_shapes[key] = tuple(training_tensor_shapes[key]) configuration_maps = { 'module_to_stage_map': None, 'stage_to_rank_map': None, 'stage_to_depth_map': None } if args.config_path is not None: json_config_file = json.load(open(args.config_path, 'r')) configuration_maps['module_to_stage_map'] = json_config_file.get( "module_to_stage_map", None) configuration_maps['stage_to_rank_map'] = json_config_file.get( "stage_to_rank_map", None) configuration_maps['stage_to_rank_map'] = { int(k): v for (k, v) in configuration_maps['stage_to_rank_map'].items() } configuration_maps['stage_to_depth_map'] = json_config_file.get( "stage_to_depth_map", None) r = runtime.StageRuntime( model=model, distributed_backend=args.distributed_backend, fp16=args.fp16, loss_scale=args.loss_scale, training_tensor_shapes=training_tensor_shapes, eval_tensor_shapes=eval_tensor_shapes, training_tensor_dtypes=dtypes, inputs_module_destinations=inputs_module_destinations, target_tensor_names=target_tensor_names, configuration_maps=configuration_maps, master_addr=args.master_addr, rank=args.rank, local_rank=args.local_rank, num_ranks_in_server=args.num_ranks_in_server, verbose_freq=args.verbose_frequency, model_type=runtime.IMAGE_CLASSIFICATION, port=args.port, enable_recompute=args.recompute) # stage needed to determine if current stage is the first stage # num_stages needed to determine if current stage is the last stage # num_ranks needed to determine number of warmup_minibatches in case of pipelining args.stage = r.stage args.num_stages = r.num_stages args.num_ranks = r.num_ranks if not is_first_stage(): args.synthetic_data = True # define optimizer if args.no_input_pipelining: num_versions = 1 else: # number of versions is the total number of machines following the current # stage, shared amongst all replicas in this stage num_versions = r.num_warmup_minibatches + 1 # if specified, resume from checkpoint if args.resume: checkpoint_file_path = "%s.%d.pth.tar" % (args.resume, r.stage) assert os.path.isfile(checkpoint_file_path) logging.info("=> loading checkpoint '{}'".format(checkpoint_file_path)) checkpoint = torch.load(checkpoint_file_path) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] r.load_state_dict(checkpoint['state_dict']) logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_file_path, checkpoint['epoch'])) #optimizer = sgd.SGDWithWeightStashing(r.modules(), r.master_parameters, if args.spectrain: if args.log_dir != None: args.log_dir += '_spectrain_v1' logging.info('Using spectrain_v1') if args.square: if args.log_dir != None: args.log_dir += '_square' logging.info('s = version difference ^ 2') else: logging.info('s = version difference') optimizer = sgd.SGDWithSpectrainCHC( r.modules(), r.master_parameters, r.model_parameters, args.loss_scale, # num_versions=num_versions, num_versions=1, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, verbose_freq=args.verbose_frequency, macrobatch=args.macrobatch) else: logging.info('Not using spectrain') optimizer = sgd.SGDWithWeightStashing( r.modules(), r.master_parameters, r.model_parameters, args.loss_scale, num_versions=num_versions, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, verbose_freq=args.verbose_frequency, macrobatch=args.macrobatch) logging.info(f'log_dir: {args.log_dir}') if args.resume: optimizer.load_state_dict(checkpoint['optimizer']) cudnn.benchmark = True # Data loading code normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # logging.info(f'args.arch = {args.arch}') # logging.info(f'args.synthetic_data = {args.synthetic_data}') from keras.preprocessing.sequence import pad_sequences from keras.datasets import imdb (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=10000) x_train = pad_sequences(x_train, maxlen=200, padding="post", truncating="post") x_test = pad_sequences(x_test, maxlen=200, padding="post", truncating="post") print(x_train.shape, x_test.shape) train_dataset = TensorDataset(torch.LongTensor(x_train), torch.LongTensor(y_train)) val_dataset = TensorDataset(torch.LongTensor(x_test), torch.LongTensor(y_test)) # logging.info(f'rank[{args.rank}] type(train_dataset) = {type(train_dataset)}') # exit() global writer if dist.get_rank() == dist.get_world_size() - 1: # writer = SummaryWriter(args.log_dir) pass distributed_sampler = False train_sampler = None val_sampler = None if configuration_maps['stage_to_rank_map'] is not None: num_ranks_in_first_stage = len( configuration_maps['stage_to_rank_map'][0]) if num_ranks_in_first_stage > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=num_ranks_in_first_stage, rank=args.rank) val_sampler = torch.utils.data.distributed.DistributedSampler( val_dataset, num_replicas=num_ranks_in_first_stage, rank=args.rank) distributed_sampler = True train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) # logging.info(f'type(train_loader) = {type(train_loader)}') val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=True) # if checkpoint is loaded, start by running validation if args.resume: assert args.start_epoch > 0 validate(val_loader, r, args.start_epoch - 1) for epoch in range(args.start_epoch, args.epochs): if distributed_sampler: train_sampler.set_epoch(epoch) # train or run forward pass only for one epoch if args.forward_only: validate(val_loader, r, epoch) else: train(train_loader, r, optimizer, epoch) # evaluate on validation set prec1 = validate(val_loader, r, epoch) if r.stage != r.num_stages: prec1 = 0 # remember best prec@1 and save checkpoint best_prec1 = max(prec1, best_prec1) should_save_checkpoint = args.checkpoint_dir_not_nfs or r.rank_in_stage == 0 if args.checkpoint_dir and should_save_checkpoint: save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': r.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, args.checkpoint_dir, r.stage)
def main(): global args, best_prec1 args = parser.parse_args() # Special case handling for GNMT model l2_promote() torch.cuda.set_device(args.local_rank) # build tokenizer tokenizer = Tokenizer(os.path.join(args.data_dir, config.VOCAB_FNAME)) # define loss function criterion = build_gnmt_criterion(vocab_size=tokenizer.vocab_size, padding_idx=config.PAD, smoothing=0.1) # create stages of the model module = importlib.import_module(args.module) args.arch = module.arch() model = module.model(criterion) input_size = [args.max_length_train, args.batch_size] training_tensor_shapes = { "input0": input_size, "input1": [args.batch_size], "input2": input_size, "target": [args.max_length_train * args.batch_size], "target_length": [args.batch_size] } dtypes = { "input0": torch.int64, "input1": torch.int64, "input2": torch.int64, "target": torch.int64, "target_length": torch.int32 } inputs_module_destinations = {"input0": 0, "input1": 0, "input2": 0} target_tensor_names = {"target", "target_length"} for module_id, (stage, inputs, outputs) in enumerate( model[:-1]): # Skip last layer (loss). input_tensors = [] for module_input in inputs: if module_input in inputs_module_destinations: inputs_module_destinations[module_input] = module_id input_tensor = torch.ones(tuple( training_tensor_shapes[module_input]), dtype=dtypes[module_input]).cuda() input_tensors.append(input_tensor) stage.cuda() # PyTorch should not maintain metadata for a backward pass on # synthetic inputs. Without the following line, the runtime is # as much as 1.5x slower in a full DP configuration. with torch.no_grad(): output_tensors = stage(*tuple(input_tensors)) if not type(output_tensors) is tuple: output_tensors = [output_tensors] for output, output_tensor in zip(outputs, list(output_tensors)): training_tensor_shapes[output] = list(output_tensor.size()) dtypes[output] = output_tensor.dtype eval_tensor_shapes = {} for key in training_tensor_shapes: eval_tensor_shapes[key] = tuple(training_tensor_shapes[key]) training_tensor_shapes[key] = tuple(training_tensor_shapes[key]) configuration_maps = { 'module_to_stage_map': None, 'stage_to_rank_map': None, 'stage_to_depth_map': None } if args.config_path is not None: json_config_file = json.load(open(args.config_path, 'r')) configuration_maps['module_to_stage_map'] = json_config_file.get( "module_to_stage_map", None) configuration_maps['stage_to_rank_map'] = json_config_file.get( "stage_to_rank_map", None) configuration_maps['stage_to_rank_map'] = { int(k): v for (k, v) in configuration_maps['stage_to_rank_map'].items() } configuration_maps['stage_to_depth_map'] = json_config_file.get( "stage_to_depth_map", None) r = runtime.StageRuntime( model=model, distributed_backend=args.distributed_backend, fp16=args.fp16, loss_scale=args.loss_scale, training_tensor_shapes=training_tensor_shapes, eval_tensor_shapes=eval_tensor_shapes, training_tensor_dtypes=dtypes, inputs_module_destinations=inputs_module_destinations, target_tensor_names=target_tensor_names, configuration_maps=configuration_maps, master_addr=args.master_addr, rank=args.rank, local_rank=args.local_rank, num_ranks_in_server=args.num_ranks_in_server, verbose_freq=args.verbose_frequency, model_type=runtime.TRANSLATION, enable_recompute=args.recompute) # stage needed to determine if current stage is the first stage # num_stages needed to determine if current stage is the last stage # num_ranks needed to determine number of warmup_minibatches in case of pipelining args.stage = r.stage args.num_stages = r.num_stages args.num_ranks = r.num_ranks if not is_first_stage(): args.synthetic_data = True # define optimizer if args.no_input_pipelining: num_versions = 1 else: # number of versions is the total number of machines following the current # stage, shared amongst all replicas in this stage num_versions = r.num_warmup_minibatches + 1 # if specified, resume from checkpoint if args.resume: checkpoint_file_path = "%s.%d.pth.tar" % (args.resume, r.stage) assert os.path.isfile(checkpoint_file_path) print("=> loading checkpoint '{}'".format(checkpoint_file_path)) checkpoint = torch.load(checkpoint_file_path) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] r.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_file_path, checkpoint['epoch'])) # TODO: make this configurable by args use_adam_optimizer = True if use_adam_optimizer: optimizer = adam.AdamWithWeightStashing( modules=r.modules(), master_parameters=r.master_parameters, model_parameters=r.model_parameters, loss_scale=args.loss_scale, num_versions=num_versions, lr=args.lr, betas=(0.9, 0.999), weight_decay=args.weight_decay, verbose_freq=args.verbose_frequency, macrobatch=args.macrobatch) else: optimizer = sgd.SGDWithWeightStashing( modules=r.modules(), master_parameters=r.master_parameters, model_parameters=r.model_parameters, loss_scale=args.loss_scale, num_versions=num_versions, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, verbose_freq=args.verbose_frequency) if args.resume: optimizer.load_state_dict(checkpoint['optimizer']) cudnn.benchmark = True train_dataset = LazyParallelDataset( src_fname=os.path.join(args.data_dir, config.SRC_TRAIN_FNAME), tgt_fname=os.path.join(args.data_dir, config.TGT_TRAIN_FNAME), tokenizer=tokenizer, min_len=args.min_length_train, max_len=args.max_length_train, sort=False, max_size=None) val_dataset = ParallelDataset( src_fname=os.path.join(args.data_dir, config.SRC_VAL_FNAME), tgt_fname=os.path.join(args.data_dir, config.TGT_VAL_FNAME), tokenizer=tokenizer, min_len=args.min_length_train, max_len=args.max_length_train, sort=True) distributed_sampler = False if configuration_maps['stage_to_rank_map'] is not None: num_ranks_in_first_stage = len( configuration_maps['stage_to_rank_map'][0]) if num_ranks_in_first_stage > 1: distributed_sampler = True # TODO: fix random seeds train_loader = train_dataset.get_loader( batch_size=args.batch_size, seeds=range(args.epochs), batch_first=False, shuffle=True, bucketing=not args.no_bucketing, num_workers=args.workers, world_size=r.num_ranks_in_first_stage, rank=r.rank_in_stage if r.stage == 0 else 0) val_loader = val_dataset.get_loader( batch_size=args.batch_size, batch_first=False, shuffle=True, num_workers=args.workers, world_size=r.num_ranks_in_first_stage, seeds=range(args.epochs), rank=r.rank_in_stage if r.stage == 0 else 0) # if checkpoint is loaded, start by running validation if args.resume: assert args.start_epoch > 0 validate(val_loader, r, args.start_epoch - 1) for epoch in range(args.start_epoch, args.epochs): if distributed_sampler: train_loader.sampler.set_epoch(epoch) adjust_learning_rate(optimizer, epoch, args.epochs, r, args.lr_policy) # train or run forward pass only for one epoch if args.forward_only: validate(val_loader, r, epoch) else: train(train_loader, r, optimizer, epoch) # evaluate on validation set prec1 = validate(val_loader, r, epoch) if r.stage != r.num_stages: prec1 = 0 # remember best prec@1 and save checkpoint best_prec1 = max(prec1, best_prec1) should_save_checkpoint = args.checkpoint_dir_not_nfs or r.rank_in_stage == 0 if args.checkpoint_dir and should_save_checkpoint: save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': r.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), 'tokenizer': tokenizer.get_state() }, args.checkpoint_dir, r.stage, epoch)
def main(): global args, best_prec1 args = parser.parse_args() torch.cuda.set_device(args.local_rank) # define loss function (criterion) criterion = nn.CrossEntropyLoss() # create stages of the model module = importlib.import_module(args.module) args.arch = module.arch() model = module.model(criterion) # determine shapes of all tensors in passed-in model if args.arch == 'inception_v3': input_size = [args.batch_size, 3, 299, 299] else: input_size = [args.batch_size, 3, 224, 224] training_tensor_shapes = { "input0": input_size, "target": [args.batch_size] } dtypes = {"input0": torch.int64, "target": torch.int64} inputs_module_destinations = {"input": 0} target_tensor_names = {"target"} for (stage, inputs, outputs) in model[:-1]: # Skip last layer (loss). input_tensors = [] for input in inputs: input_tensor = torch.zeros(tuple(training_tensor_shapes[input]), dtype=torch.float32) input_tensors.append(input_tensor) with torch.no_grad(): output_tensors = stage(*tuple(input_tensors)) if not type(output_tensors) is tuple: output_tensors = [output_tensors] for output, output_tensor in zip(outputs, list(output_tensors)): training_tensor_shapes[output] = list(output_tensor.size()) dtypes[output] = output_tensor.dtype eval_tensor_shapes = {} for key in training_tensor_shapes: eval_tensor_shapes[key] = tuple([args.eval_batch_size] + training_tensor_shapes[key][1:]) training_tensor_shapes[key] = tuple(training_tensor_shapes[key]) configuration_maps = { 'module_to_stage_map': None, 'stage_to_rank_map': None, 'stage_to_depth_map': None } if args.config_path is not None: json_config_file = json.load(open(args.config_path, 'r')) configuration_maps['module_to_stage_map'] = json_config_file.get( "module_to_stage_map", None) configuration_maps['stage_to_rank_map'] = json_config_file.get( "stage_to_rank_map", None) configuration_maps['stage_to_rank_map'] = { int(k): v for (k, v) in configuration_maps['stage_to_rank_map'].items() } configuration_maps['stage_to_depth_map'] = json_config_file.get( "stage_to_depth_map", None) r = runtime.StageRuntime( model=model, distributed_backend=args.distributed_backend, fp16=args.fp16, loss_scale=args.loss_scale, training_tensor_shapes=training_tensor_shapes, eval_tensor_shapes=eval_tensor_shapes, training_tensor_dtypes=dtypes, inputs_module_destinations=inputs_module_destinations, target_tensor_names=target_tensor_names, configuration_maps=configuration_maps, master_addr=args.master_addr, rank=args.rank, local_rank=args.local_rank, num_ranks_in_server=args.num_ranks_in_server, verbose_freq=args.verbose_frequency, model_type=runtime.IMAGE_CLASSIFICATION, enable_recompute=args.recompute) # stage needed to determine if current stage is the first stage # num_stages needed to determine if current stage is the last stage # num_ranks needed to determine number of warmup_minibatches in case of pipelining args.stage = r.stage args.num_stages = r.num_stages args.num_ranks = r.num_ranks if not is_first_stage(): args.synthetic_data = True # define optimizer if args.no_input_pipelining: num_versions = 1 else: # number of versions is the total number of machines following the current # stage, shared amongst all replicas in this stage num_versions = r.num_warmup_minibatches + 1 # if specified, resume from checkpoint if args.resume: checkpoint_file_path = "%s.%d.pth.tar" % (args.resume, r.stage) assert os.path.isfile(checkpoint_file_path) print("=> loading checkpoint '{}'".format(checkpoint_file_path)) checkpoint = torch.load(checkpoint_file_path) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] r.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_file_path, checkpoint['epoch'])) optimizer = sgd.SGDWithWeightStashing(r.modules(), r.master_parameters, r.model_parameters, args.loss_scale, num_versions=num_versions, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, verbose_freq=args.verbose_frequency, macrobatch=args.macrobatch) if args.resume: optimizer.load_state_dict(checkpoint['optimizer']) cudnn.benchmark = True # Data loading code traindir = os.path.join(args.data_dir, 'train') valdir = os.path.join(args.data_dir, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if args.arch == 'inception_v3': train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(299), transforms.ToTensor(), normalize, ])) if args.synthetic_data: train_dataset = SyntheticDataset((3, 299, 299), len(train_dataset)) else: train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) if args.synthetic_data: train_dataset = SyntheticDataset((3, 224, 224), len(train_dataset)) val_dataset = datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])) distributed_sampler = False train_sampler = None val_sampler = None if configuration_maps['stage_to_rank_map'] is not None: num_ranks_in_first_stage = len( configuration_maps['stage_to_rank_map'][0]) if num_ranks_in_first_stage > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=num_ranks_in_first_stage, rank=args.rank) val_sampler = torch.utils.data.distributed.DistributedSampler( val_dataset, num_replicas=num_ranks_in_first_stage, rank=args.rank) distributed_sampler = True train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=True) # if checkpoint is loaded, start by running validation if args.resume: assert args.start_epoch > 0 validate(val_loader, r, args.start_epoch - 1) for epoch in range(args.start_epoch, args.epochs): if distributed_sampler: train_sampler.set_epoch(epoch) # train or run forward pass only for one epoch if args.forward_only: validate(val_loader, r, epoch) else: train(train_loader, r, optimizer, epoch) # evaluate on validation set prec1 = validate(val_loader, r, epoch) if r.stage != r.num_stages: prec1 = 0 # remember best prec@1 and save checkpoint best_prec1 = max(prec1, best_prec1) should_save_checkpoint = args.checkpoint_dir_not_nfs or r.rank_in_stage == 0 if args.checkpoint_dir and should_save_checkpoint: save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': r.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, args.checkpoint_dir, r.stage)