Exemple #1
0
def train():
    if not os.path.exists(args.save_folder):
        os.mkdir(args.save_folder)

    dataset = COCODetection(image_path=cfg.dataset.train_images,
                            info_file=cfg.dataset.train_info,
                            transform=SSDAugmentation(MEANS))
    '''
    dataset = COCODetection(image_path=cfg.dataset.train_images,
                            info_file=cfg.dataset.train_info,
                            transform=BaseTransform(MEANS))
    '''

    if args.validation_epoch > 0:
        setup_eval()
        val_dataset = COCODetection(image_path=cfg.dataset.valid_images,
                                    info_file=cfg.dataset.valid_info,
                                    transform=BaseTransform(MEANS))

    # Parallel wraps the underlying module, but when saving and loading we don't want that
    yolact_net = Yolact()
    net = yolact_net
    net.train()

    # I don't use the timer during training (I use a different timing method).
    # Apparently there's a race condition with multiple GPUs.
    timer.disable_all()

    # Both of these can set args.resume to None, so do them before the check
    if args.resume == 'interrupt':
        args.resume = SavePath.get_interrupt(args.save_folder)
    elif args.resume == 'latest':
        args.resume = SavePath.get_latest(args.save_folder, cfg.name)

    if args.resume is not None:
        print('Resuming training, loading {}...'.format(args.resume))
        yolact_net.load_weights(args.resume)

        if args.start_iter == -1:
            args.start_iter = SavePath.from_str(args.resume).iteration
    else:
        print('Initializing weights...')
        yolact_net.init_weights(backbone_path=args.save_folder +
                                cfg.backbone.path)

    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.decay)
    criterion = MultiBoxLoss(num_classes=cfg.num_classes,
                             pos_threshold=cfg.positive_iou_threshold,
                             neg_threshold=cfg.negative_iou_threshold,
                             negpos_ratio=3)

    if args.cuda:
        cudnn.benchmark = True
        net = nn.DataParallel(net).cuda()
        criterion = nn.DataParallel(criterion).cuda()

    # loss counters
    loc_loss = 0
    conf_loss = 0
    iteration = max(args.start_iter, 0)
    last_time = time.time()

    epoch_size = len(dataset) // args.batch_size
    num_epochs = math.ceil(cfg.max_iter / epoch_size)

    # Which learning rate adjustment step are we on? lr' = lr * gamma ^ step_index
    step_index = 0

    data_loader = data.DataLoader(dataset,
                                  args.batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=True,
                                  collate_fn=detection_collate,
                                  pin_memory=True)

    save_path = lambda epoch, iteration: SavePath(
        cfg.name, epoch, iteration).get_path(root=args.save_folder)
    time_avg = MovingAverage()

    global loss_types  # Forms the print order
    loss_avgs = {k: MovingAverage(100) for k in loss_types}

    print('Begin training!')
    print()
    # try-except so you can use ctrl+c to save early and stop training
    try:
        for epoch in range(num_epochs):
            # Resume from start_iter
            if (epoch + 1) * epoch_size < iteration:
                continue

            for datum in data_loader:
                # Stop if we've reached an epoch if we're resuming from start_iter
                if iteration == (epoch + 1) * epoch_size:
                    break

                # Stop at the configured number of iterations even if mid-epoch
                if iteration == cfg.max_iter:
                    break

                # Change a config setting if we've reached the specified iteration
                changed = False
                for change in cfg.delayed_settings:
                    if iteration >= change[0]:
                        changed = True
                        cfg.replace(change[1])

                        # Reset the loss averages because things might have changed
                        for avg in loss_avgs:
                            avg.reset()

                # If a config setting was changed, remove it from the list so we don't keep checking
                if changed:
                    cfg.delayed_settings = [
                        x for x in cfg.delayed_settings if x[0] > iteration
                    ]

                # Warm up by linearly interpolating the learning rate from some smaller value
                if cfg.lr_warmup_until > 0 and iteration <= cfg.lr_warmup_until:
                    set_lr(optimizer, (args.lr - cfg.lr_warmup_init) *
                           (iteration / cfg.lr_warmup_until) +
                           cfg.lr_warmup_init)

                # Adjust the learning rate at the given iterations, but also if we resume from past that iteration
                while step_index < len(
                        cfg.lr_steps
                ) and iteration >= cfg.lr_steps[step_index]:
                    step_index += 1
                    set_lr(optimizer, args.lr * (args.gamma**step_index))

                # Load training data
                # Note, for training on multiple gpus this will use the custom replicate and gather I wrote up there
                images, targets, masks, num_crowds = prepare_data(datum)

                # Forward Pass
                out = net(images)

                # Compute Loss
                optimizer.zero_grad()

                wrapper = ScatterWrapper(targets, masks, num_crowds)
                losses = criterion(out, wrapper, wrapper.make_mask())

                losses = {k: v.mean()
                          for k, v in losses.items()
                          }  # Mean here because Dataparallel
                loss = sum([losses[k] for k in losses])

                # Backprop
                loss.backward(
                )  # Do this to free up vram even if loss is not finite
                if torch.isfinite(loss).item():
                    optimizer.step()

                # Add the loss to the moving average for bookkeeping
                for k in losses:
                    loss_avgs[k].add(losses[k].item())

                cur_time = time.time()
                elapsed = cur_time - last_time
                last_time = cur_time

                # Exclude graph setup from the timing information
                if iteration != args.start_iter:
                    time_avg.add(elapsed)

                if iteration % 10 == 0:
                    eta_str = str(
                        datetime.timedelta(seconds=(cfg.max_iter - iteration) *
                                           time_avg.get_avg())).split('.')[0]

                    total = sum([loss_avgs[k].get_avg() for k in losses])
                    loss_labels = sum([[k, loss_avgs[k].get_avg()]
                                       for k in loss_types if k in losses], [])

                    print(('[%3d] %7d ||' + (' %s: %.3f |' * len(losses)) +
                           ' T: %.3f || ETA: %s || timer: %.3f') %
                          tuple([epoch, iteration] + loss_labels +
                                [total, eta_str, elapsed]),
                          flush=True)

                iteration += 1

                if iteration % args.save_interval == 0 and iteration != args.start_iter:
                    if args.keep_latest:
                        latest = SavePath.get_latest(args.save_folder,
                                                     cfg.name)

                    print('Saving state, iter:', iteration)
                    yolact_net.save_weights(save_path(epoch, iteration))

                    if args.keep_latest and latest is not None:
                        if args.keep_latest_interval <= 0 or iteration % args.keep_latest_interval != args.save_interval:
                            print('Deleting old save...')
                            os.remove(latest)

            # This is done per epoch
            if args.validation_epoch > 0:
                if epoch % args.validation_epoch == 0 and epoch > 0:
                    compute_validation_map(yolact_net, val_dataset)
    except KeyboardInterrupt:
        print('Stopping early. Saving network...')

        # Delete previous copy of the interrupted network so we don't spam the weights folder
        SavePath.remove_interrupt(args.save_folder)

        yolact_net.save_weights(
            save_path(epoch,
                      repr(iteration) + '_interrupt'))
        exit()

    yolact_net.save_weights(save_path(epoch, iteration))
Exemple #2
0
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
    """
    Removes detections with lower object confidence score than 'conf_thres'
    Non-Maximum Suppression to further filter detections.
    Returns detections with shape:
        (x1, y1, x2, y2, object_conf, class_conf, class)
    """

    min_wh = 2  # (pixels) minimum box width and height

    output = [None] * len(prediction)
    for image_i, pred in enumerate(prediction):

        # Multiply conf by class conf to get combined confidence
        class_conf, class_pred = pred[:, 5:].max(1)
        pred[:, 4] *= class_conf

        # Select only suitable predictions
        i = (pred[:, 4] > conf_thres) & (
            pred[:, 2:4] > min_wh).all(1) & torch.isfinite(pred).all(1)
        pred = pred[i]

        # If none are remaining => process next image
        if len(pred) == 0:
            continue

        # Select predicted classes
        class_conf = class_conf[i]
        class_pred = class_pred[i].unsqueeze(1).float()

        # Box (center x, center y, width, height) to (x1, y1, x2, y2)
        pred[:, :4] = xywh2xyxy(pred[:, :4])
        # pred[:, 4] *= class_conf  # improves mAP from 0.549 to 0.551

        # Detections ordered as (x1y1x2y2, obj_conf, class_conf, class_pred)
        pred = torch.cat((pred[:, :5], class_conf.unsqueeze(1), class_pred), 1)

        # Get detections sorted by decreasing confidence scores
        pred = pred[(-pred[:, 4]).argsort()]

        det_max = []
        nms_style = 'MERGE'  # 'OR' (default), 'AND', 'MERGE' (experimental)
        for c in pred[:, -1].unique():
            dc = pred[pred[:, -1] == c]  # select class c
            n = len(dc)
            if n == 1:
                det_max.append(dc)  # No NMS required if only 1 prediction
                continue
            elif n > 100:
                dc = dc[:
                        100]  # limit to first 100 boxes: https://github.com/ultralytics/yolov3/issues/117

            # Non-maximum suppression
            if nms_style == 'OR':  # default
                # METHOD1
                # ind = list(range(len(dc)))
                # while len(ind):
                # j = ind[0]
                # det_max.append(dc[j:j + 1])  # save highest conf detection
                # reject = (bbox_iou(dc[j], dc[ind]) > nms_thres).nonzero()
                # [ind.pop(i) for i in reversed(reject)]

                # METHOD2
                while dc.shape[0]:
                    det_max.append(dc[:1])  # save highest conf detection
                    if len(dc) == 1:  # Stop if we're at the last detection
                        break
                    iou = bbox_iou(dc[0], dc[1:])  # iou with other boxes
                    dc = dc[1:][iou < nms_thres]  # remove ious > threshold

            elif nms_style == 'AND':  # requires overlap, single boxes erased
                while len(dc) > 1:
                    iou = bbox_iou(dc[0], dc[1:])  # iou with other boxes
                    if iou.max() > 0.5:
                        det_max.append(dc[:1])
                    dc = dc[1:][iou < nms_thres]  # remove ious > threshold

            elif nms_style == 'MERGE':  # weighted mixture box
                while len(dc):
                    if len(dc) == 1:
                        det_max.append(dc)
                        break
                    i = bbox_iou(dc[0], dc) > nms_thres  # iou with other boxes
                    weights = dc[i, 4:5]
                    dc[0, :4] = (weights * dc[i, :4]).sum(0) / weights.sum()
                    det_max.append(dc[:1])
                    dc = dc[i == 0]

            elif nms_style == 'SOFT':  # soft-NMS https://arxiv.org/abs/1704.04503
                sigma = 0.5  # soft-nms sigma parameter
                while len(dc):
                    if len(dc) == 1:
                        det_max.append(dc)
                        break
                    det_max.append(dc[:1])
                    iou = bbox_iou(dc[0], dc[1:])  # iou with other boxes
                    dc = dc[1:]
                    dc[:, 4] *= torch.exp(-iou**2 / sigma)  # decay confidences
                    # dc = dc[dc[:, 4] > nms_thres]  # new line per https://github.com/ultralytics/yolov3/issues/362

        if len(det_max):
            det_max = torch.cat(det_max)  # concatenate
            output[image_i] = det_max[(-det_max[:, 4]).argsort()]  # sort

    return output
Exemple #3
0
def test_gmsd_loss_forward_backward(prediction: torch.Tensor,
                                    target: torch.Tensor, device: str) -> None:
    prediction.requires_grad_()
    loss_value = GMSDLoss()(prediction.to(device), target.to(device))
    loss_value.backward()
    assert torch.isfinite(prediction.grad).all(), LEAF_VARIABLE_ERROR_MESSAGE
Exemple #4
0
def test_arsinh_grad(logunif_input):
    stereographic.math.arsinh(logunif_input).sum().backward()
    assert torch.isfinite(logunif_input.grad).all()
Exemple #5
0
def test_project_k_grad(logunif_input):
    vec = logunif_input[:, None] * torch.ones(logunif_input.shape[0], 10)
    k = logunif_input.detach().clone().requires_grad_()
    stereographic.math.project(vec, k=k[:, None]).sum().backward()
    assert torch.isfinite(logunif_input.grad).all()
    assert torch.isfinite(k.grad).all()
