def backward(self, grad_output): grad_output = grad_output.contiguous() saved_input, weight, mean, invstd, count_all = self.saved_tensors need_input_grad, need_weight_grad, need_bias_grad = self.needs_input_grad[ 0:3] # calculate local stats as well as grad_weight / grad_bias sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce( grad_output, saved_input, mean, invstd, weight, need_input_grad, need_weight_grad, need_bias_grad) if need_input_grad: # synchronizing stats used to calculate input gradient. sum_dy_handle = allreduce_async(sum_dy, op=Sum, name='sync_batch_norm.sum_dy') sum_dy_xmu_handle = allreduce_async( sum_dy_xmu, op=Sum, name='sync_batch_norm.sum_dy_xmu') # wait on the async communication to finish sum_dy = synchronize(sum_dy_handle) sum_dy_xmu = synchronize(sum_dy_xmu_handle) if _SYNC_BN_V2 or _SYNC_BN_V3: count_all_sum = count_all.sum() mean_dy = sum_dy / count_all_sum mean_dy_xmu = sum_dy_xmu / count_all_sum else: # before 1.5.0, sum_dy was sum of means from every worker, so we just # need to divide it by number of workers mean_dy = sum_dy / size() mean_dy_xmu = sum_dy_xmu / size() # backward pass for gradient calculation grad_input = torch.batch_norm_backward_elemt( grad_output, saved_input, mean, invstd, weight, mean_dy, mean_dy_xmu) else: grad_input = None # synchronizing of grad_weight / grad_bias is not needed as distributed # training would handle all reduce. if weight is None or not need_weight_grad: grad_weight = None if weight is None or not need_bias_grad: grad_bias = None return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
def upsnet_train(): if is_master: logger.info('training config:{}\n'.format(pprint.pformat(config))) gpus = [torch.device('cuda', int(_)) for _ in config.gpus.split(',')] num_replica = hvd.size() if config.train.use_horovod else len(gpus) num_gpus = 1 if config.train.use_horovod else len(gpus) # create models train_model = eval(config.symbol)().cuda() # create optimizer params_lr = train_model.get_params_lr() # we use custom optimizer and pass lr=1 to support different lr for different weights optimizer = SGD(params_lr, lr=1, momentum=config.train.momentum, weight_decay=config.train.wd) if config.train.use_horovod: optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=train_model.named_parameters()) optimizer.zero_grad() # create data loader train_dataset = eval(config.dataset.dataset)(image_sets=config.dataset.image_set.split('+'), flip=config.train.flip, result_path=final_output_path) val_dataset = eval(config.dataset.dataset)(image_sets=config.dataset.test_image_set.split('+'), flip=False, result_path=final_output_path, phase='val') if config.train.use_horovod: train_sampler = distributed.DistributedSampler(train_dataset, num_replicas=hvd.size(), rank=hvd.rank()) val_sampler = distributed.DistributedSampler(val_dataset, num_replicas=hvd.size(), rank=hvd.rank()) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train.batch_size, sampler=train_sampler, num_workers=num_gpus * 4, drop_last=False, collate_fn=train_dataset.collate) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.train.batch_size, sampler=val_sampler, num_workers=num_gpus * 4, drop_last=False, collate_fn=val_dataset.collate) else: train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train.batch_size, shuffle=config.train.shuffle, num_workers=num_gpus * 4 if not config.debug_mode else num_gpus * 4, drop_last=False, collate_fn=train_dataset.collate) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.train.batch_size, shuffle=False, num_workers=num_gpus * 4 if not config.debug_mode else num_gpus * 4, drop_last=False, collate_fn=val_dataset.collate) # preparing curr_iter = config.train.begin_iteration batch_end_callback = [Speedometer(num_replica * config.train.batch_size, config.train.display_iter)] metrics = [] metrics_name = [] if config.network.has_rpn: metrics.extend([AvgMetric(name='rpn_cls_loss'), AvgMetric(name='rpn_bbox_loss'),]) metrics_name.extend(['rpn_cls_loss', 'rpn_bbox_loss']) if config.network.has_rcnn: metrics.extend([AvgMetric(name='rcnn_accuracy'), AvgMetric(name='cls_loss'), AvgMetric(name='bbox_loss'),]) metrics_name.extend(['rcnn_accuracy', 'cls_loss', 'bbox_loss']) if config.network.has_mask_head: metrics.extend([AvgMetric(name='mask_loss'), ]) metrics_name.extend(['mask_loss']) if config.network.has_fcn_head: metrics.extend([AvgMetric(name='fcn_loss'), ]) metrics_name.extend(['fcn_loss']) if config.train.fcn_with_roi_loss: metrics.extend([AvgMetric(name='fcn_roi_loss'), ]) metrics_name.extend(['fcn_roi_loss']) if config.network.has_panoptic_head: metrics.extend([AvgMetric(name='panoptic_accuracy'), AvgMetric(name='panoptic_loss'), ]) metrics_name.extend(['panoptic_accuracy', 'panoptic_loss']) if config.train.resume: train_model.load_state_dict(torch.load(os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth')), resume=True) optimizer.load_state_dict(torch.load(os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth'))) if config.train.use_horovod: hvd.broadcast_parameters(train_model.state_dict(), root_rank=0) else: if is_master: train_model.load_state_dict(torch.load(config.network.pretrained)) if config.train.use_horovod: hvd.broadcast_parameters(train_model.state_dict(), root_rank=0) if not config.train.use_horovod: train_model = DataParallel(train_model, device_ids=[int(_) for _ in config.gpus.split(',')]).to(gpus[0]) if is_master: batch_end_callback[0](0, 0) train_model.eval() # start training while curr_iter < config.train.max_iteration: if config.train.use_horovod: train_sampler.set_epoch(curr_iter) if config.network.use_syncbn: train_model.train() if config.network.backbone_freeze_at > 0: train_model.freeze_backbone(config.network.backbone_freeze_at) if config.network.backbone_fix_bn: train_model.resnet_backbone.eval() for inner_iter, batch in enumerate(train_loader): data, label, _ = batch for k, v in data.items(): data[k] = v if not torch.is_tensor(v) else v.cuda() for k, v in label.items(): label[k] = v if not torch.is_tensor(v) else v.cuda() lr = adjust_learning_rate(optimizer, curr_iter, config) optimizer.zero_grad() output = train_model(data, label) loss = 0 if config.network.has_rpn: loss = loss + output['rpn_cls_loss'].mean() + output['rpn_bbox_loss'].mean() if config.network.has_rcnn: loss = loss + output['cls_loss'].mean() + output['bbox_loss'].mean() * config.train.bbox_loss_weight if config.network.has_mask_head: loss = loss + output['mask_loss'].mean() if config.network.has_fcn_head: loss = loss + output['fcn_loss'].mean() * config.train.fcn_loss_weight if config.train.fcn_with_roi_loss: loss = loss + output['fcn_roi_loss'].mean() * config.train.fcn_loss_weight * 0.2 if config.network.has_panoptic_head: loss = loss + output['panoptic_loss'].mean() * config.train.panoptic_loss_weight loss.backward() optimizer.step(lr) losses = [] losses.append(allreduce_async(loss, name='train_total_loss')) for l in metrics_name: losses.append(allreduce_async(output[l].mean(), name=l)) loss = hvd.synchronize(losses[0]).item() if is_master: writer.add_scalar('train_total_loss', loss, curr_iter) for i, (metric, l) in enumerate(zip(metrics, metrics_name)): loss = hvd.synchronize(losses[i + 1]).item() if is_master: writer.add_scalar('train_' + l, loss, curr_iter) metric.update(_, _, loss) curr_iter += 1 if curr_iter in config.train.decay_iteration: if is_master: logger.info('decay momentum buffer') for k in optimizer.state_dict()['state'].keys(): if 'momentum_buffer' in optimizer.state_dict()['state'][k]: optimizer.state_dict()['state'][k]['momentum_buffer'].div_(10) if is_master: if curr_iter % config.train.display_iter == 0: for callback in batch_end_callback: callback(curr_iter, metrics) if curr_iter % config.train.snapshot_step == 0: logger.info('taking snapshot ...') torch.save(train_model.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.pth')) torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.state.pth')) else: inner_iter = 0 train_iterator = train_loader.__iter__() while inner_iter + num_gpus <= len(train_loader): batch = [] for gpu_id in gpus: data, label, _ = train_iterator.next() for k, v in data.items(): data[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True) for k, v in label.items(): label[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True) batch.append((data, label)) inner_iter += 1 lr = adjust_learning_rate(optimizer, curr_iter, config) optimizer.zero_grad() if config.train.use_horovod: output = train_model(data, label) else: output = train_model(*batch) loss = 0 if config.network.has_rpn: loss = loss + output['rpn_cls_loss'].mean() + output['rpn_bbox_loss'].mean() if config.network.has_rcnn: loss = loss + output['cls_loss'].mean() + output['bbox_loss'].mean() if config.network.has_mask_head: loss = loss + output['mask_loss'].mean() if config.network.has_fcn_head: loss = loss + output['fcn_loss'].mean() * config.train.fcn_loss_weight if config.train.fcn_with_roi_loss: loss = loss + output['fcn_roi_loss'].mean() * config.train.fcn_loss_weight * 0.2 if config.network.has_panoptic_head: loss = loss + output['panoptic_loss'].mean() * config.train.panoptic_loss_weight loss.backward() optimizer.step(lr) losses = [] losses.append(loss.item()) for l in metrics_name: losses.append(output[l].mean().item()) loss = losses[0] if is_master: writer.add_scalar('train_total_loss', loss, curr_iter) for i, (metric, l) in enumerate(zip(metrics, metrics_name)): loss = losses[i + 1] if is_master: writer.add_scalar('train_' + l, loss, curr_iter) metric.update(_, _, loss) curr_iter += 1 if curr_iter in config.train.decay_iteration: if is_master: logger.info('decay momentum buffer') for k in optimizer.state_dict()['state'].keys(): optimizer.state_dict()['state'][k]['momentum_buffer'].div_(10) if is_master: if curr_iter % config.train.display_iter == 0: for callback in batch_end_callback: callback(curr_iter, metrics) if curr_iter % config.train.snapshot_step == 0: logger.info('taking snapshot ...') torch.save(train_model.module.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.pth')) torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.state.pth')) while True: try: train_iterator.next() except: break for metric in metrics: metric.reset() if config.train.eval_data: train_model.eval() if config.train.use_horovod: for inner_iter, batch in enumerate(val_loader): data, label, _ = batch for k, v in data.items(): data[k] = v if not torch.is_tensor(v) else v.cuda(non_blocking=True) for k, v in label.items(): label[k] = v if not torch.is_tensor(v) else v.cuda(non_blocking=True) with torch.no_grad(): output = train_model(data, label) for metric, l in zip(metrics, metrics_name): loss = hvd.allreduce(output[l].mean()).item() if is_master: metric.update(_, _, loss) else: inner_iter = 0 val_iterator = val_loader.__iter__() while inner_iter + len(gpus) <= len(val_loader): batch = [] for gpu_id in gpus: data, label, _ = val_iterator.next() for k, v in data.items(): data[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True) for k, v in label.items(): label[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True) batch.append((data, label)) inner_iter += 1 with torch.no_grad(): if config.train.use_horovod: output = train_model(data, label) else: output = train_model(*batch) losses = [] for l in metrics_name: losses.append(allreduce_async(output[l].mean(), name=l) if config.train.use_horovod else output[l].mean().item()) for metric, loss in zip(metrics, losses): loss = hvd.synchronize(loss).item() if config.train.use_horovod else loss if is_master: metric.update(_, _, loss) while True: try: val_iterator.next() except Exception: break s = 'Batch [%d]\t Epoch[%d]\t' % (curr_iter, curr_iter // len(train_loader)) for metric in metrics: m, v = metric.get() s += 'Val-%s=%f,\t' % (m, v) if is_master: writer.add_scalar('val_' + m, v, curr_iter) metric.reset() if is_master: logger.info(s) if is_master and config.train.use_horovod: logger.info('taking snapshot ...') torch.save(train_model.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth')) torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth')) elif not config.train.use_horovod: logger.info('taking snapshot ...') torch.save(train_model.module.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth')) torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth'))
def backward(self, grad_output): grad_output = grad_output.contiguous() saved_input, weight, mean, invstd, count_all = self.saved_tensors need_input_grad, need_weight_grad, need_bias_grad = self.needs_input_grad[ 0:3] # calculate local stats as well as grad_weight / grad_bias sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce( grad_output, saved_input, mean, invstd, weight, need_input_grad, need_weight_grad, need_bias_grad) if need_input_grad: # synchronizing stats used to calculate input gradient. sum_dy_handle = allreduce_async(sum_dy, op=Sum, name='sync_batch_norm.sum_dy') sum_dy_xmu_handle = allreduce_async( sum_dy_xmu, op=Sum, name='sync_batch_norm.sum_dy_xmu') # wait on the async communication to finish sum_dy = synchronize(sum_dy_handle) sum_dy_xmu = synchronize(sum_dy_xmu_handle) if _SYNC_BN_V4: # from 1.9.0 on we need a count tensor on all devices # count_all is calculated as total count across all ranks in forward function count_all = count_all.to(dtype=torch.int, device=grad_output.device) elif _SYNC_BN_V2 or _SYNC_BN_V3: # before 1.9.0 we need the count as an integer to compute means values count = count_all.sum() else: # before 1.5.0, sum_dy was sum of means from every worker, so we just # need to divide it by number of workers count = size() # backward pass for gradient calculation # we are calling into a non-public undocumented function which broke moving to 1.9.0 # https://github.com/pytorch/pytorch/issues/57900 if _SYNC_BN_V4: # from 1.9.0 on, sums and count parameters expected grad_input = torch.batch_norm_backward_elemt( grad_output, saved_input, mean, invstd, weight, sum_dy, sum_dy_xmu, count_all) else: # before 1.9.0, mean parameters expected, not sums and count grad_input = torch.batch_norm_backward_elemt( grad_output, saved_input, mean, invstd, weight, sum_dy / count, sum_dy_xmu / count) else: grad_input = None # synchronizing of grad_weight / grad_bias is not needed as distributed # training would handle all reduce. if weight is None or not need_weight_grad: grad_weight = None if weight is None or not need_bias_grad: grad_bias = None return grad_input, grad_weight, grad_bias, None, None, None, None, None, None