def train():
    data = config['data']
    img_size, img_size_test = config['img_size'] if len(
        config['img_size']
    ) == 2 else config['img_size'] * 2  # train, test sizes
    epochs = config[
        'epochs']  # 500200 batches at bs 64, 117263 images = 273 epochs
    batch_size = config['batch_size']
    accumulate = config[
        'accumulate']  # effective bs = batch_size * accumulate = 16 * 4 = 64

    # Initialize
    init_seeds(config['seed'])
    if config['multi_scale']:
        img_sz_min = round(img_size / 32 / 1.5)
        img_sz_max = round(img_size / 32 * 1.5)
        img_size = img_sz_max * 32  # initiate with maximum multi_scale size
        print('Using multi-scale %g - %g' % (img_sz_min * 32, img_size))

    # Configure run
    data_dict = parse_data_cfg(data)
    nc = int(data_dict['classes'])  # number of classes

    # Initialize Teacher
    if config['teacher_darknet'] == 'default':
        teacher = Darknet(cfg=config['teacher_cfg'],
                          arc=config['teacher_arc']).to(device)
    elif config['teacher_darknet'] == 'soft':
        teacher = SoftDarknet(cfg=config['teacher_cfg'],
                              arc=config['teacher_arc']).to(device)
    # Initialize Student
    if config['student_darknet'] == 'default':
        if 'nano' in config['student_cfg']:
            print('Using a YOLO Nano arc')
            student = YOLO_Nano(config['student_cfg']).to(device)
        else:
            student = Darknet(cfg=config['student_cfg']).to(device)
    elif config['student_darknet'] == 'soft':
        student = SoftDarknet(cfg=config['student_cfg'],
                              arc=config['student_arc']).to(device)
    # Create Discriminators
    D_models = None
    if len(config['teacher_indexes']):
        D_models = Discriminator(config['teacher_indexes'], teacher,
                                 config['D_kernel_size'], False).to(device)

    G_optim = create_optimizer(student, config)
    D_optim = create_optimizer(D_models, config, is_D=True)
    GAN_criterion = torch.nn.BCEWithLogitsLoss()

    mask = None
    if ('mask' in config and config['mask']) or ('mask_path' in config
                                                 and config['mask_path']):
        print('Creating mask')
        mask = create_mask_LTH(teacher).to(device)

    start_epoch, best_fitness, teacher, student, mask, D_models, G_optim, D_optim = load_kd_checkpoints(
        config, teacher, student, mask, D_models, G_optim, D_optim, device)

    if mask is not None:
        print('Applying mask in teacher')
        apply_mask_LTH(teacher, mask)
        del mask
        torch.cuda.empty_cache()

    if config['xavier_norm']:
        initialize_model(student, torch.nn.init.xavier_normal_)
    elif config['xavier_uniform']:
        initialize_model(student, torch.nn.init.xavier_uniform_)

    G_scheduler = create_scheduler(config, G_optim, start_epoch)
    D_scheduler = create_scheduler(config, D_optim, start_epoch)

    # Mixed precision training https://github.com/NVIDIA/apex
    if mixed_precision:
        student, G_optim = amp.initialize(student,
                                          G_scheduler,
                                          opt_level='O1',
                                          verbosity=0)

    # Initialize distributed training
    if device.type != 'cpu' and torch.cuda.device_count(
    ) > 1 and torch.distributed.is_available():
        dist.init_process_group(
            backend='nccl',  # 'distributed backend'
            init_method=
            'tcp://127.0.0.1:9999',  # distributed training init method
            world_size=1,  # number of nodes for distributed training
            rank=0)  # distributed training node rank
        teacher = torch.nn.parallel.DistributedDataParallel(
            teacher, find_unused_parameters=True)
        teacher.yolo_layers = teacher.module.yolo_layers  # move yolo layer indices to top level
        student = torch.nn.parallel.DistributedDataParallel(
            student, find_unused_parameters=True)
        student.yolo_layers = student.module.yolo_layers  # move yolo layer indices to top level

    trainloader, validloader = create_dataloaders(config)

    # Start training
    nb = len(trainloader)
    prebias = start_epoch == 0
    student.nc = nc  # attach number of classes to student
    teacher.nc = nc

    student.arc = config['student_arc']  # attach yolo architecture
    teacher.arc = config['teacher_arc']

    student.hyp = config['hyp']  # attach hyperparameters to student
    teacher.hyp = config['hyp']

    student.class_weights = labels_to_class_weights(
        trainloader.dataset.labels, nc).to(device)  # attach class weights
    teacher.class_weights = student.class_weights

    maps = np.zeros(nc)  # mAP per class
    # torch.autograd.set_detect_anomaly(True)
    results = (
        0, 0, 0, 0, 0, 0, 0
    )  # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
    t0 = time.time()
    torch_utils.model_info(student, report='summary')  # 'full' or 'summary'
    print('Starting training for %g epochs...' % epochs)

    teacher.train()
    max_wo_best = 0
    ###############
    # Start epoch #
    ###############
    for epoch in range(start_epoch, epochs):
        student.train()
        student.gr = 1 - (1 +
                          math.cos(min(epoch * 2, epochs) * math.pi /
                                   epochs)) / 2  # GIoU <-> 1.0 loss ratio

        # Prebias
        if prebias:
            ne = max(round(30 / nb), 3)  # number of prebias epochs
            ps = np.interp(epoch, [0, ne], [0.1, config['hyp']['lr0'] * 2]), \
                np.interp(epoch, [0, ne], [0.9, config['hyp']['momentum']])  # prebias settings (lr=0.1, momentum=0.9)
            if epoch == ne:
                print_model_biases(student)
                prebias = False

            # Bias optimizer settings
            G_optim.param_groups[2]['lr'] = ps[0]
            if G_optim.param_groups[2].get(
                    'momentum') is not None:  # for SGD but not Adam
                G_optim.param_groups[2]['momentum'] = ps[1]

        # Update image weights (optional)
        if trainloader.dataset.image_weights:
            w = student.class_weights.cpu().numpy() * (
                1 - maps)**2  # class weights
            image_weights = labels_to_image_weights(trainloader.dataset.labels,
                                                    nc=nc,
                                                    class_weights=w)
            trainloader.dataset.indices = random.choices(
                range(trainloader.dataset.n),
                weights=image_weights,
                k=trainloader.dataset.n)  # rand weighted idx

        mloss = torch.zeros(9).to(device)  # mean losses
        print(('\n' + '%10s' * 13) %
              ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'G_loss', 'D_loss',
               'D_x', 'D_g_z1', 'D_g_z2', 'total', 'targets', 'img_size'))
        pbar = tqdm(enumerate(trainloader), total=nb)  # progress bar
        ####################
        # Start mini-batch #
        ####################
        for i, (imgs, targets, paths, _) in pbar:
            real_data_label = ft(imgs.shape[0],
                                 device=device).uniform_(.7, 1.0)
            fake_data_label = ft(imgs.shape[0], device=device).uniform_(.0, .3)

            ni = i + nb * epoch  # number integrated batches (since train start)
            imgs = imgs.to(device).float(
            ) / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
            targets = targets.to(device)

            # Plot images with bounding boxes
            if ni < 1:
                f = config[
                    'sub_working_dir'] + 'train_batch%g.png' % i  # filename
                plot_images(imgs=imgs, targets=targets, paths=paths, fname=f)
                if tb_writer:
                    tb_writer.add_image(f,
                                        cv2.imread(f)[:, :, ::-1],
                                        dataformats='HWC')

            # Multi-Scale training
            if config['multi_scale']:
                if ni / accumulate % 1 == 0:  #  adjust img_size (67% - 150%) every 1 batch
                    img_size = random.randrange(img_sz_min,
                                                img_sz_max + 1) * 32
                sf = img_size / max(imgs.shape[2:])  # scale factor
                if sf != 1:
                    ns = [
                        math.ceil(x * sf / 32.) * 32 for x in imgs.shape[2:]
                    ]  # new shape (stretched to 32-multiple)
                    imgs = F.interpolate(imgs,
                                         size=ns,
                                         mode='bilinear',
                                         align_corners=False)

            # Run student
            if len(config['student_indexes']
                   ) and epoch < config['second_stage']:
                pred_std, fts_std = student(imgs, config['student_indexes'])
                if 'nano' in config[
                        'student_cfg']:  # YOLO Nano outputs in the reversed order
                    fts_std.reverse()
            else:
                pred_std = student(imgs)

            ###################################################
            # Update D: maximize log(D(x)) + log(1 - D(G(z))) #
            ###################################################
            D_loss_real, D_loss_fake, D_x, D_g_z1 = ft([.0]), ft([.0]), ft(
                [.0]), ft([.0])
            if epoch < config['second_stage']:
                # Run teacher
                with torch.no_grad():
                    _, fts_tch = teacher(imgs, config['teacher_indexes'])

                # Adding noise to Discriminator: flipping labels
                if random.random() < .05:
                    aux = real_data_label
                    real_data_label = fake_data_label
                    fake_data_label = aux

                # Discriminate the real data
                real_data_discrimination = D_models(fts_tch)
                for output in real_data_discrimination:
                    D_x += output.mean().item() / 3.
                # Discriminate the fake data
                fake_data_discrimination = D_models(
                    [x.detach() for x in fts_std])
                for output in fake_data_discrimination:
                    D_g_z1 += output.mean().item() / 3.

                # Compute loss
                for x in real_data_discrimination:
                    D_loss_real += GAN_criterion(x.view(-1), real_data_label)
                for x in fake_data_discrimination:
                    D_loss_fake += GAN_criterion(x.view(-1), fake_data_label)

                # Scale loss by nominal batch_size of 64
                D_loss_real *= batch_size / 64
                D_loss_fake *= batch_size / 64

                # Compute gradient
                D_loss_real.backward()
                D_loss_fake.backward()

                # Optimize accumulated gradient
                if ni % accumulate == 0:
                    D_optim.step()
                    D_optim.zero_grad()

            ###################################
            # Update G: maximize log(D(G(z))) #
            ###################################
            G_loss, D_g_z2 = ft([.0]), ft([.0])
            if epoch < config['second_stage']:
                # Since we already update D, perform another forward with fake batch through D
                fake_data_discrimination = D_models(
                    [x.detach() for x in fts_std])
                for output in fake_data_discrimination:
                    D_g_z2 += output.mean().item() / 3.

                # Compute loss
                real_data_label = torch.ones(imgs.shape[0], device=device)
                for x in fake_data_discrimination:
                    G_loss += GAN_criterion(
                        x.view(-1), real_data_label
                    )  # fake labels are real for generator cost

                # Scale loss by nominal batch_size of 64
                G_loss *= batch_size / 64

                # Compute gradient
                G_loss.backward()

            # Compute loss
            obj_detec_loss, loss_items = compute_loss(pred_std, targets,
                                                      student)

            # Scale loss by nominal batch_size of 64
            obj_detec_loss *= batch_size / 64

            if epoch < config['second_stage']: obj_detec_loss *= .05

            # Compute gradient
            obj_detec_loss.backward()

            # Optimize accumulated gradient
            if ni % accumulate == 0:
                G_optim.step()
                G_optim.zero_grad()

            D_loss = D_loss_real + D_loss_fake
            total_loss = obj_detec_loss + D_loss + G_loss
            all_losses = torch.cat([
                loss_items[:3], G_loss, D_loss, D_x, D_g_z1, D_g_z2, total_loss
            ]).detach()
            if not torch.isfinite(total_loss):
                print('WARNING: non-finite loss, ending training ', all_losses)
                return results

            # Print batch results
            mloss = (mloss * i + all_losses) / (i + 1)  # update mean losses
            mem = '%.3gG' % (torch.cuda.memory_cached() /
                             1E9 if torch.cuda.is_available() else 0)  # (GB)
            s = ('%10s' * 2 + '%10.3g' * 11) % ('%g/%g' %
                                                (epoch, epochs - 1), mem,
                                                *mloss, len(targets), img_size)
            pbar.set_description(s)
        ##################
        # End mini-batch #
        ##################

        # Update scheduler
        G_scheduler.step()
        D_scheduler.step()

        final_epoch = epoch + 1 == config['epochs']
        if not config['notest'] or final_epoch:  # Calculate mAP
            results, maps = guarantee_test(student, config, device,
                                           config['student_cfg'], data,
                                           batch_size, img_size_test,
                                           validloader, final_epoch, test.test)

        # Write epoch results
        with open(config['results_file'], 'a') as f:
            f.write(s + '%10.3g' * 7 % results +
                    '\n')  # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
        if len(config['name']) and config['bucket']:
            os.system('gsutil cp results.txt gs://%s/results/results%s.txt' %
                      (config['bucket'], config['name']))

        # Write Tensorboard results
        if tb_writer:
            x = list(mloss) + list(results)
            titles = [
                'GIoU', 'Objectness', 'Classification', 'Generator Loss',
                'Discriminator Loss', 'D_x', 'D_g_z1', 'D_g_z2'
                'Train Loss', 'Precision', 'Recall', 'mAP', 'F1', 'val GIoU',
                'val Objectness', 'val Classification'
            ]
            for xi, title in zip(x, titles):
                tb_writer.add_scalar(title, xi, epoch)

        # Update best mAP
        fi = fitness(np.array(results).reshape(
            1, -1))  # fitness_i = weighted combination of [P, R, mAP, F1]
        if fi > best_fitness:
            best_fitness = fi
            max_wo_best = 0
        else:
            max_wo_best += 1
            if config['early_stop'] and max_wo_best == config['early_stop']:
                print('Ending training due to early stop')

        # Save training results
        save = (not config['nosave']) or (final_epoch and not config['evolve'])
        if save:
            with open(config['results_file'], 'r') as f:
                # Create checkpoint
                chkpt = {
                    'epoch':
                    epoch,
                    'best_fitness':
                    best_fitness,
                    'training_results':
                    f.read(),
                    'model':
                    student.module.state_dict()
                    if type(student) is nn.parallel.DistributedDataParallel
                    else student.state_dict(),
                    'D':
                    D_models.state_dict(),
                    'G_optim':
                    None if final_epoch else G_optim.state_dict(),
                    'D_optim':
                    None if final_epoch else D_optim.state_dict()
                }

            # Save last checkpoint
            torch.save(chkpt, config['last'])

            # Save best checkpoint
            if best_fitness == fi:
                torch.save(
                    chkpt, config['best_gan']
                    if epoch < config['second_stage'] else config['best'])

            # Delete checkpoint
            del chkpt
            torch.cuda.empty_cache()

        if config['early_stop'] and max_wo_best == config['early_stop']: break
    #############
    # End epoch #
    #############

    n = config['name']
    if len(n):
        n = '_' + n if not n.isnumeric() else n
        fresults, flast, fbest = 'results%s.txt' % n, 'last%s.pt' % n, 'best%s.pt' % n
        os.rename(config['results_file'], config['sub_working_dir'] + fresults)
        os.rename(config['last'], config['sub_working_dir'] +
                  flast) if os.path.exists(config['last']) else None
        os.rename(config['best'], config['sub_working_dir'] +
                  fbest) if os.path.exists(config['best']) else None
        # Updating results, last and best
        config['results_file'] = config['sub_working_dir'] + fresults
        config['last'] = config['sub_working_dir'] + flast
        config['best'] = config['sub_working_dir'] + fbest

        if config['bucket']:  # save to cloud
            os.system('gsutil cp %s gs://%s/results' %
                      (fresults, config['bucket']))
            os.system('gsutil cp %s gs://%s/weights' %
                      (config['sub_working_dir'] + flast, config['bucket']))
            # os.system('gsutil cp %s gs://%s/weights' % (config['sub_working_dir'] + fbest, config['bucket']))

    if not config['evolve']:
        plot_results(folder=config['sub_working_dir'])

    print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1,
                                                    (time.time() - t0) / 3600))
    dist.destroy_process_group() if torch.cuda.device_count() > 1 else None
    torch.cuda.empty_cache()

    return results