Exemple #6
0
def train(model, train_loader, criterion, scheduler, optimizer, epoch, params,
          args):
    start = time.time()
    total_loss = []

    model.train()
    model.is_training = True

    pbar = tqdm(train_loader, desc='==> Train', position=1)
    idx = 0
    for batch in pbar:
        images, targets = convert_batch(batch, args.device)
        #print(images.tensors.shape)
        #print(targets)

        #print('images', images.shape)
        #images, targets_a, targets_b, lam = mixup_data(images, targets,
        #                                              args.alpha, args.is_cuda)
        #images, targets_a, targets_b = map(Variable, (images, targets_a, targets_b))

        outputs = model(images)
        #print(f'{epoch}-{idx}', len(outputs))

        #loss, loss_xy, loss_wh, loss_obj, loss_cls, loss_l2 = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
        loss, loss_xy, loss_wh, loss_obj, loss_cls, loss_l2 = criterion(
            outputs, targets)

        print(
            '\rLoss {:<10.2f} XY {:<10.2f} WH {:<10.2f} Obj {:<10.2f} Cls {:<10.2f} L2 {:<10.2f}'
            .format(loss, loss_xy, loss_wh, loss_obj, loss_cls, loss_l2),
            end='',
            flush=True)
        sys.stdout.flush()
        #print(loss)

        loss = loss.mean()
        # total_loss += loss.data[0]
        if loss == 0 or not torch.isfinite(loss):
            print('loss equal zero(0)')
            continue

        loss.backward()
        total_loss.append(loss.item())
        mean_loss = np.mean(total_loss)
        if (idx + 1) % args.grad_accum_steps == 0:
            clip_grad_norm_(model.parameters(), args.max_grad_norm)
            # zero grad first since first step requires zero grad before step
            optimizer.zero_grad()
            optimizer.step()

        iter_step(epoch, loss, mean_loss, optimizer, params, args)
        idx += 1
        pbar.update()
        pbar.set_postfix({
            'Loss': loss.item(),
            'Mean_loss': mean_loss,
        })
        # pbar.set_description()

    # end of training epoch
    scheduler.step(mean_loss)
    #scheduler.step(epoch) # for gradual warmup
    result = {'time': time.time() - start, 'loss': mean_loss}
    for key, value in result.items():
        print('    {:15s}: {}'.format(str(key), value))

    return mean_loss
Exemple #7
0
def do_train(cfg, model, resume=False):
    model.train()
    optimizer = build_optimizer(cfg, model)
    scheduler = build_lr_scheduler(cfg, optimizer)

    checkpointer = DetectionCheckpointer(model,
                                         cfg.OUTPUT_DIR,
                                         optimizer=optimizer,
                                         scheduler=scheduler)
    start_iter = (checkpointer.resume_or_load(
        cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1)
    #
    scheduler.step(start_iter)
    #
    max_iter = cfg.SOLVER.MAX_ITER

    periodic_checkpointer = PeriodicCheckpointer(checkpointer,
                                                 cfg.SOLVER.CHECKPOINT_PERIOD,
                                                 max_iter=max_iter)

    writers = ([
        CommonMetricPrinter(max_iter),
        JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
        TensorboardXWriter(cfg.OUTPUT_DIR),
    ] if comm.is_main_process() else [])

    # compared to "train_net.py", we do not support accurate timing and
    # precise BN here, because they are not trivial to implement in a small training loop
    data_loader = build_detection_train_loader(cfg)
    logger.info("Starting training from iteration {}".format(start_iter))
    with EventStorage(start_iter) as storage:
        for data, iteration in zip(data_loader, range(start_iter, max_iter)):
            iteration = iteration + 1
            storage.step()

            loss_dict = model(data)
            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {
                k: v.item()
                for k, v in comm.reduce_dict(loss_dict).items()
            }
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                storage.put_scalars(total_loss=losses_reduced,
                                    **loss_dict_reduced)

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            storage.put_scalar("lr",
                               optimizer.param_groups[0]["lr"],
                               smoothing_hint=False)
            scheduler.step()

            if (cfg.TEST.EVAL_PERIOD > 0
                    and iteration % cfg.TEST.EVAL_PERIOD == 0
                    and iteration != max_iter):
                do_test(cfg, model)
                # Compared to "train_net.py", the test results are not dumped to EventStorage
                comm.synchronize()

            if iteration - start_iter > 5 and (iteration % 20 == 0
                                               or iteration == max_iter):
                for writer in writers:
                    writer.write()
            periodic_checkpointer.step(iteration)
Exemple #8
0
def test():
    # 获取模型的参数,args是efficientNet网络的结构参数,params模型的超参数
    args, params = get_model_params('efficientnet-b0', None)
    # print(args,params)
    model = LYCNet(args, params)
    # 加载预训练的模型
    # model_dict = model.state_dict()
    # pretrain_dict = torch.load('./pretrainedModel/efficientnet-b0-355c32eb.pth')
    # pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict}
    # model_dict.update(pretrain_dict)
    # model.load_state_dict(model_dict)
    # 数据初始化
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    distill_temp = 5
    loss_func = MSDLoss(distill_temp)
    batch_size = 32
    # num_epoches = 100
    # learning_rate = 0.01

    # 载入数据集
    # train_dataset = datasets.CIFAR100(root='./dataset', train=True, nsforms=transform, download=True)
    val_dataset = datasets.CIFAR100(root='./dataset', train=False, transform=transform, download=True)
    # train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, drop_last=True)

    # progress_bar = tqdm(train_loader)
    # 这里定义损失函数,temp表示蒸馏损失

    iter_num = len(val_dataset) // batch_size
    loss_func = loss_func.cuda()
    model = model.cuda()

    # 定义优化器
    # optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)

    # 损失函数的参数
    alpha = 0.1
    # beta_begin = 0.1
    # beta_end = 0.001
    beta = 0.4515244252618963
    print(str(len(val_dataset)) + '\n')
    checkpoints = [0,10,20,30,40,50,60,70,80,90]
    model.eval()
    for i,value in enumerate(checkpoints):
        model.load_state_dict(torch.load(f'LYCNet_{value}epoch.path').state_dict())
        time1 = time.time()
        epoch_loss = []
        print(f'alpha = {alpha}   beta = {beta}')
        correct_num_list = [0,0,0,0,0]
        for iter, data in enumerate(val_loader):
            imgs = data[0]
            imgs = imgs.cuda()
            # print(imgs.size())
            labels = data[1]
            labels = labels.cuda()

            # optimizer.zero_grad()
            features, outputs = model(imgs)
            correct_num_list = [x+correct_num_list[i] for i,x in enumerate(cal_correct_num(outputs,labels))]
            # print(labels)
            # print('\n--------------\n')
            # print(features[0].size())
            # print('\n--------------\n')
            # print(outputs[0].size())
            # print('\n--------------\n')
            loss = loss_func(features, outputs, labels, alpha, beta)
            # print(loss)
            if loss == 0 or not torch.isfinite(loss):
                continue
            # 计算梯度
            # loss.backward()
            # 更新参数
            # optimizer.step()

            epoch_loss.append(float(loss))
        # scheduler.step(np.mean(epoch_loss))
        print(f'checkpoint:epoch{value}  : loss = {np.mean(epoch_loss)},cost_time = {time.time() - time1},accuracy = {[x/(iter_num*batch_size) for x in correct_num_list]}')
Exemple #9
0
def train(opt):
    train_file = 'train.keys'
    val_file = 'valid.keys'
    # Train similarity network or manipulation network independently or the whole network.
    train_simi = True
    train_mani = True
    train_fusion = True

    # According to the papers, set input_size default to 256.
    input_size = 256

    train_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((input_size, input_size)),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    val_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    target_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
    ])
    train_set = USCISIDataset(opt.lmdb_dir, train_file, train_transform,
                              target_transform)
    val_set = USCISIDataset(opt.lmdb_dir, val_file, val_transform,
                            target_transform)

    training_params = {
        'batch_size': opt.batch_size,
        'shuffle': True,
        'drop_last': True,
        #    'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    val_params = {
        'batch_size': opt.batch_size,
        'shuffle': False,
        'drop_last': True,
        #   'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    training_generator = DataLoader(train_set, **training_params)
    val_generator = DataLoader(val_set, **val_params)

    model = BusterNet(image_size=input_size)

    if opt.load_weights is not None:
        try:
            # Load pretrain VGG16 in https://download.pytorch.org/models/vgg16-397923af.pth or continuing training
            if 'vgg16_bn' in opt.load_weights:
                vgg_backbone = torch.load(opt.load_weights)
                model.manipulation_net.load_state_dict(vgg_backbone,
                                                       strict=False)
                model.similarity_net.load_state_dict(vgg_backbone,
                                                     strict=False)
            else:
                model.load_state_dict(torch.load(opt.load_weights),
                                      strict=False)
        except RuntimeError as e:
            print(f'[Warning] Ignoring {e}')
        print(f'[Info] loaded weights: {os.path.basename(opt.load_weights)}')
    else:
        print('[Info] initializing weights...')
    #     init_weights(model)

    if opt.freeze_layers is not None:
        assert isinstance(opt.freeze_layers, list), "Required List string"

        def freeze_layers(m):
            classname = m.__class__.__name__
            for ntl in opt.freeze_layers:
                if ntl in classname:
                    for param in m.parameters():
                        param.require_grad = False

        model.apply(freeze_layers)
        print('[Info] freeze layers in ', opt.freeze_layers)

    # warp the model with loss function, to reduce the memory usage on gpu0 and speedup
    model = ModelWithLoss(model,
                          train_simi=train_simi,
                          train_mani=train_mani,
                          train_fusion=train_fusion)

    if opt.num_gpus > 1 and opt.batch_size // opt.num_gpus < 4:
        model.apply(replace_w_sync_bn)
        use_sync_bn = True
    else:
        use_sync_bn = False

    os.makedirs(opt.saved_path, exist_ok=True)
    writer = SummaryWriter(
        opt.log_path +
        f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

    if opt.num_gpus > 0:
        model = model.cuda()
        if opt.num_gpus > 1:
            model = CustomDataParallel(model, opt.num_gpus)
            if use_sync_bn:
                patch_replication_callback(model)

    if opt.optim == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), opt.lr)
    elif opt.optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), opt.lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    opt.lr,
                                    momentum=0.9,
                                    nesterov=True)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3,
                                                           verbose=True)

    last_step = 0
    epoch = 0
    best_loss = 1e5
    best_epoch = 0
    step = max(0, last_step)
    model.train()

    num_iter_per_epoch = len(training_generator)

    try:
        for epoch in range(opt.num_epochs):

            epoch_loss = []
            progress_bar = tqdm(training_generator)
            for iter, data in enumerate(progress_bar):
                last_epoch = step // num_iter_per_epoch
                if iter < step - last_epoch * num_iter_per_epoch:
                    progress_bar.update()
                    continue
                try:
                    imgs, gts, _ = data

                    if opt.num_gpus == 1:
                        # if only one gpu, just send it to cuda:0
                        # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here
                        imgs = imgs.cuda()
                        gts = gts.cuda()

                    optimizer.zero_grad()

                    fusion_loss, mani_loss, simi_loss = model(imgs, gts)
                    fusion_loss = fusion_loss.mean()
                    simi_loss = simi_loss.mean()
                    mani_loss = mani_loss.mean()

                    loss = fusion_loss + mani_loss + simi_loss
                    if loss == 0 or not torch.isfinite(loss):
                        continue

                    loss.backward()
                    # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
                    optimizer.step()

                    epoch_loss.append(float(loss))

                    progress_bar.set_description(
                        'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Fusion loss: {:.5f}. Mani loss: {:.5f}. Mini loss: {:.5f} Total loss: {:.5f}'
                        .format(step, epoch,
                                opt.num_epochs, iter + 1, num_iter_per_epoch,
                                fusion_loss.item(), mani_loss.item(),
                                simi_loss.item(), loss.item()))
                    writer.add_scalar('Loss', loss, step)
                    writer.add_scalar('fusion_loss', fusion_loss, step)
                    writer.add_scalar('simi_loss', simi_loss, step)
                    writer.add_scalar('mani_loss', mani_loss, step)

                    # log learning_rate
                    current_lr = optimizer.param_groups[0]['lr']
                    writer.add_scalar('learning_rate', current_lr, step)

                    step += 1

                    if step % opt.save_interval == 0 and step > 0:
                        save_checkpoint(model, f'model_{epoch}_{step}.pth')
                        print('checkpoint...')

                except Exception as e:
                    print('[Error]', traceback.format_exc())
                    print(e)
                    continue
            scheduler.step(np.mean(epoch_loss))

            if epoch % opt.val_interval == 0:
                model.eval()
                loss_fusion_ls = []
                loss_simi_ls = []
                loss_mani_ls = []
                for iter, data in enumerate(val_generator):
                    with torch.no_grad():
                        imgs, gts, _ = data

                        if opt.num_gpus == 1:
                            imgs = imgs.cuda()
                            gts = gts.cuda()

                        fusion_loss, mani_loss, simi_loss = model(imgs, gts)
                        fusion_loss = fusion_loss.mean()
                        simi_loss = simi_loss.mean()
                        mani_loss = mani_loss.mean()

                        loss = fusion_loss + mani_loss + simi_loss
                        if loss == 0 or not torch.isfinite(loss):
                            continue

                        loss_fusion_ls.append(fusion_loss.item())
                        loss_simi_ls.append(simi_loss.item())
                        loss_mani_ls.append(mani_loss.item())

                fusion_loss = np.mean(loss_fusion_ls)
                simi_loss = np.mean(loss_simi_ls)
                mani_loss = np.mean(loss_mani_ls)
                loss = fusion_loss + simi_loss + mani_loss

                print(
                    'Val. Epoch: {}/{}. Fusion loss: {:1.5f}. Simi loss: {:1.5f}. Mani loss: {:1.5f}. Total loss: {:1.5f}'
                    .format(epoch, opt.num_epochs, fusion_loss, simi_loss,
                            mani_loss, loss))
                writer.add_scalar('Val_Loss', loss, step)
                writer.add_scalar('Val_Fusion_loss', fusion_loss, step)
                writer.add_scalar('Val_Simi_loss', simi_loss, step)
                writer.add_scalar('Val_Mani_loss', mani_loss, step)

                if loss + opt.es_min_delta < best_loss:
                    best_loss = loss
                    best_epoch = epoch

                    save_checkpoint(model, f'model_{epoch}_{step}.pth')

                model.train()

                # Early stopping
                if epoch - best_epoch > opt.es_patience > 0:
                    print(
                        '[Info] Stop training at epoch {}. The lowest loss achieved is {}'
                        .format(epoch, best_loss))
                    break
    except KeyboardInterrupt:
        save_checkpoint(model, f'model_{epoch}_{step}.pth')
        writer.close()
    writer.close()
