def main(opt):
    torch.manual_seed(opt.seed)
    torch.backends.cudnn.benchmark = not opt.not_cuda_benchmark and not opt.test
    Dataset = get_dataset(opt.dataset, opt.task)
    opt = opts().update_dataset_info_and_set_heads(opt, Dataset)

    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str
    opt.device = torch.device('cuda' if opt.gpus[0] >= 0 else 'cpu')
    print('Creating model...')
    model = create_model('res_101', opt.heads, opt.head_conv)

    # print(model)

    optimizer = torch.optim.Adam(model.parameters(), opt.lr)
    model, optimizer, start_epoch = load_model(model,
                                               train_model,
                                               optimizer=optimizer,
                                               resume=opt.resume,
                                               lr=opt.lr,
                                               lr_step=opt.lr_step)
    from utils.prune_util import obtain_filters_mask_l1_dict, gather_l1_weights_dict, obtain_filters_mask_l1_dict_percent
    # l1_weights = gather_l1_weights_dict(model)
    # sorted_l1, sorted_l1_index = torch.sort(l1_weights)
    # thresh_index_l1 = int(len(l1_weights)*global_percent)
    # thresh_l1 = sorted_l1[thresh_index_l1].cuda()
    # print("global percent l1_weights is", thresh_l1)
    # print("len l1 is", len(l1_weights))
    # num_filters, filters_mask, prune_idx, prune_bn_bias_idx= obtain_filters_mask_l1_dict(model, thresh_l1, keep_layer)

    # print(model)
    num_filters, filters_mask, prune_idx, prune_bn_bias_idx = obtain_filters_mask_l1_dict_percent(
        model, percent)
    print("num_filters is", num_filters)
    print("prune idx is", prune_idx)
    # print("filters_mask is", filters_mask)
    print("prune bn and bias idx is", prune_bn_bias_idx)

    prune_weight_dict = {}
    idx_name = {}
    idx_bias_name = {}
    id_mask = 0
    from utils.prune_util import del_current_weights

    ############################# del channels and save weights and save bias and save bn params ######################
    for idx, name in enumerate(model.state_dict()):
        prune_weight_dict[name] = model.state_dict()[name]
        if idx in prune_idx:
            idx_name[idx] = name
            del_weigh_mask = filters_mask[id_mask]
            id_mask += 1
            del_weigh_mask_numpy = del_weigh_mask.cpu().numpy()
            weights_numpy = model.state_dict()[name].cpu().numpy()
            channel_indices = set(np.arange(len(del_weigh_mask_numpy)))
            non_zero_id = set(list(np.nonzero(del_weigh_mask_numpy)[0]))
            current_zero_id = list(channel_indices.difference(non_zero_id))
            new_weights = np.delete(weights_numpy, current_zero_id, axis=0)
            prune_weight_dict[idx_name[idx]] = torch.from_numpy(new_weights)
            # print("finished del channels name is", name)
            # print("orignal shape is             ", model.state_dict()[name].shape)
            # print("last    shape is             ", prune_weight_dict[name].shape)
        elif idx in prune_bn_bias_idx:
            idx_bias_name[idx] = name

    print("prune idx is ", prune_idx)
    ################################ del input weights and del bias anddel  bn params #################################
    id_mask = 0
    prune_id = 0
    deconv_zero_id = []
    for i, (layer_name, value) in enumerate(prune_weight_dict.items()):
        if i in prune_idx:
            current_idx = i
            try:
                next_idx = prune_idx[prune_idx.index(current_idx) + 1]
            except:
                print("current idx is", current_idx)
                print("len prune idx is", len(prune_idx))
                continue
            if "downsample" in idx_name[next_idx]:
                id_mask -= 3

            del_weigh_mask = filters_mask[id_mask]
            del_weigh_mask_numpy = del_weigh_mask.cpu().numpy()
            channel_indices = set(np.arange(len(del_weigh_mask_numpy)))
            non_zero_id = set(list(np.nonzero(del_weigh_mask_numpy)[0]))
            current_zero_id = list(channel_indices.difference(non_zero_id))
            # print("next layer name is", idx_name[next_idx])
            # print("ori layer name shape is ", prune_weight_dict[idx_name[next_idx]].shape)
            prune_weight_dict[idx_name[next_idx]] = torch.from_numpy(
                np.delete(prune_weight_dict[idx_name[next_idx]].numpy(),
                          current_zero_id,
                          axis=1))
            # print("last layer name shape is ", prune_weight_dict[idx_name[next_idx]].shape)
            if "downsample" in idx_name[next_idx]:
                id_mask += 3
            id_mask += 1

            if "downsample" in idx_name[next_idx]:
                print(" idx_name[next_idx] is", idx_name[next_idx])
                print("current_zero_id is", current_zero_id)
                # break
            prune_id += 1
        elif i in prune_bn_bias_idx:
            current_bias_idx = i
            if id_mask == 0:
                continue
            current_bias_mask_idx = id_mask - 1
            del_weigh_mask = filters_mask[current_bias_mask_idx]
            ############# because del channels last layer don not use id_mask+1, so in herem last layer must use current_bias_mask_idx= current_bias_mask_idx+1
            if "layer4.2.bn3" in idx_bias_name[current_bias_idx]:
                # print('idx_bias_name[current_bias_idx] is', idx_bias_name[current_bias_idx])
                del_weigh_mask = filters_mask[current_bias_mask_idx + 1]
                print("ori del mask is", del_weigh_mask)
                del_weigh_mask = filters_mask[-1]
                print("last del mask is", del_weigh_mask)
            del_weigh_mask_numpy = del_weigh_mask.cpu().numpy()
            # print("current name is ", layer_name)
            # print("current prune_idx_bias name is ", idx_bias_name[current_bias_idx])
            # print("orignal bias layer shape is", prune_weight_dict[idx_bias_name[current_bias_idx]].shape)
            channel_indices = set(np.arange(len(del_weigh_mask_numpy)))
            non_zero_id = set(list(np.nonzero(del_weigh_mask_numpy)[0]))
            current_zero_id = list(channel_indices.difference(non_zero_id))
            prune_weight_dict[
                idx_bias_name[current_bias_idx]] = torch.from_numpy(
                    np.delete(prune_weight_dict[
                        idx_bias_name[current_bias_idx]].numpy(),
                              current_zero_id,
                              axis=0))
            if 'layer4.2.bn3.running_var' == idx_bias_name[current_bias_idx]:
                print("before deconv_layers.0.weight shape is",
                      prune_weight_dict["deconv_layers.0.weight"].shape)
                print("current zero id is", current_zero_id)
                prune_weight_dict["deconv_layers.0.weight"] = torch.from_numpy(
                    np.delete(
                        prune_weight_dict["deconv_layers.0.weight"].numpy(),
                        current_zero_id,
                        axis=0))
                print("last deconv_layers.0.weight shape is",
                      prune_weight_dict["deconv_layers.0.weight"].shape)
                print("finished del deconv input")
    torch.save(prune_weight_dict, 'exp/ctdet/default/l1_prune_model.pt')
    prune_model = create_model_101_prune('resPrune_101',
                                         opt.heads,
                                         opt.head_conv,
                                         percent=percent)
    prune_model.load_state_dict(
        torch.load('exp/ctdet/default/l1_prune_model.pt'), strict=False)
