예제 #1
0
    def build_model(self) -> nn.Module:

        opt = get_cli_args(batch_size=pedl_batch_size,
                           prebias=pedl_prebias,
                           accumulate=pedl_accumulate)
        hyp = get_hyp(lr0=pedl_init_lr)

        # Initialize model
        model = Darknet(opt.cfg, arc=opt.arc)  # .to(device)

        # Fetch starting weights
        # TODO Once download_data_fn is implemented this should go into download_data
        attempt_download(opt.weights)
        chkpt = torch.load(opt.weights)

        # load model
        try:
            chkpt["model"] = {
                k: v
                for k, v in chkpt["model"].items()
                if model.state_dict()[k].numel() == v.numel()
            }
            model.load_state_dict(chkpt["model"], strict=False)
        except KeyError as e:
            s = (
                "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s. "
                "See https://github.com/ultralytics/yolov3/issues/657" %
                (opt.weights, opt.cfg, opt.weights))
            raise KeyError(s) from e

        del chkpt

        data_dict = get_data_cfg()
        nc = 1 if opt.single_cls else int(data_dict["classes"])
        model.nc = nc  # attach number of classes to model

        model.arc = opt.arc  # attach yolo architecture
        model.hyp = hyp  # attach hyperparameters to model

        train_dataset = LazyModule.get()

        # The model class weights depend on the dataset labels
        model.class_weights = labels_to_class_weights(
            train_dataset.labels, nc)  # attach class weights

        return model
