コード例 #1
0
    def build(self, backbone, cfg):
        self.backbone = backbone
        if backbone == "mobilenet":
            from models.mobilenet.MobilePose import createModel
            from config.model_cfg import mobile_opt as model_ls
            self.feature_layer_num, self.feature_layer_name = 155, "features"
        elif backbone == "seresnet101":
            from models.seresnet101.FastPose import createModel
            from config.model_cfg import seresnet_cfg as model_ls
            self.feature_layer_num, self.feature_layer_name = 327, "seresnet101"
        elif backbone == "efficientnet":
            from models.efficientnet.EfficientPose import createModel
            from config.model_cfg import efficientnet_cfg as model_ls
        elif backbone == "shufflenet":
            from models.shufflenet.ShufflePose import createModel
            from config.model_cfg import shufflenet_cfg as model_ls
            self.feature_layer_num, self.feature_layer_name = 167, "shuffle"
        elif backbone == "seresnet18":
            from models.seresnet18.FastPose import createModel
            from config.model_cfg import seresnet18_cfg as model_ls
            self.feature_layer_num, self.feature_layer_name = 75, "seresnet18"
        elif backbone == "seresnet50":
            from models.seresnet50.FastPose import createModel
            from config.model_cfg import seresnet50_cfg as model_ls
            self.feature_layer_num, self.feature_layer_name = 75, "seresnet50"
        else:
            raise ValueError("Your model name is wrong")

        # self.model_cfg = cfg
        try:
            self.model = createModel(cfg)
        except:
            self.model = createModel(model_ls[cfg])
        if self.device != "cpu":
            self.model.cuda()
コード例 #2
0
ファイル: test.py プロジェクト: CheungBH/PoseTrainingPytorch
def main(structure, cfg, data_info, weight, batch=4):

    if structure == "mobilenet":
        from models.mobilenet.MobilePose import createModel
        from config.model_cfg import mobile_opt as model_ls
    elif structure == "seresnet101":
        from models.seresnet.FastPose import createModel
        from config.model_cfg import seresnet_cfg as model_ls
    elif structure == "efficientnet":
        from models.efficientnet.EfficientPose import createModel
        from config.model_cfg import efficientnet_cfg as model_ls
    elif structure == "shufflenet":
        from models.shufflenet.ShufflePose import createModel
        from config.model_cfg import shufflenet_cfg as model_ls
    else:
        raise ValueError("Your model name is wrong")
    model_cfg = model_ls[cfg]
    opt.loadModel = weight

    # Model Initialize
    if device != "cpu":
        m = createModel(cfg=model_cfg).cuda()
    else:
        m = createModel(cfg=model_cfg).cpu()

    m.load_state_dict(torch.load(weight))
    flops = print_model_param_flops(m)
    # print("FLOPs of current model is {}".format(flops))
    params = print_model_param_nums(m)
    # print("Parameters of current model is {}".format(params))
    inf_time = get_inference_time(m,
                                  height=opt.outputResH,
                                  width=opt.outputResW)
    # print("Inference time is {}".format(inf_time))

    # Model Transfer
    if device != "cpu":
        criterion = torch.nn.MSELoss().cuda()
    else:
        criterion = torch.nn.MSELoss()

    test_dataset = TestDataset(data_info, train=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=batch,
                                              num_workers=0,
                                              pin_memory=True)

    loss, acc, dist, auc, pr, pt_acc, pt_dist, pt_auc, pt_pr, thresh = test(
        test_loader, m, criterion)
    return (flops, params, inf_time), (loss, acc, dist, auc,
                                       pr), (pt_acc, pt_dist, pt_auc,
                                             pt_pr), thresh