示例#2
0
from utils.debugger import Debugger
import cv2
import os

# num_classes = 3
num_classes = 5
pause = True
# vis_thresh = 0.01
vis_thresh = 0.3
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '7'
# MODEL_PATH ='/home/pcl/pytorch_work/my_github/centernet_simple/weights/model_best_dla34.pth'
MODEL_PATH ='/home/pcl/pytorch_work/my_github/centernet_simple/exp/ctdet/coco_res_prune/model_best_map_0.48.pth'
TASK = 'ctdet' # or 'multi_pose' for human pose estimation
# opt = opts().init('{} --load_model {} --flip_test'.format(TASK, MODEL_PATH).split(' '))
opt = opts().init('{} --load_model {}'.format(TASK, MODEL_PATH).split(' '))
print(opt)
detector = detector_factory[opt.task](opt)
img_dir = '/home/pcl/pytorch_work/CenterNet-master/dianli_images/'

def show_results(debugger, image, results):
    debugger.add_img(image, img_id='ctdet')
    for j in range(1, num_classes + 1):
        for bbox in results[j]:
            if bbox[4] > vis_thresh:
                debugger.add_coco_bbox(bbox[:4], j - 1, bbox[4], img_id='ctdet')
    debugger.show_all_imgs(pause=pause)

