def prepare_data(args):
    train_transform_S = get_transform(train=True,
                                      dataset_name=cfg.DATASET.SOURCE)
    train_transform_T = get_transform(train=True,
                                      dataset_name=cfg.DATASET.TARGET)
    val_transform = get_transform(train=False, dataset_name=cfg.DATASET.VAL)

    train_dataset_S = eval('Dataset.%s' % cfg.DATASET.SOURCE)(
        cfg.DATASET.DATAROOT_S,
        cfg.DATASET.TRAIN_SPLIT_S,
        transform=train_transform_S)

    train_dataset_T = eval('Dataset.%s' % cfg.DATASET.TARGET)(
        cfg.DATASET.DATAROOT_T,
        cfg.DATASET.TRAIN_SPLIT_T,
        transform=train_transform_T)

    val_dataset = eval('Dataset.%s' % cfg.DATASET.VAL)(
        cfg.DATASET.DATAROOT_VAL,
        cfg.DATASET.VAL_SPLIT,
        transform=val_transform)

    # construct dataloaders
    train_dataloader_S = data_utils.get_dataloader(
        train_dataset_S,
        cfg.TRAIN.TRAIN_BATCH_SIZE,
        cfg.NUM_WORKERS,
        train=True,
        distributed=args.distributed,
        world_size=gen_utils.get_world_size())

    train_dataloader_T = data_utils.get_dataloader(
        train_dataset_T,
        cfg.TRAIN.TRAIN_BATCH_SIZE,
        cfg.NUM_WORKERS,
        train=True,
        distributed=args.distributed,
        world_size=gen_utils.get_world_size())

    val_dataloader = data_utils.get_dataloader(
        val_dataset,
        cfg.TRAIN.VAL_BATCH_SIZE,
        cfg.NUM_WORKERS,
        train=False,
        distributed=args.distributed,
        world_size=gen_utils.get_world_size())

    dataloaders = {'train_S': train_dataloader_S, \
            'train_T': train_dataloader_T, 'val': val_dataloader}

    return dataloaders
示例#2
0
def train(config, epoch, num_epoch, epoch_iters, base_lr, num_iters,
          trainloader, optimizer, model, writer_dict, device):

    # Training
    model.train()
    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    ave_loss1 = AverageMeter()
    ave_aux_loss = AverageMeter()
    ave_error_loss = AverageMeter()
    tic = time.time()
    cur_iters = epoch * epoch_iters
    writer = writer_dict['writer']
    global_steps = writer_dict['train_global_steps']
    rank = get_rank()
    world_size = get_world_size()

    for i_iter, batch in enumerate(trainloader):
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)
        losses, aux_loss, error_loss, _ = model(images, labels)
        # print('pred', pred[2].size())
        loss = losses.mean() + 0.4 * aux_loss.mean() + 1 * error_loss.mean()

        reduced_loss = reduce_tensor(loss)
        loss1 = reduce_tensor(losses)
        aux_loss = reduce_tensor(aux_loss)
        error_losses = reduce_tensor(error_loss)

        model.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss
        ave_loss.update(reduced_loss.item())
        ave_loss1.update(loss1.item())
        ave_aux_loss.update(aux_loss.item())
        ave_error_loss.update(error_losses.item())

        lr = adjust_learning_rate(optimizer, base_lr, num_iters,
                                  i_iter + cur_iters)

        if i_iter % config.PRINT_FREQ == 0 and rank == 0:
            print_loss = ave_loss.average() / world_size
            print_loss1 = ave_loss1.average() / world_size
            print_loss_aux = ave_aux_loss.average() / world_size
            print_error_loss = ave_error_loss.average() / world_size
            msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \
                  'lr: {:.6f}, Loss: {:.6f}, Loss_1: {:.6f}, Loss_aux: {:.6f}, error_loss: {:.6f}' .format(
                      epoch, num_epoch, i_iter, epoch_iters,
                      batch_time.average(), lr, print_loss, print_loss1, print_loss_aux, print_error_loss)
            logging.info(msg)

            writer.add_scalar('train_loss', print_loss, global_steps)
            writer_dict['train_global_steps'] = global_steps + 1
示例#3
0
    def source_only(self, data_S, gt_S, data_T, gt_T, *others, **kwargs):
        self.set_domain_id(0)
        preds = self.net(data_S)['out']
        preds = F.interpolate(preds,
                              size=data_S.shape[-2:],
                              mode='bilinear',
                              align_corners=False)
        ce_loss = self.CELoss([preds], gt_S)
        if cfg.TRAIN.WITH_LOV:
            if self.distributed:
                lov_loss = lovasz_softmax_multigpu(F.softmax(preds, dim=1),
                                                   gt_S,
                                                   classes='present',
                                                   per_image=False,
                                                   ignore=255)
            else:
                lov_loss = lovasz_softmax(F.softmax(preds, dim=1),
                                          gt_S,
                                          classes='present',
                                          per_image=False,
                                          ignore=255)

            ce_loss += (cfg.TRAIN.LOV_W * get_world_size() *
                        self.iter_size) * lov_loss

        out_dict = {
            'feats_S': None,
            'feats_T': None,
            'preds_S': preds,
            'preds_T': None
        }
        return {'total': ce_loss}, out_dict