コード例 #3
0
def detect_sparse(weight, sparse_file, thresh=(50,99), device="cpu"):

    if opt.backbone == "mobilenet":
        from models.mobilenet.MobilePose import createModel
        from config.model_cfg import mobile_opt as model_ls
    elif opt.backbone == "seresnet101":
        from models.seresnet.FastPose import createModel
        from config.model_cfg import seresnet_cfg as model_ls
    elif opt.backbone == "efficientnet":
        from models.efficientnet.EfficientPose import createModel
        from config.model_cfg import efficientnet_cfg as model_ls
    elif opt.backbone == "shufflenet":
        from models.shufflenet.ShufflePose import createModel
        from config.model_cfg import shufflenet_cfg as model_ls
    else:
        raise ValueError("Your model name is wrong")
    model_cfg = model_ls[opt.struct]
    opt.loadModel = weight

    # weights = "test_weight/ceiling_0911_s/to17kps_s5E-7_acc/to17kps_s5E-7_best_acc.pkl"
    if device == "cpu":
        model = createModel(cfg=model_cfg).cpu()
        model.load_state_dict(torch.load(weight, map_location="cpu"))
    else:
        model = createModel(cfg=model_cfg)
        model.load_state_dict(torch.load(weight))

    tmp = "./model.txt"
    print(model, file=open(tmp, 'w'))
    prune_idx = obtain_prune_idx(tmp)
    sorted_bn = sort_bn(model, prune_idx)
    percent_ls = range(thresh[0], thresh[1], 1)
    if not os.path.exists(sparse_file):
        with open(sparse_file, "a+") as f:
            f.write("Model_name,"+",".join(map(lambda x: str(x), range(thresh[0], thresh[1]+1)))+"\n")

    f = open(sparse_file, "a+")
    model_res = weight.split("/")[-2] + "-" + weight.split("/")[-1] + ","
    for percent in percent_ls:
        threshold = obtain_bn_threshold(model, sorted_bn, percent/100)
        print("{}---->{}".format(percent, threshold))
        model_res += str(threshold.tolist())
        model_res += ","
    f.write(model_res + '\n')
コード例 #4
0
ファイル: prune.py プロジェクト: CheungBH/PoseTrainingPytorch
def pruning(weight,
            compact_model_path,
            compact_model_cfg="cfg.txt",
            thresh=80,
            device="cpu"):
    if opt.backbone == "mobilenet":
        from models.mobilenet.MobilePose import createModel
        from config.model_cfg import mobile_opt as model_ls
    elif opt.backbone == "seresnet101":
        from models.seresnet101.FastPose import createModel
        from config.model_cfg import seresnet_cfg as model_ls
    elif opt.backbone == "seresnet18":
        from models.seresnet18.FastPose import createModel
        from config.model_cfg import seresnet_cfg as model_ls
    elif opt.backbone == "efficientnet":
        from models.efficientnet.EfficientPose import createModel
        from config.model_cfg import efficientnet_cfg as model_ls
    elif opt.backbone == "shufflenet":
        from models.shufflenet.ShufflePose import createModel
        from config.model_cfg import shufflenet_cfg as model_ls
    elif opt.backbone == "seresnet50":
        from models.seresnet50.FastPose import createModel
        from config.model_cfg import seresnet50_cfg as model_ls
    else:
        raise ValueError("Your model name is wrong")

    try:
        model_cfg = model_ls[opt.struct]
        # opt.loadModel = weight

        model = createModel(cfg=model_cfg)
    except:
        model = createModel(cfg=opt.struct)

    model.load_state_dict(torch.load(weight))
    if device == "cpu":
        model.cpu()
    else:
        model.cuda()
    # torch_out = torch.onnx.export(model, torch.rand(1, 3, 224, 224), "onnx_pose.onnx", verbose=False,)

    tmp = "./buffer/model.txt"
    print(model, file=open(tmp, 'w'))
    if opt.backbone == "seresnet18":
        all_bn_id, normal_idx, shortcut_idx, downsample_idx, head_idx = obtain_prune_idx2(
            model)
    elif opt.backbone == "seresnet50" or opt.backbone == "seresnet101":
        all_bn_id, normal_idx, shortcut_idx, downsample_idx, head_idx = obtain_prune_idx_50(
            model)
    else:
        raise ValueError("Not a correct name")
    prune_idx = normal_idx + head_idx
    sorted_bn = sort_bn(model, prune_idx)

    threshold = obtain_bn_threshold(model, sorted_bn, thresh / 100)
    pruned_filters, pruned_maskers = obtain_filters_mask(
        model, prune_idx, threshold)
    CBLidx2mask = {
        idx - 1: mask.astype('float32')
        for idx, mask in zip(all_bn_id, pruned_maskers)
    }
    CBLidx2filter = {
        idx - 1: filter_num
        for idx, filter_num in zip(all_bn_id, pruned_filters)
    }

    for head in head_idx:
        adjust_mask(CBLidx2mask, CBLidx2filter, model, head)

    valid_filter = {
        k: v
        for k, v in CBLidx2filter.items() if k + 1 in prune_idx
    }
    channel_str = ",".join(map(lambda x: str(x), valid_filter.values()))
    print(channel_str, file=open(compact_model_cfg, "w"))
    m_cfg = {
        'backbone':
        opt.backbone,
        'keypoints':
        opt.kps,
        'se_ratio':
        opt.se_ratio,
        "first_conv":
        CBLidx2filter[all_bn_id[0] - 1],
        'residual':
        get_residual_channel([filt for _, filt in valid_filter.items()],
                             opt.backbone),
        'channels':
        get_channel_dict([filt for _, filt in valid_filter.items()],
                         opt.backbone),
        "head_type":
        "pixel_shuffle",
        "head_channel": [CBLidx2filter[i - 1] for i in head_idx]
    }
    write_cfg(m_cfg, "buffer/cfg_{}.json".format(opt.backbone))

    compact_model = createModel(cfg=compact_model_cfg).cpu()
    print(compact_model, file=open("buffer/pruned.txt", 'w'))

    if opt.backbone == "seresnet18":
        init_weights_from_loose_model(compact_model, model, CBLidx2mask,
                                      valid_filter, downsample_idx, head_idx)
    elif opt.backbone == "seresnet50" or opt.backbone == "seresnet101":
        init_weights_from_loose_model50(compact_model, model, CBLidx2mask,
                                        valid_filter, downsample_idx, head_idx)
    torch.save(compact_model.state_dict(), compact_model_path)