Exemple #10
0
 def test_fn(a):
     return torch.isfinite(a) & a.ne(0)
Exemple #11
0
def train():
    #获取模型的参数,args是efficientNet网络的结构参数,params模型的超参数
    now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
    logger = get_logger(f'logging/{now}.txt')
    args,params = get_model_params('efficientnet-b0',None)
    # print(args,params)
    model = LYCNet(args,params)
    #加载预训练的模型
    model_dict = model.state_dict()
    pretrain_dict = torch.load('./pretrainedModel/efficientnet-b0-355c32eb.pth')
    pretrain_dict = {k: v for k,v in pretrain_dict.items() if k in model_dict}
    model_dict.update(pretrain_dict)
    model.load_state_dict(model_dict)
    #数据初始化
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
    ])
    distill_temp = 5
    loss_func = MSDLoss(distill_temp)
    batch_size = 64
    num_epoches  = 100
    learning_rate = 0.01

    #载入数据集
    train_dataset = datasets.CIFAR100(root='./dataset',train=True,transform=transform,download=True)
    val_dataset = datasets.CIFAR100(root='./dataset',train=False,transform=transform,download=True)
    train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,drop_last=True)
    val_loader = DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,drop_last=True)

    # progress_bar = tqdm(train_loader)
    #这里定义损失函数,temp表示蒸馏损失

    iter_num = len(train_dataset)//batch_size
    loss_func = loss_func.cuda()
    model = model.cuda()

  #定义优化器
    optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,patience=3,verbose=True)


    #损失函数的参数
    alpha = 0.1
    beta_begin = 0.1
    beta_end = 0.001
    logger.info('training params:sample_numbers = {},distill_temp = {},batch_size = {},num_epoches = {},learning_rate = {},alpha = {},\
    beta_begin = {},beta_end = {}'.format(len(train_dataset),distill_temp,batch_size,num_epoches,learning_rate,alpha,beta_begin,beta_end))
    # print(str(len(train_dataset))+'\n')
    model.train()
    logger.info('start training!')
    for epoch in range(num_epoches):
        time1 = time.time()
        epoch_loss = []
        beta = 0.5*(1+cos(pi*epoch/num_epoches)*(beta_begin-beta_end))+beta_end
        # print(f'alpha = {alpha}   beta = {beta}\n')
        for iter,data in enumerate(train_loader):
            imgs = data[0]
            imgs = imgs.cuda()
            # print(imgs.size())
            labels = data[1]
            labels = labels.cuda()

            optimizer.zero_grad()
            features,outputs = model(imgs)
            loss = loss_func(features,outputs,labels,alpha,beta)
            if loss ==0 or not torch.isfinite(loss):
                continue
            #计算梯度
            loss.backward()
            #更新参数
            optimizer.step()
            epoch_loss.append(float(loss))
        scheduler.step(np.mean(epoch_loss))
        # print(f'epoch:{epoch + 1}  : loss  = {np.mean(epoch_loss)}\n,cost_time = {time.time()-time1}')
        logger.info('Epoch:[{}/{}]\t time = {:.3f}\t loss = {:.5f}\t alpha = {}\t beta = {:.5f}'.format(epoch+1,num_epoches,time.time()-time1,np.mean(epoch_loss),alpha,beta))
        if epoch%10 == 0:
            torch.save(model,f'LYCNet_{epoch}epoch.path')
def compute_rank_from_scores(
    true_score: torch.FloatTensor,
    all_scores: torch.FloatTensor,
) -> Dict[str, torch.FloatTensor]:
    """Compute rank and adjusted rank given scores.

    :param true_score: torch.Tensor, shape: (batch_size, 1)
        The score of the true triple.
    :param all_scores: torch.Tensor, shape: (batch_size, num_entities)
        The scores of all corrupted triples (including the true triple).
    :return: a dictionary
        {
            'best': best_rank,
            'worst': worst_rank,
            'avg': avg_rank,
            'adj': adj_rank,
        }

        where

        best_rank: shape: (batch_size,)
            The best rank is the rank when assuming all options with an equal score are placed behind the current
            test triple.
        worst_rank:
            The worst rank is the rank when assuming all options with an equal score are placed in front of current
            test triple.
        avg_rank:
            The average rank is the average of the best and worst rank, and hence the expected rank over all
            permutations of the elements with the same score as the currently considered option.
        adj_rank: shape: (batch_size,)
            The adjusted rank normalises the average rank by the expected rank a random scoring would
            achieve, which is (#number_of_options + 1)/2
    """
    # The best rank is the rank when assuming all options with an equal score are placed behind the currently
    # considered. Hence, the rank is the number of options with better scores, plus one, as the rank is one-based.
    best_rank = (all_scores > true_score).sum(dim=1) + 1

    # The worst rank is the rank when assuming all options with an equal score are placed in front of the currently
    # considered. Hence, the rank is the number of options which have at least the same score minus one (as the
    # currently considered option in included in all options). As the rank is one-based, we have to add 1, which
    # nullifies the "minus 1" from before.
    worst_rank = (all_scores >= true_score).sum(dim=1)

    # The average rank is the average of the best and worst rank, and hence the expected rank over all permutations of
    # the elements with the same score as the currently considered option.
    average_rank = (best_rank + worst_rank).float() * 0.5

    # We set values which should be ignored to NaN, hence the number of options which should be considered is given by
    number_of_options = torch.isfinite(all_scores).sum(dim=1).float()

    # The expected rank of a random scoring
    expected_rank = 0.5 * (number_of_options + 1)

    # The adjusted ranks is normalized by the expected rank of a random scoring
    adjusted_average_rank = average_rank / expected_rank
    # TODO adjusted_worst_rank
    # TODO adjusted_best_rank

    return {
        RANK_BEST: best_rank,
        RANK_WORST: worst_rank,
        RANK_AVERAGE: average_rank,
        RANK_AVERAGE_ADJUSTED: adjusted_average_rank,
    }
def train(data_cfg,
          model,
          midas_model,
          train_dataloader,
          test_dataloader,
          start_epoch,
          epochs,
          img_size,
          optimizer,
          scheduler=None):
    mse_loss = nn.MSELoss()
    nb = len(train_dataloader)  # number of batches
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    torch.autograd.set_detect_anomaly(True)
    results = (
        0, 0, 0, 0, 0, 0, 0
    )  # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
    print('Starting training for %g epochs...' % epochs)
    for epoch in range(
            start_epoch, epochs
    ):  # epoch ------------------------------------------------------------------
        model.train()

        mloss = torch.zeros(4).to(device)  # mean losses
        avg_midas_loss = 0
        print('\n-----------------Train---------------\n')
        print(('\n' + '%12s' * 12) %
              ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'yolo_total',
               'targets', ' midas_loss', 'plane_bce', 'plane_dice',
               'plane_total', 'img_size'))
        pbar = tqdm(enumerate(train_dataloader), total=nb)  # progress bar
        for i, (imgs, targets, seg_masks, paths, _) in pbar:
            # batch -------------------------------------------------------------
            ni = i + nb * epoch  # number integrated batches (since train start)
            imgs = imgs.to(device).float(
            ) / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
            seg_masks = seg_masks.to(device)
            targets = targets.to(device)

            midas_target = midas_model.forward(imgs)
            # Run model
            midas_pred, plane_pred, pred = model(imgs)

            # Compute YOLO loss
            yolo_loss, yolo_loss_items = compute_yolo_loss(
                pred, targets, model)
            if not torch.isfinite(yolo_loss):
                print('WARNING: non-finite loss, ending training ',
                      yolo_loss_items)

            # Compute midas loss
            midas_loss = mse_loss(midas_pred, midas_target) / 1e5

            metrics = defaultdict(float)

            # Compute Plane Segmentation loss
            plane_seg_loss = calc_segmentation_loss(plane_pred, seg_masks,
                                                    metrics)
            plane_seg_print = planar_epoch_print_metrics(metrics, i + 1)

            loss = midas_loss + yolo_loss + plane_seg_loss

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # Print batch results
            mloss = (mloss * i + yolo_loss_items) / (i + 1
                                                     )  # update mean losses
            avg_midas_loss = (avg_midas_loss * i + midas_loss.detach()) / (
                i + 1)  # update mean losses
            mem = '%.3gG' % (torch.cuda.memory_reserved() /
                             1E9 if torch.cuda.is_available() else 0)  # (GB)
            s = ('%12s' * 2 + '%12g' * 10) % ('%g/%g' %
                                              (epoch, epochs - 1), mem, *mloss,
                                              len(targets), avg_midas_loss,
                                              *plane_seg_print, img_size)

            pbar.set_description(s)

            # end batch ------------------------------------------------------------------------------------------------
        if scheduler:
            scheduler.step()
        mloss = torch.zeros(4).to(device)  # mean losses
        print('\n-----------------Test---------------\n')
        print(('\n' + '%12s' * 12) %
              ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'yolo_total',
               'targets', ' midas_loss', 'plane_bce', 'plane_dice',
               'plane_total', 'img_size'))
        pbar = tqdm(enumerate(test_dataloader),
                    total=len(test_dataloader))  # progress bar
        for i, (
                imgs, targets, seg_masks, paths, _
        ) in pbar:  # batch -------------------------------------------------------------
            model.eval()
            ni = i + nb * epoch  # number integrated batches (since train start)
            imgs = imgs.to(device).float(
            ) / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
            seg_masks = seg_masks.to(device)
            targets = targets.to(device)

            midas_target = midas_model.forward(imgs)
            # Run model
            midas_pred, plane_pred, _, pred = model(imgs)

            # Compute YOLO loss
            yolo_loss, yolo_loss_items = compute_yolo_loss(
                pred, targets, model)

            # Compute midas loss
            midas_loss = mse_loss(midas_pred, midas_target) / 1e5

            metrics = defaultdict(float)

            # Compute Plane Segmentation loss
            plane_seg_loss = calc_segmentation_loss(plane_pred, seg_masks,
                                                    metrics)
            plane_seg_print = planar_epoch_print_metrics(metrics, i + 1)

            optimizer.step()
            optimizer.zero_grad()
            # ema.update(model)

            # Print batch results
            mloss = (mloss * i + yolo_loss_items) / (i + 1
                                                     )  # update mean losses
            avg_midas_loss = (avg_midas_loss * i + midas_loss.detach()) / (
                i + 1)  # update mean losses
            mem = '%.3gG' % (torch.cuda.memory_reserved() /
                             1E9 if torch.cuda.is_available() else 0)  # (GB)
            s = ('%12s' * 2 + '%12g' * 10) % ('%g/%g' %
                                              (epoch, epochs - 1), mem, *mloss,
                                              len(targets), avg_midas_loss,
                                              *plane_seg_print, img_size)

            pbar.set_description(s)
        results, maps = test(data_cfg,
                             model=model,
                             single_cls=False,
                             iou_thres=0.6,
                             dataloader=test_dataloader)