def inference(img_dir):
    results_imgs = []
    for img in os.listdir(img_dir):
    # def prune_and_eval(model, CBL_idx, CBLidx2mask):
    #     import copy
    #     model_copy = copy.deepcopy(model)
    #     for idx, name in enumerate(model_copy.state_dict()):
    #         if idx in prune_idx:
    #             bn_module = model_copy.state_dict()[name]
    #             mask = CBLidx2mask[idx].cuda()
    #             for i in range(mask.shape[0]):
    #                 if mask[i].cpu().numpy() == 0:
    #                     bn_module[i,:,:,:].zero_()
    #     val_loader = torch.utils.data.DataLoader(Dataset(opt, 'val'), batch_size=1, shuffle=False, num_workers=1,
    #                                              pin_memory=True)
    #     trainer = CtTrainer(opt, model_copy, optimizer)
    #     trainer.set_device(opt.gpus, opt.chunk_sizes, opt.device)
    #     with torch.no_grad():
    #         log_dict_val, preds = trainer.val(0, val_loader)
    #         val_loader.dataset.run_eval(preds, opt.save_dir)
    #     # print(f'mask the gamma as zero, mAP of the model is {mAP:.4f}')
    #
    # prune_and_eval(model, CBL_idx, CBLidx2mask)
    #
    # for i in CBLidx2mask:
    #     CBLidx2mask[i] = CBLidx2mask[i].clone().cpu().numpy()

    # pruned_model = prune_model_keep_size2(model, prune_idx, CBL_idx, CBLidx2mask)