コード例 #5
0
def main():
    cmd_ls = sys.argv[1:]
    cmd = generate_cmd(cmd_ls)
    if "--freeze_bn False" in cmd:
        opt.freeze_bn = False
    if "--addDPG False" in cmd:
        opt.addDPG = False

    print(
        "----------------------------------------------------------------------------------------------------"
    )
    print("This is the model with id {}".format(save_ID))
    print(opt)
    print("Training backbone is: {}".format(opt.backbone))
    dataset_str = ""
    for k, v in config.train_info.items():
        dataset_str += k
        dataset_str += ","
    print("Training data is: {}".format(dataset_str[:-1]))
    print("Warm up end at {}".format(warm_up_epoch))
    for k, v in config.bad_epochs.items():
        if v > 1:
            raise ValueError("Wrong stopping accuracy!")
    print(
        "----------------------------------------------------------------------------------------------------"
    )

    exp_dir = os.path.join("exp/{}/{}".format(folder, save_ID))
    log_dir = os.path.join(exp_dir, "{}".format(save_ID))
    os.makedirs(log_dir, exist_ok=True)
    log_name = os.path.join(log_dir, "{}.txt".format(save_ID))
    train_log_name = os.path.join(log_dir, "{}_train.xlsx".format(save_ID))
    bn_file = os.path.join(log_dir, "{}_bn.txt".format(save_ID))
    # Prepare Dataset

    # Model Initialize
    if device != "cpu":
        m = createModel(cfg=model_cfg).cuda()
    else:
        m = createModel(cfg=model_cfg).cpu()
    print(m, file=open("model.txt", "w"))

    begin_epoch = 0
    pre_train_model = opt.loadModel
    flops = print_model_param_flops(m)
    print("FLOPs of current model is {}".format(flops))
    params = print_model_param_nums(m)
    print("Parameters of current model is {}".format(params))
    inf_time = get_inference_time(m,
                                  height=opt.outputResH,
                                  width=opt.outputResW)
    print("Inference time is {}".format(inf_time))
    print(
        "----------------------------------------------------------------------------------------------------"
    )

    if opt.freeze > 0 or opt.freeze_bn:
        if opt.backbone == "mobilenet":
            feature_layer_num = 155
            feature_layer_name = "features"
        elif opt.backbone == "seresnet101":
            feature_layer_num = 327
            feature_layer_name = "preact"
        elif opt.backbone == "seresnet18":
            feature_layer_num = 75
            feature_layer_name = "seresnet18"
        elif opt.backbone == "shufflenet":
            feature_layer_num = 167
            feature_layer_name = "shuffle"
        else:
            raise ValueError("Not a correct name")

        feature_num = int(opt.freeze * feature_layer_num)

        for idx, (n, p) in enumerate(m.named_parameters()):
            if len(p.shape) == 1 and opt.freeze_bn:
                p.requires_grad = False
            elif feature_layer_name in n and idx < feature_num:
                p.requires_grad = False
            else:
                p.requires_grad = True

    writer = SummaryWriter('exp/{}/{}'.format(folder, save_ID), comment=cmd)

    if device != "cpu":
        # rnd_inps = Variable(torch.rand(3, 3, 224, 224), requires_grad=True).cuda()
        rnd_inps = torch.rand(3, 3, 224, 224).cuda()
    else:
        rnd_inps = torch.rand(3, 3, 224, 224)
        # rnd_inps = Variable(torch.rand(3, 3, 224, 224), requires_grad=True)
    try:
        writer.add_graph(m, (rnd_inps, ))
    except:
        pass

    shuffle_dataset = False
    for k, v in config.train_info.items():
        if k not in open_source_dataset:
            shuffle_dataset = True

    train_dataset = MyDataset(config.train_info, train=True)
    val_dataset = MyDataset(config.train_info, train=False)
    if shuffle_dataset:
        val_dataset.img_val, val_dataset.bbox_val, val_dataset.part_val = \
            train_dataset.img_val, train_dataset.bbox_val, train_dataset.part_val

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opt.trainBatch,
                                               shuffle=True,
                                               num_workers=opt.trainNW,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=opt.validBatch,
                                             shuffle=True,
                                             num_workers=opt.valNW,
                                             pin_memory=True)

    # for k, v in config.train_info.items():
    #     train_dataset = Mscoco([v[0], v[1]], train=True, val_img_num=v[2])
    #     val_dataset = Mscoco([v[0], v[1]], train=False, val_img_num=v[2])
    #
    # train_loaders[k] = torch.utils.data.DataLoader(
    #     train_dataset, batch_size=config.train_batch, shuffle=True, num_workers=config.train_mum_worker,
    #     pin_memory=True)
    #
    # val_loaders[k] = torch.utils.data.DataLoader(
    #     val_dataset, batch_size=config.val_batch, shuffle=False, num_workers=config.val_num_worker, pin_memory=True)
    #
    # train_loader = torch.utils.data.DataLoader(
    #         train_dataset, batch_size=config.train_batch, shuffle=True, num_workers=config.train_mum_worker,
    #         pin_memory=True)
    # val_loader = torch.utils.data.DataLoader(
    #         val_dataset, batch_size=config.val_batch, shuffle=False, num_workers=config.val_num_worker, pin_memory=True)

    # assert train_loaders != {}, "Your training data has not been specific! "

    os.makedirs("exp/{}/{}".format(folder, save_ID), exist_ok=True)
    if pre_train_model:
        if "duc_se.pth" not in pre_train_model:
            if "pretrain" not in pre_train_model:
                try:
                    info_path = os.path.join("exp", folder, save_ID,
                                             "option.pkl")
                    info = torch.load(info_path)
                    opt.trainIters = info.trainIters
                    opt.valIters = info.valIters
                    begin_epoch = int(pre_train_model.split("_")[-1][:-4]) + 1
                except:
                    # begin_epoch = int(pre_train_model.split("_")[-1][:-4]) + 1
                    with open(log_name, "a+") as f:
                        f.write(cmd)

            print('Loading Model from {}'.format(pre_train_model))
            m.load_state_dict(torch.load(pre_train_model))
        else:
            with open(log_name, "a+") as f:
                f.write(cmd)
            print('Loading Model from {}'.format(pre_train_model))
            m.load_state_dict(torch.load(pre_train_model))
            m.conv_out = nn.Conv2d(m.DIM,
                                   opt.kps,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1)
            if device != "cpu":
                m.conv_out.cuda()
            os.makedirs("exp/{}/{}".format(folder, save_ID), exist_ok=True)
    else:
        print('Create new model')
        with open(log_name, "a+") as f:
            f.write(cmd)
            print(opt, file=f)
            f.write("FLOPs of current model is {}\n".format(flops))
            f.write("Parameters of current model is {}\n".format(params))

    with open(os.path.join(log_dir, "tb.py"), "w") as pyfile:
        pyfile.write("import os\n")
        pyfile.write("os.system('conda init bash')\n")
        pyfile.write("os.system('conda activate py36')\n")
        pyfile.write(
            "os.system('tensorboard --logdir=../../../../exp/{}/{}')".format(
                folder, save_ID))

    params_to_update, layers = [], 0
    for name, param in m.named_parameters():
        layers += 1
        if param.requires_grad:
            params_to_update.append(param)
    print("Training {} layers out of {}".format(len(params_to_update), layers))

    if optimize == 'rmsprop':
        optimizer = torch.optim.RMSprop(params_to_update,
                                        lr=opt.LR,
                                        momentum=opt.momentum,
                                        weight_decay=opt.weightDecay)
    elif optimize == 'adam':
        optimizer = torch.optim.Adam(params_to_update,
                                     lr=opt.LR,
                                     weight_decay=opt.weightDecay)
    elif optimize == 'sgd':
        optimizer = torch.optim.SGD(params_to_update,
                                    lr=opt.LR,
                                    momentum=opt.momentum,
                                    weight_decay=opt.weightDecay)
    else:
        raise Exception

    if mix_precision:
        m, optimizer = amp.initialize(m, optimizer, opt_level="O1")

    # Model Transfer
    if device != "cpu":
        m = torch.nn.DataParallel(m).cuda()
        criterion = torch.nn.MSELoss().cuda()
    else:
        m = torch.nn.DataParallel(m)
        criterion = torch.nn.MSELoss()

    # loss, acc = valid(val_loader, m, criterion, optimizer, writer)
    # print('Valid:-{idx:d} epoch | loss:{loss:.8f} | acc:{acc:.4f}'.format(
    #     idx=-1,
    #     loss=loss,
    #     acc=acc
    # ))

    early_stopping = EarlyStopping(patience=opt.patience, verbose=True)
    train_acc, val_acc, train_loss, val_loss, best_epoch, train_dist, val_dist, train_auc, val_auc, train_PR, val_PR = \
        0, 0, float("inf"), float("inf"), 0, float("inf"), float("inf"), 0, 0, 0, 0
    train_acc_ls, val_acc_ls, train_loss_ls, val_loss_ls, train_dist_ls, val_dist_ls, train_auc_ls, val_auc_ls, \
        train_pr_ls, val_pr_ls, epoch_ls, lr_ls = [], [], [], [], [], [], [], [], [], [], [], []
    decay, decay_epoch, lr, i = 0, [], opt.LR, begin_epoch
    stop = False
    m_best = m

    train_log = open(train_log_name, "w", newline="")
    bn_log = open(bn_file, "w")
    csv_writer = csv.writer(train_log)
    csv_writer.writerow(write_csv_title())
    begin_time = time.time()

    os.makedirs("result", exist_ok=True)
    result = os.path.join(
        "result", "{}_result_{}.csv".format(opt.expFolder, config.computer))
    exist = os.path.exists(result)

    # Start Training
    try:
        for i in range(opt.nEpochs)[begin_epoch:]:

            opt.epoch = i
            epoch_ls.append(i)
            train_log_tmp = [save_ID, i, lr]

            log = open(log_name, "a+")
            print('############# Starting Epoch {} #############'.format(i))
            log.write(
                '############# Starting Epoch {} #############\n'.format(i))

            # optimizer, lr = adjust_lr(optimizer, i, config.lr_decay, opt.nEpochs)
            # writer.add_scalar("lr", lr, i)
            # print("epoch {}: lr {}".format(i, lr))

            loss, acc, dist, auc, pr, pt_acc, pt_dist, pt_auc, pt_pr = \
                train(train_loader, m, criterion, optimizer, writer)
            train_log_tmp.append(" ")
            train_log_tmp.append(loss)
            train_log_tmp.append(acc.tolist())
            train_log_tmp.append(dist.tolist())
            train_log_tmp.append(auc)
            train_log_tmp.append(pr)
            for a in pt_acc:
                train_log_tmp.append(a.tolist())
            train_log_tmp.append(" ")
            for d in pt_dist:
                train_log_tmp.append(d.tolist())
            train_log_tmp.append(" ")
            for ac in pt_auc:
                train_log_tmp.append(ac)
            train_log_tmp.append(" ")
            for p in pt_pr:
                train_log_tmp.append(p)
            train_log_tmp.append(" ")

            train_acc_ls.append(acc)
            train_loss_ls.append(loss)
            train_dist_ls.append(dist)
            train_auc_ls.append(auc)
            train_pr_ls.append(pr)
            train_acc = acc if acc > train_acc else train_acc
            train_loss = loss if loss < train_loss else train_loss
            train_dist = dist if dist < train_dist else train_dist
            train_auc = auc if auc > train_auc else train_auc
            train_PR = pr if pr > train_PR else train_PR

            log.write(
                'Train:{idx:d} epoch | loss:{loss:.8f} | acc:{acc:.4f} | dist:{dist:.4f} | AUC: {AUC:.4f} | PR: {PR:.4f}\n'
                .format(
                    idx=i,
                    loss=loss,
                    acc=acc,
                    dist=dist,
                    AUC=auc,
                    PR=pr,
                ))

            opt.acc = acc
            opt.loss = loss
            m_dev = m.module

            loss, acc, dist, auc, pr, pt_acc, pt_dist, pt_auc, pt_pr = valid(
                val_loader, m, criterion, writer)
            train_log_tmp.insert(9, loss)
            train_log_tmp.insert(10, acc.tolist())
            train_log_tmp.insert(11, dist.tolist())
            train_log_tmp.insert(12, auc)
            train_log_tmp.insert(13, pr)
            train_log_tmp.insert(14, " ")
            for a in pt_acc:
                train_log_tmp.append(a.tolist())
            train_log_tmp.append(" ")
            for d in pt_dist:
                train_log_tmp.append(d.tolist())
            train_log_tmp.append(" ")
            for ac in pt_auc:
                train_log_tmp.append(ac)
            train_log_tmp.append(" ")
            for p in pt_pr:
                train_log_tmp.append(p)
            train_log_tmp.append(" ")

            val_acc_ls.append(acc)
            val_loss_ls.append(loss)
            val_dist_ls.append(dist)
            val_auc_ls.append(auc)
            val_pr_ls.append(pr)
            if acc > val_acc:
                best_epoch = i
                val_acc = acc
                torch.save(
                    m_dev.state_dict(),
                    'exp/{0}/{1}/{1}_best_acc.pkl'.format(folder, save_ID))
                m_best = copy.deepcopy(m)
            val_loss = loss if loss < val_loss else val_loss
            if dist < val_dist:
                val_dist = dist
                torch.save(
                    m_dev.state_dict(),
                    'exp/{0}/{1}/{1}_best_dist.pkl'.format(folder, save_ID))
            if auc > val_auc:
                val_auc = auc
                torch.save(
                    m_dev.state_dict(),
                    'exp/{0}/{1}/{1}_best_auc.pkl'.format(folder, save_ID))
            if pr > val_PR:
                val_PR = pr
                torch.save(
                    m_dev.state_dict(),
                    'exp/{0}/{1}/{1}_best_pr.pkl'.format(folder, save_ID))

            log.write(
                'Valid:{idx:d} epoch | loss:{loss:.8f} | acc:{acc:.4f} | dist:{dist:.4f} | AUC: {AUC:.4f} | PR: {PR:.4f}\n'
                .format(
                    idx=i,
                    loss=loss,
                    acc=acc,
                    dist=dist,
                    AUC=auc,
                    PR=pr,
                ))

            bn_sum, bn_num = 0, 0
            for mod in m.modules():
                if isinstance(mod, nn.BatchNorm2d):
                    bn_num += mod.num_features
                    bn_sum += torch.sum(abs(mod.weight))
                    writer.add_histogram("bn_weight",
                                         mod.weight.data.cpu().numpy(), i)

            bn_ave = bn_sum / bn_num
            bn_log.write("{} --> {}".format(i, bn_ave))
            print("Current bn : {} --> {}".format(i, bn_ave))
            bn_log.write("\n")
            log.close()
            csv_writer.writerow(train_log_tmp)

            writer.add_scalar("lr", lr, i)
            print("epoch {}: lr {}".format(i, lr))
            lr_ls.append(lr)

            torch.save(opt, 'exp/{}/{}/option.pkl'.format(folder, save_ID, i))
            if i % opt.save_interval == 0 and i != 0:
                torch.save(
                    m_dev.state_dict(),
                    'exp/{0}/{1}/{1}_{2}.pkl'.format(folder, save_ID, i))
                # torch.save(
                #     optimizer, 'exp/{}/{}/optimizer.pkl'.format(dataset, save_folder))

            if i < warm_up_epoch:
                optimizer, lr = warm_up_lr(optimizer, i)
            elif i == warm_up_epoch:
                lr = opt.LR
                early_stopping(acc)
            else:
                early_stopping(acc)
                if early_stopping.early_stop:
                    optimizer, lr = lr_decay(optimizer, lr)
                    decay += 1
                    # if decay == 2:
                    #     draw_pred_img = False
                    if decay > opt.lr_decay_time:
                        stop = True
                    else:
                        decay_epoch.append(i)
                        early_stopping.reset(
                            int(opt.patience * patience_decay[decay]))
                        # torch.save(m_dev.state_dict(), 'exp/{0}/{1}/{1}_decay{2}.pkl'.format(folder, save_ID, decay))
                        m = m_best

            for epo, ac in config.bad_epochs.items():
                if i == epo and val_acc < ac:
                    stop = True
            if stop:
                print("Training finished at epoch {}".format(i))
                break

        training_time = time.time() - begin_time
        writer.close()
        train_log.close()

        # draw_graph(epoch_ls, train_loss_ls, val_loss_ls, train_acc_ls, val_acc_ls, train_dist_ls, val_dist_ls, log_dir)
        draw_graph(epoch_ls, train_loss_ls, val_loss_ls, "loss", log_dir)
        draw_graph(epoch_ls, train_acc_ls, val_acc_ls, "acc", log_dir)
        draw_graph(epoch_ls, train_auc_ls, val_auc_ls, "AUC", log_dir)
        draw_graph(epoch_ls, train_dist_ls, val_dist_ls, "dist", log_dir)
        draw_graph(epoch_ls, train_pr_ls, val_pr_ls, "PR", log_dir)

        with open(result, "a+") as f:
            if not exist:
                title_str = "id,backbone,structure,DUC,params,flops,time,loss_param,addDPG,kps,batch_size,optimizer," \
                            "freeze_bn,freeze,sparse,sparse_decay,epoch_num,LR,Gaussian,thresh,weightDecay,loadModel," \
                            "model_location, ,folder_name,training_time,train_acc,train_loss,train_dist,train_AUC," \
                            "train_PR,val_acc,val_loss,val_dist,val_AUC,val_PR,best_epoch,final_epoch"
                title_str = write_decay_title(len(decay_epoch), title_str)
                f.write(title_str)
            info_str = "{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}, ,{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n".\
                format(save_ID, opt.backbone, opt.struct, opt.DUC, params, flops, inf_time, opt.loss_allocate, opt.addDPG,
                       opt.kps, opt.trainBatch, opt.optMethod, opt.freeze_bn, opt.freeze, opt.sparse_s, opt.sparse_decay,
                       opt.nEpochs, opt.LR, opt.hmGauss, opt.ratio, opt.weightDecay, opt.loadModel, config.computer,
                       os.path.join(folder, save_ID), training_time, train_acc, train_loss, train_dist, train_auc,
                       train_PR, val_acc, val_loss, val_dist, val_auc, val_PR, best_epoch, i)
            info_str = write_decay_info(decay_epoch, info_str)
            f.write(info_str)
    # except IOError:
    #     with open(result, "a+") as f:
    #         training_time = time.time() - begin_time
    #         writer.close()
    #         train_log.close()
    #         info_str = "{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}, ,{},{},{}\n". \
    #             format(save_ID, opt.backbone, opt.struct, opt.DUC, params, flops, inf_time, opt.loss_allocate, opt.addDPG,
    #                    opt.kps, opt.trainBatch, opt.optMethod, opt.freeze_bn, opt.freeze, opt.sparse_s, opt.sparse_decay,
    #                    opt.nEpochs, opt.LR, opt.hmGauss, opt.ratio, opt.weightDecay, opt.loadModel, config.computer,
    #                    os.path.join(folder, save_ID), training_time, "Some file is closed")
    #         f.write(info_str)
    except ZeroDivisionError:
        with open(result, "a+") as f:
            training_time = time.time() - begin_time
            writer.close()
            train_log.close()
            info_str = "{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}, ,{},{},{}\n". \
                format(save_ID, opt.backbone, opt.struct, opt.DUC, params, flops, inf_time, opt.loss_allocate, opt.addDPG,
                       opt.kps, opt.trainBatch, opt.optMethod, opt.freeze_bn, opt.freeze, opt.sparse_s, opt.sparse_decay,
                       opt.nEpochs, opt.LR, opt.hmGauss, opt.ratio, opt.weightDecay, opt.loadModel, config.computer,
                       os.path.join(folder, save_ID), training_time, "Gradient flow")
            f.write(info_str)
    except KeyboardInterrupt:
        with open(result, "a+") as f:
            training_time = time.time() - begin_time
            writer.close()
            train_log.close()
            info_str = "{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}, ,{},{},{}\n". \
                format(save_ID, opt.backbone, opt.struct, opt.DUC, params, flops, inf_time, opt.loss_allocate, opt.addDPG,
                       opt.kps, opt.trainBatch, opt.optMethod, opt.freeze_bn, opt.freeze, opt.sparse_s, opt.sparse_decay,
                       opt.nEpochs, opt.LR, opt.hmGauss, opt.ratio, opt.weightDecay, opt.loadModel, config.computer,
                       os.path.join(folder, save_ID), training_time, "Be killed by someone")
            f.write(info_str)

    print("Model {} training finished".format(save_ID))
    print(
        "----------------------------------------------------------------------------------------------------"
    )