Exemple #14
0
        desc=f"Drawing {num_samples} posterior samples",
    )

    num_sampled_total, num_remaining = 0, num_samples
    accepted, acceptance_rate = [], float("Nan")
    leakage_warning_raised = False

    # To cover cases with few samples without leakage:
    sampling_batch_size = min(num_samples, max_sampling_batch_size)
    while num_remaining > 0:

        # Sample and reject.
        candidates = posterior_nn.sample(sampling_batch_size,
                                         context=x).reshape(
                                             sampling_batch_size, -1)
        are_within_prior = torch.isfinite(prior.log_prob(candidates))
        samples = candidates[are_within_prior]
        accepted.append(samples)

        # Update.
        num_sampled_total += sampling_batch_size
        num_remaining -= samples.shape[0]
        pbar.update(samples.shape[0])

        # To avoid endless sampling when leakage is high, we raise a warning if the
        # acceptance rate is too low after the first 1_000 samples.
        acceptance_rate = (num_samples - num_remaining) / num_sampled_total

        # For remaining iterations (leakage or many samples) continue sampling with
        # fixed batch size.
        sampling_batch_size = max_sampling_batch_size
Exemple #15
0
    def _line_search(self, full_step, dataset, advs, action_distrib_old, gain):
        """Do line search for a safe step size."""
        policy_params = list(self.policy.parameters())
        policy_params_sizes = [param.numel() for param in policy_params]
        policy_params_shapes = [param.shape for param in policy_params]
        step_size = 1.0
        flat_params = _flatten_and_concat_variables(policy_params).detach()

        if self.recurrent:
            seqs_states = []
            for ep in dataset:
                states = self.batch_states(
                    [transition["state"] for transition in ep], self.device,
                    self.phi)
                if self.obs_normalizer:
                    states = self.obs_normalizer(states, update=False)
                seqs_states.append(states)
            with torch.no_grad(), pfrl.utils.evaluating(self.model):
                policy_rs = concatenate_recurrent_states(
                    _collect_first_recurrent_states_of_policy(dataset))

            def evaluate_current_policy():
                distrib, _ = pack_and_forward(self.policy, seqs_states,
                                              policy_rs)
                return distrib

        else:
            states = self.batch_states(
                [transition["state"] for transition in dataset], self.device,
                self.phi)
            if self.obs_normalizer:
                states = self.obs_normalizer(states, update=False)

            def evaluate_current_policy():
                return self.policy(states)

        flat_transitions = (flatten_sequences_time_first(dataset)
                            if self.recurrent else dataset)
        actions = torch.tensor(
            [transition["action"] for transition in flat_transitions],
            device=self.device,
        )
        log_prob_old = torch.tensor(
            [transition["log_prob"] for transition in flat_transitions],
            device=self.device,
            dtype=torch.float,
        )

        for i in range(self.line_search_max_backtrack + 1):
            self.logger.info("Line search iteration: %s step size: %s", i,
                             step_size)
            new_flat_params = flat_params + step_size * full_step
            new_params = _split_and_reshape_to_ndarrays(
                new_flat_params,
                sizes=policy_params_sizes,
                shapes=policy_params_shapes,
            )
            _replace_params_data(policy_params, new_params)
            with torch.no_grad(), pfrl.utils.evaluating(self.policy):
                new_action_distrib = evaluate_current_policy()
                new_gain = self._compute_gain(
                    log_prob=new_action_distrib.log_prob(actions),
                    log_prob_old=log_prob_old,
                    entropy=new_action_distrib.entropy(),
                    advs=advs,
                )
                new_kl = torch.mean(
                    torch.distributions.kl_divergence(action_distrib_old,
                                                      new_action_distrib))

            improve = float(new_gain) - float(gain)
            self.logger.info("Surrogate objective improve: %s", improve)
            self.logger.info("KL divergence: %s", float(new_kl))
            if not torch.isfinite(new_gain):
                self.logger.info(
                    "Surrogate objective is not finite. Bakctracking...")
            elif not torch.isfinite(new_kl):
                self.logger.info(
                    "KL divergence is not finite. Bakctracking...")
            elif improve < 0:
                self.logger.info(
                    "Surrogate objective didn't improve. Bakctracking...")
            elif float(new_kl) > self.max_kl:
                self.logger.info(
                    "KL divergence exceeds max_kl. Bakctracking...")
            else:
                self.kl_record.append(float(new_kl))
                self.policy_step_size_record.append(step_size)
                break
            step_size *= 0.5
        else:
            self.logger.info(
                "Line search coundn't find a good step size. The policy was not"
                " updated.")
            self.policy_step_size_record.append(0.0)
            _replace_params_data(
                policy_params,
                _split_and_reshape_to_ndarrays(flat_params,
                                               sizes=policy_params_sizes,
                                               shapes=policy_params_shapes),
            )
Exemple #16
0
    def __init__(self, tensor):
        self.floating_dtype = tensor.dtype.is_floating_point
        self.int_mode = True
        self.sci_mode = False
        self.max_width = 1

        with torch.no_grad():
            tensor_view = tensor.reshape(-1)

        if not self.floating_dtype:
            for value in tensor_view:
                value_str = '{}'.format(value)
                self.max_width = max(self.max_width, len(value_str))

        else:
            nonzero_finite_vals = torch.masked_select(
                tensor_view,
                torch.isfinite(tensor_view) & tensor_view.ne(0))

            if nonzero_finite_vals.numel() == 0:
                # no valid number, do nothing
                return

            # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
            nonzero_finite_abs = nonzero_finite_vals.abs().double()
            nonzero_finite_min = nonzero_finite_abs.min().double()
            nonzero_finite_max = nonzero_finite_abs.max().double()

            for value in nonzero_finite_vals:
                if value != torch.ceil(value):
                    self.int_mode = False
                    break

            if self.int_mode:
                # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
                # to indicate that the tensor is of floating type. add 1 to the len to account for this.
                if nonzero_finite_max / nonzero_finite_min > 1000. or nonzero_finite_max > 1.e8:
                    self.sci_mode = True
                    for value in nonzero_finite_vals:
                        value_str = ('{{:.{}e}}').format(
                            PRINT_OPTS.precision).format(value)
                        self.max_width = max(self.max_width, len(value_str))
                else:
                    for value in nonzero_finite_vals:
                        value_str = ('{:.0f}').format(value)
                        self.max_width = max(self.max_width,
                                             len(value_str) + 1)
            else:
                # Check if scientific representation should be used.
                if nonzero_finite_max / nonzero_finite_min > 1000.\
                        or nonzero_finite_max > 1.e8\
                        or nonzero_finite_min < 1.e-4:
                    self.sci_mode = True
                    for value in nonzero_finite_vals:
                        value_str = ('{{:.{}e}}').format(
                            PRINT_OPTS.precision).format(value)
                        self.max_width = max(self.max_width, len(value_str))
                else:
                    for value in nonzero_finite_vals:
                        value_str = ('{{:.{}f}}').format(
                            PRINT_OPTS.precision).format(value)
                        self.max_width = max(self.max_width, len(value_str))

        if PRINT_OPTS.sci_mode is not None:
            self.sci_mode = PRINT_OPTS.sci_mode
 def _check_finite(self, loss: torch.Tensor) -> None:
     if not torch.isfinite(loss).all():
         raise ValueError(
             f'The loss returned in `training_step` is {loss}.')
     model = self.trainer.lightning_module
     detect_nan_parameters(model)
