def prepare_data(args): train_transform_S = get_transform(train=True, dataset_name=cfg.DATASET.SOURCE) train_transform_T = get_transform(train=True, dataset_name=cfg.DATASET.TARGET) val_transform = get_transform(train=False, dataset_name=cfg.DATASET.VAL) train_dataset_S = eval('Dataset.%s' % cfg.DATASET.SOURCE)( cfg.DATASET.DATAROOT_S, cfg.DATASET.TRAIN_SPLIT_S, transform=train_transform_S) train_dataset_T = eval('Dataset.%s' % cfg.DATASET.TARGET)( cfg.DATASET.DATAROOT_T, cfg.DATASET.TRAIN_SPLIT_T, transform=train_transform_T) val_dataset = eval('Dataset.%s' % cfg.DATASET.VAL)( cfg.DATASET.DATAROOT_VAL, cfg.DATASET.VAL_SPLIT, transform=val_transform) # construct dataloaders train_dataloader_S = data_utils.get_dataloader( train_dataset_S, cfg.TRAIN.TRAIN_BATCH_SIZE, cfg.NUM_WORKERS, train=True, distributed=args.distributed, world_size=gen_utils.get_world_size()) train_dataloader_T = data_utils.get_dataloader( train_dataset_T, cfg.TRAIN.TRAIN_BATCH_SIZE, cfg.NUM_WORKERS, train=True, distributed=args.distributed, world_size=gen_utils.get_world_size()) val_dataloader = data_utils.get_dataloader( val_dataset, cfg.TRAIN.VAL_BATCH_SIZE, cfg.NUM_WORKERS, train=False, distributed=args.distributed, world_size=gen_utils.get_world_size()) dataloaders = {'train_S': train_dataloader_S, \ 'train_T': train_dataloader_T, 'val': val_dataloader} return dataloaders
def train(config, epoch, num_epoch, epoch_iters, base_lr, num_iters, trainloader, optimizer, model, writer_dict, device): # Training model.train() batch_time = AverageMeter() ave_loss = AverageMeter() ave_loss1 = AverageMeter() ave_aux_loss = AverageMeter() ave_error_loss = AverageMeter() tic = time.time() cur_iters = epoch * epoch_iters writer = writer_dict['writer'] global_steps = writer_dict['train_global_steps'] rank = get_rank() world_size = get_world_size() for i_iter, batch in enumerate(trainloader): images, labels, _, _ = batch images = images.to(device) labels = labels.long().to(device) losses, aux_loss, error_loss, _ = model(images, labels) # print('pred', pred[2].size()) loss = losses.mean() + 0.4 * aux_loss.mean() + 1 * error_loss.mean() reduced_loss = reduce_tensor(loss) loss1 = reduce_tensor(losses) aux_loss = reduce_tensor(aux_loss) error_losses = reduce_tensor(error_loss) model.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - tic) tic = time.time() # update average loss ave_loss.update(reduced_loss.item()) ave_loss1.update(loss1.item()) ave_aux_loss.update(aux_loss.item()) ave_error_loss.update(error_losses.item()) lr = adjust_learning_rate(optimizer, base_lr, num_iters, i_iter + cur_iters) if i_iter % config.PRINT_FREQ == 0 and rank == 0: print_loss = ave_loss.average() / world_size print_loss1 = ave_loss1.average() / world_size print_loss_aux = ave_aux_loss.average() / world_size print_error_loss = ave_error_loss.average() / world_size msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \ 'lr: {:.6f}, Loss: {:.6f}, Loss_1: {:.6f}, Loss_aux: {:.6f}, error_loss: {:.6f}' .format( epoch, num_epoch, i_iter, epoch_iters, batch_time.average(), lr, print_loss, print_loss1, print_loss_aux, print_error_loss) logging.info(msg) writer.add_scalar('train_loss', print_loss, global_steps) writer_dict['train_global_steps'] = global_steps + 1
def source_only(self, data_S, gt_S, data_T, gt_T, *others, **kwargs): self.set_domain_id(0) preds = self.net(data_S)['out'] preds = F.interpolate(preds, size=data_S.shape[-2:], mode='bilinear', align_corners=False) ce_loss = self.CELoss([preds], gt_S) if cfg.TRAIN.WITH_LOV: if self.distributed: lov_loss = lovasz_softmax_multigpu(F.softmax(preds, dim=1), gt_S, classes='present', per_image=False, ignore=255) else: lov_loss = lovasz_softmax(F.softmax(preds, dim=1), gt_S, classes='present', per_image=False, ignore=255) ce_loss += (cfg.TRAIN.LOV_W * get_world_size() * self.iter_size) * lov_loss out_dict = { 'feats_S': None, 'feats_T': None, 'preds_S': preds, 'preds_T': None } return {'total': ce_loss}, out_dict
def validate(config, testloader, model, writer_dict, device): rank = get_rank() world_size = get_world_size() model.eval() ave_loss = AverageMeter() confusion_matrix = np.zeros( (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES)) confusion_matrix_sum = np.zeros( (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES)) with torch.no_grad(): for _, batch in enumerate(testloader): image, label, boundary_gt, _, _ = batch size = label.size() image = image.to(device) boundary_gt = boundary_gt.to(device) label = label.long().to(device) losses, aux_loss, error_loss, losses_2, aux_loss_2, error_loss_2, preds = model( image, label, boundary_gt.float()) pred = F.upsample(input=preds[0], size=(size[-2], size[-1]), mode='bilinear') loss = (losses + 0.4 * aux_loss + 4 * error_loss + losses_2 + 0.4 * aux_loss_2 + 4 * error_loss_2).mean() reduced_loss = reduce_tensor(loss) ave_loss.update(reduced_loss.item()) confusion_matrix += get_confusion_matrix( label, pred, size, config.DATASET.NUM_CLASSES, config.TRAIN.IGNORE_LABEL) confusion_matrix = torch.from_numpy(confusion_matrix).to(device) reduced_confusion_matrix = reduce_tensor(confusion_matrix) confusion_matrix = reduced_confusion_matrix.cpu().numpy() pos = confusion_matrix.sum(1) res = confusion_matrix.sum(0) tp = np.diag(confusion_matrix) IoU_array = (tp / np.maximum(1.0, pos + res - tp)) mean_IoU = IoU_array.mean() print_loss = ave_loss.average() / world_size if rank == 0: writer = writer_dict['writer'] global_steps = writer_dict['valid_global_steps'] writer.add_scalar('valid_loss', print_loss, global_steps) writer.add_scalar('valid_mIoU', mean_IoU, global_steps) writer_dict['valid_global_steps'] = global_steps + 1 # cv2.imwrite(str(global_steps)+'_boundary.png', (preds[0][0][0].data.cpu().numpy()*255).astype(np.uint8)) # cv2.imwrite(str(global_steps) + '_error.png', (preds[2][0][0].data.cpu().numpy() * 255).astype(np.uint8)) cv2.imwrite( str(global_steps) + '_error.png', (preds[2][0][0].data.cpu().numpy() * 255).astype(np.uint8)) return print_loss, mean_IoU, IoU_array
def train(config, epoch, num_epoch, epoch_iters, base_lr, num_iters, trainloader, optimizer, lr_scheduler, model, writer_dict, device): # Training model.train() batch_time = AverageMeter() ave_loss = AverageMeter() tic = time.time() cur_iters = epoch*epoch_iters writer = writer_dict['writer'] global_steps = writer_dict['train_global_steps'] rank = get_rank() world_size = get_world_size() for i_iter, batch in enumerate(trainloader): images, labels, _, _ = batch images = images.to(device) labels = labels.long().to(device) losses, _ = model(images, labels, train_step=(lr_scheduler._step_count-1)) loss = losses.mean() reduced_loss = reduce_tensor(loss) model.zero_grad() loss.backward() optimizer.step() if config.TRAIN.LR_SCHEDULER != 'step': lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - tic) tic = time.time() # update average loss ave_loss.update(reduced_loss.item()) lr = adjust_learning_rate(optimizer, base_lr, num_iters, i_iter+cur_iters) if i_iter % config.PRINT_FREQ == 0 and rank == 0: print_loss = ave_loss.average() / world_size msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \ 'lr: {:.6f}, Loss: {:.6f}' .format( epoch, num_epoch, i_iter, epoch_iters, batch_time.average(), lr, print_loss) logging.info(msg) writer.add_scalar('train_loss', print_loss, global_steps) writer_dict['train_global_steps'] = global_steps + 1 batch_time = AverageMeter()
def reduce_tensor(inp): """ Reduce the loss from all processes so that process with rank 0 has the averaged results. """ world_size = get_world_size() if world_size < 2: return inp with torch.no_grad(): reduced_inp = inp dist.reduce(reduced_inp, dst=0) return reduced_inp
def validate(config, testloader, model, writer_dict, device): rank = get_rank() world_size = get_world_size() model.eval() ave_loss = AverageMeter() tot_inter = np.zeros(config.DATASET.NUM_CLASSES) tot_union = np.zeros(config.DATASET.NUM_CLASSES) with torch.no_grad(): for i_iter, batch in enumerate(testloader): image, label, _, _ = batch size = label.size() label = label.long().to(device) image = image.to(device) loss, pred = model(image, label) if pred.size()[-2] != size[-2] or pred.size()[-1] != size[-1]: pred = F.interpolate(pred, size=(size[-2], size[-1]), mode='bilinear', align_corners=False) reduced_loss = reduce_tensor(loss) ave_loss.update(reduced_loss.item()) batch_inter, batch_union = batch_intersection_union( pred, label, config.DATASET.NUM_CLASSES) tot_inter += batch_inter tot_union += batch_union if i_iter % config.PRINT_FREQ == 0 and rank == 0: msg = f'Iter: {i_iter}, Loss: {ave_loss.average() / world_size:.6f}' logging.info(msg) tot_inter = torch.from_numpy(tot_inter).to(device) tot_union = torch.from_numpy(tot_union).to(device) tot_inter = reduce_tensor(tot_inter).cpu().numpy() tot_union = reduce_tensor(tot_union).cpu().numpy() IoU = np.float64(1.0) * tot_inter / (np.spacing(1, dtype=np.float64) + tot_union) mean_IoU = IoU.mean() print_loss = ave_loss.average() / world_size if rank == 0: writer = writer_dict['writer'] global_steps = writer_dict['valid_global_steps'] writer.add_scalar('valid_loss', print_loss, global_steps) writer.add_scalar('valid_mIoU', mean_IoU, global_steps) writer_dict['valid_global_steps'] = global_steps + 1 return print_loss, mean_IoU
def validate(config, testloader, model, writer_dict, device): rank = get_rank() world_size = get_world_size() model.eval() ave_loss = AverageMeter() confusion_matrix = np.zeros( (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES)) with torch.no_grad(): for _, batch in enumerate(testloader): image, label, _, _ = batch size = label.size() image = image.to(device) label = label.long().to(device) losses, pred = model(image, label) pred = F.upsample(input=pred, size=( size[-2], size[-1]), mode='bilinear') loss = losses.mean() reduced_loss = reduce_tensor(loss) ave_loss.update(reduced_loss.item()) confusion_matrix += get_confusion_matrix( label, pred, size, config.DATASET.NUM_CLASSES, config.TRAIN.IGNORE_LABEL) confusion_matrix = torch.from_numpy(confusion_matrix).to(device) reduced_confusion_matrix = reduce_tensor(confusion_matrix) confusion_matrix = reduced_confusion_matrix.cpu().numpy() pos = confusion_matrix.sum(1) res = confusion_matrix.sum(0) tp = np.diag(confusion_matrix) IoU_array = (tp / np.maximum(1.0, pos + res - tp)) mean_IoU = IoU_array.mean() print_loss = ave_loss.average()/world_size if rank == 0: writer = writer_dict['writer'] global_steps = writer_dict['valid_global_steps'] writer.add_scalar('valid_loss', print_loss, global_steps) writer.add_scalar('valid_mIoU', mean_IoU, global_steps) writer_dict['valid_global_steps'] = global_steps + 1 return print_loss, mean_IoU, IoU_array
def train(config, epoch, num_epoch, epoch_iters, trainloader, optimizer, lr_scheduler, model, writer_dict, device): # Training model.train() batch_time = AverageMeter() ave_loss = AverageMeter() tic = time.time() rank = get_rank() world_size = get_world_size() for i_iter, batch in enumerate(trainloader, 0): images, labels, _, _ = batch labels = labels.long().to(device) images = images.to(device) loss, _ = model(images, labels) reduced_loss = reduce_tensor(loss) optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - tic) tic = time.time() # update average loss ave_loss.update(reduced_loss.item()) lr = optimizer.param_groups[0]['lr'] if i_iter % config.PRINT_FREQ == 0 and rank == 0: print_loss = ave_loss.average() / world_size msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \ 'lr: {:.6f}, Loss: {:.6f}' .format( epoch, num_epoch, i_iter, epoch_iters, batch_time.average(), lr, print_loss) logging.info(msg) if rank == 0: writer = writer_dict['writer'] global_steps = writer_dict['train_global_steps'] writer.add_scalar('train_loss', ave_loss.average() / world_size, global_steps) writer_dict['train_global_steps'] = global_steps + 1
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument("--input_dir", type=str, required=True) parser.add_argument("--teacher_model", default=None, type=str, required=True) parser.add_argument("--student_model", default=None, type=str, required=True) parser.add_argument("--output_dir", default=None, type=str, required=True) parser.add_argument('--vocab_file', type=str, default=None, required=True, help="Vocabulary mapping/file BERT was pretrainined on") # Other parameters parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--reduce_memory", action="store_true", help="Store training data as on-disc memmaps to massively reduce memory usage") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=8, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay') parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") parser.add_argument('--steps_per_epoch', type=int, default=-1, help="Number of updates steps to in one epoch.") parser.add_argument('--max_steps', type=int, default=-1, help="Number of training steps.") parser.add_argument('--amp', action='store_true', default=False, help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--continue_train', action='store_true', default=False, help='Whether to train from checkpoints') parser.add_argument('--disable_progress_bar', default=False, action='store_true', help='Disable tqdm progress bar') parser.add_argument('--max_grad_norm', type=float, default=1., help="Gradient Clipping threshold") # Additional arguments parser.add_argument('--eval_step', type=int, default=1000) # This is used for running on Huawei Cloud. parser.add_argument('--data_url', type=str, default="") #Distillation specific parser.add_argument('--value_state_loss', action='store_true', default=False) parser.add_argument('--hidden_state_loss', action='store_true', default=False) parser.add_argument('--use_last_layer', action='store_true', default=False) parser.add_argument('--use_kld', action='store_true', default=False) parser.add_argument('--use_cosine', action='store_true', default=False) parser.add_argument('--distill_config', default="distillation_config.json", type=str, help="path the distillation config") parser.add_argument('--num_workers', type=int, default=4, help='number of DataLoader worker processes per rank') args = parser.parse_args() logger.info('args:{}'.format(args)) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, stream=sys.stdout) logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( device, n_gpu, bool(args.local_rank != -1), args.amp)) if args.gradient_accumulation_steps < 1: raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( args.gradient_accumulation_steps)) # Reference params author_gbs = 256 author_steps_per_epoch = 22872 author_epochs = 3 author_max_steps = author_steps_per_epoch * author_epochs # Compute present run params if args.max_steps == -1 or args.steps_per_epoch == -1: args.steps_per_epoch = author_steps_per_epoch * author_gbs // (args.train_batch_size * get_world_size() * args.gradient_accumulation_steps) args.max_steps = author_max_steps * author_gbs // (args.train_batch_size * get_world_size() * args.gradient_accumulation_steps) #Set seed set_seed(args.seed, n_gpu) if os.path.exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) if not os.path.exists(args.output_dir) and is_main_process(): os.makedirs(args.output_dir) tokenizer = BertTokenizer.from_pretrained(args.teacher_model, do_lower_case=args.do_lower_case) teacher_model, teacher_config = BertModel.from_pretrained(args.teacher_model, distill_config=args.distill_config) # Required to make sure model's fwd doesn't return anything. required for DDP. # fwd output not being used in loss computation crashes DDP teacher_model.make_teacher() if args.continue_train: student_model, student_config = BertForPreTraining.from_pretrained(args.student_model, distill_config=args.distill_config) else: student_model, student_config = BertForPreTraining.from_scratch(args.student_model, distill_config=args.distill_config) # We need a projection layer since teacher.hidden_size != student.hidden_size use_projection = student_config.hidden_size != teacher_config.hidden_size if use_projection: project = Project(student_config, teacher_config) if args.continue_train: project_model_file = os.path.join(args.student_model, "project.bin") project_ckpt = torch.load(project_model_file, map_location="cpu") project.load_state_dict(project_ckpt) distill_config = {"nn_module_names": []} #Empty list since we don't want to use nn module hooks here distill_hooks_student, distill_hooks_teacher = DistillHooks(distill_config), DistillHooks(distill_config) student_model.register_forward_hook(distill_hooks_student.child_to_main_hook) teacher_model.register_forward_hook(distill_hooks_teacher.child_to_main_hook) ## Register hooks on nn.Modules # student_fwd_pre_hook = student_model.register_forward_pre_hook(distill_hooks_student.register_nn_module_hook) # teacher_fwd_pre_hook = teacher_model.register_forward_pre_hook(distill_hooks_teacher.register_nn_module_hook) student_model.to(device) teacher_model.to(device) if use_projection: project.to(device) if args.local_rank != -1: teacher_model = torch.nn.parallel.DistributedDataParallel( teacher_model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=False ) student_model = torch.nn.parallel.DistributedDataParallel( student_model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=False ) if use_projection: project = torch.nn.parallel.DistributedDataParallel( project, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=False ) size = 0 for n, p in student_model.named_parameters(): logger.info('n: {}'.format(n)) logger.info('p: {}'.format(p.nelement())) size += p.nelement() logger.info('Total parameters: {}'.format(size)) # Prepare optimizer param_optimizer = list(student_model.named_parameters()) if use_projection: param_optimizer += list(project.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False) scheduler = LinearWarmUpScheduler(optimizer, warmup=args.warmup_proportion, total_steps=args.max_steps) global_step = 0 logging.info("***** Running training *****") logging.info(" Num examples = {}".format(args.train_batch_size * args.max_steps)) logging.info(" Batch size = %d", args.train_batch_size) logging.info(" Num steps = %d", args.max_steps) # Prepare the data loader. if is_main_process(): tic = time.perf_counter() train_dataloader = lddl.torch.get_bert_pretrain_data_loader( args.input_dir, local_rank=args.local_rank, vocab_file=args.vocab_file, data_loader_kwargs={ 'batch_size': args.train_batch_size * n_gpu, 'num_workers': args.num_workers, 'pin_memory': True, }, base_seed=args.seed, log_dir=None if args.output_dir is None else os.path.join(args.output_dir, 'lddl_log'), log_level=logging.WARNING, start_epoch=0, ) if is_main_process(): print('get_bert_pretrain_data_loader took {} s!'.format(time.perf_counter() - tic)) train_dataloader = tqdm(train_dataloader, desc="Iteration", disable=args.disable_progress_bar) if is_main_process() else train_dataloader tr_loss, tr_att_loss, tr_rep_loss, tr_value_loss = 0., 0., 0., 0. nb_tr_examples, local_step = 0, 0 student_model.train() scaler = torch.cuda.amp.GradScaler() transformer_losses = TransformerLosses(student_config, teacher_config, device, args) iter_start = time.time() while global_step < args.max_steps: for batch in train_dataloader: if global_step >= args.max_steps: break #remove forward_pre_hook after one forward pass #the purpose of forward_pre_hook is to register #forward_hooks on nn_module_names provided in config # if idx == 1: # student_fwd_pre_hook.remove() # teacher_fwd_pre_hook.remove() # # return # Initialize loss metrics if global_step % args.steps_per_epoch == 0: tr_loss, tr_att_loss, tr_rep_loss, tr_value_loss = 0., 0., 0., 0. mean_loss, mean_att_loss, mean_rep_loss, mean_value_loss = 0., 0., 0., 0. batch = {k: v.to(device) for k, v in batch.items()} input_ids, segment_ids, input_mask, lm_label_ids, is_next = batch['input_ids'], batch['token_type_ids'], batch['attention_mask'], batch['labels'], batch['next_sentence_labels'] att_loss = 0. rep_loss = 0. value_loss = 0. with torch.cuda.amp.autocast(enabled=args.amp): student_model(input_ids, segment_ids, input_mask, None) # Gather student states extracted by hooks temp_model = unwrap_ddp(student_model) student_atts = flatten_states(temp_model.distill_states_dict, "attention_scores") student_reps = flatten_states(temp_model.distill_states_dict, "hidden_states") student_values = flatten_states(temp_model.distill_states_dict, "value_states") student_embeddings = flatten_states(temp_model.distill_states_dict, "embedding_states") bsz, attn_heads, seq_len, _ = student_atts[0].shape #No gradient for teacher training with torch.no_grad(): teacher_model(input_ids, segment_ids, input_mask) # Gather teacher states extracted by hooks temp_model = unwrap_ddp(teacher_model) teacher_atts = [i.detach() for i in flatten_states(temp_model.distill_states_dict, "attention_scores")] teacher_reps = [i.detach() for i in flatten_states(temp_model.distill_states_dict, "hidden_states")] teacher_values = [i.detach() for i in flatten_states(temp_model.distill_states_dict, "value_states")] teacher_embeddings = [i.detach() for i in flatten_states(temp_model.distill_states_dict, "embedding_states")] teacher_layer_num = len(teacher_atts) student_layer_num = len(student_atts) #MiniLM if student_config.distillation_config["student_teacher_layer_mapping"] == "last_layer": if student_config.distillation_config["use_attention_scores"]: student_atts = [student_atts[-1]] new_teacher_atts = [teacher_atts[-1]] if student_config.distillation_config["use_value_states"]: student_values = [student_values[-1]] new_teacher_values = [teacher_values[-1]] if student_config.distillation_config["use_hidden_states"]: new_teacher_reps = [teacher_reps[-1]] new_student_reps = [student_reps[-1]] else: assert teacher_layer_num % student_layer_num == 0 layers_per_block = int(teacher_layer_num / student_layer_num) if student_config.distillation_config["use_attention_scores"]: new_teacher_atts = [teacher_atts[i * layers_per_block + layers_per_block - 1] for i in range(student_layer_num)] if student_config.distillation_config["use_value_states"]: new_teacher_values = [teacher_values[i * layers_per_block + layers_per_block - 1] for i in range(student_layer_num)] if student_config.distillation_config["use_hidden_states"]: new_teacher_reps = [teacher_reps[i * layers_per_block + layers_per_block - 1] for i in range(student_layer_num)] new_student_reps = student_reps if student_config.distillation_config["use_attention_scores"]: att_loss = transformer_losses.compute_loss(student_atts, new_teacher_atts, loss_name="attention_loss") if student_config.distillation_config["use_hidden_states"]: if use_projection: rep_loss = transformer_losses.compute_loss(project(new_student_reps), new_teacher_reps, loss_name="hidden_state_loss") else: rep_loss = transformer_losses.compute_loss(new_student_reps, new_teacher_reps, loss_name="hidden_state_loss") if student_config.distillation_config["use_embedding_states"]: if use_projection: rep_loss += transformer_losses.compute_loss(project(student_embeddings), teacher_embeddings, loss_name="embedding_state_loss") else: rep_loss += transformer_losses.compute_loss(student_embeddings, teacher_embeddings, loss_name="embedding_state_loss") if student_config.distillation_config["use_value_states"]: value_loss = transformer_losses.compute_loss(student_values, new_teacher_values, loss_name="value_state_loss") loss = att_loss + rep_loss + value_loss if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps tr_att_loss += att_loss.item() / args.gradient_accumulation_steps if student_config.distillation_config["use_hidden_states"]: tr_rep_loss += rep_loss.item() / args.gradient_accumulation_steps if student_config.distillation_config["use_value_states"]: tr_value_loss += value_loss.item() / args.gradient_accumulation_steps if args.amp: scaler.scale(loss).backward() scaler.unscale_(optimizer) else: loss.backward() if use_projection: torch.nn.utils.clip_grad_norm_(chain(student_model.parameters(), project.parameters()), args.max_grad_norm, error_if_nonfinite=False) else: torch.nn.utils.clip_grad_norm_(student_model.parameters(), args.max_grad_norm, error_if_nonfinite=False) tr_loss += loss.item() nb_tr_examples += input_ids.size(0) local_step += 1 if local_step % args.gradient_accumulation_steps == 0: scheduler.step() if args.amp: scaler.step(optimizer) scaler.update() else: optimizer.step() optimizer.zero_grad() global_step = optimizer.param_groups[0]["step"] if "step" in optimizer.param_groups[0] else 0 if (global_step % args.steps_per_epoch) > 0: mean_loss = tr_loss / (global_step % args.steps_per_epoch) mean_att_loss = tr_att_loss / (global_step % args.steps_per_epoch) mean_rep_loss = tr_rep_loss / (global_step % args.steps_per_epoch) value_loss = tr_value_loss / (global_step % args.steps_per_epoch) if (global_step + 1) % args.eval_step == 0 and is_main_process(): result = {} result['global_step'] = global_step result['lr'] = optimizer.param_groups[0]["lr"] result['loss'] = mean_loss result['att_loss'] = mean_att_loss result['rep_loss'] = mean_rep_loss result['value_loss'] = value_loss result['perf'] = (global_step + 1) * get_world_size() * args.train_batch_size * args.gradient_accumulation_steps / (time.time() - iter_start) output_eval_file = os.path.join(args.output_dir, "log.txt") if is_main_process(): with open(output_eval_file, "a") as writer: logger.info("***** Eval results *****") for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) # Save a trained model model_name = "{}".format(WEIGHTS_NAME) logging.info("** ** * Saving fine-tuned model ** ** * ") # Only save the model it-self model_to_save = student_model.module if hasattr(student_model, 'module') else student_model if use_projection: project_to_save = project.module if hasattr(project, 'module') else project output_model_file = os.path.join(args.output_dir, model_name) output_config_file = os.path.join(args.output_dir, CONFIG_NAME) output_project_file = os.path.join(args.output_dir, "project.bin") torch.save(model_to_save.state_dict(), output_model_file) if use_projection: torch.save(project_to_save.state_dict(), output_project_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(args.output_dir) if oncloud: logging.info(mox.file.list_directory(args.output_dir, recursive=True)) logging.info(mox.file.list_directory('.', recursive=True)) mox.file.copy_parallel(args.output_dir, args.data_url) mox.file.copy_parallel('.', args.data_url) model_name = "{}".format(WEIGHTS_NAME) logging.info("** ** * Saving fine-tuned model ** ** * ") model_to_save = student_model.module if hasattr(student_model, 'module') else student_model if use_projection: project_to_save = project.module if hasattr(project, 'module') else project output_project_file = os.path.join(args.output_dir, "project.bin") if is_main_process(): torch.save(project_to_save.state_dict(), output_project_file) output_model_file = os.path.join(args.output_dir, model_name) output_config_file = os.path.join(args.output_dir, CONFIG_NAME) if is_main_process(): torch.save(model_to_save.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(args.output_dir) if oncloud: logging.info(mox.file.list_directory(args.output_dir, recursive=True)) logging.info(mox.file.list_directory('.', recursive=True)) mox.file.copy_parallel(args.output_dir, args.data_url) mox.file.copy_parallel('.', args.data_url)
def association(self, data_S, gt_S, data_T, gt_T, **kwargs): if cfg.MODEL.DOMAIN_BN: self.set_domain_id(1) res_T = self.net(data_T) preds_T = res_T['out'] feats_T = res_T['feat'] if cfg.MODEL.DOMAIN_BN: self.set_domain_id(0) res_S = self.net(data_S) preds_S = res_S['out'] feats_S = res_S['feat'] total_loss = 0.0 total_loss_dict = {} H, W = feats_S.shape[-2:] new_gt_S = F.interpolate(gt_S.type( torch.cuda.FloatTensor).unsqueeze(1), size=(H, W), mode='nearest').squeeze(1) new_gt_T = F.interpolate(gt_T.type( torch.cuda.FloatTensor).unsqueeze(1), size=(H, W), mode='nearest').squeeze(1) if cfg.TRAIN.USE_CROP: scale_factor = cfg.TRAIN.SCALE_FACTOR N = feats_S.size(0) new_H, new_W = int(scale_factor * H), int(scale_factor * W) feats_S, probs_S, new_gt_S = solver_utils.crop( feats_S, preds_S, new_gt_S, new_H, new_W) feats_T, probs_T, new_gt_T = solver_utils.crop( feats_T, preds_T, new_gt_T, new_H, new_W) elif cfg.TRAIN.USE_DOWNSAMPLING: scale_factor = cfg.TRAIN.SCALE_FACTOR feats_S = F.interpolate(feats_S, scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=False, align_corners=False) feats_T = F.interpolate(feats_T, scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=False, align_corners=False) new_preds_S = F.interpolate(preds_S, scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=False, align_corners=False) new_preds_T = F.interpolate(preds_T, scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=False, align_corners=False) H, W = feats_S.shape[-2:] new_gt_S = F.interpolate(gt_S.type( torch.cuda.FloatTensor).unsqueeze(1), size=(H, W), mode='nearest').squeeze(1) new_gt_T = F.interpolate(gt_T.type( torch.cuda.FloatTensor).unsqueeze(1), size=(H, W), mode='nearest').squeeze(1) probs_S, probs_T = F.softmax(new_preds_S, dim=1), F.softmax(new_preds_T, dim=1) else: probs_S, probs_T = F.softmax(preds_S, dim=1), F.softmax(preds_T, dim=1) ass_loss_dict = self.FeatAssociationLoss(feats_S, feats_T, new_gt_S, new_gt_T) ass_loss = ass_loss_dict['association'] total_loss += cfg.TRAIN.ASSO_W * ass_loss total_loss_dict.update(ass_loss_dict) if cfg.TRAIN.APPLY_MULTILAYER_ASSOCIATION: ass_loss_classifier_dict = self.ClsAssociationLoss( probs_S, probs_T, new_gt_S, new_gt_T) ass_loss_classifier = ass_loss_classifier_dict['association'] total_loss += cfg.TRAIN.ASSO_W * ass_loss_classifier ass_loss_classifier_dict = { key + '_cls': ass_loss_classifier_dict[key] for key in ass_loss_classifier_dict } total_loss_dict.update(ass_loss_classifier_dict) if cfg.TRAIN.LSR_THRES > 0.0: lsr_thres = cfg.TRAIN.LSR_THRES lsr_loss_S = solver_utils.LSR(F.log_softmax(preds_S, dim=1), dim=1, thres=cfg.TRAIN.LSR_THRES) lsr_loss_T = solver_utils.LSR(F.log_softmax(preds_T, dim=1), dim=1, thres=cfg.TRAIN.LSR_THRES) total_loss += cfg.TRAIN.LSR_W * lsr_loss_S total_loss += cfg.TRAIN.LSR_W * lsr_loss_T total_loss_dict['lsr_S'] = lsr_loss_S total_loss_dict['lsr_T'] = lsr_loss_T preds = F.interpolate(preds_S, size=gt_S.shape[-2:], mode='bilinear', align_corners=False) ce_loss = 1.0 * self.CELoss([preds], gt_S) if self.distributed: lov_loss = lovasz_softmax_multigpu(F.softmax(preds, dim=1), gt_S, classes='present', per_image=False, ignore=255) else: lov_loss = lovasz_softmax(F.softmax(preds, dim=1), gt_S, classes='present', per_image=False, ignore=255) ce_loss += (cfg.TRAIN.LOV_W * get_world_size() * self.iter_size) * lov_loss total_loss += ce_loss total_loss_dict['ce_loss'] = ce_loss total_loss_dict['total'] = total_loss preds_T = F.interpolate(preds_T, size=gt_S.shape[-2:], mode='bilinear', align_corners=False) out_dict = { 'feats_S': feats_S, 'feats_T': feats_T, 'preds_S': preds, 'preds_T': preds_T } return total_loss_dict, out_dict
def train(config, epoch, num_epoch, epoch_iters, base_lr, num_iters, trainloader, optimizer, model, writer_dict, device): # Training model.train() batch_time = AverageMeter() ave_loss = AverageMeter() ave_loss_joints = AverageMeter() ave_loss_inp = AverageMeter() ave_acc = AverageMeter() tic = time.time() cur_iters = epoch * epoch_iters writer = writer_dict['writer'] global_steps = writer_dict['train_global_steps'] rank = get_rank() world_size = get_world_size() for i_iter, batch in enumerate(trainloader): images, labels, target_weight, _, name, joints, joints_vis = batch size = labels.size() #cv2.imwrite('groundtruth/gt_'+str(i_iter)+'.png', labels[0].detach().numpy()) images = images.to(device) labels = labels.to(device) losses, losses_joints, losses_inp, pred = model( images, labels, target_weight) #forward #pred = F.upsample(input=pred, size=(size[-2], size[-1]), mode='bilinear') #pred = pred.to('cpu') #cv2.imwrite('prediction/pred_'+str(i_iter)+'.png',pred[0][0].detach().numpy()) #print("saved") label_joints, _ = get_max_preds( labels[:, 0:15, :, :].detach().cpu().numpy()) pred_joints, _ = get_max_preds(pred[:, 0:15, :, :].detach().cpu().numpy()) _, acc, _, _ = accuracy(pred[:, 0:15, :, :].detach().cpu().numpy(), labels[:, 0:15, :, :].detach().cpu().numpy()) save_batch_image_with_joints( images[:, 0:3, :, :], label_joints * 4, joints_vis, 'results/full_RGBD/train/joint_gt/{}_gt.png'.format(i_iter)) save_batch_image_with_joints( images[:, 0:3, :, :], pred_joints * 4, joints_vis, 'results/full_RGBD/train/joint_pred/{}_pred.png'.format(i_iter)) labels = F.upsample(input=labels, size=(256, 256), mode='bilinear') pred = F.upsample(input=pred, size=(256, 256), mode='bilinear') cv2.imwrite( 'results/full_RGBD/train/depth_gt/{}_gt.png'.format(i_iter), labels[0, 15, :, :].detach().cpu().numpy()) cv2.imwrite( 'results/full_RGBD/train/depth_pred/{}_pred.png'.format(i_iter), pred[0, 15, :, :].detach().cpu().numpy()) loss = losses.mean() loss_joints = losses_joints.mean() loss_inp = losses_inp.mean() reduced_loss = reduce_tensor(loss) reduced_loss_joints = reduce_tensor(loss_joints) reduced_loss_inp = reduce_tensor(loss_inp) model.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - tic) tic = time.time() # update average loss ave_loss.update(reduced_loss.item()) ave_loss_joints.update(reduced_loss_joints.item()) ave_loss_inp.update(reduced_loss_inp.item()) ave_acc.update(acc) lr = adjust_learning_rate(optimizer, base_lr, num_iters, i_iter + cur_iters) if i_iter % config.PRINT_FREQ == 0 and rank == 0: print_loss = ave_loss.average() / world_size print_loss_joints = ave_loss_joints.average() / world_size print_loss_inp = ave_loss_inp.average() / world_size print_acc = ave_acc.average() / world_size msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \ 'lr: {:.6f}, Loss: {:.6f}, {:.6f}, {:.6f}, Acc: {:.6f}' .format( epoch, num_epoch, i_iter, epoch_iters, batch_time.average(), lr, print_loss, print_loss_joints, print_loss_inp,print_acc) logging.info(msg) writer.add_scalar('train_loss', print_loss, global_steps) writer.add_scalar('train_loss_joint', print_loss_joints, global_steps) writer.add_scalar('train_loss_depth', print_loss_inp, global_steps) writer.add_scalar('train_accuracy', print_acc, global_steps) writer_dict['train_global_steps'] = global_steps + 1
def validate(config, testloader, model, writer_dict, device): rank = get_rank() #0 world_size = get_world_size() #1 model.eval() ave_loss = AverageMeter() ave_loss_joints = AverageMeter() ave_loss_inp = AverageMeter() ave_accs = AverageMeter() ave_acc = AverageMeter() confusion_matrix = np.zeros( (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES)) with torch.no_grad(): for i_iter, batch in enumerate(testloader): image, label, target_weight, _, name, joints, joints_vis = batch size = label.size() #cv2.imwrite('validation_result/groundtruth/gt_'+str(i_iter)+'.png', label[0].detach().numpy()) image = image.to(device) label = label.to(device) losses, losses_joints, losses_inp, pred = model( image, label, target_weight) #pred = F.upsample(input=pred, size=(64, 64), mode='bilinear') label_joints, _ = get_max_preds( label[:, 0:15, :, :].detach().cpu().numpy()) pred_joints, _ = get_max_preds( pred[:, 0:15, :, :].detach().cpu().numpy()) accs, acc, _, _ = accuracy( pred[:, 0:15, :, :].detach().cpu().numpy(), label[:, 0:15, :, :].detach().cpu().numpy()) save_batch_image_with_joints( image[:, 0:3, :, :], label_joints * 4, joints_vis, 'results/full_RGBD/val/joint_gt/{}_gt.png'.format(i_iter)) save_batch_image_with_joints( image[:, 0:3, :, :], pred_joints * 4, joints_vis, 'results/full_RGBD/val/joint_pred/{}_pred.png'.format(i_iter)) label = F.upsample(input=label, size=(256, 256), mode='bilinear') pred = F.upsample(input=pred, size=(256, 256), mode='bilinear') cv2.imwrite( 'results/full_RGBD/val/depth_gt/{}_gt.png'.format(i_iter), label[0, 15, :, :].detach().cpu().numpy()) cv2.imwrite( 'results/full_RGBD/val/depth_pred/{}_pred.png'.format(i_iter), pred[0, 15, :, :].detach().cpu().numpy()) loss = losses.mean() loss_joints = losses_joints.mean() loss_inp = losses_inp.mean() reduced_loss = reduce_tensor(loss) reduced_loss_joints = reduce_tensor(loss_joints) reduced_loss_inp = reduce_tensor(loss_inp) ave_loss.update(reduced_loss.item()) ave_loss_joints.update(reduced_loss_joints.item()) ave_loss_inp.update(reduced_loss_inp.item()) ave_acc.update(acc) ave_accs.update(accs) print_loss = ave_loss.average() / world_size print_loss_joints = ave_loss_joints.average() / world_size print_loss_inp = ave_loss_inp.average() / world_size print_acc = ave_acc.average() / world_size print_accs = ave_accs.average() / world_size if rank == 0: writer = writer_dict['writer'] global_steps = writer_dict['valid_global_steps'] writer.add_scalar('valid_loss', print_loss, global_steps) writer.add_scalar('valid_loss_joint', print_loss_joints, global_steps) writer.add_scalar('valid_loss_depth', print_loss_inp, global_steps) writer.add_scalar('valid_accuracy', print_acc, global_steps) for i in range(15): writer.add_scalar('valid_each_accuracy_' + str(i), print_accs[i], global_steps) writer_dict['valid_global_steps'] = global_steps + 1 return print_loss, print_loss_joints, print_loss_inp, print_acc