コード例 #6
0
def main():
    # Prepare Dataset

    train_dataset = MyDataset(config.train_info, train=True)
    val_dataset = MyDataset(config.train_info, train=False)
    # for k, v in config.train_info.items():
    #     pass
    # train_dataset = Mscoco(v, train=True)
    # val_dataset = Mscoco(v, train=False)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.train_batch,
        shuffle=True,
        num_workers=config.train_mum_worker,
        pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=config.val_batch,
                                             shuffle=True,
                                             num_workers=config.val_num_worker,
                                             pin_memory=True)

    # for k, v in config.train_info.items():
    #     train_dataset = Mscoco([v[0], v[1]], train=True, val_img_num=v[2])
    #     val_dataset = Mscoco([v[0], v[1]], train=False, val_img_num=v[2])
    #
    # train_loaders[k] = torch.utils.data.DataLoader(
    #     train_dataset, batch_size=config.train_batch, shuffle=True, num_workers=config.train_mum_worker,
    #     pin_memory=True)
    #
    # val_loaders[k] = torch.utils.data.DataLoader(
    #     val_dataset, batch_size=config.val_batch, shuffle=False, num_workers=config.val_num_worker, pin_memory=True)
    #
    # train_loader = torch.utils.data.DataLoader(
    #         train_dataset, batch_size=config.train_batch, shuffle=True, num_workers=config.train_mum_worker,
    #         pin_memory=True)
    # val_loader = torch.utils.data.DataLoader(
    #         val_dataset, batch_size=config.val_batch, shuffle=False, num_workers=config.val_num_worker, pin_memory=True)

    # assert train_loaders != {}, "Your training data has not been specific! "

    # Model Initialize
    if device != "cpu":
        m = createModel(cfg=model_cfg).cuda()
    else:
        m = createModel(cfg=model_cfg).cpu()

    begin_epoch = 0
    pre_train_model = config.loadModel
    flops = print_model_param_flops(m)
    print("FLOPs of current model is {}".format(flops))
    params = print_model_param_nums(m)
    print("Parameters of current model is {}".format(params))

    if pre_train_model:
        print('Loading Model from {}'.format(pre_train_model))
        m.load_state_dict(torch.load(pre_train_model))
        opt.trainIters = config.train_batch * (begin_epoch - 1)
        opt.valIters = config.val_batch * (begin_epoch - 1)
        begin_epoch = int(pre_train_model.split("_")[-1][:-4]) + 1
        os.makedirs("exp/{}/{}".format(dataset, save_folder), exist_ok=True)
    else:
        print('Create new model')
        with open("log/{}.txt".format(save_folder), "a+") as f:
            f.write("FLOPs of current model is {}\n".format(flops))
            f.write("Parameters of current model is {}\n".format(params))
        if not os.path.exists("exp/{}/{}".format(dataset, save_folder)):
            try:
                os.mkdir("exp/{}/{}".format(dataset, save_folder))
            except FileNotFoundError:
                os.mkdir("exp/{}".format(dataset))
                os.mkdir("exp/{}/{}".format(dataset, save_folder))

    if optimize == 'rmsprop':
        optimizer = torch.optim.RMSprop(m.parameters(),
                                        lr=config.lr,
                                        momentum=config.momentum,
                                        weight_decay=config.weightDecay)
    elif optimize == 'adam':
        optimizer = torch.optim.Adam(m.parameters(),
                                     lr=config.lr,
                                     weight_decay=config.weightDecay)
    else:
        raise Exception

    if mix_precision:
        m, optimizer = amp.initialize(m, optimizer, opt_level="O1")

    writer = SummaryWriter('tensorboard/{}/{}'.format(dataset, save_folder))

    # Model Transfer
    if device != "cpu":
        m = torch.nn.DataParallel(m).cuda()
        criterion = torch.nn.MSELoss().cuda()
    else:
        m = torch.nn.DataParallel(m)
        criterion = torch.nn.MSELoss()

    rnd_inps = torch.random([2, 3, 224, 224])
    writer.add_graph(m, rnd_inps)

    # Start Training
    for i in range(config.epochs)[begin_epoch:]:
        os.makedirs("log/{}".format(dataset), exist_ok=True)
        log = open("log/{}/{}.txt".format(dataset, save_folder), "a+")
        print('############# Starting Epoch {} #############'.format(i))
        log.write('############# Starting Epoch {} #############\n'.format(i))

        for name, param in m.named_parameters():
            writer.add_histogram(name, param.clone().data.to("cpu").numpy(), i)

        loss, acc = train(train_loader, m, criterion, optimizer, writer)

        print('Train-{idx:d} epoch | loss:{loss:.8f} | acc:{acc:.4f}'.format(
            idx=i, loss=loss, acc=acc))
        log.write(
            'Train-{idx:d} epoch | loss:{loss:.8f} | acc:{acc:.4f}\n'.format(
                idx=i, loss=loss, acc=acc))

        opt.acc = acc
        opt.loss = loss
        m_dev = m.module

        loss, acc = valid(val_loader, m, criterion, optimizer, writer)

        print('Valid:-{idx:d} epoch | loss:{loss:.8f} | acc:{acc:.4f}'.format(
            idx=i, loss=loss, acc=acc))
        log.write(
            'Valid:-{idx:d} epoch | loss:{loss:.8f} | acc:{acc:.4f}\n'.format(
                idx=i, loss=loss, acc=acc))
        log.close()

        if i % config.save_interval == 0:
            torch.save(
                m_dev.state_dict(),
                'exp/{}/{}/model_{}.pkl'.format(dataset, save_folder, i))
            torch.save(opt,
                       'exp/{}/{}/option.pkl'.format(dataset, save_folder, i))
            torch.save(optimizer,
                       'exp/{}/{}/optimizer.pkl'.format(dataset, save_folder))

    writer.close()