x = torch.Tensor(1, 3, 416, 416).to(device)

# Initialize model
if 'soft' in args['cfg']: model = SoftDarknet(args['cfg']).to(device)
elif 'nano' in args['cfg']: model = YOLO_Nano(args['cfg']).to(device)
else: model = Darknet(args['cfg']).to(device)

if args['model']:
    checkpoint = torch.load(args['model'], map_location=device)
    try:
        model.load_state_dict(checkpoint['model'])
    except:
        model.load_state_dict(checkpoint)

if (args['mask'] or args['embbed']):
    mask = create_mask_LTH(model)
    if args['mask']:
        mask.load_state_dict(torch.load(args['mask'], map_location=device))
    else:
        mask.load_state_dict(checkpoint['mask'])
    apply_mask_LTH(model, mask)

elif 'soft' in args['cfg']:
    model.ticket = True
    model.temp = 1.
    _ = model(x)

if not (args['mask'] or args['embbed']
        or 'soft' in args['cfg']) or args['macs_reduction'] == 'none':
    print('using model with no MACs reduction')
    total_ops, total_params = profile(model, (x, ),
Exemple #3
0
def test(cfg,
         data,
         weights=None,
         batch_size=16,
         img_size=416,
         conf_thres=0.001,
         iou_thres=0.6,  # for nms
         save_json=False,
         single_cls=False,
         model=None,
         dataloader=None,
         folder='',
         mask=None,
         mask_weight=None,
         architecture='default'):
    # Initialize/load model and set device
    if model is None:
        device = torch_utils.select_device(args['device'], batch_size=batch_size)
        verbose = args['task'] == 'test'

        # Remove previous
        for f in glob.glob(folder + 'test_batch*.png'):
            os.remove(f)

        # Initialize model
        if 'soft' in cfg:
            model = SoftDarknet(cfg=cfg).to(device)
        else:
            if 'nano' in cfg: model = YOLO_Nano(cfg).to(device)
            else: model = Darknet(cfg=cfg).to(device)

        if mask or mask_weight:
            msk = create_mask_LTH(model)

        # Load weights
        attempt_download(weights)
        if weights.endswith('.pt'):  # pytorch format
            try:
                model.load_state_dict(torch.load(weights, map_location=device)['model'])
            except:
                model.load_state_dict(torch.load(weights, map_location=device))
        else:  # darknet format
            load_darknet_weights(model, weights)

        if device.type != 'cpu' and torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)

        # Counting parameters
        if 'soft' in cfg:
            model.ticket = True
            model.temp = 1.
            x = torch.Tensor(1, 3, 416, 416).to(device)
            y = model(x)
            masks = [m.mask for m in model.mask_modules]
            print(f"Evaluating model with {compute_removed_weights(masks)} parameters removed.")
        elif mask or mask_weight:
            # Loading LTH mask
            initial_weights = sum_of_the_weights(msk)
            if mask: msk.load_state_dict(torch.load(weights, map_location=device)['mask'])
            else: msk.load_state_dict(torch.load(mask_weight, map_location=device))
            final_weights = sum_of_the_weights(msk)
            apply_mask_LTH(model, msk)
            print(f'Evaluating model with initial weights number of {initial_weights} and final of {final_weights}. \nReduction of {final_weights * 100. / initial_weights}%.')
            del msk

    else:  # called by train.py
        device = next(model.parameters()).device  # get model device
        verbose = False

    # Configure run
    data = parse_data_cfg(data)
    nc = 1 if single_cls else int(data['classes'])  # number of classes
    path = data['test'] if 'test' in data else data['valid']  # path to test images
    names = load_classes(data['names'])  # class names
    iouv = torch.linspace(0.5, 0.95, 10).to(device)  # iou vector for [email protected]:0.95
    iouv = iouv[0].view(1)  # comment for [email protected]:0.95
    niou = iouv.numel()

    # Dataloader
    if dataloader is None:
        dataset = LoadImagesAndLabels(path, img_size, batch_size, rect=True, single_cls=single_cls, cache_labels=True)
        batch_size = min(batch_size, len(dataset))
        dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                num_workers=min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]),
                                pin_memory=True,
                                collate_fn=dataset.collate_fn)

    seen = 0
    model.eval()
    coco91class = coco80_to_coco91_class()
    s = ('%20s' + '%10s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', '[email protected]', 'F1')
    p, r, f1, mp, mr, map, mf1, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0.
    loss = torch.zeros(3, device=device)
    jdict, stats, ap, ap_class = [], [], [], []
    for batch_i, (imgs, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
        imgs = imgs.to(device).float() / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
        targets = targets.to(device)
        _, _, height, width = imgs.shape  # batch size, channels, height, width
        whwh = torch.Tensor([width, height, width, height]).to(device)

        # Plot images with bounding boxes
        f = folder + 'test_batch%g.png' % batch_i  # filename
        if batch_i < 1 and not os.path.exists(f):
            plot_images(imgs=imgs, targets=targets, paths=paths, fname=f)


        # Disable gradients
        with torch.no_grad():
            # Run model
            t = torch_utils.time_synchronized()
            inf_out, train_out = model(imgs)  # inference and training outputs
            t0 += torch_utils.time_synchronized() - t

            # Compute loss
            if hasattr(model, 'hyp'):  # if model has loss hyperparameters
                loss += compute_loss(train_out, targets, model)[1][:3]  # GIoU, obj, cls

            # Run NMS
            t = torch_utils.time_synchronized()
            output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres)
            t1 += torch_utils.time_synchronized() - t

        # Statistics per image
        for si, pred in enumerate(output):
            labels = targets[targets[:, 0] == si, 1:]
            nl = len(labels)
            tcls = labels[:, 0].tolist() if nl else []  # target class
            seen += 1

            if pred is None:
                if nl:
                    stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
                continue

            # Append to text file
            # with open('test.txt', 'a') as file:
            #    [file.write('%11.5g' * 7 % tuple(x) + '\n') for x in pred]

            # Clip boxes to image bounds
            clip_coords(pred, (height, width))

            # Append to pycocotools JSON dictionary
            if save_json:
                # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ...
                image_id = int(Path(paths[si]).stem.split('_')[-1])
                box = pred[:, :4].clone()  # xyxy
                scale_coords(imgs[si].shape[1:], box, shapes[si][0], shapes[si][1])  # to original shape
                box = xyxy2xywh(box)  # xywh
                box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
                for di, d in enumerate(pred):
                    jdict.append({'image_id': image_id,
                                  'category_id': coco91class[int(d[5])],
                                  'bbox': [floatn(x, 3) for x in box[di]],
                                  'score': floatn(d[4], 5)})

            # Assign all predictions as incorrect
            correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool, device=device)
            if nl:
                detected = []  # target indices
                tcls_tensor = labels[:, 0]

                # target boxes
                tbox = xywh2xyxy(labels[:, 1:5]) * whwh

                # Per target class
                for cls in torch.unique(tcls_tensor):
                    ti = (cls == tcls_tensor).nonzero().view(-1)  # prediction indices
                    pi = (cls == pred[:, 5]).nonzero().view(-1)  # target indices

                    # Search for detections
                    if pi.shape[0]:
                        # Prediction to target ious
                        ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1)  # best ious, indices

                        # Append detections
                        for j in (ious > iouv[0]).nonzero():
                            d = ti[i[j]]  # detected target
                            if d not in detected:
                                detected.append(d)
                                correct[pi[j]] = ious[j] > iouv  # iou_thres is 1xn
                                if len(detected) == nl:  # all targets already located in image
                                    break

            # Append statistics (correct, conf, pcls, tcls)
            stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))

    # Compute statistics
    stats = [np.concatenate(x, 0) for x in zip(*stats)]  # to numpy
    if len(stats):
        p, r, ap, f1, ap_class = ap_per_class(*stats)
        if niou > 1:
            p, r, ap, f1 = p[:, 0], r[:, 0], ap.mean(1), ap[:, 0]  # [P, R, [email protected]:0.95, [email protected]]
        mp, mr, map, mf1 = p.mean(), r.mean(), ap.mean(), f1.mean()
        nt = np.bincount(stats[3].astype(np.int64), minlength=nc)  # number of targets per class
    else:
        nt = torch.zeros(1)

    pf = '%20s' + '%10.3g' * 6  # print format
    print(pf % ('all', seen, nt.sum(), mp, mr, map, mf1))
    
    # Saving the average evaluations
    class_results = open(folder + 'per_class_evaluations.txt', 'w')
    print(s)
    print(pf % ('all', seen, nt.sum(), mp, mr, map, mf1), file=class_results)
    for i, c in enumerate(ap_class):
        # Saving the evaluations per class
        print(pf % (names[c], seen, nt[c], p[i], r[i], ap[i], f1[i]), file=class_results)
    # Closing the evaluations .txt
    class_results.close()

    # Print results per class
    if verbose and nc > 1 and len(stats):
        for i, c in enumerate(ap_class):
            print(pf % (names[c], seen, nt[c], p[i], r[i], ap[i], f1[i]))

    # Save JSON
    if save_json and map and len(jdict):
        print('\nCOCO mAP with pycocotools...')
        imgIds = [int(Path(x).stem.split('_')[-1]) for x in dataloader.dataset.img_files]
        with open(folder + 'results.json', 'w') as file:
            json.dump(jdict, file)

        try:
            from pycocotools.coco import COCO
            from pycocotools.cocoeval import COCOeval
        except:
            print('WARNING: missing pycocotools package, can not compute official COCO mAP. See requirements.txt.')

        # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
        cocoGt = COCO(glob.glob('../COCO2014/annotations/instances_val*.json')[0])  # initialize COCO ground truth api
        cocoDt = cocoGt.loadRes(folder + 'results.json')  # initialize COCO pred api

        cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
        cocoEval.params.imgIds = imgIds  # [:32]  # only evaluate these images
        cocoEval.evaluate()
        cocoEval.accumulate()
        cocoEval.summarize()
        mf1, map = cocoEval.stats[:2]  # update to pycocotools results ([email protected]:0.95, [email protected])

    # Print speeds
    if verbose:
        t = tuple(x / seen * 1E3 for x in (t0, t1, t0 + t1)) + (img_size, img_size, batch_size)  # tuple
        print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t)

    # Return results
    maps = np.zeros(nc) + map
    for i, c in enumerate(ap_class):
        maps[c] = ap[i]
    return (mp, mr, map, mf1, *(loss.cpu() / len(dataloader)).tolist()), maps
