def upsnet_train(): if is_master: logger.info('training config:{}\n'.format(pprint.pformat(config))) gpus = [torch.device('cuda', int(_)) for _ in config.gpus.split(',')] num_replica = hvd.size() if config.train.use_horovod else len(gpus) num_gpus = 1 if config.train.use_horovod else len(gpus) # create models train_model = eval(config.symbol)().cuda() # create optimizer params_lr = train_model.get_params_lr() # we use custom optimizer and pass lr=1 to support different lr for different weights optimizer = SGD(params_lr, lr=1, momentum=config.train.momentum, weight_decay=config.train.wd) if config.train.use_horovod: optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=train_model.named_parameters()) optimizer.zero_grad() # create data loader train_dataset = eval(config.dataset.dataset)(image_sets=config.dataset.image_set.split('+'), flip=config.train.flip, result_path=final_output_path) val_dataset = eval(config.dataset.dataset)(image_sets=config.dataset.test_image_set.split('+'), flip=False, result_path=final_output_path, phase='val') if config.train.use_horovod: train_sampler = distributed.DistributedSampler(train_dataset, num_replicas=hvd.size(), rank=hvd.rank()) val_sampler = distributed.DistributedSampler(val_dataset, num_replicas=hvd.size(), rank=hvd.rank()) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train.batch_size, sampler=train_sampler, num_workers=num_gpus * 4, drop_last=False, collate_fn=train_dataset.collate) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.train.batch_size, sampler=val_sampler, num_workers=num_gpus * 4, drop_last=False, collate_fn=val_dataset.collate) else: train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train.batch_size, shuffle=config.train.shuffle, num_workers=num_gpus * 4 if not config.debug_mode else num_gpus * 4, drop_last=False, collate_fn=train_dataset.collate) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.train.batch_size, shuffle=False, num_workers=num_gpus * 4 if not config.debug_mode else num_gpus * 4, drop_last=False, collate_fn=val_dataset.collate) # preparing curr_iter = config.train.begin_iteration batch_end_callback = [Speedometer(num_replica * config.train.batch_size, config.train.display_iter)] metrics = [] metrics_name = [] if config.network.has_rpn: metrics.extend([AvgMetric(name='rpn_cls_loss'), AvgMetric(name='rpn_bbox_loss'),]) metrics_name.extend(['rpn_cls_loss', 'rpn_bbox_loss']) if config.network.has_rcnn: metrics.extend([AvgMetric(name='rcnn_accuracy'), AvgMetric(name='cls_loss'), AvgMetric(name='bbox_loss'),]) metrics_name.extend(['rcnn_accuracy', 'cls_loss', 'bbox_loss']) if config.network.has_mask_head: metrics.extend([AvgMetric(name='mask_loss'), ]) metrics_name.extend(['mask_loss']) if config.network.has_fcn_head: metrics.extend([AvgMetric(name='fcn_loss'), ]) metrics_name.extend(['fcn_loss']) if config.train.fcn_with_roi_loss: metrics.extend([AvgMetric(name='fcn_roi_loss'), ]) metrics_name.extend(['fcn_roi_loss']) if config.network.has_panoptic_head: metrics.extend([AvgMetric(name='panoptic_accuracy'), AvgMetric(name='panoptic_loss'), ]) metrics_name.extend(['panoptic_accuracy', 'panoptic_loss']) if config.train.resume: train_model.load_state_dict(torch.load(os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth')), resume=True) optimizer.load_state_dict(torch.load(os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth'))) if config.train.use_horovod: hvd.broadcast_parameters(train_model.state_dict(), root_rank=0) else: if is_master: train_model.load_state_dict(torch.load(config.network.pretrained)) if config.train.use_horovod: hvd.broadcast_parameters(train_model.state_dict(), root_rank=0) if not config.train.use_horovod: train_model = DataParallel(train_model, device_ids=[int(_) for _ in config.gpus.split(',')]).to(gpus[0]) if is_master: batch_end_callback[0](0, 0) train_model.eval() # start training while curr_iter < config.train.max_iteration: if config.train.use_horovod: train_sampler.set_epoch(curr_iter) if config.network.use_syncbn: train_model.train() if config.network.backbone_freeze_at > 0: train_model.freeze_backbone(config.network.backbone_freeze_at) if config.network.backbone_fix_bn: train_model.resnet_backbone.eval() for inner_iter, batch in enumerate(train_loader): data, label, _ = batch for k, v in data.items(): data[k] = v if not torch.is_tensor(v) else v.cuda() for k, v in label.items(): label[k] = v if not torch.is_tensor(v) else v.cuda() lr = adjust_learning_rate(optimizer, curr_iter, config) optimizer.zero_grad() output = train_model(data, label) loss = 0 if config.network.has_rpn: loss = loss + output['rpn_cls_loss'].mean() + output['rpn_bbox_loss'].mean() if config.network.has_rcnn: loss = loss + output['cls_loss'].mean() + output['bbox_loss'].mean() * config.train.bbox_loss_weight if config.network.has_mask_head: loss = loss + output['mask_loss'].mean() if config.network.has_fcn_head: loss = loss + output['fcn_loss'].mean() * config.train.fcn_loss_weight if config.train.fcn_with_roi_loss: loss = loss + output['fcn_roi_loss'].mean() * config.train.fcn_loss_weight * 0.2 if config.network.has_panoptic_head: loss = loss + output['panoptic_loss'].mean() * config.train.panoptic_loss_weight loss.backward() optimizer.step(lr) losses = [] losses.append(allreduce_async(loss, name='train_total_loss')) for l in metrics_name: losses.append(allreduce_async(output[l].mean(), name=l)) loss = hvd.synchronize(losses[0]).item() if is_master: writer.add_scalar('train_total_loss', loss, curr_iter) for i, (metric, l) in enumerate(zip(metrics, metrics_name)): loss = hvd.synchronize(losses[i + 1]).item() if is_master: writer.add_scalar('train_' + l, loss, curr_iter) metric.update(_, _, loss) curr_iter += 1 if curr_iter in config.train.decay_iteration: if is_master: logger.info('decay momentum buffer') for k in optimizer.state_dict()['state'].keys(): if 'momentum_buffer' in optimizer.state_dict()['state'][k]: optimizer.state_dict()['state'][k]['momentum_buffer'].div_(10) if is_master: if curr_iter % config.train.display_iter == 0: for callback in batch_end_callback: callback(curr_iter, metrics) if curr_iter % config.train.snapshot_step == 0: logger.info('taking snapshot ...') torch.save(train_model.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.pth')) torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.state.pth')) else: inner_iter = 0 train_iterator = train_loader.__iter__() while inner_iter + num_gpus <= len(train_loader): batch = [] for gpu_id in gpus: data, label, _ = train_iterator.next() for k, v in data.items(): data[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True) for k, v in label.items(): label[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True) batch.append((data, label)) inner_iter += 1 lr = adjust_learning_rate(optimizer, curr_iter, config) optimizer.zero_grad() if config.train.use_horovod: output = train_model(data, label) else: output = train_model(*batch) loss = 0 if config.network.has_rpn: loss = loss + output['rpn_cls_loss'].mean() + output['rpn_bbox_loss'].mean() if config.network.has_rcnn: loss = loss + output['cls_loss'].mean() + output['bbox_loss'].mean() if config.network.has_mask_head: loss = loss + output['mask_loss'].mean() if config.network.has_fcn_head: loss = loss + output['fcn_loss'].mean() * config.train.fcn_loss_weight if config.train.fcn_with_roi_loss: loss = loss + output['fcn_roi_loss'].mean() * config.train.fcn_loss_weight * 0.2 if config.network.has_panoptic_head: loss = loss + output['panoptic_loss'].mean() * config.train.panoptic_loss_weight loss.backward() optimizer.step(lr) losses = [] losses.append(loss.item()) for l in metrics_name: losses.append(output[l].mean().item()) loss = losses[0] if is_master: writer.add_scalar('train_total_loss', loss, curr_iter) for i, (metric, l) in enumerate(zip(metrics, metrics_name)): loss = losses[i + 1] if is_master: writer.add_scalar('train_' + l, loss, curr_iter) metric.update(_, _, loss) curr_iter += 1 if curr_iter in config.train.decay_iteration: if is_master: logger.info('decay momentum buffer') for k in optimizer.state_dict()['state'].keys(): optimizer.state_dict()['state'][k]['momentum_buffer'].div_(10) if is_master: if curr_iter % config.train.display_iter == 0: for callback in batch_end_callback: callback(curr_iter, metrics) if curr_iter % config.train.snapshot_step == 0: logger.info('taking snapshot ...') torch.save(train_model.module.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.pth')) torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.state.pth')) while True: try: train_iterator.next() except: break for metric in metrics: metric.reset() if config.train.eval_data: train_model.eval() if config.train.use_horovod: for inner_iter, batch in enumerate(val_loader): data, label, _ = batch for k, v in data.items(): data[k] = v if not torch.is_tensor(v) else v.cuda(non_blocking=True) for k, v in label.items(): label[k] = v if not torch.is_tensor(v) else v.cuda(non_blocking=True) with torch.no_grad(): output = train_model(data, label) for metric, l in zip(metrics, metrics_name): loss = hvd.allreduce(output[l].mean()).item() if is_master: metric.update(_, _, loss) else: inner_iter = 0 val_iterator = val_loader.__iter__() while inner_iter + len(gpus) <= len(val_loader): batch = [] for gpu_id in gpus: data, label, _ = val_iterator.next() for k, v in data.items(): data[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True) for k, v in label.items(): label[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True) batch.append((data, label)) inner_iter += 1 with torch.no_grad(): if config.train.use_horovod: output = train_model(data, label) else: output = train_model(*batch) losses = [] for l in metrics_name: losses.append(allreduce_async(output[l].mean(), name=l) if config.train.use_horovod else output[l].mean().item()) for metric, loss in zip(metrics, losses): loss = hvd.synchronize(loss).item() if config.train.use_horovod else loss if is_master: metric.update(_, _, loss) while True: try: val_iterator.next() except Exception: break s = 'Batch [%d]\t Epoch[%d]\t' % (curr_iter, curr_iter // len(train_loader)) for metric in metrics: m, v = metric.get() s += 'Val-%s=%f,\t' % (m, v) if is_master: writer.add_scalar('val_' + m, v, curr_iter) metric.reset() if is_master: logger.info(s) if is_master and config.train.use_horovod: logger.info('taking snapshot ...') torch.save(train_model.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth')) torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth')) elif not config.train.use_horovod: logger.info('taking snapshot ...') torch.save(train_model.module.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth')) torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth'))
def main(): """Training for pixel-wise embeddings by pixel-segment contrastive learning loss for DensePose. """ # Retreve experiment configurations. args = parse_args('Training for pixel-wise embeddings for DensePose.') # Retrieve GPU informations. device_ids = [int(i) for i in config.gpus.split(',')] gpu_ids = [torch.device('cuda', i) for i in device_ids] num_gpus = len(gpu_ids) # Create logger and tensorboard writer. summary_writer = tensorboardX.SummaryWriter(logdir=args.snapshot_dir) color_map = vis_utils.load_color_map(config.dataset.color_map_path) model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth') optimizer_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.state.pth') # Create data loaders. train_dataset = DenseposeDataset(data_dir=args.data_dir, data_list=args.data_list, img_mean=config.network.pixel_means, img_std=config.network.pixel_stds, size=config.train.crop_size, random_crop=config.train.random_crop, random_scale=config.train.random_scale, random_mirror=config.train.random_mirror, training=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.train.batch_size, shuffle=config.train.shuffle, num_workers=num_gpus * config.num_threads, drop_last=False, collate_fn=train_dataset.collate_fn) # Create models. if config.network.backbone_types == 'panoptic_pspnet_101': embedding_model = resnet_101_pspnet(config).cuda() else: raise ValueError('Not support ' + config.network.backbone_types) if config.network.prediction_types == 'segsort': prediction_model = segsort(config).cuda() else: raise ValueError('Not support ' + config.network.prediction_types) # Use synchronize batchnorm. if config.network.use_syncbn: embedding_model = convert_model(embedding_model).cuda() prediction_model = convert_model(prediction_model).cuda() # Use customized optimizer and pass lr=1 to support different lr for # different weights. optimizer = SGD(embedding_model.get_params_lr() + prediction_model.get_params_lr(), lr=1, momentum=config.train.momentum, weight_decay=config.train.weight_decay) optimizer.zero_grad() # Load pre-trained weights. curr_iter = config.train.begin_iteration if config.train.resume: model_path = model_path_template.fromat(curr_iter) print('Resume training from {:s}'.format(model_path)) embedding_model.load_state_dict( torch.load(model_path)['embedding_model'], resume=True) prediction_model.load_state_dict( torch.load(model_path)['prediction_model'], resume=True) optimizer.load_state_dict( torch.load(optimizer_path_template.format(curr_iter))) elif config.network.pretrained: print('Loading pre-trained model: {:s}'.format( config.network.pretrained)) embedding_model.load_state_dict(torch.load(config.network.pretrained)) else: print('Training from scratch') # Distribute model weights to multi-gpus. embedding_model = DataParallel(embedding_model, device_ids=device_ids, gather_output=False) prediction_model = DataParallel(prediction_model, device_ids=device_ids, gather_output=False) if config.network.use_syncbn: patch_replication_callback(embedding_model) patch_replication_callback(prediction_model) for module in embedding_model.modules(): if isinstance(module, _BatchNorm) or isinstance(module, _ConvNd): print(module.training, module) print(embedding_model) print(prediction_model) # Create memory bank. memory_banks = {} # start training train_iterator = train_loader.__iter__() iterator_index = 0 pbar = tqdm(range(curr_iter, config.train.max_iteration)) for curr_iter in pbar: # Check if the rest of datas is enough to iterate through; # otherwise, re-initiate the data iterator. if iterator_index + num_gpus >= len(train_loader): train_iterator = train_loader.__iter__() iterator_index = 0 # Feed-forward. image_batch, label_batch = other_utils.prepare_datas_and_labels_mgpu( train_iterator, gpu_ids) iterator_index += num_gpus # Generate embeddings, clustering and prototypes. embeddings = embedding_model(*zip(image_batch, label_batch)) # Synchronize cluster indices and computer prototypes. c_inds = [emb['cluster_index'] for emb in embeddings] cb_inds = [emb['cluster_batch_index'] for emb in embeddings] cs_labs = [emb['cluster_semantic_label'] for emb in embeddings] ci_labs = [emb['cluster_instance_label'] for emb in embeddings] c_embs = [emb['cluster_embedding'] for emb in embeddings] c_embs_with_loc = [ emb['cluster_embedding_with_loc'] for emb in embeddings ] (prototypes, prototypes_with_loc, prototype_semantic_labels, prototype_instance_labels, prototype_batch_indices, cluster_indices) = ( model_utils.gather_clustering_and_update_prototypes( c_embs, c_embs_with_loc, c_inds, cb_inds, cs_labs, ci_labs, 'cuda:{:d}'.format(num_gpus - 1))) for i in range(len(label_batch)): label_batch[i]['prototype'] = prototypes[i] label_batch[i]['prototype_with_loc'] = prototypes_with_loc[i] label_batch[i][ 'prototype_semantic_label'] = prototype_semantic_labels[i] label_batch[i][ 'prototype_instance_label'] = prototype_instance_labels[i] label_batch[i]['prototype_batch_index'] = prototype_batch_indices[ i] embeddings[i]['cluster_index'] = cluster_indices[i] #semantic_tags = model_utils.gather_and_update_datas( # [lab['semantic_tag'] for lab in label_batch], # 'cuda:{:d}'.format(num_gpus-1)) #for i in range(len(label_batch)): # label_batch[i]['semantic_tag'] = semantic_tags[i] # label_batch[i]['prototype_semantic_tag'] = torch.index_select( # semantic_tags[i], # 0, # label_batch[i]['prototype_batch_index']) # Add memory bank to label batch. for k in memory_banks.keys(): for i in range(len(label_batch)): assert (label_batch[i].get(k, None) is None) label_batch[i][k] = [m.to(gpu_ids[i]) for m in memory_banks[k]] # Compute loss. outputs = prediction_model(*zip(embeddings, label_batch)) outputs = scatter_gather.gather(outputs, gpu_ids[0]) losses = [] for k in [ 'sem_ann_loss', 'sem_occ_loss', 'img_sim_loss', 'feat_aff_loss' ]: loss = outputs.get(k, None) if loss is not None: outputs[k] = loss.mean() losses.append(outputs[k]) loss = sum(losses) acc = outputs['accuracy'].mean() # Write to tensorboard summary. writer = (summary_writer if curr_iter % config.train.tensorboard_step == 0 else None) if writer is not None: summary_vis = [] summary_val = {} # Gather labels to cpu. cpu_label_batch = scatter_gather.gather(label_batch, -1) summary_vis.append( vis_utils.convert_label_to_color( cpu_label_batch['semantic_label'], color_map)) summary_vis.append( vis_utils.convert_label_to_color( cpu_label_batch['instance_label'], color_map)) # Gather outputs to cpu. vis_names = ['embedding'] cpu_embeddings = scatter_gather.gather( [{k: emb.get(k, None) for k in vis_names} for emb in embeddings], -1) for vis_name in vis_names: if cpu_embeddings.get(vis_name, None) is not None: summary_vis.append( vis_utils.embedding_to_rgb(cpu_embeddings[vis_name], 'pca')) val_names = [ 'sem_ann_loss', 'sem_occ_loss', 'img_sim_loss', 'feat_aff_loss', 'accuracy' ] for val_name in val_names: if outputs.get(val_name, None) is not None: summary_val[val_name] = outputs[val_name].mean().to('cpu') vis_utils.write_image_to_tensorboard(summary_writer, summary_vis, summary_vis[-1].shape[-2:], curr_iter) vis_utils.write_scalars_to_tensorboard(summary_writer, summary_val, curr_iter) # Backward propogation. if config.train.lr_policy == 'step': lr = train_utils.lr_step(config.train.base_lr, curr_iter, config.train.decay_iterations, config.train.warmup_iteration) else: lr = train_utils.lr_poly(config.train.base_lr, curr_iter, config.train.max_iteration, config.train.warmup_iteration) optimizer.zero_grad() loss.backward() optimizer.step(lr) # Update memory banks. with torch.no_grad(): for k in label_batch[0].keys(): if 'prototype' in k and 'memory' not in k: memory = label_batch[0][k].clone().detach() memory_key = 'memory_' + k if memory_key not in memory_banks.keys(): memory_banks[memory_key] = [] memory_banks[memory_key].append(memory) if len(memory_banks[memory_key] ) > config.train.memory_bank_size: memory_banks[memory_key] = memory_banks[memory_key][1:] # Update batch labels. for k in ['memory_prototype_batch_index']: memory_labels = memory_banks.get(k, None) if memory_labels is not None: for i, memory_label in enumerate(memory_labels): memory_labels[i] += config.train.batch_size * num_gpus # Snapshot the trained model. if ((curr_iter + 1) % config.train.snapshot_step == 0 or curr_iter == config.train.max_iteration - 1): model_state_dict = { 'embedding_model': embedding_model.module.state_dict(), 'prediction_model': prediction_model.module.state_dict() } torch.save(model_state_dict, model_path_template.format(curr_iter)) torch.save(optimizer.state_dict(), optimizer_path_template.format(curr_iter)) # Print loss in the progress bar. line = 'loss = {:.3f}, acc = {:.3f}, lr = {:.6f}'.format( loss.item(), acc.item(), lr) pbar.set_description(line)