示例#4
0
def validate(config, testloader, model, writer_dict, device):

    rank = get_rank()
    world_size = get_world_size()
    model.eval()
    ave_loss = AverageMeter()
    confusion_matrix = np.zeros(
        (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))
    confusion_matrix_sum = np.zeros(
        (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))

    with torch.no_grad():
        for _, batch in enumerate(testloader):
            image, label, boundary_gt, _, _ = batch
            size = label.size()
            image = image.to(device)
            boundary_gt = boundary_gt.to(device)
            label = label.long().to(device)

            losses, aux_loss, error_loss, losses_2, aux_loss_2, error_loss_2, preds = model(
                image, label, boundary_gt.float())
            pred = F.upsample(input=preds[0],
                              size=(size[-2], size[-1]),
                              mode='bilinear')

            loss = (losses + 0.4 * aux_loss + 4 * error_loss + losses_2 +
                    0.4 * aux_loss_2 + 4 * error_loss_2).mean()
            reduced_loss = reduce_tensor(loss)
            ave_loss.update(reduced_loss.item())

            confusion_matrix += get_confusion_matrix(
                label, pred, size, config.DATASET.NUM_CLASSES,
                config.TRAIN.IGNORE_LABEL)

    confusion_matrix = torch.from_numpy(confusion_matrix).to(device)
    reduced_confusion_matrix = reduce_tensor(confusion_matrix)

    confusion_matrix = reduced_confusion_matrix.cpu().numpy()
    pos = confusion_matrix.sum(1)
    res = confusion_matrix.sum(0)
    tp = np.diag(confusion_matrix)
    IoU_array = (tp / np.maximum(1.0, pos + res - tp))
    mean_IoU = IoU_array.mean()
    print_loss = ave_loss.average() / world_size

    if rank == 0:
        writer = writer_dict['writer']
        global_steps = writer_dict['valid_global_steps']
        writer.add_scalar('valid_loss', print_loss, global_steps)
        writer.add_scalar('valid_mIoU', mean_IoU, global_steps)
        writer_dict['valid_global_steps'] = global_steps + 1
        # cv2.imwrite(str(global_steps)+'_boundary.png', (preds[0][0][0].data.cpu().numpy()*255).astype(np.uint8))
        # cv2.imwrite(str(global_steps) + '_error.png', (preds[2][0][0].data.cpu().numpy() * 255).astype(np.uint8))
        cv2.imwrite(
            str(global_steps) + '_error.png',
            (preds[2][0][0].data.cpu().numpy() * 255).astype(np.uint8))
    return print_loss, mean_IoU, IoU_array
def train(config, epoch, num_epoch, epoch_iters, base_lr, num_iters,
         trainloader, optimizer, lr_scheduler, model, writer_dict, device):
    
    # Training
    model.train()
    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    tic = time.time()
    cur_iters = epoch*epoch_iters
    writer = writer_dict['writer']
    global_steps = writer_dict['train_global_steps']
    rank = get_rank()
    world_size = get_world_size()

    for i_iter, batch in enumerate(trainloader):
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        losses, _ = model(images, labels, train_step=(lr_scheduler._step_count-1))
        loss = losses.mean()

        reduced_loss = reduce_tensor(loss)

        model.zero_grad()
        loss.backward()
        optimizer.step()
        
        if config.TRAIN.LR_SCHEDULER != 'step':
            lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss
        ave_loss.update(reduced_loss.item())

        lr = adjust_learning_rate(optimizer,
                                  base_lr,
                                  num_iters,
                                  i_iter+cur_iters)

        if i_iter % config.PRINT_FREQ == 0 and rank == 0:
            print_loss = ave_loss.average() / world_size
            msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \
                  'lr: {:.6f}, Loss: {:.6f}' .format(
                      epoch, num_epoch, i_iter, epoch_iters, 
                      batch_time.average(), lr, print_loss)
            logging.info(msg)
            
            writer.add_scalar('train_loss', print_loss, global_steps)
            writer_dict['train_global_steps'] = global_steps + 1
            batch_time = AverageMeter()
示例#6
0
def reduce_tensor(inp):
    """
    Reduce the loss from all processes so that
    process with rank 0 has the averaged results.
    """
    world_size = get_world_size()
    if world_size < 2:
        return inp
    with torch.no_grad():
        reduced_inp = inp
        dist.reduce(reduced_inp, dst=0)
    return reduced_inp
示例#7
0
def validate(config, testloader, model, writer_dict, device):
    rank = get_rank()
    world_size = get_world_size()
    model.eval()
    ave_loss = AverageMeter()
    tot_inter = np.zeros(config.DATASET.NUM_CLASSES)
    tot_union = np.zeros(config.DATASET.NUM_CLASSES)
    with torch.no_grad():
        for i_iter, batch in enumerate(testloader):
            image, label, _, _ = batch
            size = label.size()
            label = label.long().to(device)
            image = image.to(device)

            loss, pred = model(image, label)
            if pred.size()[-2] != size[-2] or pred.size()[-1] != size[-1]:
                pred = F.interpolate(pred,
                                     size=(size[-2], size[-1]),
                                     mode='bilinear',
                                     align_corners=False)
            reduced_loss = reduce_tensor(loss)
            ave_loss.update(reduced_loss.item())

            batch_inter, batch_union = batch_intersection_union(
                pred, label, config.DATASET.NUM_CLASSES)
            tot_inter += batch_inter
            tot_union += batch_union

            if i_iter % config.PRINT_FREQ == 0 and rank == 0:
                msg = f'Iter: {i_iter}, Loss: {ave_loss.average() / world_size:.6f}'
                logging.info(msg)

    tot_inter = torch.from_numpy(tot_inter).to(device)
    tot_union = torch.from_numpy(tot_union).to(device)
    tot_inter = reduce_tensor(tot_inter).cpu().numpy()
    tot_union = reduce_tensor(tot_union).cpu().numpy()
    IoU = np.float64(1.0) * tot_inter / (np.spacing(1, dtype=np.float64) +
                                         tot_union)
    mean_IoU = IoU.mean()
    print_loss = ave_loss.average() / world_size

    if rank == 0:
        writer = writer_dict['writer']
        global_steps = writer_dict['valid_global_steps']
        writer.add_scalar('valid_loss', print_loss, global_steps)
        writer.add_scalar('valid_mIoU', mean_IoU, global_steps)
        writer_dict['valid_global_steps'] = global_steps + 1

    return print_loss, mean_IoU