예제 #2
0
def greedy_channel_select(origin_model, prune_cfg, origin_weights,
                          select_layer, device, aux_util, data_loader,
                          pruned_rate):
    init_state_dict = mask_converted(prune_cfg, origin_weights, target=None)

    prune_model = Darknet(prune_cfg).to(device)
    prune_model.load_state_dict(init_state_dict, strict=True)
    del init_state_dict
    solve_sub_problem_optimizer = optim.SGD(
        prune_model.module_list[int(select_layer)].MaskConv2d.parameters(),
        lr=hyp['lr0'],
        momentum=hyp['momentum'])
    hook_util = HookUtils()
    handles = []

    info = aux_util.layer_info[int(select_layer)]
    in_channels = info['in_channels']
    remove_k = math.floor(in_channels * pruned_rate)
    k = in_channels - remove_k

    for name, child in origin_model.module_list.named_children():
        if name == select_layer:
            handles.append(
                child.BatchNorm2d.register_forward_hook(
                    hook_util.hook_origin_input))

    aux_idx = aux_util.conv_layer_dict[select_layer]
    hook_layer_aux = aux_util.down_sample_layer[aux_idx]
    for name, child in prune_model.module_list.named_children():
        if name == select_layer:
            handles.append(
                child.BatchNorm2d.register_forward_hook(
                    hook_util.hook_prune_input))
        elif name == hook_layer_aux:
            handles.append(
                child.register_forward_hook(hook_util.hook_prune_input))

    aux_net = aux_util.creat_aux_list(416,
                                      device,
                                      conv_layer_name=select_layer)
    chkpt_aux = torch.load(aux_weight, map_location=device)
    aux_net.load_state_dict(chkpt_aux['aux{}'.format(aux_idx)])
    del chkpt_aux

    if device.type != 'cpu' and torch.cuda.device_count() > 1:
        prune_model = torch.nn.parallel.DistributedDataParallel(
            prune_model, find_unused_parameters=True)
        prune_model.yolo_layers = prune_model.module.yolo_layers
        aux_net = torch.nn.parallel.DistributedDataParallel(
            aux_net, find_unused_parameters=True)

    nb = len(data_loader)
    prune_model.nc = 80
    prune_model.hyp = hyp
    prune_model.arc = 'default'
    prune_model.eval()
    aux_net.eval()
    MSE = nn.MSELoss(reduction='mean')

    greedy = torch.zeros(k)
    for i_k in range(k):
        pbar = tqdm(enumerate(data_loader), total=nb)
        print(('\n' + '%10s' * 8) % ('Stage', 'gpu_mem', 'iter', 'MSELoss',
                                     'PdLoss', 'AuxLoss', 'Total', 'targets'))
        for i, (imgs, targets, _, _) in pbar:
            if len(targets) == 0:
                continue

            imgs = imgs.to(device).float(
            ) / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
            targets = targets.to(device)

            with torch.no_grad():
                _ = origin_model(imgs)

            _, pruning_pred = prune_model(imgs)
            pruning_loss, _ = compute_loss(pruning_pred, targets, prune_model)
            hook_util.cat_to_gpu0('prune')

            aux_pred = aux_net(hook_util.prune_features['gpu0'][1])
            aux_loss, _ = AuxNetUtils.compute_loss_for_aux(
                aux_pred, aux_net, targets)

            mse_loss = torch.zeros(1).to(device)
            mse_loss += MSE(hook_util.prune_features['gpu0'][0],
                            hook_util.origin_features['gpu0'][0])

            loss = hyp['joint_loss'] * mse_loss + pruning_loss + aux_loss

            loss.backward()

            mem = torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available(
            ) else 0
            s = ('%10s' * 3 + '%10.3g' * 5) % (
                'Pruning ' + select_layer, '%.3gG' % mem, '%g/%g' %
                (i_k, k), mse_loss, pruning_loss, aux_loss, loss, len(targets))
            pbar.set_description(s)

            hook_util.clean_hook_out('origin')
            hook_util.clean_hook_out('prune')

        grad = prune_model.module.module_list[int(
            select_layer)].MaskConv2d.weight.grad.detach().clone()**2
        grad = grad.sum((2, 3)).sqrt().sum(0)

        if i_k == 0:
            prune_model.module.module_list[int(
                select_layer)].MaskConv2d.selected_channels_mask[:] = 1e-5
            _, non_greedy_indices = torch.topk(grad, k)
            logger.info('non greedy layer{}: selected==>{}'.format(
                select_layer, str(non_greedy_indices)))

        selected_channels_mask = prune_model.module.module_list[int(
            select_layer)].MaskConv2d.selected_channels_mask
        _, indices = torch.topk(grad * (1 - selected_channels_mask), 1)
        prune_model.module.module_list[int(
            select_layer)].MaskConv2d.selected_channels_mask[indices] = 1
        greedy[i_k] = indices
        logger.info('greedy layer{} iter{}: indices==>{}'.format(
            select_layer, str(i_k), str(indices)))

        prune_model.zero_grad()

        pbar = tqdm(enumerate(data_loader), total=nb)
        mloss = torch.zeros(4).to(device)
        print(('\n' + '%10s' * 8) % ('Stage', 'gpu_mem', 'iter', 'MSELoss',
                                     'PdLoss', 'AuxLoss', 'Total', 'targets'))
        for i, (imgs, targets, _, _) in pbar:

            if len(targets) == 0:
                continue

            imgs = imgs.to(device).float(
            ) / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
            targets = targets.to(device)

            with torch.no_grad():
                _ = origin_model(imgs)

            _, pruning_pred = prune_model(imgs)
            pruning_loss, _ = compute_loss(pruning_pred, targets, prune_model)
            hook_util.cat_to_gpu0('prune')

            aux_pred = aux_net(hook_util.prune_features['gpu0'][1])
            aux_loss, _ = AuxNetUtils.compute_loss_for_aux(
                aux_pred, aux_net, targets)

            mse_loss = torch.zeros(1).to(device)
            mse_loss += MSE(hook_util.prune_features['gpu0'][0],
                            hook_util.origin_features['gpu0'][0])

            loss = hyp['joint_loss'] * mse_loss + pruning_loss + aux_loss

            loss.backward()

            solve_sub_problem_optimizer.step()
            solve_sub_problem_optimizer.zero_grad()

            mem = torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available(
            ) else 0
            mloss = (mloss * i + torch.cat(
                [mse_loss, pruning_loss, aux_loss, loss]).detach()) / (i + 1)
            s = ('%10s' * 3 + '%10.3g' * 5) % ('SubProm ' + select_layer,
                                               '%.3gG' % mem, '%g/%g' %
                                               (i_k, k), *mloss, len(targets))
            pbar.set_description(s)

            hook_util.clean_hook_out('origin')
            hook_util.clean_hook_out('prune')

    for handle in handles:
        handle.remove()

    logger.info(
        ("greedy layer{}: selected==>{}".format(select_layer, str(greedy))))
