def initialize(self, model, inputs_module_destinations, configuration_maps, master_addr, rank, local_rank, num_ranks_in_server): self.send_ranks = {} self.receive_ranks = {} self.rank = rank self.local_rank = local_rank self.stage = None self.tensor_tags = {} self.forward_minibatch_id = 0 self.backward_minibatch_id = 0 self.criterion_input_name = str(model[-1][1][0]) tensor_tag = 1 for (_, input_tensors, output_tensors) in model: for input_tensor in input_tensors: if input_tensor not in self.tensor_tags: self.tensor_tags[input_tensor] = tensor_tag tensor_tag += 1 for output_tensor in output_tensors: if output_tensor not in self.tensor_tags: self.tensor_tags[output_tensor] = tensor_tag tensor_tag += 1 for target_tensor_name in sorted(self.target_tensor_names): self.tensor_tags[target_tensor_name] = tensor_tag tensor_tag += 1 self.tensor_tags["ack"] = tensor_tag tensor_tag += 1 module_to_stage_map = configuration_maps['module_to_stage_map'] stage_to_rank_map = configuration_maps['stage_to_rank_map'] stage_to_depth_map = configuration_maps['stage_to_depth_map'] if module_to_stage_map is None: # If IP addresses not specified, resort to all layers on # single machine. assert self.rank is None self.modules_with_dependencies = ModulesWithDependencies(model) self.is_criterion = True self.rank_in_stage = 0 self.num_ranks = 1 self.num_ranks_in_first_stage = 1 self.num_ranks_in_previous_stage = 0 self.num_ranks_in_next_stage = 0 self.num_stages = 1 self.num_ranks_in_stage = 1 self.num_warmup_minibatches = 0 self.comm_handler = None else: assert len(module_to_stage_map) == len(model) assert self.rank is not None stage_to_module_map = collections.defaultdict(list) for module in range(len(module_to_stage_map)): stage_to_module_map[module_to_stage_map[module]].append(module) rank_to_stage_map = {} for stage in stage_to_rank_map: for rank in stage_to_rank_map[stage]: rank_to_stage_map[rank] = stage # Now, use this mapping to determine the modules contained in # each stage. assert 0 <= self.rank < len(rank_to_stage_map) self.num_ranks = len(rank_to_stage_map) self.num_stages = len(stage_to_module_map) self.stage = rank_to_stage_map[self.rank] self.rank_in_stage = stage_to_rank_map[self.stage].index(self.rank) self.num_ranks_in_stage = len(stage_to_rank_map[self.stage]) self.num_ranks_in_first_stage = len(stage_to_rank_map[0]) self.num_ranks_in_previous_stage = 0 self.ranks_in_previous_stage = [] if self.stage > 0: self.num_ranks_in_previous_stage = len( stage_to_rank_map[self.stage - 1]) self.ranks_in_previous_stage = stage_to_rank_map[self.stage - 1] self.num_ranks_in_next_stage = 0 self.ranks_in_next_stage = [] if self.stage < self.num_stages - 1: self.num_ranks_in_next_stage = len( stage_to_rank_map[self.stage + 1]) self.ranks_in_next_stage = stage_to_rank_map[self.stage + 1] modules = stage_to_module_map[self.stage] self.modules_with_dependencies = ModulesWithDependencies( [model[module] for module in modules]) self.is_criterion = self.stage == (self.num_stages - 1) if stage_to_depth_map is not None: self.num_warmup_minibatches = stage_to_depth_map[ str(self.stage)] else: self.num_warmup_minibatches = self.num_ranks - 1 for i in range(self.stage): self.num_warmup_minibatches -= len( stage_to_rank_map[i]) self.num_warmup_minibatches = self.num_warmup_minibatches // \ self.num_ranks_in_stage # To determine where tensors should be sent and received, first # determine the "producing" and "consuming" module IDs of each # tensor. We then use the corresponding machine ranks to send # and receive tensors. master_port = 12345 self.comm_handler = communication.CommunicationHandler( master_addr=master_addr, master_port=master_port, rank=self.rank, local_rank=self.local_rank, num_ranks_in_server=num_ranks_in_server, world_size=self.num_ranks, fp16=self.fp16, backend=self.distributed_backend) for i in range(len(model)): for j in range(i + 1, len(model)): for tensor_name in model[i][2]: if tensor_name in model[j][1]: if module_to_stage_map[i] == \ module_to_stage_map[j]: continue # For now, assume that each stage is served by only # a single machine. if module_to_stage_map[j] == self.stage: self.receive_ranks[tensor_name] = \ stage_to_rank_map[module_to_stage_map[i]] if module_to_stage_map[i] == self.stage: self.send_ranks[tensor_name] = \ stage_to_rank_map[module_to_stage_map[j]] for model_inputs in inputs_module_destinations.keys(): destination_stage = module_to_stage_map[ inputs_module_destinations[model_inputs]] if destination_stage > self.stage: self.send_ranks[model_inputs] = \ self.ranks_in_next_stage if 0 < self.stage <= destination_stage: self.receive_ranks[model_inputs] = \ self.ranks_in_previous_stage if destination_stage > 0: if model_inputs not in self.tensor_tags: self.tensor_tags[model_inputs] = tensor_tag tensor_tag += 1 modules = self.modules_with_dependencies.modules() for i in range(len(modules)): modules[i] = modules[i].cuda() if self.fp16: import apex.fp16_utils as fp16_utils modules[i] = fp16_utils.BN_convert_float(modules[i].half()) # Initialize all groups in the same order on every worker. if stage_to_rank_map is not None: groups = [] for stage in range(self.num_stages): ranks = stage_to_rank_map[stage] if len(ranks) > 1: groups.append(dist.new_group(ranks=ranks)) else: groups.append(None) group = groups[self.stage] else: group = None # self.modules_with_dependencies contains a list of PyTorch # modules, along with a list of user-defined input and output # tensor names. We use our module_executor.ModuleExecutor # class to wrap these dependencies, and use run_forward and # run_backward methods downstream. num_parameters = 0 for i in range(len(modules)): if group is not None: if ((i < (len(modules) - 1) and self.is_criterion) or not self.is_criterion): num_parameters += \ sum(x.size()[0] * x.size()[1] if len(x.size()) > 1 else x.size()[0] for x in modules[i].parameters() if x.size()) modules[i] = torch.nn.parallel.DistributedDataParallel( modules[i], process_group=group, device_ids=[local_rank], output_device=local_rank) if self.num_ranks_in_stage > 1: module_size = 4. * num_parameters print("Replicating stage: ranks=%d, module_size=%.3f" % ( self.num_ranks_in_stage, module_size)) if self.fp16: self.master_parameters = [] self.model_parameters = [] for i in range(len(modules)): import apex.fp16_utils as fp16_utils module_parameters, module_master_parameters = \ fp16_utils.prep_param_lists(modules[i]) self.master_parameters.extend(module_master_parameters) self.model_parameters.extend(module_parameters) else: self.master_parameters = list(self.parameters()) self.model_parameters = None if self.comm_handler is not None: self.comm_handler.initialize( self.receive_ranks, self.send_ranks, self.tensor_tags, self.target_tensor_names, self.training_tensor_dtypes, self.rank_in_stage, self.num_ranks_in_stage, self.ranks_in_previous_stage, self.ranks_in_next_stage)
if pretrained_path: logger.info(f'Resume from {pretrained_path}') if device == torch.device('cpu'): param = torch.load(pretrained_path, map_location='cpu' ) # parameters saved in checkpoint via model_path else: param = torch.load( pretrained_path) # parameters saved in checkpoint via model_path #param = torch.load(pretrained_path) model.load_state_dict(param) del param # fp16 if fp16: from apex import fp16_utils model = fp16_utils.BN_convert_float(model.half()) optimizer = fp16_utils.FP16_Optimizer(optimizer, verbose=False, dynamic_loss_scale=True) logger.info('Apply fp16') # Restore model if resume: model_path = output_dir.joinpath(f'model_tmp.pth') logger.info(f'Resume from {model_path}') param = torch.load(model_path) model.load_state_dict(param) del param opt_path = output_dir.joinpath(f'opt_tmp.pth') param = torch.load(opt_path) optimizer.load_state_dict(param)
def train(model, state, path, annotations, val_path, val_annotations, resize, max_size, jitter, batch_size, iterations, val_iterations, mixed_precision, lr, warmup, milestones, gamma, is_master=True, world=1, use_dali=True, verbose=True, metrics_url=None, logdir=None): 'Train the model on the given dataset' # Prepare dataset if verbose: print('Preparing dataset...') data_iterator = (DaliDataIterator if use_dali else DataIterator)( path, jitter, max_size, batch_size, model.stride, world, annotations, training=True) if verbose: print(data_iterator) # Prepare model nn_model = model model = convert_fixedbn_model(model) if torch.cuda.is_available(): model = model.cuda() if mixed_precision: model = fp16_utils.BN_convert_float(model.half()) if world > 1: model = DistributedDataParallel(model, delay_allreduce=True) model.train() # Setup optimizer and schedule optimizer = SGD(model.parameters(), lr=lr, weight_decay=0.0001, momentum=0.9) if mixed_precision: optimizer = fp16_utils.FP16_Optimizer(optimizer, static_loss_scale=128., verbose=False) if 'optimizer' in state: optimizer.load_state_dict(state['optimizer']) def schedule(train_iter): if warmup and train_iter <= warmup: return 0.9 * train_iter / warmup + 0.1 return gamma**len([m for m in milestones if m <= train_iter]) scheduler = LambdaLR(optimizer.optimizer if mixed_precision else optimizer, schedule) if verbose: print(' device: {} {}'.format( world, 'cpu' if not torch.cuda.is_available() else 'gpu' if world == 1 else 'gpus')) print(' batch: {}, precision: {}'.format( batch_size, 'mixed' if mixed_precision else 'full')) print('Training model for {} iterations...'.format(iterations)) # Create TensorBoard writer if logdir is not None: from tensorboardX import SummaryWriter if is_master and verbose: print('Writing TensorBoard logs to: {}'.format(logdir)) writer = SummaryWriter(log_dir=logdir) profiler = Profiler(['train', 'fw', 'bw']) iteration = state.get('iteration', 0) while iteration < iterations: cls_losses, box_losses = [], [] for i, (data, target) in enumerate(data_iterator): scheduler.step(iteration) # Forward pass profiler.start('fw') if mixed_precision: data = data.half() optimizer.zero_grad() cls_loss, box_loss = model([data, target]) del data profiler.stop('fw') # Backward pass profiler.start('bw') if mixed_precision: optimizer.backward(cls_loss + box_loss) else: (cls_loss + box_loss).backward() optimizer.step() # Reduce all losses cls_loss, box_loss = cls_loss.mean().clone(), box_loss.mean( ).clone() if world > 1: torch.distributed.all_reduce(cls_loss) torch.distributed.all_reduce(box_loss) cls_loss /= world box_loss /= world if is_master: cls_losses.append(cls_loss) box_losses.append(box_loss) if is_master and not isfinite(cls_loss + box_loss): raise RuntimeError('Loss is diverging!\n{}'.format( 'Try lowering the learning rate.')) del cls_loss, box_loss profiler.stop('bw') iteration += 1 profiler.bump('train') if is_master and (profiler.totals['train'] > 60 or iteration == iterations): focal_loss = torch.stack(list(cls_losses)).mean().item() box_loss = torch.stack(list(box_losses)).mean().item() learning_rate = optimizer.param_groups[0]['lr'] if verbose: msg = '[{:{len}}/{}]'.format(iteration, iterations, len=len(str(iterations))) msg += ' focal loss: {:.3f}'.format(focal_loss) msg += ', box loss: {:.3f}'.format(box_loss) msg += ', {:.3f}s/{}-batch'.format(profiler.means['train'], batch_size) msg += ' (fw: {:.3f}s, bw: {:.3f}s)'.format( profiler.means['fw'], profiler.means['bw']) msg += ', {:.1f} im/s'.format(batch_size / profiler.means['train']) msg += ', lr: {:.2g}'.format(learning_rate) print(msg, flush=True) if logdir is not None: writer.add_scalar('focal_loss', focal_loss, iteration) writer.add_scalar('box_loss', box_loss, iteration) writer.add_scalar('learning_rate', learning_rate, iteration) del box_loss, focal_loss if metrics_url: post_metrics( metrics_url, { 'focal loss': mean(cls_losses), 'box loss': mean(box_losses), 'im_s': batch_size / profiler.means['train'], 'lr': learning_rate }) # Save model weights state.update({ 'iteration': iteration, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), }) with ignore_sigint(): nn_model.save(state) profiler.reset() del cls_losses[:], box_losses[:] if val_annotations and (iteration == iterations or iteration % val_iterations == 0): infer(nn_model, val_path, None, resize, max_size, batch_size, annotations=val_annotations, mixed_precision=mixed_precision, is_master=is_master, world=world, use_dali=use_dali, verbose=False) model.train() if iteration == iterations: break if logdir is not None: writer.close()