示例#8
0
def validate(config, testloader, model, writer_dict, device):
    
    rank = get_rank()
    world_size = get_world_size()
    model.eval()
    ave_loss = AverageMeter()
    confusion_matrix = np.zeros(
        (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))

    with torch.no_grad():
        for _, batch in enumerate(testloader):
            image, label, _, _ = batch
            size = label.size()
            image = image.to(device)
            label = label.long().to(device)

            losses, pred = model(image, label)
            pred = F.upsample(input=pred, size=(
                        size[-2], size[-1]), mode='bilinear')
            loss = losses.mean()
            reduced_loss = reduce_tensor(loss)
            ave_loss.update(reduced_loss.item())

            confusion_matrix += get_confusion_matrix(
                label,
                pred,
                size,
                config.DATASET.NUM_CLASSES,
                config.TRAIN.IGNORE_LABEL)

    confusion_matrix = torch.from_numpy(confusion_matrix).to(device)
    reduced_confusion_matrix = reduce_tensor(confusion_matrix)

    confusion_matrix = reduced_confusion_matrix.cpu().numpy()
    pos = confusion_matrix.sum(1)
    res = confusion_matrix.sum(0)
    tp = np.diag(confusion_matrix)
    IoU_array = (tp / np.maximum(1.0, pos + res - tp))
    mean_IoU = IoU_array.mean()
    print_loss = ave_loss.average()/world_size

    if rank == 0:
        writer = writer_dict['writer']
        global_steps = writer_dict['valid_global_steps']
        writer.add_scalar('valid_loss', print_loss, global_steps)
        writer.add_scalar('valid_mIoU', mean_IoU, global_steps)
        writer_dict['valid_global_steps'] = global_steps + 1
    return print_loss, mean_IoU, IoU_array
示例#9
0
def train(config, epoch, num_epoch, epoch_iters, trainloader, optimizer,
          lr_scheduler, model, writer_dict, device):
    # Training
    model.train()
    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    tic = time.time()
    rank = get_rank()
    world_size = get_world_size()

    for i_iter, batch in enumerate(trainloader, 0):
        images, labels, _, _ = batch
        labels = labels.long().to(device)
        images = images.to(device)

        loss, _ = model(images, labels)
        reduced_loss = reduce_tensor(loss)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss
        ave_loss.update(reduced_loss.item())

        lr = optimizer.param_groups[0]['lr']

        if i_iter % config.PRINT_FREQ == 0 and rank == 0:
            print_loss = ave_loss.average() / world_size
            msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \
                  'lr: {:.6f}, Loss: {:.6f}' .format(
                      epoch, num_epoch, i_iter, epoch_iters,
                      batch_time.average(), lr, print_loss)
            logging.info(msg)

    if rank == 0:
        writer = writer_dict['writer']
        global_steps = writer_dict['train_global_steps']
        writer.add_scalar('train_loss',
                          ave_loss.average() / world_size, global_steps)
        writer_dict['train_global_steps'] = global_steps + 1