Exemple #4
0
path, img, im0s, _ = next(iter(dataset))
img = torch.from_numpy(img).to(device)
img = img.float()
img /= 255.0  # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
    img = img.unsqueeze(0)

if is_CS:
    yolo = SoftDarknet(cfg='cfg/voc_yolov3_soft_orig-output.cfg').to(device)
    yolo.load_state_dict(ck_model['model'])
    yolo.ticket = True
    _ = yolo(img)
else:
    yolo = Darknet(cfg='cfg/voc_yolov3.cfg').to(device)
    yolo.load_state_dict(ck_model['model'])
    mask = create_mask_LTH(yolo)
    mask.load_state_dict(ck_mask)
    apply_mask_LTH(yolo, mask)

sparse = Ch_Wise_SparseYOLO(yolo).to(device)

yolo.eval()
sparse.eval()
# Inference
pred1 = yolo(img)[0]
pred2 = sparse(img)[0]

# Apply NMS
pred1 = non_max_suppression(pred1, 0.3, 0.6)
pred2 = non_max_suppression(pred2, 0.3, 0.6)
def train():
    data = config['data']
    img_size, img_size_test = config['img_size'] if len(config['img_size']) == 2 else config['img_size'] * 2  # train, test sizes
    epochs = config['epochs']  # 500200 batches at bs 64, 117263 images = 273 epochs
    batch_size = config['batch_size']
    accumulate = config['accumulate']  # effective bs = batch_size * accumulate = 16 * 4 = 64
    
    # Initialize
    init_seeds(config['seed'])
    if config['multi_scale']:
        img_sz_min = round(img_size / 32 / 1.5)
        img_sz_max = round(img_size / 32 * 1.5)
        img_size = img_sz_max * 32  # initiate with maximum multi_scale size
        print('Using multi-scale %g - %g' % (img_sz_min * 32, img_size))

    # Configure run
    data_dict = parse_data_cfg(data)
    nc = int(data_dict['classes'])  # number of classes
    config['single_cls'] = nc == 1

    # Initialize Teacher
    if config['teacher_darknet'] == 'default':
        teacher = Darknet(cfg=config['teacher_cfg'], arc=config['teacher_arc']).to(device)
    elif config['teacher_darknet'] == 'soft':
        teacher = SoftDarknet(cfg=config['teacher_cfg'], arc=config['teacher_arc']).to(device)
    # Initialize Student
    if config['student_darknet'] == 'default':
        if 'nano' in config['student_cfg']: 
            print('Using a YOLO Nano arc')
            student = YOLO_Nano(config['student_cfg']).to(device)
        else: student = Darknet(cfg=config['student_cfg']).to(device)
    elif config['student_darknet'] == 'soft':
        student = SoftDarknet(cfg=config['student_cfg'], arc=config['student_arc']).to(device)
    # Create Hint Layers
    hint_models = None
    if len(config['teacher_indexes']):
        hint_models = HintModel(config, teacher, student).to(device)
    
    optimizer = create_optimizer(student, config)
    if len(config['teacher_indexes']):
        add_to_optimizer(config, hint_models, optimizer)        

    HINT = nn.L1Loss()

    mask = None
    if ('mask' in config and config['mask']) or ('mask_path' in config and config['mask_path']):
        print('Creating mask')
        mask = create_mask_LTH(teacher).to(device)

    start_epoch, best_fitness, teacher, student, mask, hint_models, optimizer, _ = load_kd_checkpoints(
        config, 
        teacher, student, 
        mask, hint_models,
        optimizer, None, device
    )

    if mask is not None:
        print('Applying mask in teacher')
        apply_mask_LTH(teacher, mask)
        del mask
        torch.cuda.empty_cache()

    if config['xavier_norm']:
        initialize_model(student, torch.nn.init.xavier_normal_)
    elif config['xavier_uniform']:
        initialize_model(student, torch.nn.init.xavier_uniform_)

    scheduler = create_scheduler(config, optimizer, start_epoch)

    # Mixed precision training https://github.com/NVIDIA/apex
    if mixed_precision:
        student, optimizer = amp.initialize(student, optimizer, opt_level='O1', verbosity=0)

    # Initialize distributed training
    if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available():
        dist.init_process_group(backend='nccl',  # 'distributed backend'
                                init_method='tcp://127.0.0.1:9999',  # distributed training init method
                                world_size=1,  # number of nodes for distributed training
                                rank=0)  # distributed training node rank
        teacher = torch.nn.parallel.DistributedDataParallel(teacher, find_unused_parameters=True)
        teacher.yolo_layers = teacher.module.yolo_layers  # move yolo layer indices to top level
        student = torch.nn.parallel.DistributedDataParallel(student, find_unused_parameters=True)
        student.yolo_layers = student.module.yolo_layers  # move yolo layer indices to top level

    trainloader, validloader = create_dataloaders(config)

    # Start training
    nb = len(trainloader)
    prebias = start_epoch == 0
    student.nc = nc  # attach number of classes to student
    teacher.nc = nc
    
    student.arc = config['student_arc']  # attach yolo architecture
    teacher.arc = config['teacher_arc']

    student.hyp = config['hyp']  # attach hyperparameters to student
    teacher.hyp = config['hyp']  # attach hyperparameters to student
    mu = ft([h['mu']]) # mu variable to weight the hard lcls and soft lcls in Eq: 2 (value not informed)
    ni = ft([h['ni']]) # ni variable to weight the teacher bounded regression loss.
    margin = ft([h['margin']]) # m variable used as margin in teacher bounded regression loss. (value not informed)
    
    student.class_weights = labels_to_class_weights(trainloader.dataset.labels, nc).to(device)  # attach class weights
    teacher.class_weights = student.class_weights

    maps = np.zeros(nc)  # mAP per class
    # torch.autograd.set_detect_anomaly(True)
    results = (0, 0, 0, 0, 0, 0, 0)  # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
    t0 = time.time()
    torch_utils.model_info(student, report='summary')  # 'full' or 'summary'
    print('Starting training for %g epochs...' % epochs)

    teacher.eval()
    max_wo_best = 0
    ###############
    # Start epoch #
    ###############
    for epoch in range(start_epoch, epochs):  
        student.train()
        student.gr = 1 - (1 + math.cos(min(epoch * 2, epochs) * math.pi / epochs)) / 2  # GIoU <-> 1.0 loss ratio

        # Prebias
        if prebias:
            ne = max(round(30 / nb), 3)  # number of prebias epochs
            ps = np.interp(epoch, [0, ne], [0.1, config['hyp']['lr0'] * 2]), \
                np.interp(epoch, [0, ne], [0.9, config['hyp']['momentum']])  # prebias settings (lr=0.1, momentum=0.9)
            if epoch == ne:
                print_model_biases(student)
                prebias = False

            # Bias optimizer settings
            optimizer.param_groups[2]['lr'] = ps[0]
            if optimizer.param_groups[2].get('momentum') is not None:  # for SGD but not Adam
                optimizer.param_groups[2]['momentum'] = ps[1]

        # Update image weights (optional)
        if trainloader.dataset.image_weights:
            w = student.class_weights.cpu().numpy() * (1 - maps) ** 2  # class weights
            image_weights = labels_to_image_weights(trainloader.dataset.labels, nc=nc, class_weights=w)
            trainloader.dataset.indices = random.choices(range(trainloader.dataset.n), weights=image_weights, k=trainloader.dataset.n)  # rand weighted idx

        mloss = torch.zeros(5).to(device)  # mean losses
        print(('\n' + '%10s' * 9) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'hint', 'total', 'targets', 'img_size'))
        pbar = tqdm(enumerate(trainloader), total=nb)  # progress bar
        ####################
        # Start mini-batch #
        ####################
        for i, (imgs, targets, paths, _) in pbar: 
            ni = i + nb * epoch  # number integrated batches (since train start)
            imgs = imgs.to(device).float() / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
            targets = targets.to(device)

            # Plot images with bounding boxes
            if ni < 1:
                f = config['sub_working_dir'] + 'train_batch%g.png' % i  # filename
                plot_images(imgs=imgs, targets=targets, paths=paths, fname=f)
                if tb_writer:
                    tb_writer.add_image(f, cv2.imread(f)[:, :, ::-1], dataformats='HWC')

            # Multi-Scale training
            if config['multi_scale']:
                if ni / accumulate % 1 == 0:  #  adjust img_size (67% - 150%) every 1 batch
                    img_size = random.randrange(img_sz_min, img_sz_max + 1) * 32
                sf = img_size / max(imgs.shape[2:])  # scale factor
                if sf != 1:
                    ns = [math.ceil(x * sf / 32.) * 32 for x in imgs.shape[2:]]  # new shape (stretched to 32-multiple)
                    imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)

            # Run teacher
            with torch.no_grad():
                inf_out, tch_train_output, fts_tch = teacher(imgs, config['teacher_indexes'])
                tch_loss = compute_loss(tch_train_output, targets, teacher, True)
                bboxes_tch = non_max_suppression(inf_out, conf_thres=.1, iou_thres=0.6)
                targets_tch = torch.Tensor()
                # creating labels from teacher outputs
                for j, detections in enumerate(bboxes_tch): # a list of detections per image
                    if detections is not None and len(detections): 
                        for *xyxy, _, cls_tch in detections: # ignoring the confidence
                            xyxy = torch.Tensor(xyxy)
                            if len(xyxy.shape) == 1: xyxy = xyxy.view(-1, *xyxy.shape)
                            l = torch.Tensor(len(xyxy), 6)
                            # the boxes are unormalized. If not multi_scale, width != height
                            xyxy[:, (0, 2)] /= imgs.shape[2]
                            xyxy[:, (1, 3)] /= imgs.shape[3]

                            l[:, 0] = j # the j-th image
                            l[:, 1] = cls_tch # classes
                            l[:, 2:] = xyxy2xywh(xyxy) # bboxes in darknet format

                            targets_tch = torch.cat([targets_tch, l])

                targets_tch = targets_tch.to(device)
                
            # Run student
            pred_std, fts_std = student(imgs, config['student_indexes'])

            # Run hint layers
            fts_guided = hint_models(fts_std)

            ################
            # Compute loss #
            ################
            hard_loss = compute_loss(pred_std, targets, student, True)
            soft_loss = compute_loss(pred_std, targets_tch, student, True)
            
            # Loss = Loss Hard + Loss Soft
            upper_bound_lreg = hard_loss[0] if hard_loss[0] + margin > tch_loss[0] else ft([.0])
            lbox =  hard_loss[0] + ni * upper_bound_lreg # Equation 4
            lobj = hard_loss[1]
            lcls = mu * hard_loss[2] + (1. - mu) * soft_loss[2] # Equation 2
            lhint = torch.cuda.FloatTensor([.0])
            for (hint, guided) in zip(fts_tch, fts_guided):
                lhint += HINT(guided, hint) # Equation 6
            loss = lbox + lobj + lcls + lhint
            loss_items = torch.cat((lbox, lobj, lcls, lhint, loss)).detach()

            if not torch.isfinite(loss):
                print('WARNING: non-finite loss, ending training ', loss_items)
                return results

            # Scale loss by nominal batch_size of 64
            loss *= batch_size / 64

            # Compute gradient
            if mixed_precision:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            # Optimize accumulated gradient
            if ni % accumulate == 0:
                optimizer.step()
                optimizer.zero_grad()

            # Print batch results
            mloss = (mloss * i + loss_items) / (i + 1)  # update mean losses
            mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0)  # (GB)
            s = ('%10s' * 2 + '%10.3g' * 7) % ('%g/%g' % (epoch, epochs - 1), mem, *mloss, len(targets), img_size)
            pbar.set_description(s)
        ##################
        # End mini-batch #
        ##################

        # Update scheduler
        scheduler.step()
        
        final_epoch = epoch + 1 == epochs
        if not config['notest'] or final_epoch:  # Calculate mAP
            teacher = teacher.to('cpu')
            hint_models = hint_models.to('cpu')
            results, maps = guarantee_test(
                student, config, device, config['cfg'], data,
                batch_size, img_size_test, validloader,
                final_epoch, test.test
            )
            teacher = teacher.to(device)
            hint_models = hint_models.to(device)

        # Write epoch results
        with open(config['results_file'], 'a') as f:
            f.write(s + '%10.3g' * 7 % results + '\n')  # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
        if len(config['name']) and config['bucket']:
            os.system('gsutil cp results.txt gs://%s/results/results%s.txt' % (config['bucket'], config['name']))

        # Write Tensorboard results
        if tb_writer:
            x = list(mloss) + list(results)
            titles = ['GIoU', 'Objectness', 'Classification', 'Hint', 'Train loss',
                      'Precision', 'Recall', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification']
            for xi, title in zip(x, titles):
                tb_writer.add_scalar(title, xi, epoch)

        # Update best mAP
        fi = fitness(np.array(results).reshape(1, -1))  # fitness_i = weighted combination of [P, R, mAP, F1]
        if fi > best_fitness:
            best_fitness = fi
            max_wo_best = 0
        else:
            max_wo_best += 1
            if config['early_stop'] and max_wo_best == config['early_stop']: print('Ending training due to early stop')

        # Save training results
        save = (not config['nosave']) or (final_epoch and not config['evolve'])
        if save:
            with open(config['results_file'], 'r') as f:
                # Create checkpoint
                chkpt = {
                    'epoch': epoch,
                    'best_fitness': best_fitness,
                    'training_results': f.read(),
                    'model': student.module.state_dict() if type(student) is nn.parallel.DistributedDataParallel 
                        else student.state_dict(),
                    'hint': None if hint_models is None
                        else hint_models.module.state_dict() if type(hint_models) is nn.parallel.DistributedDataParallel 
                        else hint_models.state_dict(),
                    'optimizer': None if final_epoch else optimizer.state_dict()}

            # Save last checkpoint
            torch.save(chkpt, config['last'])

            # Save best checkpoint
            if best_fitness == fi:
                torch.save(chkpt, config['best'])

            # Delete checkpoint
            del chkpt
            torch.cuda.empty_cache()
        
        if config['early_stop'] and max_wo_best == config['early_stop']: break
    #############
    # End epoch #
    #############

    n = config['name']
    if len(n):
        n = '_' + n if not n.isnumeric() else n
        fresults, flast, fbest = 'results%s.txt' % n, 'last%s.pt' % n, 'best%s.pt' % n
        os.rename(config['results_file'], config['sub_working_dir'] + fresults)
        os.rename(config['last'], config['sub_working_dir'] + flast) if os.path.exists(config['last']) else None
        os.rename(config['best'], config['sub_working_dir'] + fbest) if os.path.exists(config['best']) else None
        # Updating results, last and best
        config['results_file'] = config['sub_working_dir'] + fresults
        config['last'] = config['sub_working_dir'] + flast
        config['best'] = config['sub_working_dir'] + fbest

        if config['bucket']:  # save to cloud
            os.system('gsutil cp %s gs://%s/results' % (fresults, config['bucket']))
            os.system('gsutil cp %s gs://%s/weights' % (config['sub_working_dir'] + flast, config['bucket']))
            # os.system('gsutil cp %s gs://%s/weights' % (config['sub_working_dir'] + fbest, config['bucket']))

    if not config['evolve']:
        plot_results(folder= config['sub_working_dir'])

    print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
    dist.destroy_process_group() if torch.cuda.device_count() > 1 else None
    torch.cuda.empty_cache()

    return results
Exemple #6
0
def train():
    cfg = config['cfg']
    img_size, img_size_test = config['img_size'] if len(
        config['img_size']
    ) == 2 else config['img_size'] * 2  # train, test sizes
    epochs = config[
        'epochs']  # 500200 batches at bs 64, 117263 images = 273 epochs
    accumulate = config[
        'accumulate']  # effective bs = batch_size * accumulate = 16 * 4 = 64

    # Initialize
    init_seeds(config['seed'])
    if config['multi_scale']:
        img_sz_min = round(img_size / 32 / 1.5)
        img_sz_max = round(img_size / 32 * 1.5)
        img_size = img_sz_max * 32  # initiate with maximum multi_scale size
        print('Using multi-scale %g - %g' % (img_sz_min * 32, img_size))

    # Configure run
    data_dict = parse_data_cfg(config['data'])
    nc = int(data_dict['classes'])  # number of classes

    # Initialize model
    model = Darknet(cfg, arc=config['arc']).to(device)
    mask = create_mask_LTH(model)

    optimizer = create_optimizer(model, config)
    start_iteration, start_epoch, best_fitness, model, mask, optimizer = load_checkpoints_mask(
        config, model, mask, optimizer, device, attempt_download,
        load_darknet_weights)
    scheduler = create_scheduler(config, optimizer, start_epoch)

    # Kind of initialization
    if config['xavier_norm']:
        initialize_model(model, torch.nn.init.xavier_normal_)
    elif config['xavier_uniform']:
        initialize_model(model, torch.nn.init.xavier_uniform_)

    # Initialize distributed training
    if device.type != 'cpu' and torch.cuda.device_count(
    ) > 1 and torch.distributed.is_available():
        dist.init_process_group(
            backend='nccl',  # 'distributed backend'
            init_method=
            'tcp://127.0.0.1:9999',  # distributed training init method
            world_size=1,  # number of nodes for distributed training
            rank=0)  # distributed training node rank
        model = torch.nn.parallel.DistributedDataParallel(
            model, find_unused_parameters=True)
        model.yolo_layers = model.module.yolo_layers  # move yolo layer indices to top level

    trainloader, validloader = create_dataloaders(config)

    # Start training
    nb = len(trainloader)
    prebias = start_epoch == 0
    model.nc = nc  # attach number of classes to model
    config['single_cls'] = nc == 1
    model.arc = config['arc']  # attach yolo architecture
    model.hyp = config['hyp']  # attach hyperparameters to model
    model.class_weights = labels_to_class_weights(
        trainloader.dataset.labels, nc).to(device)  # attach class weights
    maps = np.zeros(nc)  # mAP per class
    # torch.autograd.set_detect_anomaly(True)
    results = (
        0, 0, 0, 0, 0, 0, 0
    )  # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
    t0 = time.time()
    torch_utils.model_info(model, report='summary')  # 'full' or 'summary'
    print('Starting training for %g epochs...' % epochs)

    counter = 0
    ###################
    # Start Iteration #
    ###################
    for it in range(start_iteration, config['iterations']):

        config['last'] = config['sub_working_dir'] + 'last_it_{}.pt'.format(it)
        config['best'] = config['sub_working_dir'] + 'best_it_{}.pt'.format(it)
        max_wo_best = 0
        ###############
        # Start epoch #
        ###############
        for epoch in range(start_epoch, epochs):
            model.train()
            model.gr = 1 - (
                1 + math.cos(min(epoch * 2, epochs) * math.pi /
                             epochs)) / 2  # GIoU <-> 1.0 loss ratio

            # Prebias
            if prebias:
                ne = max(round(30 / nb), 3)  # number of prebias epochs
                ps = np.interp(epoch, [0, ne], [0.1, config['hyp']['lr0'] * 2]), \
                    np.interp(epoch, [0, ne], [0.9, config['hyp']['momentum']])  # prebias settings (lr=0.1, momentum=0.9)
                if epoch == ne:
                    print_model_biases(model)
                    prebias = False

                # Bias optimizer settings
                optimizer.param_groups[2]['lr'] = ps[0]
                if optimizer.param_groups[2].get(
                        'momentum') is not None:  # for SGD but not Adam
                    optimizer.param_groups[2]['momentum'] = ps[1]

            # Update image weights (optional)
            if trainloader.dataset.image_weights:
                w = model.class_weights.cpu().numpy() * (
                    1 - maps)**2  # class weights
                image_weights = labels_to_image_weights(
                    trainloader.dataset.labels, nc=nc, class_weights=w)
                trainloader.dataset.indices = random.choices(
                    range(trainloader.dataset.n),
                    weights=image_weights,
                    k=trainloader.dataset.n)  # rand weighted idx

            mloss = torch.zeros(4).to(device)  # mean losses
            print(('\n' + '%10s' * 9) %
                  ('Iter', 'Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total',
                   'targets', 'img_size'))
            pbar = tqdm(enumerate(trainloader), total=nb)  # progress bar

            # Backup for late reseting
            if epoch == config['reseting'] - 1:
                mask = mask.to('cpu')
                backup = create_backup(model)
                torch.save(
                    backup.state_dict(), config['sub_working_dir'] +
                    'bckp_it-{}_epoch-{}.pt'.format(it + 1, epoch + 1))
                backup = backup.to('cpu')
                mask = mask.to(device)

            ####################
            # Start mini-batch #
            ####################
            for i, (imgs, targets, paths, _) in pbar:
                ni = i + nb * epoch  # number integrated batches (since train start)

                ##############
                # Apply mask #
                ##############
                apply_mask_LTH(model, mask)

                imgs = imgs.to(device).float(
                ) / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
                targets = targets.to(device)

                # Plot images with bounding boxes
                if ni < 1:
                    f = config[
                        'sub_working_dir'] + 'train_batch%g.png' % i  # filename
                    plot_images(imgs=imgs,
                                targets=targets,
                                paths=paths,
                                fname=f)
                    if tb_writer:
                        tb_writer.add_image(f,
                                            cv2.imread(f)[:, :, ::-1],
                                            dataformats='HWC')

                # Multi-Scale training
                if config['multi_scale']:
                    if ni / accumulate % 1 == 0:  #  adjust img_size (67% - 150%) every 1 batch
                        img_size = random.randrange(img_sz_min,
                                                    img_sz_max + 1) * 32
                    sf = img_size / max(imgs.shape[2:])  # scale factor
                    if sf != 1:
                        ns = [
                            math.ceil(x * sf / 32.) * 32
                            for x in imgs.shape[2:]
                        ]  # new shape (stretched to 32-multiple)
                        imgs = F.interpolate(imgs,
                                             size=ns,
                                             mode='bilinear',
                                             align_corners=False)

                # Run model
                pred = model(imgs)

                # imgs = imgs.to('cpu')

                # Compute loss
                loss, loss_items = compute_loss(pred, targets, model)
                if not torch.isfinite(loss):
                    print('WARNING: non-finite loss, ending training ',
                          loss_items)
                    return results

                # Scale loss by nominal batch_size of 64
                loss *= config['batch_size'] / 64

                # Compute gradient
                loss.backward()

                # Optimize accumulated gradient
                if ni % accumulate == 0:
                    optimizer.step()
                    optimizer.zero_grad()

                # Print batch results
                mloss = (mloss * i + loss_items) / (i + 1
                                                    )  # update mean losses
                mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9
                                 if torch.cuda.is_available() else 0)  # (GB)
                s = ('%10s' * 3 + '%10.3g' * 6) % (
                    '%g/%g' % (it, config['iterations'] - 1), '%g/%g' %
                    (epoch, epochs - 1), mem, *mloss, len(targets), img_size)
                pbar.set_description(s)
            ##################
            # End mini-batch #
            ##################

            # Update scheduler
            scheduler.step()

            final_epoch = epoch + 1 == epochs
            if not config['notest'] or final_epoch:  # Calculate mAP
                results, maps = guarantee_test(model, config, device, cfg,
                                               config['data'],
                                               config['batch_size'],
                                               img_size_test, validloader,
                                               final_epoch, test.test)

            # Write epoch results
            with open(config['results_file'], 'a') as f:
                f.write(s + '%10.3g' * 7 % results +
                        '\n')  # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
            if len(config['name']) and config['bucket']:
                os.system(
                    'gsutil cp results.txt gs://%s/results/results%s.txt' %
                    (config['bucket'], config['name']))

            # Write Tensorboard results
            if tb_writer:
                x = list(mloss) + list(results)
                titles = [
                    'GIoU', 'Objectness', 'Classification', 'Train loss',
                    'Precision', 'Recall', 'mAP', 'F1', 'val GIoU',
                    'val Objectness', 'val Classification'
                ]
                for xi, title in zip(x, titles):
                    tb_writer.add_scalar(title, xi, counter)

            counter += 1

            # Update best mAP
            fi = fitness(np.array(results).reshape(
                1, -1))  # fitness_i = weighted combination of [P, R, mAP, F1]
            if fi > best_fitness:
                best_fitness = fi
                max_wo_best = 0
            else:
                max_wo_best += 1
                if config['early_stop'] and max_wo_best == config['early_stop']:
                    print('Ending training due to early stop')

            # Save training results
            save = (not config['nosave']) or (final_epoch
                                              and not config['evolve'])
            if save:
                with open(config['results_file'], 'r') as f:
                    # Create checkpoint
                    chkpt = {
                        'iteration':
                        it,
                        'epoch':
                        epoch,
                        'best_fitness':
                        best_fitness,
                        'training_results':
                        f.read(),
                        'model':
                        model.module.state_dict()
                        if type(model) is nn.parallel.DistributedDataParallel
                        else model.state_dict(),
                        'mask':
                        mask.state_dict(),
                        'optimizer':
                        None if final_epoch else optimizer.state_dict()
                    }

                # Save last checkpoint
                torch.save(chkpt, config['last'])

                # Save best checkpoint
                if best_fitness == fi:
                    torch.save(chkpt, config['best'])

                # Delete checkpoint
                del chkpt
                torch.cuda.empty_cache()

            if config['early_stop'] and max_wo_best == config['early_stop']:
                break
        #############
        # End epoch #
        #############

        # Saving current mask before prune
        torch.save(
            mask.state_dict(),
            config['sub_working_dir'] + 'mask_{}_{}.pt'.format(
                config['pruning_time'],
                'prune' if config['pruning_time'] == 1 else 'prunes'))
        # Saving current model before prune
        torch.save(model.state_dict(),
                   config['sub_working_dir'] + 'model_it_{}.pt'.format(it + 1))

        if it < config[
                'iterations'] - 1:  # Train more one iteration without pruning
            if config['prune_kind'] == 'IMP_LOCAL':
                print(
                    f"Applying IMP Local with {config['pruning_rate'] * 100}%."
                )
                IMP_LOCAL(model, mask, config['pruning_rate'])
            elif config['prune_kind'] == 'IMP_GLOBAL':
                print(
                    f"Applying IMP Global with {config['pruning_rate'] * 100}%."
                )
                IMP_GLOBAL(model, mask, config['pruning_rate'])

            mask = mask.to('cpu')
            print('Rewind weights.')
            backup = backup.to(device)
            rewind_weights(model, backup)
            backup = backup.to('cpu')
            mask = mask.to(device)
            config['pruning_time'] += 1

        optimizer = create_optimizer(model, config)
        start_epoch = 0
        best_fitness = .0
        scheduler = create_scheduler(config, optimizer, start_epoch)
    #################
    # End Iteration #
    #################

    n = config['name']
    if len(n):
        n = '_' + n if not n.isnumeric() else n
        fresults, flast, fbest = 'results%s.txt' % n, 'last%s.pt' % n, 'best%s.pt' % n
        os.rename(config['results_file'], config['sub_working_dir'] + fresults)
        os.rename(config['last'], config['sub_working_dir'] +
                  flast) if os.path.exists(config['last']) else None
        os.rename(config['best'], config['sub_working_dir'] +
                  fbest) if os.path.exists(config['best']) else None
        # Updating results, last and best
        config['results_file'] = config['sub_working_dir'] + fresults
        config['last'] = config['sub_working_dir'] + flast
        config['best'] = config['sub_working_dir'] + fbest

        if config['bucket']:  # save to cloud
            os.system('gsutil cp %s gs://%s/results' %
                      (fresults, config['bucket']))
            os.system('gsutil cp %s gs://%s/weights' %
                      (config['sub_working_dir'] + flast, config['bucket']))
            # os.system('gsutil cp %s gs://%s/weights' % (config['sub_working_dir'] + fbest, config['bucket']))

    if not config['evolve']:
        plot_results(folder=config['sub_working_dir'])

    print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1,
                                                    (time.time() - t0) / 3600))
    dist.destroy_process_group() if torch.cuda.device_count() > 1 else None
    torch.cuda.empty_cache()

    return results