예제 #3
0
def fine_tune(prune_cfg,
              data,
              aux_util,
              device,
              train_loader,
              test_loader,
              epochs=10):
    with open(progress_result, 'a') as f:
        f.write(('\n' + '%10s' * 10 + '\n') %
                ('Stage', 'Epoch', 'DIoU', 'obj', 'cls', 'Total', 'P', 'R',
                 '[email protected]', 'F1'))

    batch_size = train_loader.batch_size
    img_size = train_loader.dataset.img_size
    accumulate = 64 // batch_size
    hook_util = HookUtils()

    pruned_model = Darknet(prune_cfg, img_size=(img_size, img_size)).to(device)

    chkpt = torch.load(progress_chkpt, map_location=device)
    pruned_model.load_state_dict(chkpt['model'], strict=True)

    current_layer = chkpt['current_layer']
    aux_in_layer = aux_util.conv_layer_dict[current_layer]
    aux_model = aux_util.creat_aux_model(aux_in_layer)
    aux_model.to(device)

    aux_model.load_state_dict(chkpt['aux_in{}'.format(aux_in_layer)],
                              strict=True)
    aux_loss_scalar = max(0.01, pow((int(aux_in_layer) + 1) / 75, 2))

    start_epoch = chkpt['epoch'] + 1

    if start_epoch == epochs:
        return current_layer  # fine tune 完毕,返回需要修剪的层名

    pg0, pg1 = [], []  # optimizer parameter groups
    for k, v in dict(pruned_model.named_parameters()).items():
        if 'MaskConv2d.weight' in k:
            pg1 += [v]  # parameter group 1 (apply weight_decay)
        else:
            pg0 += [v]  # parameter group 0

    for v in aux_model.parameters():
        pg0 += [v]  # parameter group 0

    optimizer = optim.SGD(pg0,
                          lr=hyp['lr0'],
                          momentum=hyp['momentum'],
                          nesterov=True)
    optimizer.add_param_group({
        'params': pg1,
        'weight_decay': hyp['weight_decay']
    })  # add pg1 with weight_decay
    del pg0, pg1

    if chkpt['optimizer'] is not None:
        optimizer.load_state_dict(chkpt['optimizer'])

    del chkpt

    scheduler = lr_scheduler.MultiStepLR(
        optimizer, milestones=[epochs // 3, 2 * (epochs // 3)], gamma=0.1)
    scheduler.last_epoch = start_epoch - 1

    if device.type != 'cpu' and torch.cuda.device_count() > 1:
        pruned_model = nn.parallel.DistributedDataParallel(
            pruned_model, find_unused_parameters=True)
        pruned_model.yolo_layers = pruned_model.module.yolo_layers

    # -------------start train-------------
    nb = len(train_loader)
    pruned_model.nc = 80
    pruned_model.hyp = hyp
    pruned_model.arc = 'default'
    for epoch in range(start_epoch, epochs):

        # -------------register hook for model-------------
        for name, child in pruned_model.module.module_list.named_children():
            if name == aux_in_layer:
                handle = child.register_forward_hook(
                    hook_util.hook_prune_output)

        # -------------register hook for model-------------

        pruned_model.train()
        aux_model.train()

        print(('\n' + '%10s' * 7) %
              ('Stage', 'Epoch', 'gpu_mem', 'DIoU', 'obj', 'cls', 'total'))

        # -------------start batch-------------
        mloss = torch.zeros(4).to(device)
        pbar = tqdm(enumerate(train_loader), total=nb)
        for i, (img, targets, _, _) in pbar:
            if len(targets) == 0:
                continue

            ni = nb * epoch + i
            img = img.to(device).float() / 255.0
            targets = targets.to(device)

            pruned_pred = pruned_model(img)
            pruned_loss, pruned_loss_items = compute_loss(
                pruned_pred, targets, pruned_model)
            pruned_loss *= batch_size / 64

            hook_util.cat_to_gpu0()

            aux_pred = aux_model(hook_util.prune_features['gpu0'][0], targets)

            aux_loss = compute_loss_for_DCP(aux_pred, targets)
            aux_loss *= aux_loss_scalar * batch_size / 64

            loss = pruned_loss + aux_loss
            loss.backward()

            hook_util.clean_hook_out()
            if ni % accumulate == 0:
                optimizer.step()
                optimizer.zero_grad()

            pruned_loss_items[2] += aux_loss.item()
            mloss = (mloss * i + pruned_loss_items) / (i + 1)
            mem = torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available(
            ) else 0
            s = ('%10s' * 3 +
                 '%10.3g' * 4) % ('FiTune ' + current_layer, '%g/%g' %
                                  (epoch, epochs - 1), '%.3gG' % mem, *mloss)
            pbar.set_description(s)
        # -------------end batch-------------

        scheduler.step()
        handle.remove()

        results, _ = test.test(prune_cfg,
                               data,
                               batch_size=batch_size * 2,
                               img_size=416,
                               model=pruned_model,
                               conf_thres=0.1,
                               iou_thres=0.5,
                               save_json=False,
                               dataloader=test_loader)
        """
        chkpt = {'current_layer':
                 'epoch':
                 'model': 
                 'optimizer': 
                 'aux_in12': 
                 'aux_in37':
                 'aux_in62':
                 'aux_in75':
                 'prune_guide':}
        """
        chkpt = torch.load(progress_chkpt, map_location=device)
        chkpt['current_layer'] = current_layer
        chkpt['epoch'] = epoch
        chkpt['model'] = pruned_model.module.state_dict() if type(
            pruned_model
        ) is nn.parallel.DistributedDataParallel else pruned_model.state_dict(
        )
        chkpt[
            'optimizer'] = None if epoch == epochs - 1 else optimizer.state_dict(
            )
        chkpt['aux_in{}'.format(aux_in_layer)] = aux_model.state_dict()

        torch.save(chkpt, progress_chkpt)

        torch.save(chkpt, last)

        if epoch == epochs - 1:
            torch.save(chkpt,
                       '../weights/DCP/backup{}.pt'.format(current_layer))

        del chkpt

        with open(progress_result, 'a') as f:
            f.write(('%10s' * 2 + '%10.3g' * 8) %
                    ('FiTune ' + current_layer, '%g/%g' %
                     (epoch, epochs - 1), *mloss, *results[:4]) + '\n')
    # -------------end train-------------
    torch.cuda.empty_cache()
    return current_layer
예제 #4
0
def channels_select(prune_cfg, data, origin_model, aux_util, device,
                    data_loader, select_layer, pruned_rate):
    with open(progress_result, 'a') as f:
        f.write(('\n' + '%10s' * 9 + '\n') %
                ('Stage', 'Change', 'MSELoss', 'AuxLoss', 'Total', 'P', 'R',
                 '[email protected]', 'F1'))
    logger.info(('%10s' * 6) %
                ('Stage', 'Channels', 'Batch', 'MSELoss', 'AuxLoss', 'Total'))

    batch_size = data_loader.batch_size
    img_size = data_loader.dataset.img_size
    accumulate = 64 // batch_size
    hook_util = HookUtils()
    handles = []
    n_iter = math.floor(500 / batch_size)

    pruning_model = Darknet(prune_cfg,
                            img_size=(img_size, img_size)).to(device)
    chkpt = torch.load(progress_chkpt, map_location=device)
    pruning_model.load_state_dict(chkpt['model'], strict=True)

    aux_in_layer = aux_util.conv_layer_dict[select_layer]
    aux_model = aux_util.creat_aux_model(aux_in_layer)
    aux_model.to(device)

    aux_model.load_state_dict(chkpt['aux_in{}'.format(aux_in_layer)],
                              strict=True)
    aux_loss_scalar = max(0.01, pow((int(aux_in_layer) + 1) / 75, 2))

    del chkpt

    solve_sub_problem_optimizer = optim.SGD(
        pruning_model.module_list[int(aux_in_layer)].MaskConv2d.parameters(),
        lr=hyp['lr0'],
        momentum=hyp['momentum'])

    for name, child in origin_model.module_list.named_children():
        if name == aux_in_layer:
            handles.append(
                child.register_forward_hook(hook_util.hook_origin_output))
        if name == select_layer:
            handles.append(
                child.register_forward_hook(hook_util.hook_origin_output))

    for name, child in pruning_model.module_list.named_children():
        if name == aux_in_layer:
            handles.append(
                child.register_forward_hook(hook_util.hook_prune_output))
        if name == select_layer:
            handles.append(
                child.register_forward_hook(hook_util.hook_prune_output))

    if device.type != 'cpu' and torch.cuda.device_count() > 1:
        origin_model = torch.nn.parallel.DistributedDataParallel(
            origin_model, find_unused_parameters=True)
        origin_model.yolo_layers = origin_model.module.yolo_layers
        pruning_model = torch.nn.parallel.DistributedDataParallel(
            pruning_model, find_unused_parameters=True)
        pruning_model.yolo_layers = pruning_model.module.yolo_layers

    retain_channels_num = math.floor(
        aux_util.layer_info[select_layer]["in_channels"] * (1 - pruned_rate))
    pruning_model.nc = 80
    pruning_model.hyp = hyp
    pruning_model.arc = 'default'
    pruning_model.eval()
    aux_model.eval()
    MSE = nn.MSELoss(reduction='mean')
    mloss = torch.zeros(3).to(device)

    for i_k in range(retain_channels_num):

        data_iter = iter(data_loader)
        pbar = tqdm(range(n_iter), total=n_iter)
        print(('\n' + '%10s' * 6) %
              ('Stage', 'gpu_mem', 'channels', 'MSELoss', 'AuxLoss', 'Total'))
        for i in pbar:

            imgs, targets, _, _ = data_iter.next()

            if len(targets) == 0:
                continue

            imgs = imgs.to(device).float(
            ) / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
            targets = targets.to(device)

            with torch.no_grad():
                _ = origin_model(imgs)

            _, pruning_pred = pruning_model(imgs)
            pruning_loss, _ = compute_loss(pruning_pred, targets,
                                           pruning_model)

            hook_util.cat_to_gpu0()
            mse_loss = torch.zeros(1, device=device)

            aux_pred = aux_model(hook_util.prune_features['gpu0'][1], targets)
            aux_loss = compute_loss_for_DCP(aux_pred, targets)
            mse_loss += MSE(hook_util.prune_features['gpu0'][0],
                            hook_util.origin_features['gpu0'][0])

            loss = hyp['joint_loss'] * mse_loss + aux_loss + 0 * pruning_loss

            loss.backward()

            mem = torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available(
            ) else 0
            s = ('%10s' * 3 + '%10.3g' * 3) % (
                'Prune ' + select_layer, '%.3gG' % mem, '%g/%g' %
                (i_k, retain_channels_num), hyp['joint_loss'] * mse_loss,
                aux_loss, loss)
            pbar.set_description(s)

            # if (i + 1) % 10 == 0:
            #     logger.info(('%10s' * 3 + '%10.3g' * 3) %
            #                 ('Prune' + select_layer, str(i_k), '%g/%g' % (i, n_iter), hyp['joint_loss'] * mse_loss,
            #                  aux_loss, loss))

            hook_util.clean_hook_out()

        grad = pruning_model.module.module_list[int(
            select_layer)].MaskConv2d.weight.grad.detach()**2
        grad = grad.sum((2, 3)).sqrt().sum(0)

        if i_k == 0:
            pruning_model.module.module_list[int(
                select_layer)].MaskConv2d.selected_channels_mask[:] = 1e-5
            if select_layer in aux_util.sync_guide.keys():
                sync_layer = aux_util.sync_guide[select_layer]
                pruning_model.module.module_list[int(
                    sync_layer)].MaskConv2d.selected_channels_mask[(
                        -1 * aux_util.layer_info[select_layer]["in_channels"]
                    ):] = 1e-5

        selected_channels_mask = pruning_model.module.module_list[int(
            select_layer)].MaskConv2d.selected_channels_mask
        _, indices = torch.topk(grad * (1 - selected_channels_mask), 1)

        pruning_model.module.module_list[int(
            select_layer)].MaskConv2d.selected_channels_mask[indices] = 1
        if select_layer in aux_util.sync_guide.keys():
            pruning_model.module.module_list[int(
                sync_layer)].MaskConv2d.selected_channels_mask[-(
                    aux_util.layer_info[select_layer]["in_channels"] -
                    indices)] = 1

        pruning_model.zero_grad()

        pbar = tqdm(range(n_iter), total=n_iter)
        print(('\n' + '%10s' * 6) %
              ('Stage', 'gpu_mem', 'channels', 'MSELoss', 'AuxLoss', 'Total'))
        for i in pbar:

            imgs, targets, _, _ = data_iter.next()

            if len(targets) == 0:
                continue

            imgs = imgs.to(device).float(
            ) / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
            targets = targets.to(device)

            with torch.no_grad():
                _ = origin_model(imgs)

            _, pruning_pred = pruning_model(imgs)
            pruning_loss, _ = compute_loss(pruning_pred, targets,
                                           pruning_model)

            hook_util.cat_to_gpu0()
            mse_loss = torch.zeros(1, device=device)

            aux_pred = aux_model(hook_util.prune_features['gpu0'][1], targets)
            aux_loss = compute_loss_for_DCP(aux_pred, targets)
            mse_loss += MSE(hook_util.prune_features['gpu0'][0],
                            hook_util.origin_features['gpu0'][0])

            loss = hyp[
                'joint_loss'] * mse_loss + aux_loss_scalar * aux_loss + 0 * pruning_loss

            loss.backward()

            if i % accumulate == 0:
                solve_sub_problem_optimizer.step()
                solve_sub_problem_optimizer.zero_grad()

            mem = torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available(
            ) else 0
            mloss = (mloss * i +
                     torch.cat([hyp['joint_loss'] * mse_loss, aux_loss, loss
                                ]).detach()) / (i + 1)
            s = ('%10s' * 3 + '%10.3g' * 3) % (
                'SubProm ' + select_layer, '%.3gG' % mem, '%g/%g' %
                (i_k, retain_channels_num), *mloss)
            pbar.set_description(s)

            if (i + 1) % n_iter == 0:
                logger.info(('%10s' * 3 + '%10.3g' * 3) %
                            ('SubPro' + select_layer, str(i_k), '%g/%g' %
                             (i, n_iter), *mloss))

            hook_util.clean_hook_out()

    for handle in handles:
        handle.remove()

    greedy_indices = pruning_model.module.module_list[int(
        select_layer)].MaskConv2d.selected_channels_mask < 1
    pruning_model.module.module_list[int(
        select_layer)].MaskConv2d.selected_channels_mask[greedy_indices] = 0

    res, _ = test.test(prune_cfg,
                       data,
                       batch_size=batch_size * 2,
                       img_size=416,
                       model=pruning_model,
                       conf_thres=0.1,
                       iou_thres=0.5,
                       save_json=False,
                       dataloader=None)

    chkpt = torch.load(progress_chkpt, map_location=device)
    chkpt['current_layer'] = aux_util.next_prune_layer(select_layer)
    chkpt['epoch'] = -1
    chkpt['model'] = pruning_model.module.state_dict() if type(
        pruning_model
    ) is nn.parallel.DistributedDataParallel else pruning_model.state_dict()
    chkpt['optimizer'] = None

    torch.save(chkpt, progress_chkpt)

    torch.save(chkpt, last)
    del chkpt

    with open(progress_result, 'a') as f:
        f.write(('%10s' * 2 + '%10.3g' * 7) %
                ('Pruning ' + select_layer,
                 str(aux_util.layer_info[select_layer]['in_channels']) + '->' +
                 str(retain_channels_num), *mloss, *res[:4]) + '\n')

    torch.cuda.empty_cache()
예제 #5
0
def train_aux_for_LCP(cfg, backbone, neck, data_loader, weights, aux_weight,
                      hyp, device, resume, epochs):
    init_seeds()
    batch_size = data_loader.batch_size
    accumulate = 64 // batch_size

    model = Darknet(cfg).to(device)
    model_chkpt = torch.load(weights, map_location=device)
    model.load_state_dict(model_chkpt['model'], strict=True)
    del model_chkpt
    aux_util = AuxNetUtils(model, hyp, backbone, neck, strategy="LCP")
    hook_util = HookUtils()

    start_epoch = 0

    aux_model_list = []
    pg = []
    for layer in aux_util.aux_in_layer:
        aux_model = aux_util.creat_aux_model(layer)
        aux_model.to(device)
        for v in aux_model.parameters():
            pg += [v]
        aux_model_list.append(aux_model)

    optimizer = optim.SGD(pg,
                          lr=hyp['lr0'],
                          momentum=hyp['momentum'],
                          nesterov=True)
    del pg

    if resume:
        chkpt = torch.load(aux_weight, map_location=device)

        for i, layer in enumerate(aux_util.aux_in_layer):
            aux_model_list[i].load_state_dict(chkpt['aux_in{}'.format(layer)],
                                              strict=True)

        if chkpt['optimizer'] is not None:
            optimizer.load_state_dict(chkpt['optimizer'])

        start_epoch = chkpt['epoch'] + 1

    scheduler = lr_scheduler.MultiStepLR(
        optimizer, milestones=[epochs // 3, 2 * epochs // 3], gamma=0.1)
    scheduler.last_epoch = start_epoch - 1

    handles = []  # 结束训练后handle需要回收
    for name, child in model.module_list.named_children():
        if name in aux_util.aux_in_layer:
            handles.append(
                child.register_forward_hook(hook_util.hook_origin_output))

    if device.type != 'cpu' and torch.cuda.device_count() > 1:
        model = nn.parallel.DistributedDataParallel(
            model, find_unused_parameters=True)
        model.yolo_layers = model.module.yolo_layers

    nb = len(data_loader)
    model.nc = 80
    model.hyp = hyp
    model.arc = 'default'
    print('Starting training for %g epochs...' % epochs)
    for epoch in range(start_epoch, epochs):

        for aux_model in aux_model_list:
            aux_model.train()
        print(('\n' + '%10s' * 6) %
              ('Stage', 'Epoch', 'gpu_mem', 'AuxID', 'cls', 'targets'))

        # -----------------start batch-----------------
        pbar = tqdm(enumerate(data_loader), total=nb)
        model.train()
        for i, (imgs, targets, _, _) in pbar:

            if len(targets) == 0:
                continue

            ni = i + nb * epoch
            imgs = imgs.to(device).float(
            ) / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
            targets = targets.to(device)

            with torch.no_grad():
                prediction = model(imgs)

            hook_util.cat_to_gpu0()
            for aux_idx, aux_model in enumerate(aux_model_list):
                pred, loc_loss = aux_model(
                    hook_util.origin_features['gpu0'][aux_idx], targets,
                    prediction)
                loss = compute_loss_for_LCP(pred, loc_loss, targets)

                loss *= batch_size / 64

                loss.backward()

                mem = torch.cuda.memory_cached(
                ) / 1E9 if torch.cuda.is_available() else 0  # (GB)
                s = ('%10s' * 3 + '%10.3g' * 3) % ('Train Aux', '%g/%g' %
                                                   (epoch, epochs - 1),
                                                   '%.3gG' % mem, aux_idx,
                                                   loss, len(targets))
                pbar.set_description(s)
            # 每个batch后要把hook_out内容清除
            hook_util.clean_hook_out()
            if ni % accumulate == 0:
                optimizer.step()
                optimizer.zero_grad()
        # -----------------end batches-----------------

        scheduler.step()
        final_epoch = epoch + 1 == epochs
        chkpt = {
            'epoch': epoch,
            'optimizer': None if final_epoch else optimizer.state_dict()
        }
        for i, layer in enumerate(aux_util.aux_in_layer):
            chkpt['aux_in{}'.format(layer)] = aux_model_list[i].state_dict()

        torch.save(chkpt, aux_weight)

        torch.save(chkpt, "../weights/LCP/aux-coco.pt")
        del chkpt

        with open("./LCP/aux_result.txt", 'a') as f:
            f.write(s + '\n')

    # 最后要把hook全部删除
    for handle in handles:
        handle.remove()
    torch.cuda.empty_cache()
예제 #6
0
파일: fai.py 프로젝트: tomassa/yolov3
# %%

# Combine sample sets back together
samples = positive_samp  # + negative_samp

# %%

# Load the model
img_size = (352, 608)
device = 'cuda:0'
arc = 'default'
cfg = 'cfg/yolov3-tiny-anchors.cfg'
weights = 'weights/best.pt'
device = torch_utils.select_device(device, apex=False, batch_size=64)
model = Darknet(cfg, img_size=img_size, arc=arc).to(device)
model.arc = 'default'
model.nc = 1  # num classes
model.hyp = hyp
d = torch.load(weights, map_location=device)
m = d['model']
model.load_state_dict(m)

# Build the paths and pass them to the FastAI ObjectItemList
posix_paths = json_to_paths(samples)
lst = ObjectItemList(posix_paths, label_cls=YoloCategoryList)
YoloCategoryList.anchors = [
    model.module_list[l].anchors for l in model.yolo_layers
]
YoloCategoryList.img_size = img_size

# %%