示例#10
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--input_dir",
                        type=str,
                        required=True)
    parser.add_argument("--teacher_model",
                        default=None,
                        type=str,
                        required=True)
    parser.add_argument("--student_model",
                        default=None,
                        type=str,
                        required=True)
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True)
    parser.add_argument('--vocab_file',
                        type=str,
                        default=None,
                        required=True,
                        help="Vocabulary mapping/file BERT was pretrainined on")

    # Other parameters
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--reduce_memory",
                        action="store_true",
                        help="Store training data as on-disc memmaps to massively reduce memory usage")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--weight_decay',
                        '--wd',
                        default=1e-4,
                        type=float, metavar='W',
                        help='weight decay')
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--steps_per_epoch',
                        type=int,
                        default=-1,
                        help="Number of updates steps to in one epoch.")
    parser.add_argument('--max_steps',
                        type=int,
                        default=-1,
                        help="Number of training steps.")
    parser.add_argument('--amp',
                        action='store_true',
                        default=False,
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--continue_train',
                        action='store_true',
                        default=False,
                        help='Whether to train from checkpoints')
    parser.add_argument('--disable_progress_bar',
                        default=False,
                        action='store_true',
                        help='Disable tqdm progress bar')
    parser.add_argument('--max_grad_norm',
                        type=float,
                        default=1.,
                        help="Gradient Clipping threshold")

    # Additional arguments
    parser.add_argument('--eval_step',
                        type=int,
                        default=1000)

    # This is used for running on Huawei Cloud.
    parser.add_argument('--data_url',
                        type=str,
                        default="")

    #Distillation specific
    parser.add_argument('--value_state_loss',
                        action='store_true',
                        default=False)
    parser.add_argument('--hidden_state_loss',
                        action='store_true',
                        default=False)
    parser.add_argument('--use_last_layer',
                        action='store_true',
                        default=False)
    parser.add_argument('--use_kld',
                        action='store_true',
                        default=False)
    parser.add_argument('--use_cosine',
                        action='store_true',
                        default=False)
    parser.add_argument('--distill_config',
                        default="distillation_config.json",
                        type=str,
                        help="path the distillation config")
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='number of DataLoader worker processes per rank')

    args = parser.parse_args()
    logger.info('args:{}'.format(args))

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')

    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt='%m/%d/%Y %H:%M:%S',
                        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
                        stream=sys.stdout)

    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.amp))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
            args.gradient_accumulation_steps))

    # Reference params
    author_gbs = 256
    author_steps_per_epoch = 22872
    author_epochs = 3
    author_max_steps = author_steps_per_epoch * author_epochs
    # Compute present run params
    if args.max_steps == -1 or args.steps_per_epoch == -1:
        args.steps_per_epoch = author_steps_per_epoch * author_gbs // (args.train_batch_size * get_world_size() * args.gradient_accumulation_steps)
        args.max_steps = author_max_steps * author_gbs // (args.train_batch_size * get_world_size() * args.gradient_accumulation_steps)

    #Set seed
    set_seed(args.seed, n_gpu)

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir) and is_main_process():
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer.from_pretrained(args.teacher_model, do_lower_case=args.do_lower_case)

    teacher_model, teacher_config = BertModel.from_pretrained(args.teacher_model,
                                              distill_config=args.distill_config)

    # Required to make sure model's fwd doesn't return anything. required for DDP.
    # fwd output not being used in loss computation crashes DDP
    teacher_model.make_teacher()

    if args.continue_train:
        student_model, student_config = BertForPreTraining.from_pretrained(args.student_model,
                                                           distill_config=args.distill_config)
    else:
        student_model, student_config = BertForPreTraining.from_scratch(args.student_model, 
                                                        distill_config=args.distill_config)

    # We need a projection layer since teacher.hidden_size != student.hidden_size
    use_projection = student_config.hidden_size != teacher_config.hidden_size
    if use_projection:
        project = Project(student_config, teacher_config)
        if args.continue_train:
            project_model_file = os.path.join(args.student_model, "project.bin")
            project_ckpt = torch.load(project_model_file, map_location="cpu")
            project.load_state_dict(project_ckpt)

    distill_config = {"nn_module_names": []} #Empty list since we don't want to use nn module hooks here
    distill_hooks_student, distill_hooks_teacher = DistillHooks(distill_config), DistillHooks(distill_config)

    student_model.register_forward_hook(distill_hooks_student.child_to_main_hook)
    teacher_model.register_forward_hook(distill_hooks_teacher.child_to_main_hook)

    ## Register hooks on nn.Modules
    # student_fwd_pre_hook = student_model.register_forward_pre_hook(distill_hooks_student.register_nn_module_hook)
    # teacher_fwd_pre_hook = teacher_model.register_forward_pre_hook(distill_hooks_teacher.register_nn_module_hook)

    student_model.to(device)
    teacher_model.to(device)
    if use_projection:
        project.to(device)
    if args.local_rank != -1:
        teacher_model = torch.nn.parallel.DistributedDataParallel(
               teacher_model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=False
           )
        student_model = torch.nn.parallel.DistributedDataParallel(
               student_model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=False
           )
        if use_projection:
            project = torch.nn.parallel.DistributedDataParallel(
                   project, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=False
               )
    size = 0
    for n, p in student_model.named_parameters():
        logger.info('n: {}'.format(n))
        logger.info('p: {}'.format(p.nelement()))
        size += p.nelement()

    logger.info('Total parameters: {}'.format(size))

    # Prepare optimizer
    param_optimizer = list(student_model.named_parameters())
    if use_projection:
        param_optimizer += list(project.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

    optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False)
    scheduler = LinearWarmUpScheduler(optimizer, warmup=args.warmup_proportion, total_steps=args.max_steps)    

    global_step = 0
    logging.info("***** Running training *****")
    logging.info("  Num examples = {}".format(args.train_batch_size * args.max_steps))
    logging.info("  Batch size = %d", args.train_batch_size)
    logging.info("  Num steps = %d", args.max_steps)

    # Prepare the data loader.
    if is_main_process():
        tic = time.perf_counter()
    train_dataloader = lddl.torch.get_bert_pretrain_data_loader(
        args.input_dir,
        local_rank=args.local_rank,
        vocab_file=args.vocab_file,
        data_loader_kwargs={
            'batch_size': args.train_batch_size * n_gpu,
            'num_workers': args.num_workers,
            'pin_memory': True,
        },
        base_seed=args.seed,
        log_dir=None if args.output_dir is None else os.path.join(args.output_dir, 'lddl_log'),
        log_level=logging.WARNING,
        start_epoch=0,
    )
    if is_main_process():
        print('get_bert_pretrain_data_loader took {} s!'.format(time.perf_counter() - tic))
    train_dataloader = tqdm(train_dataloader, desc="Iteration", disable=args.disable_progress_bar) if is_main_process() else train_dataloader

    tr_loss, tr_att_loss, tr_rep_loss, tr_value_loss = 0., 0., 0., 0.
    nb_tr_examples, local_step = 0, 0

    student_model.train()
    scaler = torch.cuda.amp.GradScaler()

    transformer_losses = TransformerLosses(student_config, teacher_config, device, args)
    iter_start = time.time()
    while global_step < args.max_steps:
        for batch in train_dataloader:
            if global_step >= args.max_steps:
                break

            #remove forward_pre_hook after one forward pass
            #the purpose of forward_pre_hook is to register
            #forward_hooks on nn_module_names provided in config
            # if idx == 1:
            #     student_fwd_pre_hook.remove()
            #     teacher_fwd_pre_hook.remove()
            #     # return

            # Initialize loss metrics
            if global_step % args.steps_per_epoch == 0:
                tr_loss, tr_att_loss, tr_rep_loss, tr_value_loss = 0., 0., 0., 0.
                mean_loss, mean_att_loss, mean_rep_loss, mean_value_loss = 0., 0., 0., 0.

            batch = {k: v.to(device) for k, v in batch.items()}
            input_ids, segment_ids, input_mask, lm_label_ids, is_next = batch['input_ids'], batch['token_type_ids'], batch['attention_mask'], batch['labels'], batch['next_sentence_labels']

            att_loss = 0.
            rep_loss = 0.
            value_loss = 0.
            with torch.cuda.amp.autocast(enabled=args.amp):
                student_model(input_ids, segment_ids, input_mask, None)

                # Gather student states extracted by hooks
                temp_model = unwrap_ddp(student_model)
                student_atts = flatten_states(temp_model.distill_states_dict, "attention_scores")
                student_reps = flatten_states(temp_model.distill_states_dict, "hidden_states")
                student_values = flatten_states(temp_model.distill_states_dict, "value_states")
                student_embeddings = flatten_states(temp_model.distill_states_dict, "embedding_states")
                bsz, attn_heads, seq_len, _  = student_atts[0].shape

                #No gradient for teacher training
                with torch.no_grad():
                    teacher_model(input_ids, segment_ids, input_mask)

                # Gather teacher states extracted by hooks
                temp_model = unwrap_ddp(teacher_model)
                teacher_atts = [i.detach() for i in flatten_states(temp_model.distill_states_dict, "attention_scores")]
                teacher_reps = [i.detach() for i in flatten_states(temp_model.distill_states_dict, "hidden_states")]
                teacher_values = [i.detach() for i in flatten_states(temp_model.distill_states_dict, "value_states")]
                teacher_embeddings = [i.detach() for i in flatten_states(temp_model.distill_states_dict, "embedding_states")]

                teacher_layer_num = len(teacher_atts)
                student_layer_num = len(student_atts)

                #MiniLM
                if student_config.distillation_config["student_teacher_layer_mapping"] == "last_layer":
                    if student_config.distillation_config["use_attention_scores"]:
                        student_atts = [student_atts[-1]]
                        new_teacher_atts = [teacher_atts[-1]]

                    if student_config.distillation_config["use_value_states"]:
                        student_values = [student_values[-1]]
                        new_teacher_values = [teacher_values[-1]]

                    if student_config.distillation_config["use_hidden_states"]:
                        new_teacher_reps = [teacher_reps[-1]]
                        new_student_reps = [student_reps[-1]]
                else:
                    assert teacher_layer_num % student_layer_num == 0

                    layers_per_block = int(teacher_layer_num / student_layer_num)
                    if student_config.distillation_config["use_attention_scores"]:
                        new_teacher_atts = [teacher_atts[i * layers_per_block + layers_per_block - 1]
                                            for i in range(student_layer_num)]

                    if student_config.distillation_config["use_value_states"]:
                        new_teacher_values = [teacher_values[i * layers_per_block + layers_per_block - 1]
                                    for i in range(student_layer_num)]

                    if student_config.distillation_config["use_hidden_states"]:
                        new_teacher_reps = [teacher_reps[i * layers_per_block + layers_per_block - 1]
                                    for i in range(student_layer_num)]
                        new_student_reps = student_reps

                if student_config.distillation_config["use_attention_scores"]:
                    att_loss = transformer_losses.compute_loss(student_atts, new_teacher_atts, loss_name="attention_loss")

                if student_config.distillation_config["use_hidden_states"]:
                    if use_projection:
                        rep_loss = transformer_losses.compute_loss(project(new_student_reps), new_teacher_reps, loss_name="hidden_state_loss")
                    else:
                        rep_loss = transformer_losses.compute_loss(new_student_reps, new_teacher_reps, loss_name="hidden_state_loss")

                if student_config.distillation_config["use_embedding_states"]:
                    if use_projection:
                        rep_loss += transformer_losses.compute_loss(project(student_embeddings), teacher_embeddings, loss_name="embedding_state_loss")
                    else:
                        rep_loss += transformer_losses.compute_loss(student_embeddings, teacher_embeddings, loss_name="embedding_state_loss")

                if student_config.distillation_config["use_value_states"]:
                    value_loss = transformer_losses.compute_loss(student_values, new_teacher_values, loss_name="value_state_loss")

                loss = att_loss + rep_loss + value_loss


            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            tr_att_loss += att_loss.item() / args.gradient_accumulation_steps
            if student_config.distillation_config["use_hidden_states"]:
                tr_rep_loss += rep_loss.item() / args.gradient_accumulation_steps
            if student_config.distillation_config["use_value_states"]:
                tr_value_loss += value_loss.item() / args.gradient_accumulation_steps
            if args.amp:
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
            else:
                loss.backward()

            if use_projection:
                torch.nn.utils.clip_grad_norm_(chain(student_model.parameters(), project.parameters()), args.max_grad_norm, error_if_nonfinite=False)
            else:
                torch.nn.utils.clip_grad_norm_(student_model.parameters(), args.max_grad_norm, error_if_nonfinite=False)

            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            local_step += 1

            if local_step % args.gradient_accumulation_steps == 0:
                scheduler.step()
                if args.amp:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()

                optimizer.zero_grad()
                global_step = optimizer.param_groups[0]["step"] if "step" in optimizer.param_groups[0] else 0

                if (global_step % args.steps_per_epoch) > 0:
                    mean_loss = tr_loss / (global_step % args.steps_per_epoch)
                    mean_att_loss = tr_att_loss / (global_step % args.steps_per_epoch)
                    mean_rep_loss = tr_rep_loss / (global_step % args.steps_per_epoch)
                    value_loss = tr_value_loss / (global_step % args.steps_per_epoch)

                if (global_step + 1) % args.eval_step == 0 and is_main_process():
                    result = {}
                    result['global_step'] = global_step
                    result['lr'] = optimizer.param_groups[0]["lr"]
                    result['loss'] = mean_loss
                    result['att_loss'] = mean_att_loss
                    result['rep_loss'] = mean_rep_loss
                    result['value_loss'] = value_loss
                    result['perf'] = (global_step + 1) * get_world_size() * args.train_batch_size * args.gradient_accumulation_steps / (time.time() - iter_start)
                    output_eval_file = os.path.join(args.output_dir, "log.txt")
                    if is_main_process():
                        with open(output_eval_file, "a") as writer:
                            logger.info("***** Eval results *****")
                            for key in sorted(result.keys()):
                                logger.info("  %s = %s", key, str(result[key]))
                                writer.write("%s = %s\n" % (key, str(result[key])))

                        # Save a trained model
                        model_name = "{}".format(WEIGHTS_NAME)

                        logging.info("** ** * Saving fine-tuned model ** ** * ")
                        # Only save the model it-self
                        model_to_save = student_model.module if hasattr(student_model, 'module') else student_model
                        if use_projection:
                            project_to_save = project.module if hasattr(project, 'module') else project

                        output_model_file = os.path.join(args.output_dir, model_name)
                        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
                        output_project_file = os.path.join(args.output_dir, "project.bin")
                        torch.save(model_to_save.state_dict(), output_model_file)
                        if use_projection:
                            torch.save(project_to_save.state_dict(), output_project_file)
                        model_to_save.config.to_json_file(output_config_file)
                        tokenizer.save_vocabulary(args.output_dir)

                        if oncloud:
                            logging.info(mox.file.list_directory(args.output_dir, recursive=True))
                            logging.info(mox.file.list_directory('.', recursive=True))
                            mox.file.copy_parallel(args.output_dir, args.data_url)
                            mox.file.copy_parallel('.', args.data_url)

    model_name = "{}".format(WEIGHTS_NAME)
    logging.info("** ** * Saving fine-tuned model ** ** * ")
    model_to_save = student_model.module if hasattr(student_model, 'module') else student_model

    if use_projection:
        project_to_save = project.module if hasattr(project, 'module') else project
        output_project_file = os.path.join(args.output_dir, "project.bin")
        if is_main_process():
            torch.save(project_to_save.state_dict(), output_project_file)

    output_model_file = os.path.join(args.output_dir, model_name)
    output_config_file = os.path.join(args.output_dir, CONFIG_NAME)

    if is_main_process():
        torch.save(model_to_save.state_dict(), output_model_file)
        model_to_save.config.to_json_file(output_config_file)
        tokenizer.save_vocabulary(args.output_dir)

    if oncloud:
        logging.info(mox.file.list_directory(args.output_dir, recursive=True))
        logging.info(mox.file.list_directory('.', recursive=True))
        mox.file.copy_parallel(args.output_dir, args.data_url)
        mox.file.copy_parallel('.', args.data_url)