def detect(save_img=False):
    img_size = (320, 192) if ONNX_EXPORT else opt.img_size  # (320, 192) or (416, 256) or (608, 352) for (height, width)
    out, source, weights, view_img = opt.output, opt.source, opt.weights, opt.view_img

    # Initialize
    device = torch_utils.select_device(device='cpu' if ONNX_EXPORT else opt.device)
    # Get names and colors
    names = load_classes(opt.names)
    colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]

    ################
    # Create files #
    ################
    if os.path.exists(out):
        shutil.rmtree(out)  # delete output folder
    os.makedirs(out)  # make new output folder
    files = []
    for f in names:
        result_file = open(out+f'comp3_det_test_{f}.txt', 'w')
        files.append(result_file)

    # Initialize model
    if 'soft' in opt.cfg:
        model = SoftDarknet(cfg=opt.cfg).to(device)
        model.ticket = True
    else:
        model = Darknet(cfg=opt.cfg).to(device)

    # Load weights
    attempt_download(weights)
    if weights.endswith('.pt'):  # pytorch format
        model.load_state_dict(torch.load(weights, map_location=device)['model'])
        if opt.mask or opt.mask_weight:
            mask = create_mask_LTH(model)
            if opt.mask: mask.load_state_dict(torch.load(weights, map_location=device)['mask'])
            else: mask.load_state_dict(torch.load(opt.mask_weight, map_location=device))
            apply_mask_LTH(model, mask)
            del mask
    else:  # darknet format
        load_darknet_weights(model, weights)

    # Second-stage classifier
    classify = False
    if classify:
        modelc = torch_utils.load_classifier(name='resnet101', n=2)  # initialize
        modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model'])  # load weights
        modelc.to(device).eval()

    # Eval mode
    model.to(device).eval()

    # Export mode
    if ONNX_EXPORT:
        model.fuse()
        img = torch.zeros((1, 3) + img_size)  # (1, 3, 320, 192)
        f = opt.weights.replace(opt.weights.split('.')[-1], 'onnx')  # *.onnx filename
        torch.onnx.export(model, img, f, verbose=False, opset_version=11)

        # Validate exported model
        import onnx
        model = onnx.load(f)  # Load the ONNX model
        onnx.checker.check_model(model)  # Check that the IR is well formed
        print(onnx.helper.printable_graph(model.graph))  # Print a human readable representation of the graph
        return

    # Set Dataloader
    dataset = LoadImages(source, img_size=img_size)

    # Run inference
    t0 = time.time()
    for path, img, im0s, _ in dataset:
        ID = path.split(os.sep)[-1].split('.')[0]
        t = time.time()

        img = torch.from_numpy(img).to(device)
        img = img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        if img.ndimension() == 3:
            img = img.unsqueeze(0)

        # Inference
        pred = model(img)[0]

        # Apply NMS
        pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)

        # Apply Classifier
        if classify:
            pred = apply_classifier(pred, modelc, img, im0s)

        # Process detections
        for i, det in enumerate(pred):  # detections per image
            p, s, im0 = path, '', im0s

            save_path = str(Path(out) / Path(p).name)
            s += '%gx%g ' % img.shape[2:]  # print string
            if det is not None and len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

                # Print results
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    s += '%g %ss, ' % (n, names[int(c)])  # add to string

                # Write results
                for *xyxy, conf, cls in det:
                    print(ID, conf.item(), int(xyxy[0].item()), int(xyxy[1].item()), int(xyxy[2].item()), int(xyxy[3].item()), sep=' ', file=files[int(cls)])

                    if save_img or view_img:  # Add bbox to image
                        label = '%s %.2f' % (names[int(cls)], conf)
                        plot_one_box(xyxy, im0, label=label, color=colors[int(cls)])

            # Print time (inference + NMS)
            print('%sDone. (%.3fs)' % (s, time.time() - t))

            # Stream results
            if view_img:
                cv2.imshow(p, im0)
                if cv2.waitKey(1) == ord('q'):  # q to quit
                    raise StopIteration

            # Save results (image with detections)
            if save_img:
                if dataset.mode == 'images':
                    cv2.imwrite(save_path, im0)

    if save_img:
        print('Results saved to %s' % os.getcwd() + os.sep + out)
        if platform == 'darwin':  # MacOS
            os.system('open ' + out + ' ' + save_path)

    print('Done. (%.3fs)' % (time.time() - t0))
    for f in files: f.close()