Exemple #18
0
 def is_consistent(tensor):
     max_abs_diff = torch.max(torch.abs(tensor - tensor[0]))
     return (torch.isfinite(tensor).all()
             or (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all())
    def update_policy_net(self, data):
        # Get loss and info values before update
        theta_old = U.get_flat_params_from(self.ac.pi.net)
        self.ac.pi.net.zero_grad()
        loss_pi, pi_info = self.compute_loss_pi(data=data)
        self.loss_pi_before = loss_pi.item()
        self.loss_v_before = self.compute_loss_v(data['obs'],
                                                 data['target_v']).item()
        p_dist = self.ac.pi.dist(data['obs'])
        # Train policy with multiple steps of gradient descent
        loss_pi.backward()
        g_flat = U.get_flat_gradients_from(self.ac.pi.net)

        # flip sign since policy_loss = -(ration * adv)
        g_flat *= -1

        x = U.conjugate_gradients(self.Fvp, g_flat, self.cg_iters)
        assert torch.isfinite(x).all()
        # Note that xHx = g^T x, but calculating xHx is faster than g^T x
        xHx = torch.dot(x, self.Fvp(x))  # equivalent to : g^T x
        assert xHx.item() >= 0, 'No negative values'

        # perform descent direction
        alpha = torch.sqrt(2 * self.target_kl / (xHx + 1e-8))
        step_direction = alpha * x
        assert torch.isfinite(step_direction).all()

        # determine step direction and apply SGD step after grads where set
        # TRPO uses custom backtracking line search
        final_step_dir, accept_step = self.adjust_step_direction(
            step_dir=step_direction,
            g_flat=g_flat,
            p_dist=p_dist,
            data=data,
        )
        # update actor network parameters
        new_theta = theta_old + final_step_dir
        U.set_param_values_to_model(self.ac.pi.net, new_theta)

        with torch.no_grad():
            q_dist = self.ac.pi.dist(data['obs'])
            kl = torch.distributions.kl.kl_divergence(p_dist,
                                                      q_dist).mean().item()
            loss_pi, pi_info = self.compute_loss_pi(data=data)

        self.logger.store(
            **{
                'Values/Adv': data['act'].numpy(),
                'Entropy': pi_info['ent'],
                'KL': kl,
                'PolicyRatio': pi_info['ratio'],
                'Loss/Pi': self.loss_pi_before,
                'Loss/DeltaPi': loss_pi.item() - self.loss_pi_before,
                'Misc/AcceptanceStep': accept_step,
                'Misc/Alpha': alpha.item(),
                'Misc/StopIter': 1,
                'Misc/FinalStepNorm': torch.norm(final_step_dir).numpy(),
                'Misc/xHx': xHx.item(),
                'Misc/gradient_norm': torch.norm(g_flat).numpy(),
                'Misc/H_inv_g': x.norm().item(),
            })
Exemple #20
0
    def train_step(self, samples, raise_oom=False):
        """Do forward, backward and parameter update."""
        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0
        for i, sample in enumerate(samples):  # delayed update loop
            sample, is_dummy_batch = self._prepare_sample(sample)

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (self.data_parallel_world_size > 1
                        and hasattr(self.model, "no_sync")
                        and i < len(samples) - 1):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size_i, logging_output = self.task.train_step(
                        sample=sample,
                        model=self.model,
                        criterion=self.criterion,
                        optimizer=self.optimizer,
                        update_num=self.get_num_updates(),
                        ignore_grad=is_dummy_batch,
                    )
                    del loss

                logging_outputs.append(logging_output)
                sample_size += sample_size_i

                # emptying the CUDA cache after the first step can
                # reduce the chance of OOM
                if self.cuda and self.get_num_updates() == 0:
                    torch.cuda.empty_cache()
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    logger.warning(
                        "attempting to recover from OOM in forward/backward pass"
                    )
                    ooms += 1
                    self.zero_grad()
                    if self.cuda:
                        torch.cuda.empty_cache()
                    if self.cfg.distributed_training.distributed_world_size == 1:
                        return None
                else:
                    raise e

            if self.tpu and i < len(samples) - 1:
                # tpu-comment: every XLA operation before marking step is
                # appended to the IR graph, and processing too many batches
                # before marking step can lead to OOM errors.
                # To handle gradient accumulation use case, we explicitly
                # mark step here for every forward pass without a backward pass
                import torch_xla.core.xla_model as xm

                xm.mark_step()

        if is_dummy_batch:
            if torch.is_tensor(sample_size):
                sample_size.zero_()
            else:
                sample_size *= 0.0

        if torch.is_tensor(sample_size):
            sample_size = sample_size.float()
        else:
            sample_size = float(sample_size)

        # gather logging outputs from all replicas
        if self._sync_stats():
            train_time = self._local_cumulative_training_time()
            logging_outputs, (
                sample_size,
                ooms,
                total_train_time,
            ) = self._aggregate_logging_outputs(
                logging_outputs,
                sample_size,
                ooms,
                train_time,
                ignore=is_dummy_batch,
            )
            self._cumulative_training_time = (total_train_time /
                                              self.data_parallel_world_size)

        overflow = False
        try:
            with torch.autograd.profiler.record_function("reduce-grads"):
                # reduce gradients across workers
                self.optimizer.all_reduce_grads(self.model)
                if utils.has_parameters(self.criterion):
                    self.optimizer.all_reduce_grads(self.criterion)

            with torch.autograd.profiler.record_function("multiply-grads"):
                # multiply gradients by (data_parallel_size / sample_size) since
                # DDP normalizes by the number of data parallel workers for
                # improved fp16 precision.
                # Thus we get (sum_of_gradients / sample_size) at the end.
                # In case of fp16, this step also undoes loss scaling.
                # (Debugging note: Some optimizers perform this scaling on the
                # fly, so inspecting model.parameters() or optimizer.params may
                # still show the original, unscaled gradients.)
                numer = (self.data_parallel_world_size
                         if not self.cfg.optimization.use_bmuf
                         or self._sync_stats() else 1)
                self.optimizer.multiply_grads(numer / (sample_size or 1.0))
                # Note: (sample_size or 1.0) handles the case of a zero gradient, in a
                # way that avoids CPU/device transfers in case sample_size is a GPU or
                # TPU object. The assumption is that the gradient itself is also 0.

            with torch.autograd.profiler.record_function("clip-grads"):
                # clip grads
                grad_norm = self.clip_grad_norm(
                    self.cfg.optimization.clip_norm)

            # check that grad norms are consistent across workers
            # on tpu check tensor is slow
            if not self.tpu:
                if (not self.cfg.optimization.use_bmuf
                        and self.cfg.distributed_training.distributed_wrapper
                        != "SlowMo"):
                    self._check_grad_norms(grad_norm)
                if not torch.isfinite(grad_norm).all():
                    # check local gradnorm single GPU case, trigger NanDetector
                    raise FloatingPointError("gradients are Nan/Inf")

            with torch.autograd.profiler.record_function("optimizer"):
                # take an optimization step
                self.task.optimizer_step(self.optimizer,
                                         model=self.model,
                                         update_num=self.get_num_updates())

        except FloatingPointError:
            # re-run the forward and backward pass with hooks attached to print
            # out where it fails
            self.zero_grad()
            with NanDetector(self.get_model()):
                for _, sample in enumerate(samples):
                    sample, _ = self._prepare_sample(sample)
                    self.task.train_step(
                        sample,
                        self.model,
                        self.criterion,
                        self.optimizer,
                        self.get_num_updates(),
                        ignore_grad=False,
                    )
            raise
        except OverflowError as e:
            overflow = True
            logger.info(
                f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}"
            )
            grad_norm = torch.tensor(0.0).cuda()
            self.zero_grad()
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                logger.error("OOM during optimization, irrecoverable")
            raise e

        # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step
        if hasattr(self.model, "perform_additional_optimizer_actions"):
            if hasattr(self.optimizer, "fp32_params"):
                self.model.perform_additional_optimizer_actions(
                    self.optimizer.optimizer, self.optimizer.fp32_params)
            else:
                self.model.perform_additional_optimizer_actions(
                    self.optimizer.optimizer)

        logging_output = None
        if (not overflow or self.cfg.distributed_training.distributed_wrapper
                == "SlowMo"):
            self.set_num_updates(self.get_num_updates() + 1)

            if self.tpu:
                # mark step on TPUs
                import torch_xla.core.xla_model as xm

                xm.mark_step()

                # only log stats every log_interval steps
                # this causes wps to be misreported when log_interval > 1
                logging_output = {}
                if self.get_num_updates() % self.cfg.common.log_interval == 0:
                    # log memory usage
                    mem_info = xm.get_memory_info(self.device)
                    gb_free = mem_info["kb_free"] / 1024 / 1024
                    gb_total = mem_info["kb_total"] / 1024 / 1024
                    metrics.log_scalar(
                        "gb_free",
                        gb_free,
                        priority=1500,
                        round=1,
                        weight=0,
                    )
                    metrics.log_scalar(
                        "gb_total",
                        gb_total,
                        priority=1600,
                        round=1,
                        weight=0,
                    )

                    logging_output = self._reduce_and_log_stats(
                        logging_outputs,
                        sample_size,
                        grad_norm,
                    )

                # log whenever there's an XLA compilation, since these
                # slow down training and may indicate opportunities for
                # optimization
                self._check_xla_compilation()
            else:
                # log stats
                logging_output = self._reduce_and_log_stats(
                    logging_outputs,
                    sample_size,
                    grad_norm,
                )

                # clear CUDA cache to reduce memory fragmentation
                if (self.cuda and self.cfg.common.empty_cache_freq > 0
                        and ((self.get_num_updates() +
                              self.cfg.common.empty_cache_freq - 1) %
                             self.cfg.common.empty_cache_freq) == 0):
                    torch.cuda.empty_cache()

        if self.cfg.common.fp16:
            metrics.log_scalar(
                "loss_scale",
                self.optimizer.scaler.loss_scale,
                priority=700,
                round=4,
                weight=0,
            )

        metrics.log_stop_time("train_wall")
        return logging_output
Exemple #21
0
def non_max_suppression(prediction,
                        conf_thres=0.1,
                        iou_thres=0.6,
                        multi_label=True,
                        classes=None,
                        agnostic=False):
    """
    Removes detections with lower object confidence score than 'conf_thres'
    Non-Maximum Suppression to further filter detections.
    Returns detections with shape:
        (x1, y1, x2, y2, object_conf, conf, class)
    """
    # NMS methods https://github.com/ultralytics/yolov3/issues/679 'or', 'and', 'merge', 'vision', 'vision_batch'

    # Box constraints
    min_wh, max_wh = 2, 4096  # (pixels) minimum and maximum box width and height

    method = 'vision_batch'
    batched = 'batch' in method  # run once per image, all classes simultaneously
    nc = prediction[0].shape[1] - 5  # number of classes
    multi_label &= nc > 1  # multiple labels per box
    output = [None] * len(prediction)
    for image_i, pred in enumerate(prediction):
        # Apply conf constraint
        pred = pred[pred[:, 4] > conf_thres]

        # Apply width-height constraint
        pred = pred[((pred[:, 2:4] > min_wh) & (pred[:, 2:4] < max_wh)).all(1)]

        # If none remain process next image
        if not pred.shape[0]:
            continue

        # Compute conf
        pred[..., 5:] *= pred[..., 4:5]  # conf = obj_conf * cls_conf

        # Box (center x, center y, width, height) to (x1, y1, x2, y2)
        box = xywh2xyxy(pred[:, :4])

        # Detections matrix nx6 (xyxy, conf, cls)
        if multi_label:
            i, j = (pred[:, 5:] > conf_thres).nonzero().t()
            pred = torch.cat(
                (box[i], pred[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)),
                1)
        else:  # best class only
            conf, j = pred[:, 5:].max(1)
            pred = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)),
                             1)

        # Filter by class
        if classes:
            pred = pred[(j.view(-1,
                                1) == torch.tensor(classes,
                                                   device=j.device)).any(1)]

        # Apply finite constraint
        if not torch.isfinite(pred).all():
            pred = pred[torch.isfinite(pred).all(1)]

        # If none remain process next image
        if not pred.shape[0]:
            continue

        # Sort by confidence
        if not method.startswith('vision'):
            pred = pred[pred[:, 4].argsort(descending=True)]

        # Batched NMS
        if batched:
            c = pred[:, 5] * 0 if agnostic else pred[:,
                                                     5]  # class-agnostic NMS
            boxes, scores = pred[:, :4].clone(), pred[:, 4]
            boxes += c.view(-1, 1) * max_wh
            if method == 'vision_batch':
                i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
            elif method == 'fast_batch':  # FastNMS from https://github.com/dbolya/yolact
                iou = box_iou(boxes, boxes).triu_(
                    diagonal=1)  # upper triangular iou matrix
                i = iou.max(dim=0)[0] < iou_thres

            output[image_i] = pred[i]
            continue

        # All other NMS methods
        det_max = []
        cls = pred[:, -1]
        for c in cls.unique():
            dc = pred[cls == c]  # select class c
            n = len(dc)
            if n == 1:
                det_max.append(dc)  # No NMS required if only 1 prediction
                continue
            elif n > 500:
                dc = dc[:
                        500]  # limit to first 500 boxes: https://github.com/ultralytics/yolov3/issues/117

            if method == 'vision':
                det_max.append(dc[torchvision.ops.boxes.nms(
                    dc[:, :4], dc[:, 4], iou_thres)])

            elif method == 'or':  # default
                # METHOD1
                # ind = list(range(len(dc)))
                # while len(ind):
                # j = ind[0]
                # det_max.append(dc[j:j + 1])  # save highest conf detection
                # reject = (bbox_iou(dc[j], dc[ind]) > iou_thres).nonzero()
                # [ind.pop(i) for i in reversed(reject)]

                # METHOD2
                while dc.shape[0]:
                    det_max.append(dc[:1])  # save highest conf detection
                    if len(dc) == 1:  # Stop if we're at the last detection
                        break
                    iou = bbox_iou(dc[0], dc[1:])  # iou with other boxes
                    dc = dc[1:][iou < iou_thres]  # remove ious > threshold

            elif method == 'and':  # requires overlap, single boxes erased
                while len(dc) > 1:
                    iou = bbox_iou(dc[0], dc[1:])  # iou with other boxes
                    if iou.max() > 0.5:
                        det_max.append(dc[:1])
                    dc = dc[1:][iou < iou_thres]  # remove ious > threshold

            elif method == 'merge':  # weighted mixture box
                while len(dc):
                    if len(dc) == 1:
                        det_max.append(dc)
                        break
                    i = bbox_iou(dc[0], dc) > iou_thres  # iou with other boxes
                    weights = dc[i, 4:5]
                    dc[0, :4] = (weights * dc[i, :4]).sum(0) / weights.sum()
                    det_max.append(dc[:1])
                    dc = dc[i == 0]

            elif method == 'soft':  # soft-NMS https://arxiv.org/abs/1704.04503
                sigma = 0.5  # soft-nms sigma parameter
                while len(dc):
                    if len(dc) == 1:
                        det_max.append(dc)
                        break
                    det_max.append(dc[:1])
                    iou = bbox_iou(dc[0], dc[1:])  # iou with other boxes
                    dc = dc[1:]
                    dc[:, 4] *= torch.exp(-iou**2 / sigma)  # decay confidences
                    dc = dc[
                        dc[:, 4] >
                        conf_thres]  # https://github.com/ultralytics/yolov3/issues/362

        if len(det_max):
            det_max = torch.cat(det_max)  # concatenate
            output[image_i] = det_max[(-det_max[:, 4]).argsort()]  # sort

    return output