示例#11
0
    def association(self, data_S, gt_S, data_T, gt_T, **kwargs):
        if cfg.MODEL.DOMAIN_BN:
            self.set_domain_id(1)
        res_T = self.net(data_T)
        preds_T = res_T['out']
        feats_T = res_T['feat']

        if cfg.MODEL.DOMAIN_BN:
            self.set_domain_id(0)
        res_S = self.net(data_S)
        preds_S = res_S['out']
        feats_S = res_S['feat']

        total_loss = 0.0
        total_loss_dict = {}

        H, W = feats_S.shape[-2:]
        new_gt_S = F.interpolate(gt_S.type(
            torch.cuda.FloatTensor).unsqueeze(1),
                                 size=(H, W),
                                 mode='nearest').squeeze(1)
        new_gt_T = F.interpolate(gt_T.type(
            torch.cuda.FloatTensor).unsqueeze(1),
                                 size=(H, W),
                                 mode='nearest').squeeze(1)

        if cfg.TRAIN.USE_CROP:
            scale_factor = cfg.TRAIN.SCALE_FACTOR
            N = feats_S.size(0)
            new_H, new_W = int(scale_factor * H), int(scale_factor * W)

            feats_S, probs_S, new_gt_S = solver_utils.crop(
                feats_S, preds_S, new_gt_S, new_H, new_W)
            feats_T, probs_T, new_gt_T = solver_utils.crop(
                feats_T, preds_T, new_gt_T, new_H, new_W)

        elif cfg.TRAIN.USE_DOWNSAMPLING:
            scale_factor = cfg.TRAIN.SCALE_FACTOR
            feats_S = F.interpolate(feats_S,
                                    scale_factor=scale_factor,
                                    mode='bilinear',
                                    recompute_scale_factor=False,
                                    align_corners=False)
            feats_T = F.interpolate(feats_T,
                                    scale_factor=scale_factor,
                                    mode='bilinear',
                                    recompute_scale_factor=False,
                                    align_corners=False)
            new_preds_S = F.interpolate(preds_S,
                                        scale_factor=scale_factor,
                                        mode='bilinear',
                                        recompute_scale_factor=False,
                                        align_corners=False)
            new_preds_T = F.interpolate(preds_T,
                                        scale_factor=scale_factor,
                                        mode='bilinear',
                                        recompute_scale_factor=False,
                                        align_corners=False)

            H, W = feats_S.shape[-2:]
            new_gt_S = F.interpolate(gt_S.type(
                torch.cuda.FloatTensor).unsqueeze(1),
                                     size=(H, W),
                                     mode='nearest').squeeze(1)
            new_gt_T = F.interpolate(gt_T.type(
                torch.cuda.FloatTensor).unsqueeze(1),
                                     size=(H, W),
                                     mode='nearest').squeeze(1)

            probs_S, probs_T = F.softmax(new_preds_S,
                                         dim=1), F.softmax(new_preds_T, dim=1)

        else:
            probs_S, probs_T = F.softmax(preds_S, dim=1), F.softmax(preds_T,
                                                                    dim=1)

        ass_loss_dict = self.FeatAssociationLoss(feats_S, feats_T, new_gt_S,
                                                 new_gt_T)
        ass_loss = ass_loss_dict['association']
        total_loss += cfg.TRAIN.ASSO_W * ass_loss
        total_loss_dict.update(ass_loss_dict)

        if cfg.TRAIN.APPLY_MULTILAYER_ASSOCIATION:
            ass_loss_classifier_dict = self.ClsAssociationLoss(
                probs_S, probs_T, new_gt_S, new_gt_T)

            ass_loss_classifier = ass_loss_classifier_dict['association']
            total_loss += cfg.TRAIN.ASSO_W * ass_loss_classifier
            ass_loss_classifier_dict = {
                key + '_cls': ass_loss_classifier_dict[key]
                for key in ass_loss_classifier_dict
            }
            total_loss_dict.update(ass_loss_classifier_dict)

            if cfg.TRAIN.LSR_THRES > 0.0:
                lsr_thres = cfg.TRAIN.LSR_THRES
                lsr_loss_S = solver_utils.LSR(F.log_softmax(preds_S, dim=1),
                                              dim=1,
                                              thres=cfg.TRAIN.LSR_THRES)
                lsr_loss_T = solver_utils.LSR(F.log_softmax(preds_T, dim=1),
                                              dim=1,
                                              thres=cfg.TRAIN.LSR_THRES)

                total_loss += cfg.TRAIN.LSR_W * lsr_loss_S
                total_loss += cfg.TRAIN.LSR_W * lsr_loss_T

                total_loss_dict['lsr_S'] = lsr_loss_S
                total_loss_dict['lsr_T'] = lsr_loss_T

        preds = F.interpolate(preds_S,
                              size=gt_S.shape[-2:],
                              mode='bilinear',
                              align_corners=False)
        ce_loss = 1.0 * self.CELoss([preds], gt_S)
        if self.distributed:
            lov_loss = lovasz_softmax_multigpu(F.softmax(preds, dim=1),
                                               gt_S,
                                               classes='present',
                                               per_image=False,
                                               ignore=255)
        else:
            lov_loss = lovasz_softmax(F.softmax(preds, dim=1),
                                      gt_S,
                                      classes='present',
                                      per_image=False,
                                      ignore=255)

        ce_loss += (cfg.TRAIN.LOV_W * get_world_size() *
                    self.iter_size) * lov_loss

        total_loss += ce_loss
        total_loss_dict['ce_loss'] = ce_loss
        total_loss_dict['total'] = total_loss

        preds_T = F.interpolate(preds_T,
                                size=gt_S.shape[-2:],
                                mode='bilinear',
                                align_corners=False)
        out_dict = {
            'feats_S': feats_S,
            'feats_T': feats_T,
            'preds_S': preds,
            'preds_T': preds_T
        }
        return total_loss_dict, out_dict