if __name__ == '__main__':
    opt = opts().parse()
    main(opt)
    return:
        pred_content: 2d list.
    '''
    image_id = id_list[number]
    inference_img_name = name_list[number]
    pred_content = []
    results = inference_img(inference_img_name)
    for j in range(1, cfg.class_num + 1):
        for bbox in results[j]:
            if bbox[4] > cfg.score_th:
                x_min, y_min, x_max, y_max = bbox[0], bbox[1], bbox[2], bbox[3]
                score = bbox[4]
                label = j - 1
                pred_content.append(
                    [image_id, x_min, y_min, x_max, y_max, score, label])
    print("obj is", pred_content)
    # debugger.add_coco_bbox(bbox[:4], j - 1, bbox[4], img_id='ctdet')
    return pred_content


if __name__ == "__main__":
    MODEL_PATH = '/home/pcl/tf_work/map/weights/model_best_dla34.pth'
    import os
    opt = opts().init('{} --load_model {}'.format("ctdet",
                                                  MODEL_PATH).split(' '))
    os.environ['CUDA_VISIBLE_DEVICES'] = '7'
    TASK = 'ctdet'  # or 'multi_pose' for human pose estimation
    detector_test = detector_factory[TASK](opt)
    img_dir = '/home/pcl/pytorch_work/CenterNet-master/dianli_images/'
    inference(img_dir=img_dir, model=detector_test)
def main(opt):
    mkdir_pth_dir(opt.save_dir)
    torch.manual_seed(opt.seed)
    torch.backends.cudnn.benchmark = not opt.not_cuda_benchmark and not opt.test
    Dataset = get_dataset(opt.dataset, opt.task)
    opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str
    opt.device = torch.device('cuda' if opt.gpus[0] >= 0 else 'cpu')
    print('Creating model...')
    model = create_model('resPrune_101',
                         opt.heads,
                         opt.head_conv,
                         percent_rate=percent_rate,
                         prune_cnt=ori_prune_cnt)
    # model = create_model('res_101', opt.heads, opt.head_conv)
    optimizer = torch.optim.Adam(model.parameters(), opt.lr)
    model, optimizer, start_epoch = load_model(model,
                                               train_model,
                                               optimizer=optimizer,
                                               resume=opt.resume,
                                               lr=opt.lr,
                                               lr_step=opt.lr_step)
    from utils.prune_util import obtain_filters_mask_l1_dict, gather_l1_weights_dict, obtain_filters_mask_l1_dict_percent
    num_filters, filters_mask, prune_idx, prune_bn_bias_idx = obtain_filters_mask_l1_dict_percent(
        model, percent_rate)
    print("num_filters is", num_filters)
    print("prune idx is", prune_idx)
    # print("filters_mask is", filters_mask)
    print("prune bn and bias idx is", prune_bn_bias_idx)

    prune_weight_dict = {}
    idx_name = {}
    idx_bias_name = {}
    id_mask = 0
    from utils.prune_util import del_current_weights
    ############################# del channels and save weights and save bias and save bn params ######################
    for idx, name in enumerate(model.state_dict()):
        prune_weight_dict[name] = model.state_dict()[name]
        if idx in prune_idx:
            idx_name[idx] = name
            del_weigh_mask = filters_mask[id_mask]
            id_mask += 1
            del_weigh_mask_numpy = del_weigh_mask.cpu().numpy()
            weights_numpy = model.state_dict()[name].cpu().numpy()
            channel_indices = set(np.arange(len(del_weigh_mask_numpy)))
            non_zero_id = set(list(np.nonzero(del_weigh_mask_numpy)[0]))
            current_zero_id = list(channel_indices.difference(non_zero_id))
            new_weights = np.delete(weights_numpy, current_zero_id, axis=0)
            prune_weight_dict[idx_name[idx]] = torch.from_numpy(new_weights)
            # print("finished del channels name is", name)
            # print("orignal shape is             ", model.state_dict()[name].shape)
            # print("last    shape is             ", prune_weight_dict[name].shape)
        elif idx in prune_bn_bias_idx:
            idx_bias_name[idx] = name

    print("prune idx is ", prune_idx)
    ################################ del input weights and del bias anddel  bn params #################################
    id_mask = 0
    prune_id = 0
    deconv_zero_id = []
    for i, (layer_name, value) in enumerate(prune_weight_dict.items()):
        if i in prune_idx:
            current_idx = i
            try:
                next_idx = prune_idx[prune_idx.index(current_idx) + 1]
            except:
                print("current idx is", current_idx)
                print("len prune idx is", len(prune_idx))
                continue
            if "downsample" in idx_name[next_idx]:
                id_mask -= 3

            del_weigh_mask = filters_mask[id_mask]
            del_weigh_mask_numpy = del_weigh_mask.cpu().numpy()
            channel_indices = set(np.arange(len(del_weigh_mask_numpy)))
            non_zero_id = set(list(np.nonzero(del_weigh_mask_numpy)[0]))
            current_zero_id = list(channel_indices.difference(non_zero_id))
            # print("next layer name is", idx_name[next_idx])
            # print("ori layer name shape is ", prune_weight_dict[idx_name[next_idx]].shape)
            prune_weight_dict[idx_name[next_idx]] = torch.from_numpy(
                np.delete(prune_weight_dict[idx_name[next_idx]].numpy(),
                          current_zero_id,
                          axis=1))
            # print("last layer name shape is ", prune_weight_dict[idx_name[next_idx]].shape)
            if "downsample" in idx_name[next_idx]:
                id_mask += 3
            id_mask += 1

            if "downsample" in idx_name[next_idx]:
                print(" idx_name[next_idx] is", idx_name[next_idx])
                print("current_zero_id is", current_zero_id)
                # break
            prune_id += 1
        elif i in prune_bn_bias_idx:
            current_bias_idx = i
            if id_mask == 0:
                continue
            current_bias_mask_idx = id_mask - 1
            del_weigh_mask = filters_mask[current_bias_mask_idx]
            ############# because del channels last layer don not use id_mask+1, so in herem last layer must use current_bias_mask_idx= current_bias_mask_idx+1
            if "layer4.2.bn3" in idx_bias_name[current_bias_idx]:
                # print('idx_bias_name[current_bias_idx] is', idx_bias_name[current_bias_idx])
                del_weigh_mask = filters_mask[current_bias_mask_idx + 1]
                print("ori del mask is", del_weigh_mask)
                del_weigh_mask = filters_mask[-1]
                print("last del mask is", del_weigh_mask)
            del_weigh_mask_numpy = del_weigh_mask.cpu().numpy()
            # print("current name is ", layer_name)
            # print("current prune_idx_bias name is ", idx_bias_name[current_bias_idx])
            # print("orignal bias layer shape is", prune_weight_dict[idx_bias_name[current_bias_idx]].shape)
            channel_indices = set(np.arange(len(del_weigh_mask_numpy)))
            non_zero_id = set(list(np.nonzero(del_weigh_mask_numpy)[0]))
            current_zero_id = list(channel_indices.difference(non_zero_id))
            prune_weight_dict[
                idx_bias_name[current_bias_idx]] = torch.from_numpy(
                    np.delete(prune_weight_dict[
                        idx_bias_name[current_bias_idx]].numpy(),
                              current_zero_id,
                              axis=0))
            if 'layer4.2.bn3.running_var' == idx_bias_name[
                    current_bias_idx]:  #in here, must del input in first deconv , mask use layer4.bn3.runing_var mask---no_zero_id
                print("before deconv_layers.0.weight shape is",
                      prune_weight_dict["deconv_layers.0.weight"].shape)
                print("current zero id is", current_zero_id)
                prune_weight_dict["deconv_layers.0.weight"] = torch.from_numpy(
                    np.delete(
                        prune_weight_dict["deconv_layers.0.weight"].numpy(),
                        current_zero_id,
                        axis=0))
                print("last deconv_layers.0.weight shape is",
                      prune_weight_dict["deconv_layers.0.weight"].shape)
                print("finished del deconv input")

    torch.save(prune_weight_dict, os.path.join(opt.save_dir, prune_model_name))
    prune_model = create_model_101_prune('resPrune_101',
                                         opt.heads,
                                         opt.head_conv,
                                         percent_rate=percent_rate,
                                         prune_cnt=prune_cnt)
    optimizer = torch.optim.Adam(prune_model.parameters(), opt.lr)
    prune_model.load_state_dict(torch.load(
        os.path.join(opt.save_dir, prune_model_name)),
                                strict=True)
    start_epoch = 0
    train_loader = torch.utils.data.DataLoader(Dataset(opt, 'train'),
                                               batch_size=opt.batch_size,
                                               shuffle=True,
                                               num_workers=opt.num_workers,
                                               pin_memory=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(Dataset(opt, 'val'),
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=1,
                                             pin_memory=True)
    trainer = CtTrainer(opt, prune_model, optimizer)
    trainer.set_device(opt.gpus, opt.chunk_sizes, opt.device)
    best_val_loss = 1e10
    best_ap = 1e-10
    for epoch in range(start_epoch + 1, opt.num_epochs + 1):
        mark = epoch
        log_dict_train, _ = trainer.train(epoch, train_loader)
        if opt.val_intervals > 0 and epoch % opt.val_intervals == 0:
            save_model(os.path.join(opt.save_dir, 'model_{}.pth'.format(mark)),
                       epoch, prune_model, optimizer)
            with torch.no_grad():
                log_dict_val, preds = trainer.val(epoch, val_loader)
                val_loader.dataset.run_eval(preds, opt.save_dir)
                # result_json_pth = '/home/pcl/pytorch_work/my_github/centernet_simple/exp/ctdet/coco_res_prune/results.json'
                result_json_pth = os.path.join(opt.save_dir, "results.json")
                anno_json_pth = '/home/pcl/pytorch_work/my_github/centernet_simple/data/dianli/annotations/test.json'
                ap_list, map = trainer.run_epoch_voc(result_json_pth,
                                                     anno_json_pth,
                                                     score_th=0.01,
                                                     class_num=opt.num_classes)
                print(ap_list, map)
            if log_dict_val[opt.metric] <= best_val_loss:
                best_val_loss = log_dict_val[opt.metric]
                save_model(
                    os.path.join(
                        opt.save_dir, 'model_best_val_loss_' +
                        str(round(best_val_loss, 2)) + '.pth'), epoch,
                    prune_model)
            if map > best_ap:
                best_ap = map
                save_model(
                    os.path.join(
                        opt.save_dir,
                        'model_best_map_' + str(round(best_ap, 3)) + '.pth'),
                    epoch, prune_model)
        else:
            save_model(os.path.join(opt.save_dir, 'model_last.pth'), epoch,
                       prune_model, optimizer)
        if epoch in opt.lr_step:
            save_model(
                os.path.join(opt.save_dir, 'model_{}.pth'.format(epoch)),
                epoch, prune_model, optimizer)
            lr = opt.lr * (0.1**(opt.lr_step.index(epoch) + 1))
            print('Drop LR to', lr)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
示例#6
0
def main(opt):
    torch.manual_seed(opt.seed)
    torch.backends.cudnn.benchmark = not opt.not_cuda_benchmark and not opt.test
    Dataset = get_dataset(opt.dataset, opt.task)
    opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str
    opt.device = torch.device('cuda' if opt.gpus[0] >= 0 else 'cpu')
    print('Creating model...')
    # prune_model = create_model_101_prune('resPrune_101', opt.heads, opt.head_conv, percent=percent)
    prune_model = create_model_101_prune('resPrune_101',
                                         opt.heads,
                                         opt.head_conv,
                                         percent_rate=percent_rate,
                                         prune_cnt=prune_cnt)
    optimizer = torch.optim.Adam(prune_model.parameters(), opt.lr)
    prune_model, optimizer, start_epoch = load_model(prune_model,
                                                     os.path.join(
                                                         opt.save_dir,
                                                         prune_model_name),
                                                     optimizer=optimizer,
                                                     resume=opt.resume,
                                                     lr=opt.lr,
                                                     lr_step=opt.lr_step)
    # prune_model.load_state_dict(torch.load(os.path.join(opt.save_dir, prune_model_name)), strict=True)
    # start_epoch = 0
    train_loader = torch.utils.data.DataLoader(Dataset(opt, 'train'),
                                               batch_size=opt.batch_size,
                                               shuffle=True,
                                               num_workers=opt.num_workers,
                                               pin_memory=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(Dataset(opt, 'val'),
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=1,
                                             pin_memory=True)
    trainer = CtTrainer(opt, prune_model, optimizer)
    trainer.set_device(opt.gpus, opt.chunk_sizes, opt.device)
    best_val_loss = 1e10
    best_ap = 1e-10
    for epoch in range(start_epoch + 1, opt.num_epochs + 1):
        mark = epoch
        log_dict_train, _ = trainer.train(epoch, train_loader)
        if opt.val_intervals > 0 and epoch % opt.val_intervals == 0:
            save_model(os.path.join(opt.save_dir, 'model_{}.pth'.format(mark)),
                       epoch, prune_model, optimizer)
            with torch.no_grad():
                log_dict_val, preds = trainer.val(epoch, val_loader)
                val_loader.dataset.run_eval(preds, opt.save_dir)
                # result_json_pth = '/home/pcl/pytorch_work/my_github/centernet_simple/exp/ctdet/coco_res_prune/results.json'
                result_json_pth = os.path.join(opt.save_dir, "results.json")
                anno_json_pth = '/home/pcl/pytorch_work/my_github/centernet_simple/data/dianli/annotations/test.json'
                ap_list, map = trainer.run_epoch_voc(result_json_pth,
                                                     anno_json_pth,
                                                     score_th=0.01,
                                                     class_num=opt.num_classes)
                print(ap_list, map)
            if log_dict_val[opt.metric] <= best_val_loss:
                best_val_loss = log_dict_val[opt.metric]
                save_model(
                    os.path.join(
                        opt.save_dir, 'model_best_val_loss_' +
                        str(round(best_val_loss, 2)) + '.pth'), epoch,
                    prune_model)
            if map > best_ap:
                best_ap = map
                save_model(
                    os.path.join(
                        opt.save_dir,
                        'model_best_map_' + str(round(best_ap, 3)) + '.pth'),
                    epoch, prune_model)
        else:
            save_model(os.path.join(opt.save_dir, 'model_last.pth'), epoch,
                       prune_model, optimizer)
        if epoch in opt.lr_step:
            save_model(
                os.path.join(opt.save_dir, 'model_{}.pth'.format(epoch)),
                epoch, prune_model, optimizer)
            lr = opt.lr * (0.1**(opt.lr_step.index(epoch) + 1))
            print('Drop LR to', lr)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
示例#7
0
def main(opt):
    torch.manual_seed(opt.seed)
    torch.backends.cudnn.benchmark = not opt.not_cuda_benchmark and not opt.test
    Dataset = get_dataset(opt.dataset, opt.task)
    opt = opts().update_dataset_info_and_set_heads(opt, Dataset)

    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str
    opt.device = torch.device('cuda' if opt.gpus[0] >= 0 else 'cpu')
    logger = Logger(opt)

    print('Creating model...')
    model = create_model(opt.arch, opt.heads, opt.head_conv)
    optimizer = torch.optim.Adam(model.parameters(), opt.lr)
    start_epoch = 0
    if opt.load_model != '':
        model, optimizer, start_epoch = load_model(model, opt.load_model,
                                                   optimizer, opt.resume,
                                                   opt.lr, opt.lr_step)

    train_loader = torch.utils.data.DataLoader(Dataset(opt, 'train'),
                                               batch_size=opt.batch_size,
                                               shuffle=True,
                                               num_workers=opt.num_workers,
                                               pin_memory=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(Dataset(opt, 'val'),
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=1,
                                             pin_memory=True)
    trainer = CtTrainer(opt, model, optimizer)
    trainer.set_device(opt.gpus, opt.chunk_sizes, opt.device)
    TASK = 'ctdet'  # or 'multi_pose' for human pose estimation
    best_val_loss = 1e10
    best_ap = 1e-10
    for epoch in range(start_epoch + 1, opt.num_epochs + 1):
        mark = epoch
        log_dict_train, _ = trainer.train(epoch, train_loader)
        print("log_dict_train is: ", log_dict_train)
        # logger.write('epoch: {} |'.format(epoch))
        # for k, v in log_dict_train.items():
        #     logger.scalar_summary('train_{}'.format(k), v, epoch)
        #     logger.write('{} {:8f} | '.format(k, v))
        if opt.val_intervals > 0 and epoch % opt.val_intervals == 0:
            save_model(os.path.join(opt.save_dir, 'model_{}.pth'.format(mark)),
                       epoch, model, optimizer)
            with torch.no_grad():
                log_dict_val, preds = trainer.val(epoch, val_loader)
                val_loader.dataset.run_eval(preds, opt.save_dir)
                result_json_pth = os.path.join(opt.save_dir, "results.json")
                anno_json_pth = '/home/pcl/pytorch_work/my_github/centernet_simple/data/dianli/annotations/test.json'
                ap_list, map = trainer.run_epoch_voc(result_json_pth,
                                                     anno_json_pth,
                                                     score_th=0.01,
                                                     class_num=opt.num_classes)
                print(ap_list, map)
            # for k, v in log_dict_val.items():
            #     logger.scalar_summary('val_{}'.format(k), v, epoch)
            #     logger.write('{} {:8f} | '.format(k, v))
            if log_dict_val[opt.metric] <= best_val_loss:
                best_val_loss = log_dict_val[opt.metric]
                save_model(
                    os.path.join(
                        opt.save_dir, 'model_best_val_loss_' +
                        str(round(best_val_loss, 2)) + '.pth'), epoch, model)
            if map > best_ap:
                best_ap = map
                save_model(
                    os.path.join(
                        opt.save_dir,
                        'model_best_map_' + str(round(best_ap, 2)) + '.pth'),
                    epoch, model)
        else:
            save_model(os.path.join(opt.save_dir, 'model_last.pth'), epoch,
                       model, optimizer)
        # logger.write('\n')
        if epoch in opt.lr_step:
            save_model(
                os.path.join(opt.save_dir, 'model_{}.pth'.format(epoch)),
                epoch, model, optimizer)
            lr = opt.lr * (0.1**(opt.lr_step.index(epoch) + 1))
            print('Drop LR to', lr)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr