def main(): global args, best_prec1 args = parser.parse_args() if int(args.rank) == int(args.world_size) - 1: log_level = logging.INFO else: log_level = logging.WARNING # log_level = logging.INFO logging.basicConfig( level=log_level, format="[%(asctime)s] %(name)s:%(levelname)s: %(message)s") logging.info(f'Find median: {args.find_median}') logging.warning(f'rank:{args.rank}, local_rank:{args.local_rank}') torch.cuda.set_device(args.local_rank) # define loss function (criterion) criterion = nn.CrossEntropyLoss() # create stages of the model module = importlib.import_module(args.module) args.arch = module.arch() model = module.model(criterion) # determine shapes of all tensors in passed-in model if args.arch == 'inception_v3': input_size = [args.batch_size, 3, 299, 299] else: #input_size = [args.batch_size, 3, 224, 224] # input_size = [args.batch_size, 3, 32, 32] input_size = [args.batch_size, 200] training_tensor_shapes = { "input0": input_size, "target": [args.batch_size] } dtypes = {"input0": torch.int64, "target": torch.int64} inputs_module_destinations = {"input": 0} target_tensor_names = {"target"} for i, (stage, inputs, outputs) in enumerate(model[:-1]): # Skip last layer (loss). input_tensors = [] for input in inputs: if i == 0: input_tensor = torch.zeros(tuple( training_tensor_shapes[input]), dtype=torch.int64) else: input_tensor = torch.zeros(tuple( training_tensor_shapes[input]), dtype=torch.float32) input_tensors.append(input_tensor) with torch.no_grad(): logging.debug( f'[{i}] input tensor shape: {input_tensors[0].shape}') output_tensors = stage(*tuple(input_tensors)) if not type(output_tensors) is tuple: output_tensors = [output_tensors] for output, output_tensor in zip(outputs, list(output_tensors)): training_tensor_shapes[output] = list(output_tensor.size()) dtypes[output] = output_tensor.dtype eval_tensor_shapes = {} for key in training_tensor_shapes: eval_tensor_shapes[key] = tuple([args.eval_batch_size] + training_tensor_shapes[key][1:]) training_tensor_shapes[key] = tuple(training_tensor_shapes[key]) configuration_maps = { 'module_to_stage_map': None, 'stage_to_rank_map': None, 'stage_to_depth_map': None } if args.config_path is not None: json_config_file = json.load(open(args.config_path, 'r')) configuration_maps['module_to_stage_map'] = json_config_file.get( "module_to_stage_map", None) configuration_maps['stage_to_rank_map'] = json_config_file.get( "stage_to_rank_map", None) configuration_maps['stage_to_rank_map'] = { int(k): v for (k, v) in configuration_maps['stage_to_rank_map'].items() } configuration_maps['stage_to_depth_map'] = json_config_file.get( "stage_to_depth_map", None) r = runtime.StageRuntime( model=model, distributed_backend=args.distributed_backend, fp16=args.fp16, loss_scale=args.loss_scale, training_tensor_shapes=training_tensor_shapes, eval_tensor_shapes=eval_tensor_shapes, training_tensor_dtypes=dtypes, inputs_module_destinations=inputs_module_destinations, target_tensor_names=target_tensor_names, configuration_maps=configuration_maps, master_addr=args.master_addr, rank=args.rank, local_rank=args.local_rank, num_ranks_in_server=args.num_ranks_in_server, verbose_freq=args.verbose_frequency, model_type=runtime.IMAGE_CLASSIFICATION, port=args.port, enable_recompute=args.recompute) # stage needed to determine if current stage is the first stage # num_stages needed to determine if current stage is the last stage # num_ranks needed to determine number of warmup_minibatches in case of pipelining args.stage = r.stage args.num_stages = r.num_stages args.num_ranks = r.num_ranks if not is_first_stage(): args.synthetic_data = True # define optimizer if args.no_input_pipelining: num_versions = 1 else: # number of versions is the total number of machines following the current # stage, shared amongst all replicas in this stage num_versions = r.num_warmup_minibatches + 1 # if specified, resume from checkpoint if args.resume: checkpoint_file_path = "%s.%d.pth.tar" % (args.resume, r.stage) assert os.path.isfile(checkpoint_file_path) logging.info("=> loading checkpoint '{}'".format(checkpoint_file_path)) checkpoint = torch.load(checkpoint_file_path) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] r.load_state_dict(checkpoint['state_dict']) logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_file_path, checkpoint['epoch'])) #optimizer = sgd.SGDWithWeightStashing(r.modules(), r.master_parameters, if args.spectrain: if args.log_dir != None: args.log_dir += '_spectrain_v1' logging.info('Using spectrain_v1') if args.square: if args.log_dir != None: args.log_dir += '_square' logging.info('s = version difference ^ 2') else: logging.info('s = version difference') optimizer = sgd.SGDWithSpectrainCHC( r.modules(), r.master_parameters, r.model_parameters, args.loss_scale, # num_versions=num_versions, num_versions=1, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, verbose_freq=args.verbose_frequency, macrobatch=args.macrobatch) else: logging.info('Not using spectrain') optimizer = sgd.SGDWithWeightStashing( r.modules(), r.master_parameters, r.model_parameters, args.loss_scale, num_versions=num_versions, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, verbose_freq=args.verbose_frequency, macrobatch=args.macrobatch) logging.info(f'log_dir: {args.log_dir}') if args.resume: optimizer.load_state_dict(checkpoint['optimizer']) cudnn.benchmark = True # Data loading code normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # logging.info(f'args.arch = {args.arch}') # logging.info(f'args.synthetic_data = {args.synthetic_data}') from keras.preprocessing.sequence import pad_sequences from keras.datasets import imdb (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=10000) x_train = pad_sequences(x_train, maxlen=200, padding="post", truncating="post") x_test = pad_sequences(x_test, maxlen=200, padding="post", truncating="post") print(x_train.shape, x_test.shape) train_dataset = TensorDataset(torch.LongTensor(x_train), torch.LongTensor(y_train)) val_dataset = TensorDataset(torch.LongTensor(x_test), torch.LongTensor(y_test)) # logging.info(f'rank[{args.rank}] type(train_dataset) = {type(train_dataset)}') # exit() global writer if dist.get_rank() == dist.get_world_size() - 1: # writer = SummaryWriter(args.log_dir) pass distributed_sampler = False train_sampler = None val_sampler = None if configuration_maps['stage_to_rank_map'] is not None: num_ranks_in_first_stage = len( configuration_maps['stage_to_rank_map'][0]) if num_ranks_in_first_stage > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=num_ranks_in_first_stage, rank=args.rank) val_sampler = torch.utils.data.distributed.DistributedSampler( val_dataset, num_replicas=num_ranks_in_first_stage, rank=args.rank) distributed_sampler = True train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) # logging.info(f'type(train_loader) = {type(train_loader)}') val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=True) # if checkpoint is loaded, start by running validation if args.resume: assert args.start_epoch > 0 validate(val_loader, r, args.start_epoch - 1) for epoch in range(args.start_epoch, args.epochs): if distributed_sampler: train_sampler.set_epoch(epoch) # train or run forward pass only for one epoch if args.forward_only: validate(val_loader, r, epoch) else: train(train_loader, r, optimizer, epoch) # evaluate on validation set prec1 = validate(val_loader, r, epoch) if r.stage != r.num_stages: prec1 = 0 # remember best prec@1 and save checkpoint best_prec1 = max(prec1, best_prec1) should_save_checkpoint = args.checkpoint_dir_not_nfs or r.rank_in_stage == 0 if args.checkpoint_dir and should_save_checkpoint: save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': r.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, args.checkpoint_dir, r.stage)
def main(): global args, best_prec1 args = parser.parse_args() args.data = args.data_dir os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.local_rank}" # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load dataset splits load_dataset_splits(task, ['train', 'valid']) # Build criterion criterion = task.build_criterion(args) # create stages of the model module = importlib.import_module(args.module) args.arch = module.arch() model = module.model(criterion) max_positions = (args.max_source_positions, args.max_target_positions) dummy_batch = task.dataset('train').get_dummy_batch( args.max_tokens, max_positions) inputs = dummy_batch['net_input'] input0 = inputs['src_tokens'] input1 = inputs['prev_output_tokens'] target = dummy_batch['target'] training_tensor_shapes = { "input0": list(input0.size()), "input1": list(input1.size()), "target": list(target.size()), "ntokens": [1] } dtypes = { "input0": input0.dtype, "input1": input1.dtype, "target": target.dtype, "ntokens": torch.float32 } inputs_module_destinations = {"input0": 0, "input1": 0} target_tensor_names = {"target", "ntokens"} for module_id, (stage, inputs, outputs) in enumerate( model[:-1]): # Skip last layer (loss). input_tensors = [] for module_input in inputs: if module_input in inputs_module_destinations: inputs_module_destinations[module_input] = module_id input_tensor = torch.ones(tuple( training_tensor_shapes[module_input]), dtype=dtypes[module_input]).cuda() input_tensors.append(input_tensor) stage.cuda() # PyTorch should not maintain metadata for a backward pass on # synthetic inputs. Without the following line, the runtime is # as much as 1.5x slower in a full DP configuration. with torch.no_grad(): output_tensors = stage(*tuple(input_tensors)) if not type(output_tensors) is tuple: output_tensors = [output_tensors] for output, output_tensor in zip(outputs, list(output_tensors)): training_tensor_shapes[output] = list(output_tensor.size()) dtypes[output] = output_tensor.dtype eval_tensor_shapes = {} for key in training_tensor_shapes: eval_tensor_shapes[key] = tuple(training_tensor_shapes[key]) training_tensor_shapes[key] = tuple(training_tensor_shapes[key]) configuration_maps = { 'module_to_stage_map': None, 'stage_to_rank_map': None, 'stage_to_depth_map': None } if args.config_path is not None: json_config_file = json.load(open(args.config_path, 'r')) configuration_maps['module_to_stage_map'] = json_config_file.get( "module_to_stage_map", None) configuration_maps['stage_to_rank_map'] = json_config_file.get( "stage_to_rank_map", None) configuration_maps['stage_to_rank_map'] = { int(k): v for (k, v) in configuration_maps['stage_to_rank_map'].items() } configuration_maps['stage_to_depth_map'] = json_config_file.get( "stage_to_depth_map", None) r = runtime.StageRuntime( model=model, distributed_backend=args.distributed_backend, fp16=args.fp16, loss_scale=args.loss_scale, training_tensor_shapes=training_tensor_shapes, eval_tensor_shapes=eval_tensor_shapes, training_tensor_dtypes=dtypes, inputs_module_destinations=inputs_module_destinations, target_tensor_names=target_tensor_names, configuration_maps=configuration_maps, master_addr=args.master_addr, rank=args.rank, local_rank=args.local_rank, num_ranks_in_server=args.num_ranks_in_server, verbose_freq=args.verbose_frequency, model_type=runtime.TRANSLATION, enable_recompute=args.recompute) # stage needed to determine if current stage is the first stage # num_stages needed to determine if current stage is the last stage # num_ranks needed to determine number of warmup_minibatches in case of pipelining args.stage = r.stage args.num_stages = r.num_stages args.num_ranks = r.num_ranks if not is_first_stage(): args.synthetic_data = True # define optimizer if args.no_input_pipelining: num_versions = 1 else: # number of versions is the total number of machines following the current # stage, shared amongst all replicas in this stage num_versions = r.num_warmup_minibatches + 1 # if specified, resume from checkpoint if args.resume: checkpoint_file_path = os.path.join( args.checkpoint_dir, f"checkpoint.{r.stage}.pth.tar.epoch.{args.start_epoch}") assert os.path.isfile(checkpoint_file_path) print("=> loading checkpoint '{}'".format(checkpoint_file_path)) checkpoint = torch.load(checkpoint_file_path) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] r.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_file_path, checkpoint['epoch'])) # TODO: make this configurable by args use_adam_optimizer = True if use_adam_optimizer: optimizer = adam.Adam(r.master_parameters, lr=args.lr, betas=(0.9, 0.98), weight_decay=args.weight_decay) else: optimizer = sgd.SGD(r.master_parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = lr_scheduler.build_lr_scheduler(args, optimizer) if args.resume: optimizer.load_state_dict(checkpoint['optimizer']) cudnn.benchmark = True # epoch_itr = data.EpochBatchIterator( # dataset=task.dataset(args.train_subset), # max_tokens=args.max_tokens, # max_sentences=args.max_sentences_valid, # max_positions=max_positions, # ignore_invalid_inputs=True, # required_batch_size_multiple=8, # seed=1, # num_shards=1, # shard_id=0, # ) def epoch_itr(): return task.dataset('train').get_dummy_batch(args.max_tokens, max_positions) distributed_sampler = False if configuration_maps['stage_to_rank_map'] is not None: num_ranks_in_first_stage = len( configuration_maps['stage_to_rank_map'][0]) if num_ranks_in_first_stage > 1: distributed_sampler = True for epoch in range(args.start_epoch, args.epochs): if distributed_sampler: train_loader.sampler.set_epoch(epoch) # train or run forward pass only for one epoch if args.forward_only: validate(val_loader, r, epoch) else: train(epoch_itr, r, optimizer, epoch, scheduler) # evaluate on validation set # prec1 = validate(val_loader, r, epoch) prec1 = 0 if r.stage != r.num_stages: prec1 = 0 # remember best prec@1 and save checkpoint best_prec1 = max(prec1, best_prec1) should_save_checkpoint = args.checkpoint_dir_not_nfs or r.rank_in_stage == 0 if args.checkpoint_dir and should_save_checkpoint: save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': r.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict() }, args.checkpoint_dir, r.stage, epoch)
def main(): global args, best_prec1 args = parser.parse_args() torch.cuda.set_device(args.local_rank) # define loss function (criterion) criterion = nn.CrossEntropyLoss() # create stages of the model module = importlib.import_module(args.module) args.arch = module.arch() model = module.model(criterion) # determine shapes of all tensors in passed-in model if args.arch == 'inception_v3': input_size = [args.batch_size, 3, 299, 299] else: input_size = [args.batch_size, 3, 224, 224] training_tensor_shapes = { "input0": input_size, "target": [args.batch_size] } dtypes = {"input0": torch.int64, "target": torch.int64} inputs_module_destinations = {"input": 0} target_tensor_names = {"target"} for (stage, inputs, outputs) in model[:-1]: # Skip last layer (loss). input_tensors = [] for input in inputs: input_tensor = torch.zeros(tuple(training_tensor_shapes[input]), dtype=torch.float32) input_tensors.append(input_tensor) with torch.no_grad(): output_tensors = stage(*tuple(input_tensors)) if not type(output_tensors) is tuple: output_tensors = [output_tensors] for output, output_tensor in zip(outputs, list(output_tensors)): training_tensor_shapes[output] = list(output_tensor.size()) dtypes[output] = output_tensor.dtype eval_tensor_shapes = {} for key in training_tensor_shapes: eval_tensor_shapes[key] = tuple([args.eval_batch_size] + training_tensor_shapes[key][1:]) training_tensor_shapes[key] = tuple(training_tensor_shapes[key]) configuration_maps = { 'module_to_stage_map': None, 'stage_to_rank_map': None, 'stage_to_depth_map': None } if args.config_path is not None: json_config_file = json.load(open(args.config_path, 'r')) configuration_maps['module_to_stage_map'] = json_config_file.get( "module_to_stage_map", None) configuration_maps['stage_to_rank_map'] = json_config_file.get( "stage_to_rank_map", None) configuration_maps['stage_to_rank_map'] = { int(k): v for (k, v) in configuration_maps['stage_to_rank_map'].items() } configuration_maps['stage_to_depth_map'] = json_config_file.get( "stage_to_depth_map", None) r = runtime.StageRuntime( model=model, distributed_backend=args.distributed_backend, fp16=args.fp16, loss_scale=args.loss_scale, training_tensor_shapes=training_tensor_shapes, eval_tensor_shapes=eval_tensor_shapes, training_tensor_dtypes=dtypes, inputs_module_destinations=inputs_module_destinations, target_tensor_names=target_tensor_names, configuration_maps=configuration_maps, master_addr=args.master_addr, rank=args.rank, local_rank=args.local_rank, num_ranks_in_server=args.num_ranks_in_server, verbose_freq=args.verbose_frequency, model_type=runtime.IMAGE_CLASSIFICATION, enable_recompute=args.recompute) # stage needed to determine if current stage is the first stage # num_stages needed to determine if current stage is the last stage # num_ranks needed to determine number of warmup_minibatches in case of pipelining args.stage = r.stage args.num_stages = r.num_stages args.num_ranks = r.num_ranks if not is_first_stage(): args.synthetic_data = True # define optimizer if args.no_input_pipelining: num_versions = 1 else: # number of versions is the total number of machines following the current # stage, shared amongst all replicas in this stage num_versions = r.num_warmup_minibatches + 1 # if specified, resume from checkpoint if args.resume: checkpoint_file_path = "%s.%d.pth.tar" % (args.resume, r.stage) assert os.path.isfile(checkpoint_file_path) print("=> loading checkpoint '{}'".format(checkpoint_file_path)) checkpoint = torch.load(checkpoint_file_path) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] r.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_file_path, checkpoint['epoch'])) optimizer = sgd.SGDWithWeightStashing(r.modules(), r.master_parameters, r.model_parameters, args.loss_scale, num_versions=num_versions, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, verbose_freq=args.verbose_frequency, macrobatch=args.macrobatch) if args.resume: optimizer.load_state_dict(checkpoint['optimizer']) cudnn.benchmark = True # Data loading code traindir = os.path.join(args.data_dir, 'train') valdir = os.path.join(args.data_dir, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if args.arch == 'inception_v3': train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(299), transforms.ToTensor(), normalize, ])) if args.synthetic_data: train_dataset = SyntheticDataset((3, 299, 299), len(train_dataset)) else: train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) if args.synthetic_data: train_dataset = SyntheticDataset((3, 224, 224), len(train_dataset)) val_dataset = datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])) distributed_sampler = False train_sampler = None val_sampler = None if configuration_maps['stage_to_rank_map'] is not None: num_ranks_in_first_stage = len( configuration_maps['stage_to_rank_map'][0]) if num_ranks_in_first_stage > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=num_ranks_in_first_stage, rank=args.rank) val_sampler = torch.utils.data.distributed.DistributedSampler( val_dataset, num_replicas=num_ranks_in_first_stage, rank=args.rank) distributed_sampler = True train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=True) # if checkpoint is loaded, start by running validation if args.resume: assert args.start_epoch > 0 validate(val_loader, r, args.start_epoch - 1) for epoch in range(args.start_epoch, args.epochs): if distributed_sampler: train_sampler.set_epoch(epoch) # train or run forward pass only for one epoch if args.forward_only: validate(val_loader, r, epoch) else: train(train_loader, r, optimizer, epoch) # evaluate on validation set prec1 = validate(val_loader, r, epoch) if r.stage != r.num_stages: prec1 = 0 # remember best prec@1 and save checkpoint best_prec1 = max(prec1, best_prec1) should_save_checkpoint = args.checkpoint_dir_not_nfs or r.rank_in_stage == 0 if args.checkpoint_dir and should_save_checkpoint: save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': r.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, args.checkpoint_dir, r.stage)
def main(): global args, best_prec1 args = parser.parse_args() # Special case handling for GNMT model l2_promote() torch.cuda.set_device(args.local_rank) # build tokenizer tokenizer = Tokenizer(os.path.join(args.data_dir, config.VOCAB_FNAME)) # define loss function criterion = build_gnmt_criterion(vocab_size=tokenizer.vocab_size, padding_idx=config.PAD, smoothing=0.1) # create stages of the model module = importlib.import_module(args.module) args.arch = module.arch() model = module.model(criterion) input_size = [args.max_length_train, args.batch_size] training_tensor_shapes = { "input0": input_size, "input1": [args.batch_size], "input2": input_size, "target": [args.max_length_train * args.batch_size], "target_length": [args.batch_size] } dtypes = { "input0": torch.int64, "input1": torch.int64, "input2": torch.int64, "target": torch.int64, "target_length": torch.int32 } inputs_module_destinations = {"input0": 0, "input1": 0, "input2": 0} target_tensor_names = {"target", "target_length"} for module_id, (stage, inputs, outputs) in enumerate( model[:-1]): # Skip last layer (loss). input_tensors = [] for module_input in inputs: if module_input in inputs_module_destinations: inputs_module_destinations[module_input] = module_id input_tensor = torch.ones(tuple( training_tensor_shapes[module_input]), dtype=dtypes[module_input]).cuda() input_tensors.append(input_tensor) stage.cuda() # PyTorch should not maintain metadata for a backward pass on # synthetic inputs. Without the following line, the runtime is # as much as 1.5x slower in a full DP configuration. with torch.no_grad(): output_tensors = stage(*tuple(input_tensors)) if not type(output_tensors) is tuple: output_tensors = [output_tensors] for output, output_tensor in zip(outputs, list(output_tensors)): training_tensor_shapes[output] = list(output_tensor.size()) dtypes[output] = output_tensor.dtype eval_tensor_shapes = {} for key in training_tensor_shapes: eval_tensor_shapes[key] = tuple(training_tensor_shapes[key]) training_tensor_shapes[key] = tuple(training_tensor_shapes[key]) configuration_maps = { 'module_to_stage_map': None, 'stage_to_rank_map': None, 'stage_to_depth_map': None } if args.config_path is not None: json_config_file = json.load(open(args.config_path, 'r')) configuration_maps['module_to_stage_map'] = json_config_file.get( "module_to_stage_map", None) configuration_maps['stage_to_rank_map'] = json_config_file.get( "stage_to_rank_map", None) configuration_maps['stage_to_rank_map'] = { int(k): v for (k, v) in configuration_maps['stage_to_rank_map'].items() } configuration_maps['stage_to_depth_map'] = json_config_file.get( "stage_to_depth_map", None) r = runtime.StageRuntime( model=model, distributed_backend=args.distributed_backend, fp16=args.fp16, loss_scale=args.loss_scale, training_tensor_shapes=training_tensor_shapes, eval_tensor_shapes=eval_tensor_shapes, training_tensor_dtypes=dtypes, inputs_module_destinations=inputs_module_destinations, target_tensor_names=target_tensor_names, configuration_maps=configuration_maps, master_addr=args.master_addr, rank=args.rank, local_rank=args.local_rank, num_ranks_in_server=args.num_ranks_in_server, verbose_freq=args.verbose_frequency, model_type=runtime.TRANSLATION, enable_recompute=args.recompute) # stage needed to determine if current stage is the first stage # num_stages needed to determine if current stage is the last stage # num_ranks needed to determine number of warmup_minibatches in case of pipelining args.stage = r.stage args.num_stages = r.num_stages args.num_ranks = r.num_ranks if not is_first_stage(): args.synthetic_data = True # define optimizer if args.no_input_pipelining: num_versions = 1 else: # number of versions is the total number of machines following the current # stage, shared amongst all replicas in this stage num_versions = r.num_warmup_minibatches + 1 # if specified, resume from checkpoint if args.resume: checkpoint_file_path = "%s.%d.pth.tar" % (args.resume, r.stage) assert os.path.isfile(checkpoint_file_path) print("=> loading checkpoint '{}'".format(checkpoint_file_path)) checkpoint = torch.load(checkpoint_file_path) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] r.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_file_path, checkpoint['epoch'])) # TODO: make this configurable by args use_adam_optimizer = True if use_adam_optimizer: optimizer = adam.AdamWithWeightStashing( modules=r.modules(), master_parameters=r.master_parameters, model_parameters=r.model_parameters, loss_scale=args.loss_scale, num_versions=num_versions, lr=args.lr, betas=(0.9, 0.999), weight_decay=args.weight_decay, verbose_freq=args.verbose_frequency, macrobatch=args.macrobatch) else: optimizer = sgd.SGDWithWeightStashing( modules=r.modules(), master_parameters=r.master_parameters, model_parameters=r.model_parameters, loss_scale=args.loss_scale, num_versions=num_versions, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, verbose_freq=args.verbose_frequency) if args.resume: optimizer.load_state_dict(checkpoint['optimizer']) cudnn.benchmark = True train_dataset = LazyParallelDataset( src_fname=os.path.join(args.data_dir, config.SRC_TRAIN_FNAME), tgt_fname=os.path.join(args.data_dir, config.TGT_TRAIN_FNAME), tokenizer=tokenizer, min_len=args.min_length_train, max_len=args.max_length_train, sort=False, max_size=None) val_dataset = ParallelDataset( src_fname=os.path.join(args.data_dir, config.SRC_VAL_FNAME), tgt_fname=os.path.join(args.data_dir, config.TGT_VAL_FNAME), tokenizer=tokenizer, min_len=args.min_length_train, max_len=args.max_length_train, sort=True) distributed_sampler = False if configuration_maps['stage_to_rank_map'] is not None: num_ranks_in_first_stage = len( configuration_maps['stage_to_rank_map'][0]) if num_ranks_in_first_stage > 1: distributed_sampler = True # TODO: fix random seeds train_loader = train_dataset.get_loader( batch_size=args.batch_size, seeds=range(args.epochs), batch_first=False, shuffle=True, bucketing=not args.no_bucketing, num_workers=args.workers, world_size=r.num_ranks_in_first_stage, rank=r.rank_in_stage if r.stage == 0 else 0) val_loader = val_dataset.get_loader( batch_size=args.batch_size, batch_first=False, shuffle=True, num_workers=args.workers, world_size=r.num_ranks_in_first_stage, seeds=range(args.epochs), rank=r.rank_in_stage if r.stage == 0 else 0) # if checkpoint is loaded, start by running validation if args.resume: assert args.start_epoch > 0 validate(val_loader, r, args.start_epoch - 1) for epoch in range(args.start_epoch, args.epochs): if distributed_sampler: train_loader.sampler.set_epoch(epoch) adjust_learning_rate(optimizer, epoch, args.epochs, r, args.lr_policy) # train or run forward pass only for one epoch if args.forward_only: validate(val_loader, r, epoch) else: train(train_loader, r, optimizer, epoch) # evaluate on validation set prec1 = validate(val_loader, r, epoch) if r.stage != r.num_stages: prec1 = 0 # remember best prec@1 and save checkpoint best_prec1 = max(prec1, best_prec1) should_save_checkpoint = args.checkpoint_dir_not_nfs or r.rank_in_stage == 0 if args.checkpoint_dir and should_save_checkpoint: save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': r.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), 'tokenizer': tokenizer.get_state() }, args.checkpoint_dir, r.stage, epoch)
def main(): global args, best_prec1 args = parser.parse_args() if rank >= args.world_size: return print("initialising device...") local_rank = rank % args.num_ranks_in_server print("workers = ", args.workers) writer = None if args.log_dir: writer = SummaryWriter(log_dir=args.log_dir) ##### ENABLING GPU DIRECT HERE THROUGH A HACK ### args.num_ranks_in_server = args.world_size torch.cuda.set_device(local_rank) print("local rank {} device {}".format(local_rank, torch.cuda.current_device())) args.rank = rank # my change # define loss function (criterion) criterion = nn.CrossEntropyLoss() # create stages of the model module = importlib.import_module(args.module) args.arch = module.arch() model = module.model(criterion) print("Local rank {} imported module".format(local_rank)) # determine shapes of all tensors in passed-in model target_size = [args.batch_size] if args.dataset_name == "ImageNet": if args.arch == 'inception_v3': input_size = [args.batch_size, 3, 299, 299] else: input_size = [args.batch_size, 3, 224, 224] first_stage_input_dtype = torch.float32 elif args.dataset_name == "MNIST": input_size = [args.batch_size, 1, 28, 28] first_stage_input_dtype = torch.float32 elif args.dataset_name == "CIFAR10": input_size = [args.batch_size, 3, 32, 32] first_stage_input_dtype = torch.float32 elif args.dataset_name in ["wikitext-2", "wikitext-103"]: input_size = [args.batch_size, args.bptt_len] first_stage_input_dtype = torch.int64 target_size = [args.batch_size * args.bptt_len] else: print("Dataset {} not supported".format(args.dataset_name)) training_tensor_shapes = {"input0": input_size, "target": target_size} dtypes = {"input0": torch.int64, "target": torch.int64} inputs_module_destinations = {"input": 0} target_tensor_names = {"target"} stage_number = 0 for (stage, inputs, outputs) in model[:-1]: # Skip last layer (loss). input_tensors = [] for input in inputs: if stage_number == 0: input_dtype = first_stage_input_dtype else: input_dtype = torch.float32 input_tensor = torch.zeros(tuple(training_tensor_shapes[input]), dtype=input_dtype).cuda() input_tensors.append(input_tensor) stage_number += 1 stage.cuda() with torch.no_grad(): output_tensors = stage(*tuple(input_tensors)) if not type(output_tensors) is tuple: output_tensors = [output_tensors] for output, output_tensor in zip(outputs, list(output_tensors)): training_tensor_shapes[output] = list(output_tensor.size()) dtypes[output] = output_tensor.dtype del output_tensors del input_tensors stage.cpu() #print("local rank {} finished 1 forward pass...".format(local_rank)) eval_tensor_shapes = {} for key in training_tensor_shapes: eval_tensor_shapes[key] = tuple([args.eval_batch_size] + training_tensor_shapes[key][1:]) training_tensor_shapes[key] = tuple(training_tensor_shapes[key]) configuration_maps = { 'module_to_stage_map': None, 'stage_to_rank_map': None, 'stage_to_depth_map': None } if args.config_path is not None: json_config_file = json.load(open(args.config_path, 'r')) configuration_maps['module_to_stage_map'] = json_config_file.get( "module_to_stage_map", None) configuration_maps['stage_to_rank_map'] = json_config_file.get( "stage_to_rank_map", None) configuration_maps['stage_to_rank_map'] = { int(k): v for (k, v) in configuration_maps['stage_to_rank_map'].items() } # print("========================") # print(configuration_maps['stage_to_rank_map']) configuration_maps['stage_to_depth_map'] = json_config_file.get( "stage_to_depth_map", None) if args.data_prl: print("Modifying stage to rank map to be data parallel") stage_to_rank_map = configuration_maps['stage_to_rank_map'] for k in stage_to_rank_map: stage_to_rank_map[k] = list(range(args.world_size)) print("Local rank {} Staging runtime....".format(local_rank)) if args.language_modelling: model_type = runtime.LANGUAGE_MODELLING else: model_type = runtime.IMAGE_CLASSIFICATION r = runtime.StageRuntime( model=model, distributed_backend=args.distributed_backend, fp16=args.fp16, loss_scale=args.loss_scale, training_tensor_shapes=training_tensor_shapes, eval_tensor_shapes=eval_tensor_shapes, training_tensor_dtypes=dtypes, inputs_module_destinations=inputs_module_destinations, target_tensor_names=target_tensor_names, configuration_maps=configuration_maps, master_addr=args.master_addr, rank=args.rank, local_rank=local_rank, num_ranks_in_server=args.num_ranks_in_server, verbose_freq=args.verbose_frequency, model_type=model_type, enable_recompute=args.recompute) # stage needed to determine if current stage is the first stage # num_stages needed to determine if current stage is the last stage # num_ranks needed to determine number of warmup_minibatches in case of pipelining args.stage = r.stage args.num_stages = r.num_stages args.num_ranks = r.num_ranks if not is_first_stage(): args.synthetic_data = True # define optimizer if args.no_input_pipelining: num_versions = 1 else: # number of versions is the total number of machines following the current # stage, shared amongst all replicas in this stage num_versions = r.num_warmup_minibatches + 1 # if specified, resume from checkpoint if args.resume: checkpoint_file_path = "%s.%d.pth.tar" % (args.resume, r.stage) assert os.path.isfile(checkpoint_file_path) print("=> loading checkpoint '{}'".format(checkpoint_file_path)) checkpoint = torch.load(checkpoint_file_path) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] r.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_file_path, checkpoint['epoch'])) #print_msg(args.rank, "number of versions" + str(num_versions) ) if args.language_modelling: optimizer = adam.AdamWithWeightStashing( r.modules(), r.master_parameters, r.model_parameters, args.loss_scale, num_versions=num_versions, lr=args.lr, weight_decay=args.weight_decay, verbose_freq=args.verbose_frequency, macrobatch=args.macrobatch) else: # optimizer = sgd.SGDWithWeightStashing(r.modules(), r.master_parameters, # r.model_parameters, args.loss_scale, # num_versions=num_versions, # lr=args.lr, # momentum=args.momentum, # weight_decay=args.weight_decay, # verbose_freq=args.verbose_frequency, # macrobatch=args.macrobatch) optimizer = adam.AdamWithWeightStashing( r.modules(), r.master_parameters, r.model_parameters, args.loss_scale, num_versions=num_versions, lr=args.lr, weight_decay=args.weight_decay, verbose_freq=args.verbose_frequency, macrobatch=args.macrobatch) if args.resume: optimizer.load_state_dict(checkpoint['optimizer']) cudnn.benchmark = True # Data loading code normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) print(args.dataset_name) if args.dataset_name == "ImageNet": if args.arch == 'inception_v3': if args.synthetic_data: train_dataset = SyntheticDatasetImageClassification( (3, 299, 299), 10000) else: traindir = os.path.join(args.data_dir, 'train') train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(299), transforms.ToTensor(), normalize, ])) else: print("Initialising dataset..") if args.synthetic_data: train_dataset = SyntheticDatasetImageClassification( (3, 224, 224), 1281168) #modified else: traindir = os.path.join(args.data_dir, 'train') train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) if args.synthetic_data: val_dataset = SyntheticDatasetImageClassification((3, 224, 224), 10000) else: valdir = os.path.join(args.data_dir, 'val') val_dataset = datasets.ImageFolder( valdir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) # val_dataset = SyntheticDatasetImageClassification((3, 224, 224), 10000) elif args.dataset_name == "MNIST": train_dataset = datasets.MNIST(args.data_dir, download=True, train=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )), ])) val_dataset = datasets.MNIST( args.data_dir, download=True, train=False, transform=transforms.Compose([ transforms.ToTensor( ), # first, convert image to PyTorch tensor transforms.Normalize((0.1307, ), (0.3081, )) ])) elif args.dataset_name == "CIFAR10": transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) train_dataset = datasets.CIFAR10(root=args.data_dir, train=True, transform=transform) val_dataset = datasets.CIFAR10(root=args.data_dir, train=False, transform=transform) elif args.dataset_name in args.dataset_name in [ "wikitext-2", "wikitext-103" ]: tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') if not args.synthetic_data: train_dataset = huggingface.get_dataset(args.dataset_name, tokenizer, 'train', num_workers=1, bptt_len=args.bptt_len, cache_dir=args.data_dir) val_dataset = huggingface.get_dataset(args.dataset_name, tokenizer, 'validation', num_workers=1, bptt_len=args.bptt_len, cache_dir=args.data_dir) else: if args.dataset_name == "wikitext-2": train_length = 36718 else: train_length = 1801350 train_dataset = SyntheticDatasetLanguageModelling( tokenizer.vocab_size, args.bptt_len, train_length) val_dataset = SyntheticDatasetLanguageModelling( tokenizer.vocab_size, args.bptt_len, 3760) distributed_sampler = False train_sampler = None val_sampler = None if configuration_maps['stage_to_rank_map'] is not None: num_ranks_in_first_stage = len( configuration_maps['stage_to_rank_map'][0]) if num_ranks_in_first_stage > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=num_ranks_in_first_stage, rank=args.rank % num_ranks_in_first_stage) val_sampler = torch.utils.data.distributed.DistributedSampler( val_dataset, num_replicas=num_ranks_in_first_stage, rank=args.rank % num_ranks_in_first_stage) distributed_sampler = True print("initialising data loaders") train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=True) print( f"Rank {args.rank}: Length of train loader: {len(train_loader)} Length of dataset: {len(train_dataset)} BPTT_LEN {args.bptt_len} BATCH SIZE {args.batch_size}" ) # else: # train_loader = None # val_loader = None # if args.rank==0: # lengths = torch.LongTensor([len(train_loader), len(val_loader)]).cuda() # else: # lengths = torch.zeros((2)).long().cuda() lengths = torch.LongTensor([len(train_loader), len(val_loader)]) if rank == 0: quantities = [len(configuration_maps['stage_to_rank_map'][0])] for i in range(len(configuration_maps['stage_to_rank_map']) - 1): curr = len(configuration_maps['stage_to_rank_map'][i]) curr *= len(configuration_maps['stage_to_rank_map'][i + 1]) quantities.append(curr) print(quantities) lcm = np.lcm.reduce(quantities) print(f"new length should be a multiple of {lcm}") old_length = lengths[0].item() lengths[0] = (lengths[0] // lcm) * lcm print( f"Rank {args.rank} : Old Train length {old_length} Adjusted Length {lengths[0]}" ) old_length = lengths[1].item() lengths[1] = (lengths[1] // lcm) * lcm print( f"Rank {args.rank} Old Val length {old_length} Adjusted Length {lengths[1]}" ) dist.broadcast(lengths, src=0) else: dist.broadcast(lengths, src=0) num_ranks_in_first_stage = len( configuration_maps['stage_to_rank_map'][0]) lengths[0] *= num_ranks_in_first_stage lengths[1] *= num_ranks_in_first_stage lengths[0] = lengths[0] // r.num_ranks_in_stage lengths[1] = lengths[1] // r.num_ranks_in_stage train_len = lengths[0] val_len = lengths[1] print( f"rank {args.rank}, Adjusted train length {train_len}, Adjusted val length {val_len}" ) #exit() # if checkpoint is loaded, start by running validation if args.resume: assert args.start_epoch > 0 validate(val_loader, r, args.start_epoch - 1) for epoch in range(args.start_epoch, args.epochs): if args.rank == 0 and distributed_sampler: train_sampler.set_epoch(epoch) # train or run forward pass only for one epoch if args.forward_only: validate(val_loader, r, epoch) else: train(train_loader, r, optimizer, epoch, model_type, lengths, writer) # evaluate on validation set prec1 = validate(val_loader, r, epoch, lengths, model_type)