示例#12
0
def train(config, epoch, num_epoch, epoch_iters, base_lr, num_iters,
          trainloader, optimizer, model, writer_dict, device):

    # Training
    model.train()
    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    ave_loss_joints = AverageMeter()
    ave_loss_inp = AverageMeter()
    ave_acc = AverageMeter()
    tic = time.time()
    cur_iters = epoch * epoch_iters
    writer = writer_dict['writer']
    global_steps = writer_dict['train_global_steps']
    rank = get_rank()
    world_size = get_world_size()

    for i_iter, batch in enumerate(trainloader):
        images, labels, target_weight, _, name, joints, joints_vis = batch
        size = labels.size()
        #cv2.imwrite('groundtruth/gt_'+str(i_iter)+'.png', labels[0].detach().numpy())
        images = images.to(device)
        labels = labels.to(device)

        losses, losses_joints, losses_inp, pred = model(
            images, labels, target_weight)  #forward
        #pred = F.upsample(input=pred, size=(size[-2], size[-1]), mode='bilinear')
        #pred = pred.to('cpu')
        #cv2.imwrite('prediction/pred_'+str(i_iter)+'.png',pred[0][0].detach().numpy())
        #print("saved")

        label_joints, _ = get_max_preds(
            labels[:, 0:15, :, :].detach().cpu().numpy())
        pred_joints, _ = get_max_preds(pred[:,
                                            0:15, :, :].detach().cpu().numpy())

        _, acc, _, _ = accuracy(pred[:, 0:15, :, :].detach().cpu().numpy(),
                                labels[:, 0:15, :, :].detach().cpu().numpy())

        save_batch_image_with_joints(
            images[:, 0:3, :, :], label_joints * 4, joints_vis,
            'results/full_RGBD/train/joint_gt/{}_gt.png'.format(i_iter))
        save_batch_image_with_joints(
            images[:, 0:3, :, :], pred_joints * 4, joints_vis,
            'results/full_RGBD/train/joint_pred/{}_pred.png'.format(i_iter))

        labels = F.upsample(input=labels, size=(256, 256), mode='bilinear')
        pred = F.upsample(input=pred, size=(256, 256), mode='bilinear')

        cv2.imwrite(
            'results/full_RGBD/train/depth_gt/{}_gt.png'.format(i_iter),
            labels[0, 15, :, :].detach().cpu().numpy())
        cv2.imwrite(
            'results/full_RGBD/train/depth_pred/{}_pred.png'.format(i_iter),
            pred[0, 15, :, :].detach().cpu().numpy())

        loss = losses.mean()
        loss_joints = losses_joints.mean()
        loss_inp = losses_inp.mean()

        reduced_loss = reduce_tensor(loss)
        reduced_loss_joints = reduce_tensor(loss_joints)
        reduced_loss_inp = reduce_tensor(loss_inp)

        model.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss
        ave_loss.update(reduced_loss.item())
        ave_loss_joints.update(reduced_loss_joints.item())
        ave_loss_inp.update(reduced_loss_inp.item())
        ave_acc.update(acc)

        lr = adjust_learning_rate(optimizer, base_lr, num_iters,
                                  i_iter + cur_iters)

        if i_iter % config.PRINT_FREQ == 0 and rank == 0:
            print_loss = ave_loss.average() / world_size
            print_loss_joints = ave_loss_joints.average() / world_size
            print_loss_inp = ave_loss_inp.average() / world_size
            print_acc = ave_acc.average() / world_size

            msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \
                  'lr: {:.6f}, Loss: {:.6f}, {:.6f}, {:.6f}, Acc: {:.6f}' .format(
                      epoch, num_epoch, i_iter, epoch_iters,
                      batch_time.average(), lr, print_loss, print_loss_joints, print_loss_inp,print_acc)
            logging.info(msg)

            writer.add_scalar('train_loss', print_loss, global_steps)
            writer.add_scalar('train_loss_joint', print_loss_joints,
                              global_steps)
            writer.add_scalar('train_loss_depth', print_loss_inp, global_steps)
            writer.add_scalar('train_accuracy', print_acc, global_steps)
            writer_dict['train_global_steps'] = global_steps + 1