def train(opt):
    # saving setting
    opt.saved_path = opt.saved_path + opt.project
    opt.log_path = os.path.join(opt.saved_path, 'tensorboard')
    os.makedirs(opt.log_path, exist_ok=True)
    os.makedirs(opt.saved_path, exist_ok=True)

    # gpu setting
    os.environ["CUDA_VISIBLE_DEVICES"] = '2, 3, 4, 5, 6'
    gpu_number = torch.cuda.device_count()

    # dataset setting
    n_classes = 17
    n_img_all_gpu = opt.batch_size * gpu_number
    cropsize = [448, 448]
    data_root = '/home/data2/DATASET/vschallenge'
    num_workers = opt.num_workers

    ds = FaceMask(data_root, cropsize=cropsize, mode='train')
    dl = DataLoader(ds,
                    batch_size=n_img_all_gpu,
                    shuffle=True,
                    num_workers=num_workers,
                    drop_last=True)
    ds_eval = FaceMask(data_root, cropsize=cropsize, mode='val')
    dl_eval = DataLoader(ds_eval,
                         batch_size=n_img_all_gpu,
                         shuffle=True,
                         num_workers=num_workers,
                         drop_last=True)

    ignore_idx = -100
    net = BiSeNet(n_classes=n_classes)
    # load last weights
    if opt.load_weights is not None:
        if opt.load_weights.endswith('.pth'):
            weights_path = opt.load_weights
        else:
            weights_path = get_last_weights(opt.saved_path)
        try:
            last_step = int(
                os.path.basename(weights_path).split('_')[-1].split('.')[0])
        except:
            last_step = 0
        try:
            ret = net.load_state_dict(torch.load(weights_path), strict=False)
        except RuntimeError as e:
            print(f'[Warning] Ignoring {e}')
            print(
                '[Warning] Don\'t panic if you see this, this might be because you load a pretrained weights '
                'with different number of classes. The rest of the weights should be loaded already.'
            )

            print(
                f'[Info] loaded weights: {os.path.basename(weights_path)}, resuming checkpoint from step: {last_step}'
            )
    else:
        last_step = 0
        print('[Info] initializing weights...')

    writer = SummaryWriter(
        opt.log_path +
        f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

    net = net.cuda()
    net = nn.DataParallel(net)

    score_thres = 0.7
    n_min = n_img_all_gpu * cropsize[0] * cropsize[1] // opt.batch_size
    LossP = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
    Loss2 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
    Loss3 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)

    # optimizer
    momentum = 0.9
    weight_decay = 5e-4
    lr_start = opt.lr
    max_iter = 80000
    power = 0.9
    warmup_steps = 1000
    warmup_start_lr = 1e-5
    optim = Optimizer(model=net.module,
                      lr0=lr_start,
                      momentum=momentum,
                      wd=weight_decay,
                      warmup_steps=warmup_steps,
                      warmup_start_lr=warmup_start_lr,
                      max_iter=max_iter,
                      power=power)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim.optim,
                                                           patience=3,
                                                           verbose=True)
    # train loop
    loss_avg = []
    step = max(0, last_step)
    max_iter = len(dl)
    best_epoch = 0
    epoch = 0
    best_loss = 1e5
    net.train()
    try:
        for epoch in range(opt.num_epochs):
            last_epoch = step // max_iter
            if epoch < last_epoch:
                continue
            epoch_loss = []
            progress_bar = tqdm(dl)
            for iter, data in enumerate(progress_bar):
                if iter < step - last_epoch * max_iter:
                    progress_bar.update()
                    continue
                try:
                    im = data['img']
                    lb = data['label']
                    lb = torch.squeeze(lb, 1)
                    im = im.cuda()
                    lb = lb.cuda()

                    optim.zero_grad()
                    out, out16, out32 = net(im)
                    lossp = LossP(out, lb)
                    loss2 = Loss2(out16, lb)
                    loss3 = Loss3(out32, lb)
                    loss = lossp + loss2 + loss3
                    if loss == 0 or not torch.isfinite(loss):
                        continue
                    loss.backward()
                    optim.step()
                    loss_avg.append(loss.item())
                    #  print training log message
                    # progress_bar.set_description(
                    #     'Epoch: {}/{}. Iteration: {}/{}. p_loss: {:.5f}. 2_loss: {:.5f}. 3_loss: {:.5f}. loss_avg: {:.5f}'.format(
                    #         epoch, opt.num_epochs, iter + 1, max_iter, lossp.item(),
                    #         loss2.item(), loss3.item(), loss.item()))
                    print(
                        'p_loss: {:.5f}. 2_loss: {:.5f}. 3_loss: {:.5f}. loss_avg: {:.5f}'
                        .format(lossp.item(), loss2.item(), loss3.item(),
                                loss.item()))

                    writer.add_scalars('Lossp', {'train': lossp}, step)
                    writer.add_scalars('loss2', {'train': loss2}, step)
                    writer.add_scalars('loss3', {'train': loss3}, step)
                    writer.add_scalars('loss_avg', {'train': loss}, step)

                    # log learning_rate
                    lr = optim.lr
                    writer.add_scalar('learning_rate', lr, step)
                    step += 1

                    if step % opt.save_interval == 0 and step > 0:
                        save_checkpoint(net, f'Bisenet_{epoch}_{step}.pth')
                        print('checkpoint...')

                except Exception as e:
                    print('[Erro]', traceback.format_exc())
                    print(e)
                    continue
            scheduler.step(np.mean(epoch_loss))

            if epoch % opt.val_interval == 0:
                net.eval()
                loss_p = []
                loss_2 = []
                loss_3 = []
                for iter, data in enumerate(dl_eval):
                    with torch.no_grad():
                        im = data['img']
                        lb = data['label']
                        lb = torch.squeeze(lb, 1)
                        im = im.cuda()
                        lb = lb.cuda()

                        out, out16, out32 = net(im)
                        lossp = LossP(out, lb)
                        loss2 = Loss2(out16, lb)
                        loss3 = Loss3(out32, lb)
                        loss = lossp + loss2 + loss3
                        if loss == 0 or not torch.isfinite(loss):
                            continue
                        loss_p.append(lossp.item())
                        loss_2.append(loss2.item())
                        loss_3.append(loss3.item())
                lossp = np.mean(loss_p)
                loss2 = np.mean(loss_2)
                loss3 = np.mean(loss_3)
                loss = lossp + loss2 + loss3
                print(
                    'Val. Epoch: {}/{}. p_loss: {:1.5f}. 2_loss: {:1.5f}. 3_loss: {:1.5f}. Total_loss: {:1.5f}'
                    .format(epoch, opt.num_epochs, lossp, loss2, loss3, loss))
                writer.add_scalars('Total_loss', {'val': loss}, step)
                writer.add_scalars('p_loss', {'val': lossp}, step)
                writer.add_scalars('2_loss', {'val': loss2}, step)
                writer.add_scalars('3_loss', {'val': loss3}, step)

                if loss + opt.es_min_delta < best_loss:
                    best_loss = loss
                    best_epoch = epoch

                    save_checkpoint(net, f'Bisenet_{epoch}_{step}.pth')

                net.train()  # ??
                # Early stopping
                if epoch - best_epoch > opt.es_patience > 0:
                    print(
                        '[Info] Stop training at epoch {}. The lowest loss achieved is {}'
                        .format(epoch, loss))
                    break
    except KeyboardInterrupt:
        save_checkpoint(net, f'Bisenet_{epoch}_{step}.pth')
        writer.close()
    writer.close()
Exemple #23
0
def test_sin_k_grad(logunif_input):
    k = logunif_input.detach().clone().requires_grad_()
    stereographic.math.sin_k(logunif_input[None], k[:, None]).sum().backward()
    assert torch.isfinite(logunif_input.grad).all()
    assert torch.isfinite(k.grad).all()
Exemple #24
0
 def backward(ctx, grad_output):
     mask = tr.isfinite(grad_output)
     grad = tr.zeros_like(grad_output)
     grad[mask] = grad_output[mask]
     return grad
Exemple #25
0
def test_sproj(manifold, a):
    ma = manifold.sproj(manifold.inv_sproj(a))
    np.testing.assert_allclose(ma.detach(), a.detach(), atol=1e-5)
    ma.sum().backward()
    assert torch.isfinite(a.grad).all()
    assert torch.isfinite(manifold.k.grad).all()
Exemple #26
0
    def loss_fn(self, th_pred, th_target, th_ilen, th_olen) -> torch.Tensor:
        if self.ctc_type == "builtin":
            th_pred = th_pred.log_softmax(2)
            loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)

            if loss.requires_grad and self.ignore_nan_grad:
                # ctc_grad: (L, B, O)
                ctc_grad = loss.grad_fn(torch.ones_like(loss))
                ctc_grad = ctc_grad.sum([0, 2])
                indices = torch.isfinite(ctc_grad)
                size = indices.long().sum()
                if size == 0:
                    # Return as is
                    logging.warning(
                        "All samples in this mini-batch got nan grad."
                        " Returning nan value instead of CTC loss")
                elif size != th_pred.size(1):
                    logging.warning(
                        f"{th_pred.size(1) - size}/{th_pred.size(1)}"
                        " samples got nan grad."
                        " These were ignored for CTC loss.")

                    # Create mask for target
                    target_mask = torch.full(
                        [th_target.size(0)],
                        1,
                        dtype=torch.bool,
                        device=th_target.device,
                    )
                    s = 0
                    for ind, le in enumerate(th_olen):
                        if not indices[ind]:
                            target_mask[s:s + le] = 0
                        s += le

                    # Calc loss again using maksed data
                    loss = self.ctc_loss(
                        th_pred[:, indices, :],
                        th_target[target_mask],
                        th_ilen[indices],
                        th_olen[indices],
                    )
            else:
                size = th_pred.size(1)

            if self.reduce:
                # Batch-size average
                loss = loss.sum() / size
            else:
                loss = loss / size
            return loss

        elif self.ctc_type == "warpctc":
            # warpctc only supports float32
            th_pred = th_pred.to(dtype=torch.float32)

            th_target = th_target.cpu().int()
            th_ilen = th_ilen.cpu().int()
            th_olen = th_olen.cpu().int()
            loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
            if self.reduce:
                # NOTE: sum() is needed to keep consistency since warpctc
                # return as tensor w/ shape (1,)
                # but builtin return as tensor w/o shape (scalar).
                loss = loss.sum()
            return loss

        elif self.ctc_type == "gtnctc":
            log_probs = torch.nn.functional.log_softmax(th_pred, dim=2)
            return self.ctc_loss(log_probs, th_target, th_ilen, 0, "none")

        else:
            raise NotImplementedError