示例#13
0
def validate(config, testloader, model, writer_dict, device):

    rank = get_rank()  #0
    world_size = get_world_size()  #1
    model.eval()
    ave_loss = AverageMeter()
    ave_loss_joints = AverageMeter()
    ave_loss_inp = AverageMeter()
    ave_accs = AverageMeter()
    ave_acc = AverageMeter()
    confusion_matrix = np.zeros(
        (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))

    with torch.no_grad():
        for i_iter, batch in enumerate(testloader):
            image, label, target_weight, _, name, joints, joints_vis = batch
            size = label.size()
            #cv2.imwrite('validation_result/groundtruth/gt_'+str(i_iter)+'.png', label[0].detach().numpy())
            image = image.to(device)
            label = label.to(device)

            losses, losses_joints, losses_inp, pred = model(
                image, label, target_weight)

            #pred = F.upsample(input=pred, size=(64, 64), mode='bilinear')

            label_joints, _ = get_max_preds(
                label[:, 0:15, :, :].detach().cpu().numpy())
            pred_joints, _ = get_max_preds(
                pred[:, 0:15, :, :].detach().cpu().numpy())

            accs, acc, _, _ = accuracy(
                pred[:, 0:15, :, :].detach().cpu().numpy(),
                label[:, 0:15, :, :].detach().cpu().numpy())

            save_batch_image_with_joints(
                image[:, 0:3, :, :], label_joints * 4, joints_vis,
                'results/full_RGBD/val/joint_gt/{}_gt.png'.format(i_iter))
            save_batch_image_with_joints(
                image[:, 0:3, :, :], pred_joints * 4, joints_vis,
                'results/full_RGBD/val/joint_pred/{}_pred.png'.format(i_iter))

            label = F.upsample(input=label, size=(256, 256), mode='bilinear')
            pred = F.upsample(input=pred, size=(256, 256), mode='bilinear')

            cv2.imwrite(
                'results/full_RGBD/val/depth_gt/{}_gt.png'.format(i_iter),
                label[0, 15, :, :].detach().cpu().numpy())
            cv2.imwrite(
                'results/full_RGBD/val/depth_pred/{}_pred.png'.format(i_iter),
                pred[0, 15, :, :].detach().cpu().numpy())

            loss = losses.mean()
            loss_joints = losses_joints.mean()
            loss_inp = losses_inp.mean()

            reduced_loss = reduce_tensor(loss)
            reduced_loss_joints = reduce_tensor(loss_joints)
            reduced_loss_inp = reduce_tensor(loss_inp)

            ave_loss.update(reduced_loss.item())
            ave_loss_joints.update(reduced_loss_joints.item())
            ave_loss_inp.update(reduced_loss_inp.item())
            ave_acc.update(acc)
            ave_accs.update(accs)

    print_loss = ave_loss.average() / world_size
    print_loss_joints = ave_loss_joints.average() / world_size
    print_loss_inp = ave_loss_inp.average() / world_size
    print_acc = ave_acc.average() / world_size
    print_accs = ave_accs.average() / world_size

    if rank == 0:
        writer = writer_dict['writer']
        global_steps = writer_dict['valid_global_steps']
        writer.add_scalar('valid_loss', print_loss, global_steps)
        writer.add_scalar('valid_loss_joint', print_loss_joints, global_steps)
        writer.add_scalar('valid_loss_depth', print_loss_inp, global_steps)
        writer.add_scalar('valid_accuracy', print_acc, global_steps)
        for i in range(15):
            writer.add_scalar('valid_each_accuracy_' + str(i), print_accs[i],
                              global_steps)
        writer_dict['valid_global_steps'] = global_steps + 1
    return print_loss, print_loss_joints, print_loss_inp, print_acc