Exemple #27
0
def non_max_suppression(prediction,
                        conf_thres=0.1,
                        iou_thres=0.6,
                        multi_label=True,
                        classes=None,
                        agnostic=False):
    """
    Performs  Non-Maximum Suppression on inference results
    Returns detections with shape:
        nx6 (x1, y1, x2, y2, conf, cls)
    """

    # Box constraints
    min_wh, max_wh = 2, 4096  # (pixels) minimum and maximum box width and height

    method = 'merge'
    nc = prediction[0].shape[1] - 5  # number of classes
    multi_label &= nc > 1  # multiple labels per box
    output = [None] * len(prediction)

    for xi, x in enumerate(prediction):  # image index, image inference
        # Apply conf constraint
        x = x[x[:, 4] > conf_thres]

        # Apply width-height constraint
        x = x[((x[:, 2:4] > min_wh) & (x[:, 2:4] < max_wh)).all(1)]

        # If none remain process next image
        if not x.shape[0]:
            continue

        # Compute conf
        x[..., 5:] *= x[..., 4:5]  # conf = obj_conf * cls_conf

        # Box (center x, center y, width, height) to (x1, y1, x2, y2)
        box = xywh2xyxy(x[:, :4])

        # Detections matrix nx6 (xyxy, conf, cls)
        if multi_label:
            i, j = (x[:, 5:] > conf_thres).nonzero().t()
            x = torch.cat(
                (box[i], x[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
        else:  # best class only
            conf, j = x[:, 5:].max(1)
            x = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)

        # Filter by class
        if classes:
            x = x[(j.view(-1, 1) == torch.tensor(classes,
                                                 device=j.device)).any(1)]

        # Apply finite constraint
        if not torch.isfinite(x).all():
            x = x[torch.isfinite(x).all(1)]

        # If none remain process next image
        n = x.shape[0]  # number of boxes
        if not n:
            continue

        # Sort by confidence
        # if method == 'fast_batch':
        #    x = x[x[:, 4].argsort(descending=True)]

        # Batched NMS
        c = x[:, 5] * 0 if agnostic else x[:, 5]  # classes
        boxes, scores = x[:, :4].clone() + c.view(
            -1, 1) * max_wh, x[:, 4]  # boxes (offset by class), scores
        if method == 'merge':  # Merge NMS (boxes merged using weighted mean)
            i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
            if n < 1E4:  # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
                # weights = (box_iou(boxes, boxes).tril_() > iou_thres) * scores.view(-1, 1)  # box weights
                # weights /= weights.sum(0)  # normalize
                # x[:, :4] = torch.mm(weights.T, x[:, :4])
                weights = (box_iou(boxes[i], boxes) >
                           iou_thres) * scores[None]  # box weights
                x[i, :4] = torch.mm(weights / weights.sum(1, keepdim=True),
                                    x[:, :4]).float()  # merged boxes
        elif method == 'vision':
            i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
        elif method == 'fast':  # FastNMS from https://github.com/dbolya/yolact
            iou = box_iou(boxes, boxes).triu_(
                diagonal=1)  # upper triangular iou matrix
            i = iou.max(0)[0] < iou_thres

        output[xi] = x[i]
    return output
Exemple #28
0
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
    """
    Removes detections with lower object confidence score than 'conf_thres'
    Non-Maximum Suppression to further filter detections.
    Returns detections with shape:
        (x1, y1, x2, y2, object_conf, class_conf, class)
    """

    min_wh = 2  # (pixels) minimum box width and height

    output = [None] * len(prediction)
    for image_i, pred in enumerate(prediction):
        # Multiply conf by class conf to get combined confidence
        class_conf, class_pred = pred[:, 5:].max(1)
        pred[:, 4] *= class_conf

        # Select only suitable predictions
        i = (pred[:, 4] > conf_thres) & (
            pred[:, 2:4] > min_wh).all(1) & torch.isfinite(pred).all(1)
        pred = pred[i]

        # If none are remaining => process next image
        if len(pred) == 0:
            continue

        # Select predicted classes
        class_conf = class_conf[i]
        class_pred = class_pred[i].unsqueeze(1).float()

        # Box (center x, center y, width, height) to (x1, y1, x2, y2)
        pred[:, :4] = xywh2xyxy(pred[:, :4])
        # pred[:, 4] *= class_conf  # improves mAP from 0.549 to 0.551

        # Detections ordered as (x1y1x2y2, obj_conf, class_conf, class_pred)
        pred = torch.cat((pred[:, :5], class_conf.unsqueeze(1), class_pred), 1)

        # Get detections sorted by decreasing confidence scores
        pred = pred[(-pred[:, 4]).argsort()]

        det_max = []
        nms_style = 'MERGE'  # 'OR' (default), 'AND', 'MERGE' (experimental)
        for c in pred[:, -1].unique():
            dc = pred[pred[:, -1] == c]  # select class c
            n = len(dc)
            if n == 1:
                det_max.append(dc)  # No NMS required if only 1 prediction
                continue
            elif n > 100:
                dc = dc[:100]

            # Non-maximum suppression
            if nms_style == 'OR':  # default
                while dc.shape[0]:
                    det_max.append(dc[:1])  # save highest conf detection
                    if len(dc) == 1:  # Stop if we're at the last detection
                        break
                    iou = bbox_iou(dc[0], dc[1:])  # iou with other boxes
                    dc = dc[1:][iou < nms_thres]  # remove ious > threshold

            elif nms_style == 'AND':  # requires overlap, single boxes erased
                while len(dc) > 1:
                    iou = bbox_iou(dc[0], dc[1:])  # iou with other boxes
                    if iou.max() > 0.5:
                        det_max.append(dc[:1])
                    dc = dc[1:][iou < nms_thres]  # remove ious > threshold

            elif nms_style == 'MERGE':  # weighted mixture box
                while len(dc):
                    if len(dc) == 1:
                        det_max.append(dc)
                        break
                    i = bbox_iou(dc[0], dc) > nms_thres  # iou with other boxes
                    weights = dc[i, 4:5]
                    dc[0, :4] = (weights * dc[i, :4]).sum(0) / weights.sum()
                    det_max.append(dc[:1])
                    dc = dc[i == 0]

        if len(det_max):
            det_max = torch.cat(det_max)  # concatenate
            output[image_i] = det_max[(-det_max[:, 4]).argsort()]  # sort

    return output
Exemple #29
0
    def search(self, start_predictions: torch.Tensor, start_state: StateType,
               step: StepFunctionType) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Given a starting state and a step function, apply beam search to find the
        most likely target sequences.

        Notes
        -----
        If your step function returns `-inf` for some log probabilities
        (like if you're using a masked log-softmax) then some of the "best"
        sequences returned may also have `-inf` log probability. Specifically
        this happens when the beam size is smaller than the number of actions
        with finite log probability (non-zero probability) returned by the step function.
        Therefore if you're using a mask you may want to check the results from `search`
        and potentially discard sequences with non-finite log probability.

        # Parameters

        start_predictions : `torch.Tensor`
            A tensor containing the initial predictions with shape `(batch_size,)`.
            Usually the initial predictions are just the index of the "start" token
            in the target vocabulary.
        start_state : `StateType`
            The initial state passed to the `step` function. Each value of the state dict
            should be a tensor of shape `(batch_size, *)`, where `*` means any other
            number of dimensions.
        step : `StepFunctionType`
            A function that is responsible for computing the next most likely tokens,
            given the current state and the predictions from the last time step.
            The function should accept two arguments. The first being a tensor
            of shape `(group_size,)`, representing the index of the predicted
            tokens from the last time step, and the second being the current state.
            The `group_size` will be `batch_size * beam_size`, except in the initial
            step, for which it will just be `batch_size`.
            The function is expected to return a tuple, where the first element
            is a tensor of shape `(group_size, target_vocab_size)` containing
            the log probabilities of the tokens for the next step, and the second
            element is the updated state. The tensor in the state should have shape
            `(group_size, *)`, where `*` means any other number of dimensions.

        # Returns

        Tuple[torch.Tensor, torch.Tensor]
            Tuple of `(predictions, log_probabilities)`, where `predictions`
            has shape `(batch_size, beam_size, max_steps)` and `log_probabilities`
            has shape `(batch_size, beam_size)`.
        """
        batch_size = start_predictions.size()[0]

        # List of (batch_size, beam_size) tensors. One for each time step. Does not
        # include the start symbols, which are implicit.
        predictions: List[torch.Tensor] = []

        # List of (batch_size, beam_size) tensors. One for each time step. None for
        # the first.  Stores the index n for the parent prediction, i.e.
        # predictions[t-1][i][n], that it came from.
        backpointers: List[torch.Tensor] = []

        # Calculate the first timestep. This is done outside the main loop
        # because we are going from a single decoder input (the output from the
        # encoder) to the top `beam_size` decoder outputs. On the other hand,
        # within the main loop we are going from the `beam_size` elements of the
        # beam to `beam_size`^2 candidates from which we will select the top
        # `beam_size` elements for the next iteration.
        # shape: (batch_size, num_classes)
        start_class_log_probabilities, state = step(start_predictions,
                                                    start_state)

        num_classes = start_class_log_probabilities.size()[1]

        # Make sure `per_node_beam_size` is not larger than `num_classes`.
        if self.per_node_beam_size > num_classes:
            raise ConfigurationError(
                f"Target vocab size ({num_classes:d}) too small "
                f"relative to per_node_beam_size ({self.per_node_beam_size:d}).\n"
                f"Please decrease beam_size or per_node_beam_size.")

        # shape: (batch_size, beam_size), (batch_size, beam_size)
        start_top_log_probabilities, start_predicted_classes = start_class_log_probabilities.topk(
            self.beam_size)
        if self.beam_size == 1 and (start_predicted_classes
                                    == self._end_index).all():
            warnings.warn(
                "Empty sequences predicted. You may want to increase the beam size or ensure "
                "your step function is working properly.",
                RuntimeWarning,
            )
            return start_predicted_classes.unsqueeze(
                -1), start_top_log_probabilities

        # The log probabilities for the last time step.
        # shape: (batch_size, beam_size)
        last_log_probabilities = start_top_log_probabilities

        # The length of current sequence
        last_length = torch.where(
            start_predicted_classes != self._end_index,
            torch.ones_like(start_predicted_classes),
            torch.zeros_like(start_predicted_classes)).to(
                last_log_probabilities.device)
        last_length = last_length.float()

        # shape: [(batch_size, beam_size)]
        predictions.append(start_predicted_classes)

        # Log probability tensor that mandates that the end token is selected.
        # shape: (batch_size * beam_size, num_classes)
        log_probs_after_end = start_class_log_probabilities.new_full(
            (batch_size * self.beam_size, num_classes), float("-inf"))
        log_probs_after_end[:, self._end_index] = 0.0

        # Set the same state for each element in the beam.
        for key, state_tensor in state.items():
            _, *last_dims = state_tensor.size()
            # shape: (batch_size * beam_size, *)
            state[key] = (state_tensor.unsqueeze(1).expand(
                batch_size, self.beam_size,
                *last_dims).reshape(batch_size * self.beam_size, *last_dims))

        for timestep in range(self.max_steps - 1):
            # shape: (batch_size * beam_size,)
            last_predictions = predictions[-1].reshape(batch_size *
                                                       self.beam_size)

            # If every predicted token from the last step is `self._end_index`,
            # then we can stop early.
            if (last_predictions == self._end_index).all():
                break

            # Take a step. This get the predicted log probs of the next classes
            # and updates the state.
            # shape: (batch_size * beam_size, num_classes)
            class_log_probabilities, state = step(last_predictions, state)

            # shape: (batch_size * beam_size, num_classes)
            last_predictions_expanded = last_predictions.unsqueeze(-1).expand(
                batch_size * self.beam_size, num_classes)

            # Here we are finding any beams where we predicted the end token in
            # the previous timestep and replacing the distribution with a
            # one-hot distribution, forcing the beam to predict the end token
            # this timestep as well.
            # shape: (batch_size * beam_size, num_classes)
            cleaned_log_probabilities = torch.where(
                last_predictions_expanded == self._end_index,
                log_probs_after_end,
                class_log_probabilities,
            )

            # shape (both): (batch_size * beam_size, per_node_beam_size)
            top_log_probabilities, predicted_classes = cleaned_log_probabilities.topk(
                self.per_node_beam_size)
            current_top_length = torch.where(
                predicted_classes != self._end_index,
                torch.ones_like(predicted_classes),
                torch.zeros_like(predicted_classes)).to(
                    last_log_probabilities.device)
            current_top_length = current_top_length.float()

            # Here we expand the last log probabilities to (batch_size * beam_size, per_node_beam_size)
            # so that we can add them to the current log probs for this timestep.
            # This lets us maintain the log probability of each element on the beam.
            # shape: (batch_size * beam_size, per_node_beam_size)
            expanded_last_log_probabilities = (
                last_log_probabilities.unsqueeze(2).expand(
                    batch_size, self.beam_size,
                    self.per_node_beam_size).reshape(
                        batch_size * self.beam_size, self.per_node_beam_size))

            expanded_last_length = (last_length.unsqueeze(2).expand(
                batch_size, self.beam_size,
                self.per_node_beam_size).reshape(batch_size * self.beam_size,
                                                 self.per_node_beam_size))

            # shape: (batch_size * beam_size, per_node_beam_size)
            summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities
            summed_top_length = current_top_length + expanded_last_length
            summed_top_length = summed_top_length.float()

            if self.alpha > 0 and self.length_penalty:
                length_penalty = ((5.0 + summed_top_length)**self.alpha) / (
                    (5.0 + 1.0)**self.alpha)
                # print(length_penalty)
                normalized_summed_top_log_probabilities = summed_top_log_probabilities / length_penalty
            else:
                normalized_summed_top_log_probabilities = summed_top_log_probabilities

            # shape: (batch_size, beam_size * per_node_beam_size)
            reshaped_summed = summed_top_log_probabilities.reshape(
                batch_size, self.beam_size * self.per_node_beam_size)
            reshaped_normalized_summed = normalized_summed_top_log_probabilities.reshape(
                batch_size, self.beam_size * self.per_node_beam_size)
            reshaped_length = summed_top_length.reshape(
                batch_size, self.beam_size * self.per_node_beam_size)

            # shape: (batch_size, beam_size * per_node_beam_size)
            reshaped_predicted_classes = predicted_classes.reshape(
                batch_size, self.beam_size * self.per_node_beam_size)

            # Keep only the top `beam_size` beam indices.
            # shape: (batch_size, beam_size), (batch_size, beam_size)
            restricted_beam_log_probs, restricted_beam_indices = reshaped_normalized_summed.topk(
                self.beam_size)

            # Use the beam indices to extract the corresponding classes.
            # shape: (batch_size, beam_size)
            restricted_predicted_classes = reshaped_predicted_classes.gather(
                1, restricted_beam_indices)
            restricted_summed_probs = reshaped_summed.gather(
                1, restricted_beam_indices)
            restricted_length = reshaped_length.gather(
                1, restricted_beam_indices)

            predictions.append(restricted_predicted_classes)

            # shape: (batch_size, beam_size)
            last_log_probabilities = restricted_summed_probs
            last_length = restricted_length

            # The beam indices come from a `beam_size * per_node_beam_size` dimension where the
            # indices with a common ancestor are grouped together. Hence
            # dividing by per_node_beam_size gives the ancestor. (Note that this is integer
            # division as the tensor is a LongTensor.)
            # shape: (batch_size, beam_size)
            backpointer = restricted_beam_indices / self.per_node_beam_size

            backpointers.append(backpointer)

            # Keep only the pieces of the state tensors corresponding to the
            # ancestors created this iteration.
            for key, state_tensor in state.items():
                _, *last_dims = state_tensor.size()
                # shape: (batch_size, beam_size, *)
                expanded_backpointer = backpointer.view(
                    batch_size, self.beam_size,
                    *([1] * len(last_dims))).expand(batch_size, self.beam_size,
                                                    *last_dims)

                # shape: (batch_size * beam_size, *)
                state[key] = (state_tensor.reshape(
                    batch_size, self.beam_size,
                    *last_dims).gather(1, expanded_backpointer).reshape(
                        batch_size * self.beam_size, *last_dims))

        if not torch.isfinite(last_log_probabilities).all():
            warnings.warn(
                "Infinite log probabilities encountered. Some final sequences may not make sense. "
                "This can happen when the beam size is larger than the number of valid (non-zero "
                "probability) transitions that the step function produces.",
                RuntimeWarning,
            )

        # Reconstruct the sequences.
        # shape: [(batch_size, beam_size, 1)]
        reconstructed_predictions = [predictions[-1].unsqueeze(2)]

        # shape: (batch_size, beam_size)
        cur_backpointers = backpointers[-1]

        for timestep in range(len(predictions) - 2, 0, -1):
            # shape: (batch_size, beam_size, 1)
            cur_preds = predictions[timestep].gather(
                1, cur_backpointers).unsqueeze(2)

            reconstructed_predictions.append(cur_preds)

            # shape: (batch_size, beam_size)
            cur_backpointers = backpointers[timestep - 1].gather(
                1, cur_backpointers)

        # shape: (batch_size, beam_size, 1)
        final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2)

        reconstructed_predictions.append(final_preds)

        # shape: (batch_size, beam_size, max_steps)
        all_predictions = torch.cat(list(reversed(reconstructed_predictions)),
                                    2)

        if self.length_penalty:
            valid_length = all_predictions.ne(self._end_index)
            valid_length = torch.sum(valid_length, dim=-1)
            fill_value = torch.ones_like(valid_length)
            valid_length = torch.where(valid_length == 0, fill_value,
                                       valid_length)  # confirm length is not 0
            valid_length = valid_length.float()
            if self.alpha > 0:
                length_penalty = ((5.0 + valid_length)**self.alpha) / (
                    (5.0 + 1.0)**self.alpha)
                last_log_probabilities = last_log_probabilities / length_penalty
            else:
                last_log_probabilities = last_log_probabilities / valid_length

        # re sort the last_log_probability, and
        # [batch size, beam size]
        sorted_last_log_probabilities, sorted_index = torch.topk(
            last_log_probabilities, k=self.beam_size, dim=-1)

        sorted_index = sorted_index.unsqueeze(-1).expand_as(all_predictions)
        sorted_all_predictions = all_predictions.gather(1, sorted_index)

        return sorted_all_predictions, sorted_last_log_probabilities
Exemple #30
0
def train(opt):
    params = Params(f'projects/{opt.project}.yml')

    if params.num_gpus == 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
    else:
        torch.manual_seed(42)

    opt.saved_path = opt.saved_path + f'/{params.project_name}/'
    opt.log_path = opt.log_path + f'/{params.project_name}/tensorboard/'
    os.makedirs(opt.log_path, exist_ok=True)
    os.makedirs(opt.saved_path, exist_ok=True)
    training_params = {
        'batch_size': opt.train_batch_size,
        'shuffle': True,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    val_params = {
        'batch_size': opt.val_batch_size,
        'shuffle': False,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536]

    training_set = RotationCocoDataset(
        root_dir=os.path.join(opt.data_path, params.project_name),
        set=params.train_set,
        transform=transforms.Compose([
            Normalizer(mean=params.mean, std=params.std),
            Resizer(input_sizes[opt.compound_coef])
        ]))
    training_generator = DataLoader(training_set, **training_params)

    val_set = RotationCocoDataset(root_dir=os.path.join(
        opt.data_path, params.project_name),
                                  set=params.val_set,
                                  transform=transforms.Compose([
                                      Normalizer(mean=params.mean,
                                                 std=params.std),
                                      Resizer(input_sizes[opt.compound_coef])
                                  ]))
    val_generator = DataLoader(val_set, **val_params)

    model = EfficientDetBackbone(num_classes=len(params.obj_list),
                                 compound_coef=opt.compound_coef,
                                 ratios=eval(params.anchors_ratios),
                                 scales=eval(params.anchors_scales))

    # load last weights
    if opt.load_weights is not None:
        if opt.load_weights.endswith('.pth'):
            weights_path = opt.load_weights
        else:
            weights_path = get_last_weights(opt.saved_path)
        try:
            last_step = int(
                os.path.basename(weights_path).split('_')[-1].split('.')[0])
        except:
            last_step = 0

        try:
            ret = model.load_state_dict(torch.load(weights_path), strict=False)
        except RuntimeError as e:
            print(f'[Warning] Ignoring {e}')
            print(
                '[Warning] Don\'t panic if you see this, '
                'this might be because you load a pretrained weights with different number of classes. '
                'The rest of the weights should be loaded already.')

        print(
            f'[Info] loaded weights: {os.path.basename(weights_path)}, resuming checkpoint from step: {last_step}'
        )
    else:
        last_step = 0
        print('[Info] initializing weights...')
        init_weights(model)

    # freeze backbone if train head_only
    if opt.head_only:

        def freeze_backbone(m):
            classname = m.__class__.__name__
            for ntl in ['EfficientNet', 'BiFPN']:
                if ntl in classname:
                    for param in m.parameters():
                        param.requires_grad = False

        model.apply(freeze_backbone)
        print('[Info] freezed backbone')
    """
    `https://github.com/vacancy/Synchronized-BatchNorm-PyTorch`
    apply sync_bn when using multiple gpu and batch_size per gpu is lower than 4
    useful when gpu memory is limited.
    because when bn is disable, the training will be very unstable or slow to converge,
    apply sync_bn can solve it,
    by packing all mini-batc across all gpus as one batch and normalize, then send it back to all gpus.
    but it would also slow down the training by a little bit."""
    if params.num_gpus > 1 and opt.batch_size // params.num_gpus < 4:
        model.apply(replace_w_sync_bn)
        use_sync_bn = True
    else:
        use_sync_bn = False

    writer = SummaryWriter(
        opt.log_path +
        f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

    # warp the model with loss function, to reduce the memory usage on gpu0 and speedup
    model = ModelWithLoss(model, debug=opt.debug)

    if params.num_gpus > 0:
        model = model.cuda()
        if params.num_gpus > 1:
            model = CustomDataParallel(model, params.num_gpus)
            if use_sync_bn:
                patch_replication_callback(model)

    if opt.optim == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), opt.lr)

    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    opt.lr,
                                    momentum=0.9,
                                    nesterov=True)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3,
                                                           verbose=True)

    epoch = 0
    best_loss = 1e5
    best_epoch = 0
    step = max(0, last_step)
    model.train()
    num_iter_per_epoch = len(training_generator)

    try:
        for epoch in range(opt.num_epochs):
            last_epoch = step // num_iter_per_epoch
            if epoch < last_epoch:
                continue

            epoch_loss = []
            progress_bar = tqdm(training_generator)
            for iter, data in enumerate(progress_bar):
                if iter < step - last_epoch * num_iter_per_epoch:
                    progress_bar.update()
                    continue
                try:
                    imgs = data['img']
                    annot = data['annot']

                    if params.num_gpus == 1:
                        # if only one gpu, just send it to cuda:0
                        # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here
                        imgs = imgs.cuda()
                        annot = annot.cuda()

                    optimizer.zero_grad()
                    cls_loss, reg_loss = model(imgs,
                                               annot,
                                               obj_list=params.obj_list)
                    cls_loss = cls_loss.mean()
                    reg_loss = reg_loss.mean()

                    loss = cls_loss + reg_loss
                    if loss == 0 or not torch.isfinite(loss):
                        continue

                    loss.backward()

                    # grad clip (Optional)
                    # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
                    optimizer.step()

                    epoch_loss.append(float(loss))

                    progress_bar.set_description(
                        'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Total loss: {:.5f}'
                        .format(step, epoch, opt.num_epochs, iter + 1,
                                num_iter_per_epoch, cls_loss.item(),
                                reg_loss.item(), loss.item()))
                    writer.add_scalars('Loss', {'train': loss}, step)
                    writer.add_scalars('Regression_loss', {'train': reg_loss},
                                       step)
                    writer.add_scalars('Classfication_loss',
                                       {'train': cls_loss}, step)

                    # log learning_rate
                    current_lr = optimizer.param_groups[0]['lr']
                    writer.add_scalar('learning_rate', current_lr, step)

                    step += 1

                    if step % opt.save_interval == 0 and step > 0:
                        save_checkpoint(
                            model,
                            f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth'
                        )
                        print('checkpoint...')

                except Exception as e:
                    print('[Error]', traceback.format_exc())
                    print(e)
                    continue
            scheduler.step(np.mean(epoch_loss))

            if epoch % opt.val_interval == 0:
                model.eval()
                loss_regression_ls = []
                loss_classification_ls = []
                for iter, data in enumerate(val_generator):
                    with torch.no_grad():
                        imgs = data['img']
                        annot = data['annot']

                        if params.num_gpus == 1:
                            imgs = imgs.cuda()
                            annot = annot.cuda()

                        cls_loss, reg_loss = model(imgs,
                                                   annot,
                                                   obj_list=params.obj_list)
                        cls_loss = cls_loss.mean()
                        reg_loss = reg_loss.mean()

                        loss = cls_loss + reg_loss
                        if loss == 0 or not torch.isfinite(loss):
                            continue

                        loss_classification_ls.append(cls_loss.item())
                        loss_regression_ls.append(reg_loss.item())

                cls_loss = np.mean(loss_classification_ls)
                reg_loss = np.mean(loss_regression_ls)
                loss = cls_loss + reg_loss

                print(
                    'Val. Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'
                    .format(epoch, opt.num_epochs, cls_loss, reg_loss, loss))
                writer.add_scalars('Loss', {'val': loss}, step)
                writer.add_scalars('Regression_loss', {'val': reg_loss}, step)
                writer.add_scalars('Classfication_loss', {'val': cls_loss},
                                   step)

                if loss + opt.es_min_delta < best_loss:
                    best_loss = loss
                    best_epoch = epoch

                    save_checkpoint(
                        model,
                        f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth'
                    )

                model.train()

                # Early stopping
                if epoch - best_epoch > opt.es_patience > 0:
                    print(
                        '[Info] Stop training at epoch {}. The lowest loss achieved is {}'
                        .format(epoch, best_loss))
                    break
    except KeyboardInterrupt:
        save_checkpoint(
            model, f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth')
        writer.close()
    writer.close()